二叉搜索树
www.luogu.com.cn
定义
二叉搜索树(Binary Search Tree)是一种形状如二叉树的数据结构,用于快速查找和增加删除操作,它有如下几个特殊性质:
- 空树是二叉搜索树。
- 若二叉搜索树的左子树不为空,则其左子树上所有点的附加权值均小于其根节点的值。
- 若二叉搜索树的右子树不为空,则其右子树上所有点的附加权值均大于其根节点的值。
- 二叉搜索树的左右子树均为二叉搜索树。
对于一个二叉搜索树,进行操作所花费的平均时间和这个二叉树的高度成正比。对于一个有 n 个结点的二叉搜索树中,这些操作的最优时间复杂度为 O(logn),最坏为 O(n)。
操作过程
定义
对于一个二叉搜索树的每一个节点,都要存储以下几样东西:
- 左子树节点和右子树节点:
l
和 r
- 这一个节点上存储的权值:
val
- 表示这个节点权值出现次数的计数器:
cnt
- 代表这个树上的大小(即子树和自己大小之和):
size
1 2 3 4 5 6
| struct BST { int l, r; int val; int cnt, size; }tr[N];
|
与此同时,由于操作中需要不断插入新的数,因此还需要两个变量分别储存树的下标和当前的下标,我们用 root
和 idx
表示:
创建新节点
创建一个新的节点就是将这个节点全部初始化,同时返回原来的下标值:
1 2 3 4 5 6
| int new_node(int k) { tr[idx].val = k; tr[idx].size = tr[idx].cnt = 1; return idx; }
|
上传信息
像线段树一样,二叉搜索树同样也有需要维护的信息:每个子树的本身大小,有二叉树的性质可得出:
tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt
1 2 3 4
| void pushup(int u) { tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt; }
|
初始化
为了防止有时候询问的二叉搜索树为空,我们可以先在树中加入两个哨兵:-INF 和 INF,在不断的插入中,他们会始终在树的左右两端,从而有效防止查询越界。
1 2 3 4 5 6
| void build() { new_node(-INF), new_node(INF); root = 1, tr[1].r = 2; pushup(root); }
|
插入
根据二叉搜索树的性质可得,对于每个节点 u 以及要插入的权值 k,可以分为四种情况:
- u=k 时,直接在该节点的计数变量上
cnt ++
。
- u<k 时,递归到该节点的右子树节点继续插入。
- u>k 时,递归到该节点的左子树节点继续插入。
- u=0 时,则说明没有这个节点,直接利用
idx
创建一个。
tips:在递归完之后要 pushup
一遍,从而维护每个子树的大小。
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| void insert(int &u, int k) { if (u == 0) u = new_node(k); else { if (tr[u].val == k) tr[u].cnt ++; else { if (tr[u].val > k) insert(tr[u].l, k); else insert(tr[u].r, k); } } pushup(u); }
|
删除
原理和插入相似,不多赘述,一共五种情况:
- u=k 时,直接在该节点的计数变量上
cnt --
。
- u 节点为叶子节点时,直接
u = 0
。
- u<k 时,递归到该节点的右子树节点继续删除。
- u>k 时,递归到该节点的左子树节点继续删除。
- u=0 时,则说明没有这个节点,直接
return
掉 。
tips:在递归完之后要 pushup
一遍,从而维护每个子树的大小。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| void del(int &u, int k) { if (u == 0) return ; if (tr[u].val == k) { if (tr[u].cnt > 1) tr[u].cnt --; else { if (tr[u].l || tr[u].r) { if (!tr[u].r) del(tr[u].r, k); else del(tr[u].l, k); } else u = 0; } } else if (tr[u].val > k) del(tr[u].l, k); else del(tr[u].r, k); pushup(u); }
|
找前驱
一个数 x 的前驱被定义为小于 x 的最大的数。
对于一个节点 u 的权值以及权值 k,它的查询分为以下三种情况:
- u=0 时,此时说明找到了最左端,应当
return -INF
。
- u≥k 时,说明 $k $的前驱还可能在当前节点的左子树,向左递归。
- 其余的情况则说明 u≤k,此时分为两种情况,一是前驱就是 u,二是前驱在右子树当中,因此在两者中取 max。
1 2 3 4 5 6
| int get_prev(int u, int k) { if (u == 0) return -INF; if (tr[u].val >= k) return get_prev(tr[u].l, k); return max(get_prev(tr[u].r, k), tr[u].val); }
|
找后驱
一个数 x 的后驱被定义为大于 x 的最小的数。
原理同上
1 2 3 4 5 6
| int get_next(int u, int k) { if (u == 0) return INF; if (tr[u].val <= k) return get_next(tr[u].r, k); return min(get_next(tr[u].l, k), tr[u].val); }
|
通过排名找权值
排名定义为比当前数小的数的个数 +1。若有多个相同的数,应输出最小的排名。
对于一个节点 u 以及 rank 值,可以分为以下四种情况:
- tr[u].val=0 时,说明这个权值无限大,返回 INF。
- tr[tr[u].l].size ≥rank,说明权值在左子树里面,向左递归。
- tr[tr[u].l].size + tr[u].cnt≥rank,由于上面的限制条件,说明该节点权值即为答案。
- 否则向右进行递归,并且注意右子树中的 $rank $ 值是相对的。
1 2 3 4 5 6 7
| int get_val_by_rank(int u, int rank) { if (u == 0) return INF; if (tr[tr[u].l].size >= rank) return get_val_by_rank(tr[u].l, rank); if (tr[tr[u].l].size + tr[u].cnt >= rank) return tr[u].val; return get_val_by_rank(tr[u].r, rank - tr[tr[u].l].size - tr[u].cnt); }
|
通过权值找排名
原理和上面反过来差不多
1 2 3 4 5 6 7
| int get_rank_by_val(int u, int k) { if (u == 0) return 0; if (tr[u].val == k) return tr[tr[u].l].size + 1; if (tr[u].val > k) return get_rank_by_val(tr[u].l, k); return tr[tr[u].l].size + tr[u].cnt + get_rank_by_val(tr[u].r, k); }
|
模板代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull;
const int N = 1e5 + 9; const int INF = 2147483647;
struct BST { int l, r; int val; int cnt, size; }tr[N]; int root, idx;
void pushup(int u) { tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt; }
int new_node(int k) { tr[++ idx].val = k; tr[idx].size = tr[idx].cnt = 1; return idx; }
void build() { new_node(-INF), new_node(INF); root = 1, tr[1].r = 2; pushup(root); }
void insert(int &u, int k) { if (u == 0) u = new_node(k); else { if (tr[u].val == k) tr[u].cnt ++; else { if (tr[u].val > k) insert(tr[u].l, k); else insert(tr[u].r, k); } } pushup(u); }
void del(int &u, int k) { if (u == 0) return ; if (tr[u].val == k) { if (tr[u].cnt > 1) tr[u].cnt --; else { if (tr[u].l || tr[u].r) { if (!tr[u].r) del(tr[u].r, k); else del(tr[u].l, k); } else u = 0; } } else if (tr[u].val > k) del(tr[u].l, k); else del(tr[u].r, k); pushup(u); }
int get_rank_by_val(int u, int k) { if (u == 0) return 0; if (tr[u].val == k) return tr[tr[u].l].size + 1; if (tr[u].val > k) return get_rank_by_val(tr[u].l, k); return tr[tr[u].l].size + tr[u].cnt + get_rank_by_val(tr[u].r, k); }
int get_val_by_rank(int u, int rank) { if (u == 0) return INF; if (tr[tr[u].l].size >= rank) return get_val_by_rank(tr[u].l, rank); if (tr[tr[u].l].size + tr[u].cnt >= rank) return tr[u].val; return get_val_by_rank(tr[u].r, rank - tr[tr[u].l].size - tr[u].cnt); }
int get_prev(int u, int k) { if (u == 0) return -INF; if (tr[u].val >= k) return get_prev(tr[u].l, k); return max(get_prev(tr[u].r, k), tr[u].val); }
int get_next(int u, int k) { if (u == 0) return INF; if (tr[u].val <= k) return get_next(tr[u].r, k); return min(get_next(tr[u].l, k), tr[u].val); }
signed main() { build(); int n; cin >> n; while (n --) { int op, k; cin >> op >> k; if (op == 5) insert(root, k); else if (op == 1) { insert(root, k); cout << get_rank_by_val(root, k) - 1<< '\n'; del(root, k); } else if (op == 2) cout << get_val_by_rank(root, k + 1) << '\n'; else if (op == 3) cout << get_prev(root, k) << '\n'; else cout << get_next(root, k) << '\n'; } return 0; }
|
平衡树
www.luogu.com.cn
www.luogu.com.cn
定义
平衡树简介
使用搜索树的目的之一是 缩短插入、删除、修改和查找(插入、删除、修改都包括查找操作)节点的时间。 关于查找效率,如果一棵树的高度为 h,在最坏的情况,查找一个关键字需要对比 h 次,查找时间复杂度(也为平均查找长度 ASL,Average Search Length)不超过 O(h)。一棵理想的二叉搜索树所有操作的时间可以缩短到 O(logn)(n 是节点总数)。 然而 O(h) 的时间复杂度仅为理想情况。在最坏情况下,搜索树有可能退化为链表。想象一棵每个结点只有右孩子的二叉搜索树,那么它的性质就和链表一样,所有操作(增删改查)的时间是 O(n)。 可以发现操作的复杂度与树的高度 h 有关。由此引出了平衡树,通过一定操作维持树的高度(平衡性)来降低操作的复杂度。
一个平衡树有如下几个定义:
- 它是一个二叉搜索树
- 对于任意一个节点的子树,每一个节点的左子树和右子树的高度差最多为 1。
旋转Treap
简介
Treap(树堆)是一种 弱平衡 的 二叉搜索树。它同时符合二叉搜索树和堆的性质,名字也因此为 tree(树)和 heap(堆)的组合。
其中对于堆的性质是:
- 子节点值(priority)比父节点大或小(取决于是小根堆还是大根堆)。
操作过程
旋转
旋转 Treap通过 左旋 和 右旋 的方式维护平衡树的平衡,通过调整堆的优先级从而达到平衡操作。
旋转操作的含义:
- 在不影响搜索树性质的前提下,把和旋转方向相反的子树变成根节点(如左旋,就是把右子树变成根节点)
- 不影响性质,并且在旋转过后,跟旋转方向相同的子节点变成了原来的根节点(如左旋,旋转完之后的左子节点是旋转前的根节点)
代码思路需要慢慢悟出来
1 2 3 4 5 6 7 8 9 10 11 12 13
| void zig(int &u) { int v = tr[u].l; tr[u].l = tr[v].r, tr[v].r = u, u = v; pushup(tr[u].r), pushup(u); }
void zag(int &u) { int v = tr[u].r; tr[u].r = tr[v].l, tr[v].l = u, u = v; pushup(tr[u].l), pushup(u); }
|
插入
和二叉搜索树差不多,只不过要在插入的过程中维护好堆的性质。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| void insert(int &u, int k) { if (u == 0) u = new_node(k); else { if (tr[u].val == k) tr[u].cnt ++; else { if (tr[u].val > k) { insert(tr[u].l, k); if (tr[tr[u].l].key > tr[u].key) zig(u); } else { insert(tr[u].r, k); if (tr[tr[u].r].key < tr[u].key) zag(u); } } } pushup(u); }
|
删除
同样维护好堆的性质
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| void del(int &u, int k) { if (u == 0) return ; if (tr[u].val == k) { if (tr[u].cnt > 1) tr[u].cnt --; else { if (tr[u].l || tr[u].r) { if (!tr[u].r || tr[tr[u].l].key > tr[tr[u].r].key) { zig(u); del(tr[u].r, k); } else { zag(u); del(tr[u].l, k); } } else u = 0; } } else if (tr[u].val > k) del(tr[u].l, k); else del(tr[u].r, k); pushup(u); }
|
模板代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull;
const int N = 1e5 + 9; const int INF = 0x3f3f3f3f;
struct BST { int l, r; int val, key; int cnt, size; }tr[N]; int root, idx;
void pushup(int u) { tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt; }
int new_node(int k) { tr[++ idx].val = k; tr[idx].key = rand(); tr[idx].size = tr[idx].cnt = 1; return idx; }
void build() { new_node(-INF), new_node(INF); root = 1, tr[1].r = 2; pushup(root); }
void zig(int &u) { int v = tr[u].l; tr[u].l = tr[v].r, tr[v].r = u, u = v; pushup(tr[u].r), pushup(u); }
void zag(int &u) { int v = tr[u].r; tr[u].r = tr[v].l, tr[v].l = u, u = v; pushup(tr[u].l), pushup(u); }
void insert(int &u, int k) { if (u == 0) u = new_node(k); else { if (tr[u].val == k) tr[u].cnt ++; else { if (tr[u].val > k) { insert(tr[u].l, k); if (tr[tr[u].l].key > tr[u].key) zig(u); } else { insert(tr[u].r, k); if (tr[tr[u].r].key < tr[u].key) zag(u); } } } pushup(u); }
void del(int &u, int k) { if (u == 0) return ; if (tr[u].val == k) { if (tr[u].cnt > 1) tr[u].cnt --; else { if (tr[u].l || tr[u].r) { if (!tr[u].r || tr[tr[u].l].key > tr[tr[u].r].key) { zig(u); del(tr[u].r, k); } else { zag(u); del(tr[u].l, k); } } else u = 0; } } else if (tr[u].val > k) del(tr[u].l, k); else del(tr[u].r, k); pushup(u); }
int get_rank_by_val(int u, int k) { if (u == 0) return 0; if (tr[u].val == k) return tr[tr[u].l].size + 1; if (tr[u].val > k) return get_rank_by_val(tr[u].l, k); return tr[tr[u].l].size + tr[u].cnt + get_rank_by_val(tr[u].r, k); }
int get_val_by_rank(int u, int rank) { if (u == 0) return INF; if (tr[tr[u].l].size >= rank) return get_val_by_rank(tr[u].l, rank); if (tr[tr[u].l].size + tr[u].cnt >= rank) return tr[u].val; return get_val_by_rank(tr[u].r, rank - tr[tr[u].l].size - tr[u].cnt); }
int get_prev(int u, int k) { if (u == 0) return -INF; if (tr[u].val >= k) return get_prev(tr[u].l, k); return max(get_prev(tr[u].r, k), tr[u].val); }
int get_next(int u, int k) { if (u == 0) return INF; if (tr[u].val <= k) return get_next(tr[u].r, k); return min(get_next(tr[u].l, k), tr[u].val); }
signed main() { build(); int n; cin >> n; while (n --) { int op, k; cin >> op >> k; if (op == 1) insert(root, k); else if (op == 2) del(root, k); else if (op == 3) { insert(root, k); cout << get_rank_by_val(root, k) - 1<< '\n'; del(root, k); } else if (op == 4) cout << get_val_by_rank(root, k + 1) << '\n'; else if (op == 5) cout << get_prev(root, k) << '\n'; else cout << get_next(root, k) << '\n'; } return 0; }
|