[TJOI2011] 树的序

大意

  • 有一个序列 aa ,元素各不相同,按照如下的方式建成二叉查找树:
    • 按顺序插入每个元素;
    • 若树是空的,直接成为根节点;
    • 若比当前节点小,插入左子树,否则插入右子树。
  • 给定一个序列,求生成的树一样的同种序列字典序最小。
  • 序列长度 n105n \le 10^5

题解

答案就是建成树后的先序遍历,难点在优化建树。

对于一个节点 uu ,序列中的元素 xx 会插入进来,当且仅当遍历时,若 uu 在遍历到的 vv 的左子树中,x<val(v)x \lt val(v) ,否则 x>val(v)x \gt val(v)

记上述的 xx 的最大值和最小值为 mx(u)mx(u)mn(u)mn(u)
则对于左右儿子,有:

mn(l(u))=mn(u),mx(l(u))=val(u)1mn(r(u))=val(u)+1,mx(r(u))=mx(u)mn(l(u)) = mn(u), mx(l(u)) = val(u) - 1 \newline mn(r(u)) = val(u) + 1, mx(r(u)) = mx(u) \newline

而每个节点的左右儿子,是接下来的第一个 mn(u)l(u)<val(u)<r(u)mx(u)mn(u) \le l(u) \lt val(u) \lt r(u) \le mx(u)

用线段树维护 pospos 序列,其中 pos(u)=ia(i)=upos(u) = i \Leftarrow a(i) = u ,支持获取区间最小值和单点修改。

从左往右考虑,若到 ii ,使用线段树获取最小的满足条件的 l(i)l(i)r(i)r(i) ,然后 pos(i)pos(i) \leftarrow \infty ,使线段树只获取后面的值。

代码

提交记录

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <bits/stdc++.h>

using namespace std;

const int MAXN = 1e5 + 10;

int n;
int a[MAXN], pos[MAXN];

struct SegmentTree
{
int minx[MAXN * 4];

void build(int x, int l, int r) {
if (l == r) {
minx[x] = pos[l];
return;
}
int mid = (l + r) / 2;
build(x * 2, l, mid), build(x * 2 + 1, mid + 1, r);
minx[x] = min(minx[x * 2], minx[x * 2 + 1]);
}

int query(int x, int l, int r, int ql, int qr) {
if (ql > qr) return 1e9;
if (ql <= l && r <= qr) return minx[x];
int res = 1e9, mid = (l + r) / 2;
if (ql <= mid) res = min(res, query(x * 2, l, mid, ql, qr));
if (qr >= mid + 1) res= min(res, query(x * 2 + 1, mid + 1, r, ql, qr));
return res;
}

void update(int x, int l, int r, int q, int val) {
if (l == r) {
assert(l == q);
minx[x] = val;
return;
}
int mid = (l + r) / 2;
if (q <= mid) update(x * 2, l, mid, q, val);
else update(x * 2 + 1, mid + 1, r, q, val);
minx[x] = min(minx[x * 2], minx[x * 2 + 1]);
}
} seg;

void read()
{
ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);
cin >> n;

for (int i = 1; i <= n; i++) {
cin >> a[i];
pos[a[i]] = i;
}
}

int l[MAXN], r[MAXN], minx[MAXN], maxx[MAXN];

void solve(int i)
{
l[i] = seg.query(1, 1, n, minx[i], a[i] - 1);
r[i] = seg.query(1, 1, n, a[i] + 1, maxx[i]);
if (l[i] < 1e9) {
minx[l[i]] = minx[i];
maxx[l[i]] = a[i] - 1;
}
if (r[i] < 1e9) {
minx[r[i]] = a[i] + 1;
maxx[r[i]] = maxx[i];
}

}

void output(int i)
{
cout << a[i] << " ";
if (l[i] < 1e9) output(l[i]);
if (r[i] < 1e9) output(r[i]);
}

int main()
{
read();
seg.build(1, 1, n);

minx[1] = 1, maxx[1] = n;
for (int i = 1; i <= n; i++) solve(i);

output(1);
cout << endl;
}