zkw 线段树

zkw 线段树,即非递归版本的线段树,本质上与普通线段树一样,但是在一些较为简单或者有特殊性质的操作上(例如单点 / 区间修改,查询区间和 / 最值)有着码量小,空间少,常数小等特点(但是常数仍然略大于 Fenwick Tree)。

算法流程

在先前的普通线段树上,我们的每一次操作都是从线段树的根节点开始向下递归的,这种写法直观并且扩展性强,而 zkw 线段树则是从叶子节点一步一步上跳的,最终回到根部节点,有效的规避了递归模块的同时维护了区间信息。

建树

因为遵循着从底向上的原则,我们要算出线段树的叶子节点的标号是从哪一个开始,如图:

img

这颗线段树的叶子节点编号为 [8,15][8, 15],不难发现 8=2log2n+118 = 2 ^ {\left \lceil \log_2{n + 1} \right \rceil - 1},其中 nn 为数组大小,并且 158=n15 - 8 = n,因此我们得知了线段树上叶子节点的所有下标:令 T2log2n+11T \gets 2 ^ {\left \lceil \log_2{n + 1} \right \rceil } - 1,叶子节点范围为[T,T+n][T, T + n]​。

[!CAUTION]

实际上 TT 应该为 2log2n+1+12 ^ {\left \lceil \log_2{n + 1} \right \rceil } + 1​,也就是说图中的叶子节点从 17 开始标号,这是由于后面的修改和查询的局限导致的。

在后面的代码中都将依照这种写法。

而非叶子节点和不同线段树一样,两个儿子节点可以直接得出,因此有如下代码:

1
2
3
4
5
6
7
void build(int L, int R)
{
T = 1 << (int)ceil(log2(R - L + 1 + 1));
for (int i = 1; i <= n; i ++) tr[i + T].sum = a[i];
for (int i = T - 1; i >= 1; i --) tr[i].sum = tr[i << 1].sum + tr[i << 1 | 1].sum;
// RMQ:tr[i].minn = min(tr[i << 1].minn, tr[i << 1 | 1].minn);
}

通过如上的转换,下标为 xx 在线段树中的叶子节点为 x+Tx + T

单点修改 / 查询

单点修改长得很像树状数组,其实本质也差不多,都是从下到上进行更新的,所以代码也很简洁:

1
2
3
4
5
6
7
8
inline void modify(int x, int k)    // 给第 x 个点加上 k
{
for (x += T; x; x >>= 1) tr[x].sum += k;
/* RMQ:
tr[x += T].minn += k;
for (x >>= 1; x; x >>= 1) tr[x].minn = min(tr[x << 1].minn, tr[x << 1 | 1].minn);
*/
}

单点查询?真的有必要么?

1
inline int query(int x) { return tr[x + T].sum; }

单点修改下的区间查询

由于不存在 lazytag,因此这部分的代码就是 zkw 线段树的标志性代码,考虑初始化双指针挂在查询的节点左右两侧上,即对于询问 (L,R)(L, R),令双指针 (Lpos,Rpos)(Lpos, Rpos)

LposL+T1,RposR+T+1Lpos \gets L + T - 1, Rpos \gets R + T + 1

[!NOTE]

这里很好的解释了前面提到的为什么 T=2log2n+1+1T = 2 ^ {\left \lceil \log_2{n + 1} \right \rceil } + 1

  • 当询问区间为 [1,y],y[1,n)[1, y], y \in [1, n) 时,LposTLpos \gets T,此时 LposLpos 就不在叶子节点那一层了,因此整体向右移一位。
  • 当询问区间为 [x,n],x(1,n][x, n], x \in (1, n] 时,Rposn+T+1Rpos \gets n + T + 1 超出了线段树的大小,因此整体翻倍~~(省事)~~。

接着每次都循环,让两个指针每次上跳一层,由于第一次都在叶子节点上,因此最终必然会在同一个节点相遇。那么我们跳到什么时候呢?应当在他们属于同一个父亲的两个子节点时停下,此时它们所经过的所有节点包含的区间就包含了询问中的区间。

考虑每跳到新的一层时两个指针干什么,因为向上的区间越来越完整,因此一旦包括的区间在查询内部就应该处理,而不是等到指针所代表的区间变大才加上(证明显然,因为这是倍增的逆过程)。那么什么时候包括了所有区间呢?

答案就是 LposLpos 处于左子节点时的所有右子节点区间和 RposRpos 处于右子节点时所有左子节点区间的并即为查询区间。

img

看这幅图,不难发现这种方式可以不重不漏的包含所有区间。

代码呼之欲出了,也很简洁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
int query(int L, int R)
{
int res = 0;
for (L += T - 1, R += T + 1; L ^ R ^ 1; L >>= 1, R >>= 1, cur <<= 1)
{
if (~L & 1) res += tr[L ^ 1].sum;
if (R & 1) res += tr[R ^ 1].sum;
}
return res;
}

/* RMQ
int query(int L, int R)
{
int res = INF;
for (L += T - 1, R += T + 1; L ^ R ^ 1; L >>= 1, R >>= 1)
{
if (~L & 1) res = min(res, minn[L ^ 1]);
if (R & 1) res = min(res, minn[R ^ 1]);
}
return res;
}
*/

区间修改

到这里就变得复杂了,因为每次查询都是从下往上的,标记无法进行像递归线段树一样下传,我们唯一能做的就是将标记永久化了,并且这种永久化是尽量向上的,用以保障复杂度正确。

考虑什么时候添加标记,每当确认一个区间完全包围时添加即可。但是问题在于我们不知到左右指针分别包含了修改区间的多少个元素,我们也不好预处理,因此我们利用额外三个变量分别在上跳过程中记录 Lpos,RposLpos,Rpos 包含的个数和 这一层的中间的节点为根的满二叉树的叶子节点有多少个

考虑怎么上传永久化的标记,像无标记的区间查询一样如果有一整颗子树都可以被覆盖,那么直接在这个根节点上打上标记。容易证明整个区间都可以打上标记。

考虑怎么更新答案,由于标记永久化了,我们只更新这一次需要更新的就好了,在两者还在上跳时每一层都加上当前 Lpos/RposLpos / Rpos 已经覆盖的区间个数 ×k\times k。并且即使两者父节点相同后我们也有继续更新。

1
2
3
4
5
6
7
8
9
10
11
12
void modify(int L, int R, int k)
{
int lft = 0, rht = 0, cur = 1;
for (L += T - 1, R += T + 1; L ^ R ^ 1; L >>= 1, R >>= 1, cur <<= 1)
{
tr[L].sum += k * lft, tr[R].sum += k * rht;
if (~L & 1) tr[L ^ 1].tag += k, tr[L ^ 1].sum += k * cur, lft += cur;
if (R & 1) tr[R ^ 1].tag += k, tr[R ^ 1].sum += k * cur, rht += cur;
}
for (; L && R; L >>= 1, R >>= 1)
tr[L].sum += k * lft, tr[R].sum += k * rht;
}

区间查询

和修改一样,维护好三个变量即可,遇到了标记加上即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int query(int L, int R)
{
int lft = 0, rht = 0, cur = 1, res = 0;
for (L += T - 1, R += T + 1; L ^ R ^ 1; L >>= 1, R >>= 1, cur <<= 1)
{
if (tr[L].tag) res += tr[L].tag * lft;
if (tr[R].tag) res += tr[R].tag * rht;
if (~L & 1) res += tr[L ^ 1].sum, lft += cur;
if (R & 1) res += tr[R ^ 1].sum, rht += cur;
}
for (; L && R; L >>= 1, R >>= 1)
res += tr[L].tag * lft + tr[R].tag * rht;
return res;
}

Code

P3372 【模板】线段树 1 - 洛谷 | 计算机科学教育新生态

在同一时间段同语言同编译选项测试情况下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <bits/stdc++.h>

using namespace std;

#define int long long
// #define x first
// #define y second
#define File(a) freopen(a".in", "r", stdin), freopen(a".out", "w", stdout)

inline void debug() { cerr << '\n'; }
template<typename Type, typename... Args>
inline void debug(const Type& x, const Args&... y) { cerr << x << ' ', debug(y...); }
#define DEBUG(a...) cerr << "[" << #a << "] = ", debug(a)

typedef long long LL;
typedef pair<int, int> PII;

const int N = 200010;
const int INF = 0x3f3f3f3f;

template<typename Type>
inline void read(Type& res)
{
res = 0;
int ch = getchar(), flag = 0;
while (!isdigit(ch)) flag |= ch == '-', ch = getchar();
while (isdigit(ch)) res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
res = flag ? -res : res;
}
template<typename Type, typename... Args>
inline void read(Type& res, Args&... y) { read(res), read(y...); }

int n, m, a[N];

struct ZKWSegment
{
struct Node { int sum, tag; } tr[N << 1];
int T;

void build(int L, int R)
{
T = 1 << (int)ceil(log2(R - L + 2));
for (int i = 1; i <= n; i ++) tr[T + i].sum = a[i];
for (int i = T - 1; i >= 1; i --) tr[i].sum = tr[i << 1].sum + tr[i << 1 | 1].sum;
}

void modify(int L, int R, int k)
{
int lft = 0, rht = 0, cur = 1;
for (L += T - 1, R += T + 1; L ^ R ^ 1; L >>= 1, R >>= 1, cur <<= 1)
{
tr[L].sum += k * lft, tr[R].sum += k * rht;
if (~L & 1) tr[L ^ 1].tag += k, tr[L ^ 1].sum += k * cur, lft += cur;
if (R & 1) tr[R ^ 1].tag += k, tr[R ^ 1].sum += k * cur, rht += cur;
}

for (; L && R; L >>= 1, R >>= 1)
tr[L].sum += k * lft, tr[R].sum += k * rht;
}

int query(int L, int R)
{
int lft = 0, rht = 0, cur = 1, res = 0;
for (L += T - 1, R += T + 1; L ^ R ^ 1; L >>= 1, R >>= 1, cur <<= 1)
{
if (tr[L].tag) res += tr[L].tag * lft;
if (tr[R].tag) res += tr[R].tag * rht;
if (~L & 1) res += tr[L ^ 1].sum, lft += cur;
if (R & 1) res += tr[R ^ 1].sum, rht += cur;
}

for (; L && R; L >>= 1, R >>= 1)
res += tr[L].tag * lft + tr[R].tag * rht;
return res;
}

} SGT;

signed main()
{
read(n, m);
for (int i = 1; i <= n; i ++) read(a[i]);
SGT.build(1, n);

for (int i = 1, opt, x, y, k; i <= m; i ++)
{
read(opt, x, y);
if (opt == 1)
{
read(k);
SGT.modify(x, y, k);
}
else cout << SGT.query(x, y) << '\n';
}

return 0;
}

Reference

线段树的扩展之浅谈zkw线段树 - 洛谷专栏

ChatGPT