1 /**
2 Unstable
3 */
4 
5 module dkh.tree;
6 
7 import std.traits;
8 import dkh.array;
9 
10 struct Tree(T, alias _opTT, T _eT, bool hasLazy = false,
11 L = bool, alias _opTL = "a", alias _opLL = "a", L _eL = false) {
12     import std.functional : binaryFun;
13     alias opTT = binaryFun!_opTT;
14     static immutable T eT = _eT;
15     static if (hasLazy) {
16         alias opTL = binaryFun!_opTL;
17         alias opLL = binaryFun!_opLL;
18         static immutable L eL = _eL;
19     }
20     
21     alias NP = Node*;
22     /// Weighted balanced tree
23     static struct Node {
24         NP[2] ch;
25         uint length;
26         T v;
27         static if (hasLazy) L lz = eL;
28         this(in T v) {
29             length = 1;
30             this.v = v;
31         }
32         this(NP l, NP r) {
33             ch = [l, r];
34             update();
35         }
36         void update() {
37             static if (hasLazy) assert(lz == eL);
38             length = ch[0].length + ch[1].length;
39             v = opTT(ch[0].v, ch[1].v);
40         }
41         static if (hasLazy) {
42             void lzAdd(in L x) {
43                 v = opTL(v, x);
44                 lz = opLL(lz, x);
45             }
46         }
47         void push() {
48             static if (hasLazy) {
49                 if (lz == eL) return;
50                 ch[0].lzAdd(lz);
51                 ch[1].lzAdd(lz);
52                 lz = eL;
53             }
54         }
55         NP rot(uint type) {
56             // ty = 0: ((a, b), c) -> (a, (b, c))
57             push();
58             auto m = ch[type];
59             m.push();
60             ch[type] = m.ch[1-type];
61             m.ch[1-type] = &this;
62             update(); m.update();
63             return m;
64         }
65         NP bal() {
66             push();
67             foreach (f; 0..2) {
68                 if (ch[f].length*2 > ch[1-f].length*5) {
69                     ch[f].push();
70                     if (ch[f].ch[1-f].length*2 > ch[1-f].length*5 ||
71                         ch[f].ch[f].length*5 < (ch[f].ch[1-f].length+ch[1-f].length)*2) {
72                         ch[f] = ch[f].rot(1-f);
73                         update();
74                     }
75                     return rot(f);
76                 }
77             }
78             return &this;
79         }
80         NP insert(uint k, in T v) {
81             assert(0 <= k && k <= length);
82             if (length == 1) {
83                 if (k == 0) {
84                     return new Node(new Node(v), &this);
85                 } else {
86                     return new Node(&this, new Node(v));
87                 }
88             }
89             push();
90             if (k < ch[0].length) {
91                 ch[0] = ch[0].insert(k, v);
92             } else {
93                 ch[1] = ch[1].insert(k-ch[0].length, v);
94             }
95             update();
96             return bal();
97         }
98         NP removeAt(uint k) {
99             assert(0 <= k && k < length);
100             if (length == 1) {
101                 return null;
102             }
103             push();
104             if (k < ch[0].length) {
105                 ch[0] = ch[0].removeAt(k);
106                 if (ch[0] is null) return ch[1];
107             } else {
108                 ch[1] = ch[1].removeAt(k-ch[0].length);
109                 if (ch[1] is null) return ch[0];
110             }
111             update();
112             return bal();
113         }
114         const(T) at(uint k) {
115             assert(0 <= k && k < length);
116             if (length == 1) return v;
117             push();
118             if (k < ch[0].length) return ch[0].at(k);
119             return ch[1].at(k-ch[0].length);
120         }
121         void atAssign(uint k, in T x) {
122             assert(0 <= k && k < length);
123             if (length == 1) {
124                 v = x;
125                 return;
126             }
127             push();
128             if (k < ch[0].length) ch[0].atAssign(k, x);
129             else ch[1].atAssign(k-ch[0].length, x);
130             update();
131         }        
132         const(T) sum(int a, int b) {
133             if (b <= 0 || length.to!int <= a) return eT;
134             if (a <= 0 && length.to!int <= b) return v;
135             push();
136             return opTT(ch[0].sum(a, b), ch[1].sum(a - ch[0].length, b - ch[0].length));
137         }
138         static if (hasLazy) {
139             void add(int a, int b, L x) {
140                 if (b <= 0 || length.to!int <= a) return;
141                 if (a <= 0 && length.to!int <= b) {
142                     lzAdd(x);
143                     return;
144                 }
145                 push();
146                 ch[0].add(a, b, x);
147                 ch[1].add(a - ch[0].length.to!int, b - ch[0].length.to!int, x);
148                 update();
149             }
150         }
151         void check() {
152             if (length == 1) return;
153             assert(length == ch[0].length + ch[1].length);
154             ch[0].check();
155             ch[1].check();
156             assert(ch[0].length*5 >= ch[1].length*2);
157             assert(ch[1].length*5 >= ch[0].length*2);
158         }
159         void pr() {
160             import std.stdio;
161             if (length == 1) {
162                 writef("(%d)", v);
163                 return;
164             }
165             write("(");
166             ch[0].pr();
167             write(v);
168             ch[1].pr();
169             write(")");
170         }
171     }
172     static NP merge(NP l, NP r, NP buf = null) {
173         if (!l) return r;
174         if (!r) return l;
175         if (l.length*2 > r.length*5) {
176             l.push();
177             l.ch[1] = merge(l.ch[1], r, buf);
178             l.update();
179             return l.bal();
180         } else if (l.length*5 < r.length*2) {
181             r.push();
182             r.ch[0] = merge(l, r.ch[0], buf);
183             r.update();
184             return r.bal();
185         }
186         if (buf == null) buf = new Node();
187         buf.ch = [l, r];
188         buf.update();
189         return buf;
190     }
191     static NP[2] split(NP n, uint k) {
192         if (!n) return [null, null];
193         if (n.length == 1) {
194             if (k == 0) return [null, n];
195             else return [n, null];
196         }
197         NP[2] p;
198         n.push();
199         if (k < n.ch[0].length) {
200             p = split(n.ch[0], k);
201             p[1] = merge(p[1], n.ch[1], n);
202         } else {
203             p = split(n.ch[1], k - n.ch[0].length);
204             p[0] = merge(n.ch[0], p[0], n);
205         }
206         return p;
207     }
208     import std.conv : to;
209     Node* tr;
210     this(T v) { tr = new Node(v); }
211     this(Node* tr) { this.tr = tr; }
212     this(in T[] v) {
213         if (v.length == 0) return;
214         if (v.length == 1) {
215             tr = new Node(v[0]);
216             return;
217         }
218         auto ltr = Tree(v[0..$/2]);
219         auto rtr = Tree(v[$/2..$]);
220         this = ltr.merge(rtr);
221     }
222     @property size_t length() const { return (!tr ? 0 : tr.length); }
223     alias opDollar = length;
224     
225     void insert(size_t k, in T v) {
226         assert(0 <= k && k <= length);
227         if (tr is null) {
228             tr = new Node(v);
229             return;
230         }
231         tr = tr.insert(k.to!int, v);
232     }
233     void removeAt(size_t k) {
234         assert(0 <= k && k < length);
235         tr = tr.removeAt(k.to!int);
236     }
237     Tree trim(size_t a, size_t b) {
238         auto v = split(tr, b.to!uint);
239         auto u = split(v[0], a.to!uint);
240         tr = merge(u[0], v[1]);
241         return Tree(u[1]);
242     }
243     Tree split(size_t k) {
244         auto u = split(tr, k.to!uint);
245         tr = u[0];
246         return Tree(u[1]);
247     }
248     ref Tree merge(Tree r) {
249         tr = merge(tr, r.tr);
250         return this;
251     }
252     static if (hasLazy) {        
253         void opIndexOpAssign(string op : "+")(in L x, size_t[2] rng) {
254             if (!tr) return;
255             tr.add(rng[0].to!uint, rng[1].to!uint, x);
256         }
257     }
258     const(T) opIndex(size_t k) {
259         assert(0 <= k && k < length);
260         return tr.at(k.to!int);
261     }
262     void opIndexAssign(in T x, size_t k) {
263         return tr.atAssign(k.to!int, x);
264     }
265     struct Range {
266         Tree* eng;
267         size_t start, end;
268         @property T sum() {
269             if (!eng.tr) return eT;
270             return eng.tr.sum(start.to!uint, end.to!uint);
271         }
272     }
273     size_t[2] opSlice(size_t dim)(size_t start, size_t end) {
274         assert(0 <= start && start <= end && end <= length());
275         return [start, end];
276     }
277     Range opIndex(size_t[2] rng) {
278         return Range(&this, rng[0].to!uint, rng[1].to!uint);
279     }
280     string toString() {
281         //todo: more optimize
282         import std.range : iota;
283         import std.algorithm : map;
284         import std.conv : to;
285         string s;
286         s ~= "Tree(";
287         s ~= iota(length).map!(i => this[i]).to!string;
288         s ~= ")";
289         return s;
290     }
291     void check() {
292         if (tr) tr.check();
293     }
294     void pr() {
295         if (tr) tr.pr();
296     }
297 }
298 
299 alias SimpleTree(T, alias op, T e) = Tree!(T, op, e);
300 alias LazyTree(T, L, alias opTT, alias opTL, alias opLL, T eT, L eL) = 
301     Tree!(T, opTT, eT, true, L, opTL, opLL, eL);
302 
303 import std.traits : isInstanceOf;
304 
305 ptrdiff_t binSearchLeft(alias pred, T)(T t, ptrdiff_t _a, ptrdiff_t _b)
306 if(isInstanceOf!(Tree, T)) {
307     import std.conv : to;
308     import std.traits : Unqual;
309     int a = _a.to!int, b = _b.to!int;
310     Unqual!(typeof(T.eT)) x = T.eT;
311     if (pred(x)) return a-1;
312     if (t.tr is null) return 0;
313     
314     alias opTT = T.opTT;
315     int pos = a;
316     void f(T.Node* n, int a, int b, int offset) {
317         if (b <= offset || offset + n.length <= a) return;
318         if (a <= offset && offset + n.length <= b && !pred(opTT(x, n.v))) {
319             x = opTT(x, n.v);
320             pos = offset + n.length;
321             return;
322         }
323         if (n.length == 1) return;
324         f(n.ch[0], a, b, offset);
325         if (pos >= offset + n.ch[0].length) {
326             f(n.ch[1], a, b, offset + n.ch[0].length);
327         }
328     }
329     f(t.tr, a, b, 0);
330     return pos;
331 }
332 
333 ptrdiff_t binSearchRight(alias pred, T)(T t, ptrdiff_t a, ptrdiff_t b)
334 if(isInstanceOf!(Tree, T)) {
335     import std.conv : to;
336     import std.traits : Unqual;
337     int a = _a.to!int, b = _b.to!int;
338     Unqual!(typeof(T.e)) x = T.e;
339     if (pred(x)) return b;
340     if (t.tr is null) return 0;
341 
342     alias op = T.op;
343     int pos = b-1;
344     void f(T.Node* n, int a, int b, int offset) {
345         if (b <= offset || offset + n.length <= a) return;
346         if (a <= offset && offset + n.length <= b && !pred(opTT(n.v, x))) {
347             x = opTT(n.v, x);
348             pos = offset - 1;
349             return;
350         }
351         if (n.length == 1) return;
352         f(n.ch[1], a, b, offset + n.ch[0].length);
353         if (pos < offset + n.ch[0].length) {
354             f(n.ch[0], a, b, offset);
355         }
356     }
357     f(t.tr, a, b, 0);
358     return pos;
359 }
360 
361 unittest {
362     import std.meta : AliasSeq;
363     import std.random;
364     import dkh.modint;
365     alias Mint = ModInt!(10^^9 + 7);
366     auto rndM = (){ return Mint(uniform(0, 10^^9 + 7)); };
367     void check() {
368         alias T = SimpleTree!(Mint, "a+b", Mint(0));
369         T t;
370         Mint sm = 0;
371         foreach (i; 0..100) {
372             auto x = rndM();
373             sm += x;
374             t.insert(0, x);
375         }
376         assert(sm == t[0..$].sum);
377     }
378     check();
379 }
380 
381 
382 unittest {
383     import std.conv : to;
384     import std.meta : AliasSeq;
385     import std.algorithm : swap;
386     import std.random;
387     import dkh.modint, dkh.foundation;
388     alias Mint = ModInt!(10^^9 + 7);
389     auto rndM = (){ return Mint(uniform(0, 10^^9 + 7)); };
390     void check() {
391         alias T1 = SimpleTree!(Mint[2], (a, b) => [a[0]+b[0], a[1]+b[1]].fixed, [Mint(0), Mint(0)].fixed);
392         alias T2 = LazyTree!(Mint[2], Mint,
393             (a, b) => [a[0]+b[0], a[1]+b[1]].fixed, 
394             (a, b) => [a[0]+a[1]*b, a[1]].fixed,
395             (a, b) => a+b,
396             [Mint(0), Mint(0)].fixed, Mint(0));
397         T1 t1;
398         T2 t2;
399         foreach (ph; 0..1000) {
400             assert(t1.length == t2.length);
401             int L = t1.length.to!int;
402             int ty = uniform(0, 3);
403             if (ty == 0) {
404                 auto x = rndM();
405                 auto idx = uniform(0, L+1);
406                 t1.insert(idx, [x, Mint(1)]);
407                 t2.insert(idx, [x, Mint(1)]);
408             } else if (ty == 1) {
409                 int l = uniform(0, L+1);
410                 int r = uniform(0, L+1);
411                 if (l > r) swap(l, r);
412                 assert(t1[0..$].sum == t2[0..$].sum);
413             } else {
414                 int l = uniform(0, L+1);
415                 int r = uniform(0, L+1);
416                 if (l > r) swap(l, r);
417                 auto x = rndM();
418                 foreach (i; l..r) {
419                     auto u = t1[i];
420                     t1[i] = [t1[i][0] + x, t1[i][1]];
421                 }
422                 t2[l..r] += x;
423             }
424             t1.check();
425             t2.check();
426         }
427     }
428     check();
429 }
430 
431 unittest {
432     import std.random;
433     import std.algorithm;
434     import std.conv;
435     import std.container.rbtree;
436     import std.stdio;
437 
438     import dkh.stopwatch;
439     StopWatch sw; sw.start;
440     auto nv = redBlackTree!(true, int)([]);
441     alias T = SimpleTree!(int, max, int.min);
442     auto tr = T();
443     foreach (ph; 0..10000) {
444         int ty = uniform(0, 2);
445         if (ty == 0) {
446             int x = uniform(0, 100);
447             nv.insert(x);
448             auto idx = tr.binSearchLeft!(y => x <= y)(0, tr.length);
449             if (uniform(0, 2)) {
450                 auto tr2 = tr.trim(idx, tr.length); 
451                 tr.merge(T(x)).merge(tr2);
452             } else {
453                 tr.insert(idx, x);
454             }
455         } else {
456             if (!nv.length) continue;
457             int i = uniform(0, nv.length.to!int);
458             auto u = nv[];
459             foreach (_; 0..i) u.popFront();
460             assert(u.front == tr[i]);
461             int x = tr[i];
462             nv.removeKey(x);
463             if (uniform(0, 2)) {
464                 tr.removeAt(i);
465             } else {
466                 tr.trim(i, i+1);
467             }
468         }
469         tr.check();
470         assert(nv.length == tr.length);
471     }
472     writeln("Set TEST: ", sw.peek.toMsecs);
473 }