1 module dkh.tree;
2 
3 import std.traits;
4 import dkh.array;
5 
6 struct Tree(T, alias _opTT, T _eT, bool hasLazy = false,
7 L = bool, alias _opTL = "a", alias _opLL = "a", L _eL = false) {
8     import std.functional : binaryFun;
9     alias opTT = binaryFun!_opTT;
10     static immutable T eT = _eT;
11     static if (hasLazy) {
12         alias opTL = binaryFun!_opTL;
13         alias opLL = binaryFun!_opLL;
14         static immutable L eL = _eL;
15     }
16     
17     alias NP = Node*;
18     /// Weighted balanced tree
19     static struct Node {
20         NP[2] ch;
21         uint length;
22         T v;
23         static if (hasLazy) L lz = eL;
24         this(in T v) {
25             length = 1;
26             this.v = v;
27         }
28         this(NP l, NP r) {
29             ch = [l, r];
30             update();
31         }
32         void update() {
33             static if (hasLazy) assert(lz == eL);
34             length = ch[0].length + ch[1].length;
35             v = opTT(ch[0].v, ch[1].v);
36         }
37         static if (hasLazy) {
38             void lzAdd(in L x) {
39                 v = opTL(v, x);
40                 lz = opLL(lz, x);
41             }
42         }
43         void push() {
44             static if (hasLazy) {
45                 if (lz == eL) return;
46                 ch[0].lzAdd(lz);
47                 ch[1].lzAdd(lz);
48                 lz = eL;
49             }
50         }
51         NP rot(uint type) {
52             // ty = 0: ((a, b), c) -> (a, (b, c))
53             push();
54             auto m = ch[type];
55             m.push();
56             ch[type] = m.ch[1-type];
57             m.ch[1-type] = &this;
58             update(); m.update();
59             return m;
60         }
61         NP bal() {
62             push();
63             foreach (f; 0..2) {
64                 if (ch[f].length*2 > ch[1-f].length*5) {
65                     ch[f].push();
66                     if (ch[f].ch[1-f].length*2 > ch[1-f].length*5 ||
67                         ch[f].ch[f].length*5 < (ch[f].ch[1-f].length+ch[1-f].length)*2) {
68                         ch[f] = ch[f].rot(1-f);
69                         update();
70                     }
71                     return rot(f);
72                 }
73             }
74             return &this;
75         }
76         NP insert(uint k, in T v) {
77             assert(0 <= k && k <= length);
78             if (length == 1) {
79                 if (k == 0) {
80                     return new Node(new Node(v), &this);
81                 } else {
82                     return new Node(&this, new Node(v));
83                 }
84             }
85             push();
86             if (k < ch[0].length) {
87                 ch[0] = ch[0].insert(k, v);
88             } else {
89                 ch[1] = ch[1].insert(k-ch[0].length, v);
90             }
91             update();
92             return bal();
93         }
94         NP removeAt(uint k) {
95             assert(0 <= k && k < length);
96             if (length == 1) {
97                 return null;
98             }
99             push();
100             if (k < ch[0].length) {
101                 ch[0] = ch[0].removeAt(k);
102                 if (ch[0] is null) return ch[1];
103             } else {
104                 ch[1] = ch[1].removeAt(k-ch[0].length);
105                 if (ch[1] is null) return ch[0];
106             }
107             update();
108             return bal();
109         }
110         const(T) at(uint k) {
111             assert(0 <= k && k < length);
112             if (length == 1) return v;
113             push();
114             if (k < ch[0].length) return ch[0].at(k);
115             return ch[1].at(k-ch[0].length);
116         }
117         void atAssign(uint k, in T x) {
118             assert(0 <= k && k < length);
119             if (length == 1) {
120                 v = x;
121                 return;
122             }
123             push();
124             if (k < ch[0].length) ch[0].atAssign(k, x);
125             else ch[1].atAssign(k-ch[0].length, x);
126             update();
127         }        
128         const(T) sum(int a, int b) {
129             if (b <= 0 || length.to!int <= a) return eT;
130             if (a <= 0 && length.to!int <= b) return v;
131             push();
132             return opTT(ch[0].sum(a, b), ch[1].sum(a - ch[0].length, b - ch[0].length));
133         }
134         static if (hasLazy) {
135             void add(int a, int b, L x) {
136                 if (b <= 0 || length.to!int <= a) return;
137                 if (a <= 0 && length.to!int <= b) {
138                     lzAdd(x);
139                     return;
140                 }
141                 push();
142                 ch[0].add(a, b, x);
143                 ch[1].add(a - ch[0].length.to!int, b - ch[0].length.to!int, x);
144                 update();
145             }
146         }
147         void check() {
148             if (length == 1) return;
149             assert(length == ch[0].length + ch[1].length);
150             ch[0].check();
151             ch[1].check();
152             assert(ch[0].length*5 >= ch[1].length*2);
153             assert(ch[1].length*5 >= ch[0].length*2);
154         }
155         void pr() {
156             import std.stdio;
157             if (length == 1) {
158                 writef("(%d)", v);
159                 return;
160             }
161             write("(");
162             ch[0].pr();
163             write(v);
164             ch[1].pr();
165             write(")");
166         }
167     }
168     static NP merge(NP l, NP r, NP buf = null) {
169         if (!l) return r;
170         if (!r) return l;
171         if (l.length*2 > r.length*5) {
172             l.push();
173             l.ch[1] = merge(l.ch[1], r, buf);
174             l.update();
175             return l.bal();
176         } else if (l.length*5 < r.length*2) {
177             r.push();
178             r.ch[0] = merge(l, r.ch[0], buf);
179             r.update();
180             return r.bal();
181         }
182         if (buf == null) buf = new Node();
183         buf.ch = [l, r];
184         buf.update();
185         return buf;
186     }
187     static NP[2] split(NP n, uint k) {
188         if (!n) return [null, null];
189         if (n.length == 1) {
190             if (k == 0) return [null, n];
191             else return [n, null];
192         }
193         NP[2] p;
194         n.push();
195         if (k < n.ch[0].length) {
196             p = split(n.ch[0], k);
197             p[1] = merge(p[1], n.ch[1], n);
198         } else {
199             p = split(n.ch[1], k - n.ch[0].length);
200             p[0] = merge(n.ch[0], p[0], n);
201         }
202         return p;
203     }
204     import std.conv : to;
205     Node* tr;
206     this(T v) { tr = new Node(v); }
207     this(Node* tr) { this.tr = tr; }
208     this(in T[] v) {
209         if (v.length == 0) return;
210         if (v.length == 1) {
211             tr = new Node(v[0]);
212             return;
213         }
214         auto ltr = Tree(v[0..$/2]);
215         auto rtr = Tree(v[$/2..$]);
216         this = ltr.merge(rtr);
217     }
218     @property size_t length() const { return (!tr ? 0 : tr.length); }
219     alias opDollar = length;
220     
221     void insert(size_t k, in T v) {
222         assert(0 <= k && k <= length);
223         if (tr is null) {
224             tr = new Node(v);
225             return;
226         }
227         tr = tr.insert(k.to!int, v);
228     }
229     void removeAt(size_t k) {
230         assert(0 <= k && k < length);
231         tr = tr.removeAt(k.to!int);
232     }
233     Tree trim(size_t a, size_t b) {
234         auto v = split(tr, b.to!uint);
235         auto u = split(v[0], a.to!uint);
236         tr = merge(u[0], v[1]);
237         return Tree(u[1]);
238     }
239     Tree split(size_t k) {
240         auto u = split(tr, k.to!uint);
241         tr = u[0];
242         return Tree(u[1]);
243     }
244     ref Tree merge(Tree r) {
245         tr = merge(tr, r.tr);
246         return this;
247     }
248     static if (hasLazy) {        
249         void opIndexOpAssign(string op : "+")(in L x, size_t[2] rng) {
250             if (!tr) return;
251             tr.add(rng[0].to!uint, rng[1].to!uint, x);
252         }
253     }
254     const(T) opIndex(size_t k) {
255         assert(0 <= k && k < length);
256         return tr.at(k.to!int);
257     }
258     void opIndexAssign(in T x, size_t k) {
259         return tr.atAssign(k.to!int, x);
260     }
261     struct Range {
262         Tree* eng;
263         size_t start, end;
264         @property T sum() {
265             if (!eng.tr) return eT;
266             return eng.tr.sum(start.to!uint, end.to!uint);
267         }
268     }
269     size_t[2] opSlice(size_t dim)(size_t start, size_t end) {
270         assert(0 <= start && start <= end && end <= length());
271         return [start, end];
272     }
273     Range opIndex(size_t[2] rng) {
274         return Range(&this, rng[0].to!uint, rng[1].to!uint);
275     }
276     string toString() {
277         //todo: more optimize
278         import std.range : iota;
279         import std.algorithm : map;
280         import std.conv : to;
281         string s;
282         s ~= "Tree(";
283         s ~= iota(length).map!(i => this[i]).to!string;
284         s ~= ")";
285         return s;
286     }
287     void check() {
288         if (tr) tr.check();
289     }
290     void pr() {
291         if (tr) tr.pr();
292     }
293 }
294 
295 alias SimpleTree(T, alias op, T e) = Tree!(T, op, e);
296 alias LazyTree(T, L, alias opTT, alias opTL, alias opLL, T eT, L eL) = 
297     Tree!(T, opTT, eT, true, L, opTL, opLL, eL);
298 
299 import std.traits : isInstanceOf;
300 
301 ptrdiff_t binSearchLeft(alias pred, T)(T t, ptrdiff_t _a, ptrdiff_t _b)
302 if(isInstanceOf!(Tree, T)) {
303     import std.conv : to;
304     import std.traits : Unqual;
305     int a = _a.to!int, b = _b.to!int;
306     Unqual!(typeof(T.eT)) x = T.eT;
307     if (pred(x)) return a-1;
308     if (t.tr is null) return 0;
309     
310     alias opTT = T.opTT;
311     int pos = a;
312     void f(T.Node* n, int a, int b, int offset) {
313         if (b <= offset || offset + n.length <= a) return;
314         if (a <= offset && offset + n.length <= b && !pred(opTT(x, n.v))) {
315             x = opTT(x, n.v);
316             pos = offset + n.length;
317             return;
318         }
319         if (n.length == 1) return;
320         f(n.ch[0], a, b, offset);
321         if (pos >= offset + n.ch[0].length) {
322             f(n.ch[1], a, b, offset + n.ch[0].length);
323         }
324     }
325     f(t.tr, a, b, 0);
326     return pos;
327 }
328 
329 ptrdiff_t binSearchRight(alias pred, T)(T t, ptrdiff_t a, ptrdiff_t b)
330 if(isInstanceOf!(Tree, T)) {
331     import std.conv : to;
332     import std.traits : Unqual;
333     int a = _a.to!int, b = _b.to!int;
334     Unqual!(typeof(T.e)) x = T.e;
335     if (pred(x)) return b;
336     if (t.tr is null) return 0;
337 
338     alias op = T.op;
339     int pos = b-1;
340     void f(T.Node* n, int a, int b, int offset) {
341         if (b <= offset || offset + n.length <= a) return;
342         if (a <= offset && offset + n.length <= b && !pred(opTT(n.v, x))) {
343             x = opTT(n.v, x);
344             pos = offset - 1;
345             return;
346         }
347         if (n.length == 1) return;
348         f(n.ch[1], a, b, offset + n.ch[0].length);
349         if (pos < offset + n.ch[0].length) {
350             f(n.ch[0], a, b, offset);
351         }
352     }
353     f(t.tr, a, b, 0);
354     return pos;
355 }
356 
357 unittest {
358     import std.traits : AliasSeq;
359     import std.random;
360     import dkh.modint;
361     alias Mint = ModInt!(10^^9 + 7);
362     auto rndM = (){ return Mint(uniform(0, 10^^9 + 7)); };
363     void check() {
364         alias T = SimpleTree!(Mint, "a+b", Mint(0));
365         T t;
366         Mint sm = 0;
367         foreach (i; 0..100) {
368             auto x = rndM();
369             sm += x;
370             t.insert(0, x);
371         }
372         assert(sm == t[0..$].sum);
373     }
374     check();
375 }
376 
377 
378 unittest {
379     import std.conv : to;
380     import std.traits : AliasSeq;
381     import std.algorithm : swap;
382     import std.random;
383     import dkh.modint, dkh.foundation;
384     alias Mint = ModInt!(10^^9 + 7);
385     auto rndM = (){ return Mint(uniform(0, 10^^9 + 7)); };
386     void check() {
387         alias T1 = SimpleTree!(Mint[2], (a, b) => [a[0]+b[0], a[1]+b[1]].fixed, [Mint(0), Mint(0)].fixed);
388         alias T2 = LazyTree!(Mint[2], Mint,
389             (a, b) => [a[0]+b[0], a[1]+b[1]].fixed, 
390             (a, b) => [a[0]+a[1]*b, a[1]].fixed,
391             (a, b) => a+b,
392             [Mint(0), Mint(0)].fixed, Mint(0));
393         T1 t1;
394         T2 t2;
395         foreach (ph; 0..1000) {
396             assert(t1.length == t2.length);
397             int L = t1.length.to!int;
398             int ty = uniform(0, 3);
399             if (ty == 0) {
400                 auto x = rndM();
401                 auto idx = uniform(0, L+1);
402                 t1.insert(idx, [x, Mint(1)]);
403                 t2.insert(idx, [x, Mint(1)]);
404             } else if (ty == 1) {
405                 int l = uniform(0, L+1);
406                 int r = uniform(0, L+1);
407                 if (l > r) swap(l, r);
408                 assert(t1[0..$].sum == t2[0..$].sum);
409             } else {
410                 int l = uniform(0, L+1);
411                 int r = uniform(0, L+1);
412                 if (l > r) swap(l, r);
413                 auto x = rndM();
414                 foreach (i; l..r) {
415                     auto u = t1[i];
416                     t1[i] = [t1[i][0] + x, t1[i][1]];
417                 }
418                 t2[l..r] += x;
419             }
420             t1.check();
421             t2.check();
422         }
423     }
424     check();
425 }
426 
427 unittest {
428     import std.random;
429     import std.algorithm;
430     import std.conv;
431     import std.container.rbtree;
432     import std.stdio;
433 
434     import dkh.stopwatch;
435     StopWatch sw; sw.start;
436     auto nv = redBlackTree!(true, int)([]);
437     alias T = SimpleTree!(int, max, int.min);
438     auto tr = T();
439     foreach (ph; 0..10000) {
440         int ty = uniform(0, 2);
441         if (ty == 0) {
442             int x = uniform(0, 100);
443             nv.insert(x);
444             auto idx = tr.binSearchLeft!(y => x <= y)(0, tr.length);
445             if (uniform(0, 2)) {
446                 auto tr2 = tr.trim(idx, tr.length); 
447                 tr.merge(T(x)).merge(tr2);
448             } else {
449                 tr.insert(idx, x);
450             }
451         } else {
452             if (!nv.length) continue;
453             int i = uniform(0, nv.length.to!int);
454             auto u = nv[];
455             foreach (_; 0..i) u.popFront();
456             assert(u.front == tr[i]);
457             int x = tr[i];
458             nv.removeKey(x);
459             if (uniform(0, 2)) {
460                 tr.removeAt(i);
461             } else {
462                 tr.trim(i, i+1);
463             }
464         }
465         tr.check();
466         assert(nv.length == tr.length);
467     }
468     writeln("Set TEST: ", sw.peek.toMsecs);
469 }