📄 zoj3018_树套树(树状数组+平衡树,二维求和).cpp
字号:
#include <stdio.h>
#include <string.h>
const int maxn = 50000, maxt = 20001;
struct Node {
Node *left, *right;
int level, size, key, num;
Node() {}
Node(int _key, int _num, Node*_left, Node*_right) {
key = _key;
num = _num;
left = _left;
right = _right;
size = _num;
level = 1;
}
inline void fix_size() {
size = left->size + right->size + num;
}
};
Node*nil;
void nil_init() {
nil = new Node();
nil->left = nil->right = 0;
nil->size = nil->num = 0; nil->level = 0;
}
int less(Node*t, int x) {
Node*p = t;
int res = 0;
while (p != nil) {
if (x < p->key)
p = p->left;
else {
res += p->left->size + p->num;
p = p->right;
}
}
return res;
}
Node*skew(Node*t) {
if (t->left->level == t->level) {
Node*p = t->left;
t->left = p->right;
p->right = t;
t->fix_size();
p->fix_size();
t = p;
}
return t;
}
Node*split(Node*t) {
if (t->right->right->level == t->level) {
Node*p = t->right;
t->right = p->left;
p->left = t;
t->fix_size();
p->fix_size();
t = p;
t->level++;
}
return t;
}
Node*insert(Node*t, const int&x, const int&n) {
if (t == nil)
return new Node(x, n, nil, nil);
if (x < t->key)
t->left = insert(t->left, x, n);
else if (x > t->key)
t->right = insert(t->right, x, n);
else {
t->num += n;
t->size += n;
return t;
}
t = skew(t);
t = split(t);
t->fix_size();
return t;
}
Node*root[maxt];
bool add;
char s[1000];
void insert(int x, int y, int n) {
while (x < maxt) {
root[x] = insert(root[x], y, n);
x += (x&(-x));
}
}
int count(int x, int y1, int y2) {
int res = 0;
while (x > 0) {
res += less(root[x], y2) - less(root[x], y1-1);
x -= (x&(-x));
}
return res;
}
void init() {
nil_init();
for (int i = 0; i < maxt; ++i) root[i] = nil;
}
int main() {
init();
while (gets(s)) {
if (s[0] == 'E')
init();
if (s[0] == 'I')
add = true;
else if (s[0] == 'Q')
add = false;
else if (s[0] >= '0' && s[0] <= '9') {
if (add) {
int x, y, n;
sscanf(s, "%d%d%d", &x, &y, &n);
insert(x, y, n);
}
else {
int x1, y1, x2, y2;
sscanf(s, "%d%d%d%d", &x1, &x2, &y1, &y2);
if (x1 > x2 || y1 > y2) {
printf("0\n");
continue;
}
if (x1 < 1) x1 = 1;
if (y1 < 1) y1 = 1;
if (x2 > 20000) x2 = 20000;
if (y2 > 20000) y2 = 20000;
printf("%d\n", count(x2, y1, y2) - count(x1-1, y1, y2));
}
}
}
return 0;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -