树链剖分

树链剖分的基本思想

通过将树分割成链的形式,从而把树形变为线性结构,减少处理难度。

树链剖分(树剖/链剖)有多种形式,如 重链剖分,长链剖分 和用于 Link/cut Tree 的剖分(有时被称作「实链剖分」),大多数情况下(没有特别说明时),「树链剖分」都指「重链剖分」。

树链有如下几个特征:

  1. 一棵树上的任意一条链的长度不超过 log2n\log_2 n,且一条链上的各个节点深度互不相同(相对于根节点而言)。
  2. 通过特殊的遍历方式,树链剖分可以保证一条链上的 DFS 序连续,从而更加方便地使用线段树或树状数组维护树上的区间信息。

重链剖分

重链剖分的基本定义

重链剖分,顾名思义,是一种通过子节点大小进行剖分的形式,我们给出以下定义:

  1. 重子节点:一个非叶子节点中子树最大的子节点,如果存在多个则取其一。

  2. 轻子节点:除了重子节点以外的所有子节点。

  3. 重边:连接任意两个重儿子的边。

  4. 轻边:除了重边的其他所有树边。

  5. 重链:若干条重边首尾相连而成的一条链。

我们将落单的叶子节点本身看做重子节点,不难发现,整棵树就被剖分成了一条条重链。

需要注意的是,每一条重链都以轻子节点为起点。

实现

树链剖分的处理通过两次 DFS 遍历完成。

对于第一次 DFS,我们求出如下数值:

  • 任意节点到达根的距离(即其深度):depth[]

  • 任意节点的父亲节点(根节点默认为 00):fa[]

  • 任意节点子树的大小(包括其本身):sz[]

  • 任意节点的重子节点(没有则为 00):hson[]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void dfs1(int ver, int pre, int deep)
{
depth[ver] = deep, fa[ver] = pre, sz[ver] = 1;
int maxn = -1;
for (int i = h[ver]; ~i; i = ne[i])
{
int j = e[i];
if (j == pre) continue;
dfs1(j, ver, deep + 1);
sz[ver] += sz[j];
if (maxn == -1 || maxn < sz[j]) maxn = sz[j], hson[ver] = j;
// 更新重子节点
}
}

对于第二次 DFS,我们求出如下数值:

  • 遍历时的各个节点的 dfs 序:dfn[]

  • 每个节点所属重链的最顶端节点:top[]

  • dfs 序对应的节点编号:id[],有 id(dfn(x))=xid(dfn(x)) = x

  • 每个节点在 dfs 序上的对应权值:val[]

1
2
3
4
5
6
7
8
9
10
11
12
13
void dfs2(int ver, int topf)
{
dfn[ver] = ++ timestamp, val[timestamp] = a[ver], top[ver] = topf;
if (!hson[ver]) return;
dfs2(hson[ver], topf); // 先遍历重子节点

for (int i = h[ver]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[ver] || j == hson[ver]) continue;
dfs2(j, j); // 再遍历轻子节点
}
}

之所以要先遍历重子节点,是因为我们要保证重链上的 dfs 序连续,这样才可以进行区间操作,按照 dfndfn 排序后的序列即为剖分后的链。

重链剖分的性质

  1. 树上每个节点都属于且仅属于一条重链

  2. 所有的重链将整棵树 完全剖分

  3. 当我们向下经过一条 轻边 时,所在子树的大小至少会除以二,保证了复杂度的正确性。

常见应用

维护路径权值和

选取左右端点所在树中深度更大的节点,维护它到所在重链顶端的区间信息,之后不断上跳,知道它和另一端点在同一链上,维护两点之间的信息。使用线段树或者树状数组等数据结构,即可在 O(log2n)O(\log^2 n)​ 的时间内单次维护查询。

1
2
3
4
5
6
7
8
9
10
11
12
void modify_range(int x, int y, int k)
{
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]]) swap(x, y);
SGT.modify(1, dfn[top[x]], dfn[x], k);
x = fa[top[x]];
}

if (depth[x] > depth[y]) swap(x, y);
SGT.modify(1, dfn[x], dfn[y], k);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
int query_range(int x, int y)
{
int res = 0;
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]]) swap(x, y);
res = (res + SGT.query(1, dfn[top[x]], dfn[x])) % mod;
x = fa[top[x]];
}

if (depth[x] > depth[y]) swap(x, y);
res = (res + SGT.query(1, dfn[x], dfn[y])) % mod;
return res;
}

维护子树信息

思路相似,但更加简单,经过 dfn 重新划分后,一颗子树的 dfn 序列一定在 [dfn[x],dfn[x]+sz[x]1][dfn[x], dfn[x] +sz[x] - 1] 之间,单次维护即可,时间复杂度 O(logn)O(\log n)

1
2
3
4
void modify_subtree(int x, int k)
{
SGT.modify(1, dfn[x], dfn[x] + sz[x] - 1, k);
}
1
2
3
4
int query_subtree(int x)
{
return SGT.query(1, dfn[x], dfn[x] + sz[x] - 1);
}

求 LCA

与倍增求法相似,但常数更小。

每次选取重链顶端节点深度更大的节点上跳,知道两者在同一重链上,此时深度较小者为两节点 LCA。

1
2
3
4
5
6
7
8
9
int lca(int a, int b)
{
while (top[a] != top[b])
{
if (depth[top[a]] > depth[top[b]]) a = fa[top[a]];
else b = fa[top[b]];
}
return depth[a] > depth[b] ? a : b;
}

例题 & Code

P3384 【模板】重链剖分/树链剖分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
// Problem: P3384 【模板】重链剖分/树链剖分
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3384
// Memory Limit: 128 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>

using namespace std;

// #define int long long
#define DEBUG
#define lc u << 1
#define rc u << 1 | 1
#define File(a) freopen(a".in", "r", stdin); freopen(a".out", "w", stdout)

typedef long long LL;
typedef pair<int, int> PII;


const int N = 100010, M = N << 1;
const int INF = 0x3f3f3f3f;

int n, m, root, mod;
int h[N], e[M], ne[M], idx;
int a[N], val[N];
int depth[N], fa[N], sz[N], hson[N];
int dfn[N], timestamp;
int top[N], id[N];

struct Tree
{
struct Node
{
int l, r, sum, tag;
inline int len() {return r - l + 1; }
} tr[N << 2];

void pushup(int u)
{
tr[u].sum = (tr[lc].sum + tr[rc].sum) % mod;
}

void build(int u, int l, int r)
{
tr[u].l = l, tr[u].r = r;
if (l == r) return tr[u].sum = val[l], void(0);
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}

void pushdown(int u)
{
if (!tr[u].tag) return;
tr[lc].sum = (tr[lc].sum + tr[u].tag * tr[lc].len()) % mod;
tr[rc].sum = (tr[rc].sum + tr[u].tag * tr[rc].len()) % mod;
tr[lc].tag += tr[u].tag, tr[rc].tag += tr[u].tag;
tr[u].tag = 0;
}

void modify(int u, int l, int r, int k)
{
if (l <= tr[u].l && tr[u].r <= r)
{
tr[u].sum = (tr[u].sum + tr[u].len() * k) % mod;
tr[u].tag += k;
return;
}

pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(lc, l, r, k);
if (r > mid) modify(rc, l, r, k);
pushup(u);
}

int query(int u, int l, int r)
{
if (l <= tr[u].l && tr[u].r <= r)
return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int res = 0;
if (l <= mid) res = (res + query(lc, l, r)) % mod;
if (r > mid) res = (res + query(rc, l, r)) % mod;
return res;
}
} SGT;

inline void add(int a, int b)
{
e[++ idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs1(int ver, int pre, int deep)
{
depth[ver] = deep, fa[ver] = pre, sz[ver] = 1;
int maxn = -1;
for (int i = h[ver]; ~i; i = ne[i])
{
int j = e[i];
if (j == pre) continue;
dfs1(j, ver, deep + 1);
sz[ver] += sz[j];

if (maxn == -1 || maxn < sz[j]) maxn = sz[j], hson[ver] = j;
}
}

void dfs2(int ver, int topf)
{
dfn[ver] = ++ timestamp, val[timestamp] = a[ver], top[ver] = topf;
if (!hson[ver]) return;
dfs2(hson[ver], topf);

for (int i = h[ver]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[ver] || j == hson[ver]) continue;
dfs2(j, j);
}
}

void modify_range(int x, int y, int k)
{
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]]) swap(x, y);
SGT.modify(1, dfn[top[x]], dfn[x], k);
x = fa[top[x]];
}

if (depth[x] > depth[y]) swap(x, y);
SGT.modify(1, dfn[x], dfn[y], k);
}

int query_range(int x, int y)
{
int res = 0;
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]]) swap(x, y);
res = (res + SGT.query(1, dfn[top[x]], dfn[x])) % mod;
x = fa[top[x]];
}

if (depth[x] > depth[y]) swap(x, y);
res = (res + SGT.query(1, dfn[x], dfn[y])) % mod;
return res;
}

void modify_subtree(int x, int k)
{
SGT.modify(1, dfn[x], dfn[x] + sz[x] - 1, k);
}

int query_subtree(int x)
{
return SGT.query(1, dfn[x], dfn[x] + sz[x] - 1);
}

signed main()
{
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
memset(h, -1, sizeof h);
cin >> n >> m >> root >> mod;
for (int i = 1; i <= n; i ++) cin >> a[i];
for (int i = 1; i < n; i ++)
{
int u, v; cin >> u >> v;
add(u, v), add(v, u);
}

dfs1(root, 0, 1);
dfs2(root, root);
SGT.build(1, 1, n);

while (m --)
{
int opt; cin >> opt;
if (opt == 1)
{
int x, y, z; cin >> x >> y >> z;
modify_range(x, y, z);
}
else if (opt == 2)
{
int x, y; cin >> x >> y;
cout << query_range(x, y) % mod << '\n';
}
else if (opt == 3)
{
int x, y; cin >> x >> y;
modify_subtree(x, y);
}
else
{
int x; cin >> x;
cout << query_subtree(x) % mod << '\n';
}
}

return 0;
}

Reference

树链剖分 - OI Wiki

P3384 【模板】重链剖分/树链剖分 题解