- 前言
- 一、例题 p3369
- 二、思路及代码
- 1.思路
- 2.代码
splay 是 tarjan老爷子的又一发明,解决了平衡树中的很多问题
当然,splay的内容学起来也是很复杂,于是我试着用图的方式将这个算法展示出来
splay实现的数据结构是结构体所维护的二叉树,每个节点保存 父节点,儿子节点,节点权值,权值出现个数,和子树大小这几个信息
旋转过程主要由update(), rotate(), splay() 即 更新函数、旋转函数、splay函数三个函数实现,其分别用来:更新旋转后的子树规模,左旋以及右旋,强制更新至根节点
rotate:
可对照代码一起理解
void rotate(int x) {
int y = t[x].fa;
int z = t[y].fa;
int k = (t[y].child[1] == x); // 用来判断左旋还是右旋
t[z].child[(t[z].child[1] == y)] = x;
t[x].fa = z;
t[y].child[k] = t[x].child[k ^ 1];
t[t[x].child[k ^ 1]].fa = y;
t[x].child[k ^ 1] = y;
t[y].fa = x;
update(y);
update(x);
}
splay: 可简要地分为两种情况:
对照代码:
void splay(int x, int s) {
while (t[x].fa != s) { // s 为目标根节点
int y = t[x].fa, z = t[y].fa;
if (z != s)
(t[z].child[0] == y) ^ (t[y].child[0] == x) ? rotate(x) : rotate(y);
rotate(x);
}
if (s == 0) root = x;
}
这四个转化并不是特别显然,但用手画画还是容易理解的
然后是5个功能函数:find(), insert(), maxnext(), delete(), kth(),分别用来:查询排名,二分插入,前继后继查找,删除节点,查询第k大,我们来一一介绍
find(): 首先利用平衡树性质递归查找节点, 然后进行splay变换, 排名即为左子树大小 insert(): 二分递归查找应插入位置, 然后开点,最后还有splay变换
maxnext(): 首先对x进行find()查询排名, 这样x便处于根节点的位置, 其前继和后继节点也很明显 delete(): 删除是一个较为复杂的操作,具体步骤如下: 首先找到这个数的前驱,把他Splay到根节点 然后找到这个数后继,把他旋转到前驱的底下 比前驱大的数是后继,在右子树 比后继小的且比前驱大的有且仅有当前数 在后继的左子树上面, 因此直接把当前根节点的右儿子的左儿子删掉就可以
kth(): 查询排名第k的数, 其实和find()也很类似, 利用平衡树的性质, 二分查找即可
题目链接:洛谷 p3369
很经典的平衡树模板,直接套模板即可
借鉴了很多ybb的内容:https://www.cnblogs.com/cjyyb/p/7499020.html
2.代码代码如下:
#include
using namespace std;
const int maxn = 201000;
struct splaytree {
int fa, child[2], val, cnt, size;
} t[maxn]; // 父节点,左右子节点,权值,权值出现次数,子树大小
int root, cnt;
void update(int x) {
t[x].size = t[t[x].child[0]].size + t[t[x].child[1]].size + t[x].cnt;
}
void rotate(int x) {
int y = t[x].fa;
int z = t[y].fa;
int k = (t[y].child[1] == x); // 用来判断左旋还是右旋
t[z].child[(t[z].child[1] == y)] = x;
t[x].fa = z;
t[y].child[k] = t[x].child[k ^ 1];
t[t[x].child[k ^ 1]].fa = y;
t[x].child[k ^ 1] = y;
t[y].fa = x;
update(y);
update(x);
}
void splay(int x, int s) {
while (t[x].fa != s) { // s 为目标根节点
int y = t[x].fa, z = t[y].fa;
if (z != s)
(t[z].child[0] == y) ^ (t[y].child[0] == x) ? rotate(x) : rotate(y);
rotate(x);
}
if (s == 0) root = x;
}
void find(int x) {
int u = root;
if (!u) return;
while (t[u].child[x > t[u].val] && x != t[u].val)
u = t[u].child[x > t[u].val];
splay(u, 0);
}
void insert(int x) {
int u = root, fa = 0;
while (u && t[u].val != x) {
fa = u;
u = t[u].child[x > t[u].val];
}
if (u)
t[u].cnt++;
else {
u = ++cnt;
if (fa) t[fa].child[x > t[fa].val] = u;
t[u].child[0] = t[u].child[1] = 0;
t[cnt].fa = fa;
t[cnt].val = x;
t[cnt].cnt = 1;
t[cnt].size = 1;
}
splay(u, 0);
}
int maxnext(int x, int f) {
find(x); // 此时 x 是根节点
int u = root; // f: 0 前继, 1: 后继
if (t[u].val > x && f) return u;
if (t[u].val 1) {
t[del].cnt--;
splay(del, 0);
} else
t[maxnet].child[0] = 0;
}
int kth(int x) {
int u = root;
while (t[u].size t[y].size + t[u].cnt) {
x -= t[y].size + t[u].cnt;
u = t[u].child[1];
} else if (t[y].size >= x)
u = y;
else
return t[u].val;
}
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
int n;
scanf("%d", &n);
insert(1e9);
insert(-1e9);
while (n--) {
int opt, x;
scanf("%d%d", &opt, &x);
if (opt == 1) insert(x);
if (opt == 2) Delete(x);
if (opt == 3) find(x), printf("%d\n", t[t[root].child[0]].size);
if (opt == 4) printf("%d\n", kth(x + 1));
if (opt == 5) printf("%d\n", t[maxnext(x, 0)].val);
if (opt == 6) printf("%d\n", t[maxnext(x, 1)].val);
}
return 0;
}