线段树

线段树是 OI 竞赛中最强大的数据结构之一,可以用来维护和、积以及最值等具有合并性质的信息。

一般线段树

P3372 【模板】线段树 1 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

P3373 【模板】线段树 2 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

以模板一为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
struct SegmentTree
{
#define lc u << 1
#define rc u << 1 | 1
struct Tree
{
int l, r, sum, tag;
inline int len() { return r - l + 1; }
inline void addtag(int k) { sum += len() * k, tag += k; }
} tr[N << 2];

inline void pushup(int u) { tr[u].sum = tr[lc].sum + tr[rc].sum; }

void build(int u, int l, int r)
{
tr[u].l = l, tr[u].r = r;
if (l == r) return tr[u].sum = a[l], void(0);
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}

void pushdown(int u)
{
int &k = tr[u].tag;
tr[lc].addtag(k), tr[rc].addtag(k);
k = 0;
}

void dot_modify(int u, int x, int k)
{
if (tr[u].l == tr[u].r)
return tr[u].addtag(k);
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
dot_modify(u << 1 | (x > mid), x, k);
pushup(u);
}

void range_modify(int u, int l, int r, int k)
{
if (l <= tr[u].l && tr[u].r <= r)
return tr[u].addtag(k);
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) range_modify(lc, l, r, k);
if (r > mid) range_modify(rc, l, r, k);
pushup(u);
}

int query(int u, int l, int r)
{
if (l <= tr[u].l && tr[u].r <= r)
return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if (l <= mid) res += query(lc, l, r);
if (r > mid) res += query(rc, l, r);
return res;
}
} SGT;

动态开点线段树

在遇到需要以权值为下标建立的线段树时,有时候过大的值域使得无法如上种方式建立线段树,这时候需要权值线段树和动态开点的技巧。

T125847 【模板】动态开点线段树 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

P3372 【模板】线段树 1 - 洛谷 | 计算机科学教育新生态

数组版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
struct SegmentTree
{
#define lc tr[u].ls
#define rc tr[u].rs

int tot = 0, root = 0;
struct Node
{
int ls, rs, len, sum, tag;
inline void addtag(int k) { sum += len * k, tag += k; }
} tr[N << 2];

inline int newnode(int L, int R)
{
tr[++ tot].len = R - L + 1;
return tot;
}

inline void pushup(int u) { tr[u].sum = tr[lc].sum + tr[rc].sum; }

void insert(int &u, int l, int r, int x, int k)
{
if (!u) u = newnode(l, r);
if (l == r) return tr[u].sum = k, void();
int mid = l + r >> 1;
if (x <= mid) insert(lc, l, mid, x, k);
else insert(rc, mid + 1, r, x, k);
pushup(u);
}

inline void pushdown(int u, int L, int R)
{
if (!tr[u].tag) return;
int mid = L + R >> 1;
if (!lc) lc = newnode(L, mid);
if (!rc) rc = newnode(mid + 1, R);
tr[lc].addtag(tr[u].tag), tr[rc].addtag(tr[u].tag);
tr[u].tag = 0;
}

void modify(int &u, int l, int r, int L, int R, int k)
{
if (!u) u = newnode(L, R);
if (L <= l && r <= R)
return tr[u].addtag(k);

pushdown(u, l, r);
int mid = l + r >> 1;
if (L <= mid) modify(lc, l, mid, L, R, k);
if (R > mid) modify(rc, mid + 1, r, L, R, k);
pushup(u);
}

int query(int u, int l, int r, int L, int R)
{
if (!u) return 0LL;
if (L <= l && r <= R) return tr[u].sum;

pushdown(u, l, r);
int mid = l + r >> 1, res = 0;
if (L <= mid) res += query(lc, l, mid, L, R);
if (R > mid) res += query(rc, mid + 1, r, L, R);
return res;
}
} SGT;

指针版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
struct SegmentTree
{
struct Node
{
int len, sum, tag;
Node *ls, *rs;
Node(int _len = 1) : len(_len), sum(0), tag(0), ls(NULL), rs(NULL) {}
inline void addtag(int k) { sum += k * len, tag += k; }
} *Root = NULL, EMP;

inline Node* get(Node *u) { return u ? u : &EMP; }

inline void pushup(Node *u) { u -> sum = get(u -> ls) -> sum + get(u -> rs) -> sum; } // !

void insert(Node *&u, int l, int r, int x, int k)
{
if (!u) u = new Node(r - l + 1);
if (l == r) return void(u -> sum = k);

int mid = l + r >> 1;
if (x <= mid) insert(u -> ls, l, mid, x, k);
else insert(u -> rs, mid + 1, r, x, k);
pushup(u);
}

void pushdown(Node *u)
{
if (!u || !u -> tag) return;
if (!u -> ls) u -> ls = new Node(u -> len >> 1);
if (!u -> rs) u -> rs = new Node(u -> len + 1 >> 1);

u -> ls -> addtag(u -> tag), u -> rs -> addtag(u -> tag);
u -> tag = 0;
}

void modify(Node *&u, int l, int r, int L, int R, int k)
{
if (!u) u = new Node(r - l + 1);
if (L <= l && r <= R) return u -> addtag(k);

pushdown(u);
int mid = l + r >> 1, res = 0;
if (L <= mid) modify(u -> ls, l, mid, L, R, k);
if (R > mid) modify(u -> rs, mid + 1, r, L, R, k);
pushup(u);
}

int query(Node *u, int l, int r, int L, int R)
{
if (!u) return 0LL;
if (L <= l && r <= R) return u -> sum;

pushdown(u);
int mid = l + r >> 1, res = 0;
if (L <= mid) res += query(u -> ls, l, mid, L, R);
if (R > mid) res += query(u -> rs, mid + 1, r, L, R);
return res;
}
} SGT;

合并线段树

雨天的尾巴 /【模板】线段树合并 - 洛谷 | 计算机科学教育新生态

对于动态开店的线段树,不能直接在两个线段树上按位相加,因为所对应的权值不一定相同,因此我们应当递归地合并线段树,根据子树情况维护。

简单来说分为三种情况,对于两颗线段树 A,BA, B 来说,现在分别遍历到了 x,yx, y​ 这两个节点,则有:

  • xxyy 均为 00,合并后也不存在,返回即可。

  • xxyy 其中一个非 00,此时只需要留下非 00 的点即可。

  • 剩下的情况按位相加即可。

数组版本

1
2
3
4
5
6
7
8
9
10
11
int merge(int x, int y, int l, int r)
{
if (!x || !y) return x | y;
if (l == r) return tr[x].sum += tr[y].sum, x;

int mid = l + r >> 1;
tr[x].ls = merge(tr[x].ls, tr[y].ls, l, mid);
tr[x].rs = merge(tr[x].rs, tr[y].rs, mid + 1, r);
delnode(y);
return pushup(x), x;
}

指针版本

1
2
3
4
5
6
7
8
9
10
11
*Node merge(Node *x, Node *y, int l, int r)
{
if (!x || !y) return x ? y : x;
if (l == r) return x -> sum += y -> sum, x;

int mid = l + r >> 1;
x -> ls = merge(x -> ls, y -> ls, l, mid);
x -> rs = merge(x -> rs, y -> rs, mid + 1, r);
delete(y);
return pushup(x), x;
}

分裂线段树

P5494 【模板】线段树分裂 - 洛谷 | 计算机科学教育新生态

支持将一颗线段数(例如权值线段树)前 kk​ 个权值分裂出来,类似线段树上二分,同样存在三种分讨。

1
2
3
4
5
6
7
8
9
10
void split(int x, int &y, int k)
{
if (!x) return;
y = newnode();
int cur = tr[tr[x].ls].sum;
if (cur < k) split(tr[x].rs, tr[y].rs, k - cur);
else if (cur == k) swap(tr[x].rs, tr[y].rs);
else swap(tr[x].rs, tr[y].rs), split(tr[x].ls, tr[y].ls, k);
pushup(x), pushup(y);
}
1
2
3
4
5
6
7
8
9
10
void split(Node *x, Node *&y, int k)
{
if (!x) return;
y = new Node();
int cur = x -> ls ? x -> ls -> sum : 0;
if (cur < k) split(x -> rs, y -> rs, k - cur);
else if (cur == k) swap(x -> rs, y -> rs);
else swap(x -> rs, y -> rs), split(x -> ls, y -> ls, k);
pushup(x), pushup(y);
}

可持久化线段树(主席树)

P3919 【模板】可持久化线段树 1(可持久化数组) - 洛谷 | 计算机科学教育新生态

一种支持查询历史版本的线段树类型,对于单点修改和单点查询有:

数组版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
struct SegmentTree
{
#define lc tr[u].ls
#define rc tr[u].rs

int tot = 0;
struct Node { int ls, rs, sum; } tr[N << 5];

inline int newnode() { return ++ tot; }

inline int clone(int x) { return tr[newnode()] = tr[x], tot; }

inline void pushup(int u) { tr[u].sum = tr[lc].sum + tr[rc].sum; }

void build(int &u, int l, int r)
{
if (!u) u = newnode();
if (l == r) return tr[u].sum = a[l], void();
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}

void insert(int &u, int l, int r, int x, int k)
{
u = clone(u);
if (l == r) return tr[u].sum = k, void();
int mid = l + r >> 1;
if (x <= mid) insert(lc, l, mid, x, k);
else insert(rc, mid + 1, r, x, k);
pushup(u);
}

int query(int u, int l, int r, int x)
{
if (!u) return 0;
if (l == r) return tr[u].sum;
int mid = l + r >> 1;
if (x <= mid) return query(lc, l, mid, x);
else return query(rc, mid + 1, r, x);
}
} SGT;

指针版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
struct SegmentTree
{
struct Node
{
int sum;
Node *ls, *rs;
Node() : sum(0), ls(NULL), rs(NULL) {}
} *root[N], EMP;

inline Node* get(Node* u) { return u ? u : &EMP; }

inline void pushup(Node *u) { u -> sum = get(u -> ls) -> sum + get(u -> rs) -> sum; }

inline Node* clone(Node *u)
{
if (!u) return NULL;
Node *x = new Node();
return *x = *u, x;
}

void build(Node *&u, int l, int r)
{
if (!u) u = new Node();
if (l == r) return void(u -> sum = a[l]);

int mid = l + r >> 1;
build(u -> ls, l, mid), build(u -> rs, mid + 1, r);
pushup(u);
}

void insert(Node *&u, int l, int r, int x, int k)
{
u = clone(u);
if (l == r) return void(u -> sum = k);

int mid = l + r >> 1;
if (x <= mid) insert(u -> ls, l, mid, x, k);
else insert(u -> rs, mid + 1, r, x, k);
pushup(u);
}

int query(Node *u, int l, int r, int x)
{
if (!u) return 0LL;
if (l == r) return u -> sum;

int mid = l + r >> 1;
if (x <= mid) return query(u -> ls, l, mid, x);
else return query(u -> rs, mid + 1, r, x);
}
} SGT;

主函数中新建 root[] 存储每个历史版本的根。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
signed main()
{
read(n, m);
for (int i = 1; i <= n; i ++) read(a[i]);
SGT.build(root[0], 1, n);

int rt, opt, x, y;
for (int i = 1; i <= m; i ++)
{
read(rt, opt, x);
if (opt == 2)
{
cout << SGT.query(root[rt], 1, n, x) << '\n';
root[i] = root[rt];
}
else
{
read(y);
root[i] = root[rt];
SGT.insert(root[i], 1, n, x, y);
}
}

return 0;
}

经典应用

静态区间 rank 问题

P3834 【模板】可持久化线段树 2 - 洛谷 | 计算机科学教育新生态

类似前缀和的思想,用边界的两棵树相减求出答案。

1
2
3
4
5
6
7
int query(int x, int y, int l, int r, int k)
{
if (l == r) return l;
int mid = l + r >> 1, cur = tr[tr[y].ls].sum - tr[tr[x].ls].sum;
if (cur >= k) return query(tr[x].ls, tr[y].ls, l, mid, k);
else return query(tr[x].rs, tr[y].rs, mid + 1, r, k - cur);
}
1
2
3
4
5
6
7
int query(Node *x, Node *y, int L, int R, int k)
{
if (L == R) return L;
int mid = L + R >> 1, cnt = get(y -> ls) -> sum - get(x -> ls) -> sum;
if (cnt < k) return query(x -> rs, y -> rs, mid + 1, R, k - cnt);
else return query(x -> ls, y -> ls, L, mid, k);
}

主函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
signed main()
{
read(n, m);
for (int i = 1; i <= n; i ++)
read(a[i]), disc[++ idx] = a[i];
sort(disc + 1, disc + 1 + idx);
idx = unique(disc + 1, disc + 1 + idx) - disc - 1;
for (int i = 1; i <= n; i ++)
a[i] = lower_bound(disc + 1, disc + 1 + idx, a[i]) - disc;

SGT.build(root[0], 1, idx);
for (int i = 1; i <= n; i ++)
root[i] = root[i - 1], SGT.insert(root[i], 1, idx, a[i], 1);

while (m --)
{
int l, r, k; read(l, r, k);
cout << disc[SGT.query(root[l - 1], root[r], 1, idx, k)] << '\n';
}

return 0;
}