1 module dkh.segtree.segex;
2 
3 import dkh.segtree.primitive;
4 import dkh.segtree.simpleseg;
5 import dkh.segtree.lazyseg;
6 
7 struct SimpleSegNaiveEngine(T, alias opTT, T eT) {
8     alias DataType = T;
9     alias LazyType = void;
10     uint n, sz, lg;
11     T[] d;
12     @property uint length() const {return n;}
13     this(uint n) {
14         import std.algorithm : each;
15         uint lg = 0;
16         while ((2^^lg) < n) lg++;
17         this.n = n;
18         this.lg = lg;
19         sz = 2^^lg;
20         d = new T[](2*sz);
21         d.each!((ref x) => x = eT);
22     }
23     this(T[] first) {
24         import std.conv : to;
25         import std.algorithm : each;
26         n = first.length.to!uint;
27         if (n == 0) return;
28         uint lg = 0;
29         while ((2^^lg) < n) lg++;
30         this.lg = lg;
31         sz = 2^^lg;
32         d = new T[](2*sz);
33         d.each!((ref x) => x = eT);
34         foreach (i; 0..n) {
35             d[sz+i] = first[i];
36         }
37         foreach_reverse (i; 1..sz) {
38             update(i);
39         }
40     }
41     private void push(uint k) {}
42     void update(uint k) {
43         d[k] = opTT(d[2*k], d[2*k+1]);
44     }
45     T single(uint k) {
46         return d[k+sz];
47     }
48     void singleSet(uint k, T x) {
49         k += sz;
50         d[k] = x;
51         foreach (uint i; 1..lg+1) {
52             update(k>>i);
53         }
54     }
55     //d[a]+d[a+1]+...+d[b-1]
56     T sum(uint a, uint b, uint l, uint r, uint k) {
57         if (b <= l || r <= a) return eT;
58         if (a <= l && r <= b) return d[k];
59         push(k);
60         uint md = (l+r)/2;
61         return opTT(sum(a, b, l, md, 2*k),
62             sum(a, b, md, r, 2*k+1));
63     }    
64     T sum(uint a, uint b) {
65         assert(0 <= a && a <= b && b <= n);
66         return sum(a, b, 0, sz, 1);
67     }
68 }
69 
70 struct LazySegBlockEngine(T, L, alias opTT, alias opTL, alias opLL, T eT, L eL) {
71     static immutable uint B = 16;
72     import std.algorithm : min;
73     import std.typecons : Tuple;
74     alias DataType = T;
75     alias LazyType = L;
76     static struct Block {
77         T[B] d;
78         this(T[] first) {
79             uint r = min(first.length, B);
80             foreach (i; 0..r) {
81                 d[i] = first[i];
82             }
83             foreach (i; r..B) {
84                 d[i] = eT;
85             }
86         }
87         T sum(uint a, uint b) {
88             T sm = eT;
89             foreach (i; a..b) {
90                 sm = opTT(sm, d[i]);
91             }
92             return sm;
93         }
94         void add(uint a, uint b, L x) {
95             foreach (i; a..b) {
96                 d[i] = opTL(d[i], x);
97             }
98         }
99     }
100     uint N, n, sz, lg;
101     Block[] blks;
102     alias S = Tuple!(T, "d", L, "lz");
103     S[] s;
104     this(uint N) {
105         import std.conv : to;
106         import std.algorithm : each;
107         this.N = N;
108         n = N / B + 1;
109         blks = new Block[n+1];
110         foreach (i; 0..n+1) {
111             blks[i].d.each!((ref x) => x = eT);
112         }
113         uint lg = 0;
114         while ((2^^lg) < n) lg++;
115         this.lg = lg;
116         sz = 2^^lg;
117         s = new S[](2*sz);
118         s.each!((ref x) => x = S(eT, eL));
119     }
120     this(T[] first) {
121         import std.conv : to;
122         import std.algorithm : each;
123         this.N = first.length.to!uint;
124         n = first.length.to!uint / B + 1;
125         blks = new Block[n+1];
126         foreach (i; 0..n) {
127             blks[i] = Block(first[i*B..min($, (i+1)*B)]);
128         }
129         blks[n] = Block([]);
130         uint lg = 0;
131         while ((2^^lg) < n) lg++;
132         this.lg = lg;
133         sz = 2^^lg;
134         s = new S[](2*sz);
135         s.each!((ref x) => x = S(eT, eL));
136         foreach (i; 0..n) {
137             s[sz+i].d = blks[i].sum(0, B);
138         }
139         foreach_reverse (i; 1..sz) {
140             update(i);
141         }
142     }
143     @property size_t length() const { return N; }
144     pragma(inline):
145     private void lzAdd(uint k, in L x) {
146         s[k].lz = opLL(s[k].lz, x);
147         s[k].d = opTL(s[k].d, x);
148     }
149     private void push(uint k) {
150         if (s[k].lz == eL) return;
151         lzAdd(2*k, s[k].lz);
152         lzAdd(2*k+1, s[k].lz);
153         s[k].lz = eL;
154     }
155     private void pushPath(uint k) {
156         k += sz;
157         foreach_reverse (i; 1..lg+1) {
158             push(k>>i);
159         }
160     }
161     private void pushPath2(uint a, uint b) {
162         a += sz; b += sz;
163         foreach_reverse (i; 1..lg+1) {
164             push(a>>i);
165             if ((a>>i) != (b>>i)) push(b>>i);
166         }
167     }
168     private void update(uint k) {
169         s[k].d = opTT(s[2*k].d, s[2*k+1].d);
170     }
171 
172     T single(uint k) {
173         pushPath(k/B);
174         return opTL(blks[k/B].d[k%B], s[k/B+sz].lz);
175     }
176     void singleSet(uint k, T x) {
177         pushPath(k/B);
178         if (s[k/B+sz].lz != eL) blks[k/B].add(0, B, s[k/B+sz].lz);
179         s[k/B+sz].lz = eL;
180         blks[k/B].d[k%B] = x;
181         upPath(k/B);
182     }
183     T sumBody(uint a, uint b) {
184         assert(0 <= a && a <= b && b <= n);
185         T sml = eT, smr = eT;
186         a += sz; b += sz;
187         while (a < b) {
188             if (a & 1) sml = opTT(sml, s[a++].d);
189             if (b & 1) smr = opTT(s[--b].d, smr);
190             a >>= 1; b >>= 1;
191         }
192         return opTT(sml, smr);
193     }
194     T sum(uint a, uint b) {
195         if (a == b) return eT;
196         uint aB = a / B, aC = a % B;
197         uint bB = b / B, bC = b % B;
198         if (aB == bB) {
199             pushPath(aB);
200             return opTL(blks[aB].sum(aC, bC), s[aB+sz].lz);
201         }
202         pushPath2(aB, bB);
203         auto left = opTL(blks[aB].sum(aC, B), s[aB+sz].lz);
204         auto right = opTL(blks[bB].sum(0, bC), s[bB+sz].lz);
205         return opTT(opTT(left, sumBody(aB+1, bB)), right);
206     }
207     void upPath(uint k) {
208         k += sz;
209         s[k].d = blks[k-sz].sum(0, B);
210         foreach (i; 1..lg+1) {
211             k >>= 1;
212             update(k);
213         }
214     }
215     void upPath2(uint a, uint b) {
216         a += sz; b += sz;
217         s[a].d = blks[a-sz].sum(0, B);
218         s[b].d = blks[b-sz].sum(0, B);
219         foreach (i; 1..lg+1) {
220             a >>= 1; b >>= 1;
221             update(a);
222             if (a != b) update(b);
223         }
224     }
225     void addBody(uint a, uint b, L x) {
226         assert(0 <= a && a <= b && b <= n);
227         a += sz; b += sz;
228         while (a < b) {
229             if (a & 1) lzAdd(a++, x);
230             if (b & 1) lzAdd(--b, x);
231             a >>= 1; b >>= 1;
232         }        
233     }
234     void add(uint a, uint b, L x) {
235         if (a == b) return;
236         uint aB = a / B, aC = a % B;
237         uint bB = b / B, bC = b % B;
238         if (aB == bB) {
239             pushPath(aB);
240             if (s[aB+sz].lz != eL) blks[aB].add(0, B, s[aB+sz].lz);
241             s[aB+sz].lz = eL;
242             blks[aB].add(aC, bC, x);
243             upPath(aB);
244             return;
245         }
246         pushPath2(aB, bB);
247         if (s[aB+sz].lz != eL) blks[aB].add(0, B, s[aB+sz].lz);
248         if (s[bB+sz].lz != eL) blks[bB].add(0, B, s[bB+sz].lz);
249         s[aB+sz].lz = s[bB+sz].lz = eL;
250         blks[aB].add(aC, B, x);
251         blks[bB].add(0, bC, x);
252         addBody(aB+1, bB, x);
253         upPath2(aB, bB);
254     }
255 }
256 
257 struct LazySegNaiveEngine(T, L, alias opTT, alias opTL, alias opLL, T eT, L eL) {
258     alias DataType = T;
259     alias LazyType = L;
260     alias BinSearch = binSearchLazyNaive;
261     import std.functional : binaryFun;
262     uint n, sz, lg;
263     T[] d; L[] lz;
264     @property size_t length() const {return n;}
265     this(uint n) {
266         import std.algorithm : each;
267         uint lg = 0;
268         while ((2^^lg) < n) lg++;
269         this.n = n;
270         this.lg = lg;
271         sz = 2^^lg;
272         d = new T[](2*sz);
273         d.each!((ref x) => x = eT);
274         lz = new L[](2*sz);
275         lz.each!((ref x) => x = eL);
276     }
277     this(T[] first) {
278         import std.conv : to;
279         import std.algorithm : each;
280         n = first.length.to!uint;
281         if (n == 0) return;
282         uint lg = 0;
283         while ((2^^lg) < n) lg++;
284         this.lg = lg;
285         sz = 2^^lg;
286         d = new T[](2*sz);
287         d.each!((ref x) => x = eT);
288         foreach (i; 0..n) {
289             d[sz+i] = first[i];
290         }
291         foreach_reverse (i; 1..sz) {
292             update(i);
293         }
294         lz = new L[](2*sz);
295         lz.each!((ref x) => x = eL);
296     }
297     private void lzAdd(uint k, L x) {
298         d[k] = opTL(d[k], x);
299         lz[k] = opLL(lz[k], x);
300     }
301     public void push(uint k) {
302         if (lz[k] == eL) return;
303         lzAdd(2*k, lz[k]);
304         lzAdd(2*k+1, lz[k]);
305         lz[k] = eL;
306     }
307     void update(uint k) {
308         d[k] = opTT(d[2*k], d[2*k+1]);
309     }
310     T single(uint k) {
311         k += sz;
312         foreach_reverse (uint i; 1..lg+1) {
313             push(k>>i);
314         }
315         return d[k];
316     }
317     void singleSet(uint k, T x) {
318         k += sz;
319         foreach_reverse (uint i; 1..lg+1) {
320             push(k>>i);
321         }
322         d[k] = x;
323         foreach (uint i; 1..lg+1) {
324             d[k>>i] = opTT(d[2*(k>>i)], d[2*(k>>i)+1]);
325         }
326     }
327     //d[a]+d[a+1]+...+d[b-1]
328     T sum(uint a, uint b, uint l, uint r, uint k) {
329         if (b <= l || r <= a) return eT;
330         if (a <= l && r <= b) return d[k];
331         push(k);
332         uint md = (l+r)/2;
333         return opTT(sum(a, b, l, md, 2*k),
334             sum(a, b, md, r, 2*k+1));
335     }    
336     T sum(uint a, uint b) {
337         assert(0 <= a && a <= b && b <= n);
338         return sum(a, b, 0, sz, 1);
339     }
340     void add(uint a, uint b, L x, uint l, uint r, uint k) {
341         if (b <= l || r <= a) return;
342         if (a <= l && r <= b) {
343             lzAdd(k, x);
344             return;
345         }
346         push(k);
347         uint md = (l+r)/2;
348         add(a, b, x, l, md, 2*k);
349         add(a, b, x, md, r, 2*k+1);
350         d[k] = opTT(d[2*k], d[2*k+1]);
351     }
352     void add(uint a, uint b, L x) {
353         assert(0 <= a && a <= b && b <= n);
354         add(a, b, x, 0, sz, 1);
355     }
356 }
357 
358 int binSearchLazyNaive(bool rev, alias pred, TR)(TR t, int a, int b) {
359     import std.traits : TemplateArgsOf;
360     alias args = TemplateArgsOf!TR;
361     alias opTT = args[2];
362     auto x = args[5];
363     with (t) {
364         static if (!rev) {
365             //left
366             if (pred(x)) return a-1;
367             int pos = a;
368             void f(int a, int b, int l, int r, int k) {
369                 if (b <= l || r <= a) return;
370                 if (a <= l && r <= b && !pred(opTT(x, d[k]))) {
371                     x = opTT(x, d[k]);
372                     pos = r;
373                     return;
374                 }
375                 if (l+1 == r) return;
376                 push(k);
377                 int md = (l+r)/2;
378                 f(a, b, l, md, 2*k);
379                 if (pos >= md) f(a, b, md, r, 2*k+1);
380             }
381             f(a, b, 0, sz, 1);
382             return pos;
383         } else {
384             //right
385             if (pred(x)) return b;
386             int pos = b-1;
387             void f(int a, int b, int l, int r, int k) {
388                 if (b <= l || r <= a) return;
389                 if (a <= l && r <= b && !pred(opTT(x, d[k]))) {
390                     x = opTT(d[k], x);
391                     pos = l-1;
392                     return;
393                 }
394                 if (l+1 == r) return;
395                 push(k);
396                 int md = (l+r)/2;
397                 f(a, b, md, r, 2*k+1);
398                 if (pos < md) f(a, b, l, md, 2*k);
399             }
400             f(a, b, 0, sz, 1);
401             return pos;
402         }
403     }
404 }
405 
406 unittest {
407     import dkh.segtree.naive;
408     import std.traits : AliasSeq;
409     alias SimpleEngines = AliasSeq!(SimpleSegEngine);
410     alias LazyEngines = AliasSeq!(LazySegEngine, LazySegNaiveEngine);
411 
412     import std.random;
413     
414     void f(alias T)() {
415         auto nav = LazySeg!(uint, uint,
416             (a, b) => (a | b),
417             (a, b) => (a | b),
418             (a, b) => (a | b),
419             0U, 0U, Naive)(100);
420         auto seg = LazySeg!(uint, uint,
421             (a, b) => (a | b),
422             (a, b) => (a | b),
423             (a, b) => (a | b),
424             0U, 0U, T)(100);
425         foreach (i; 0..100) {
426             auto u = uniform!"[]"(0, 31);
427             seg[i] = u;
428             nav[i] = u;
429         }
430         foreach (i; 0..100) {
431             foreach (j; i..101) {
432                 foreach (x; 0..32) {
433                     assert(
434                         nav.binSearchLeft!((a) => a & x)(i, j) ==
435                         seg.binSearchLeft!((a) => a & x)(i, j));
436                     assert(seg.binSearchLeft!((a) => true)(i, j) == i-1);
437                     assert(
438                         nav.binSearchRight!((a) => a & x)(i, j) ==
439                         seg.binSearchRight!((a) => a & x)(i, j));
440                     assert(seg.binSearchRight!((a) => true)(i, j) == j);
441                 }
442             }
443         }
444     }
445     void g(alias T)() {
446         auto nav = SimpleSeg!(uint,
447             (a, b) => (a | b),
448             0U, NaiveSimple)(100);
449         auto seg = SimpleSeg!(uint,
450             (a, b) => (a | b),
451             0U, T)(100);
452         foreach (i; 0..100) {
453             auto u = uniform!"[]"(0, 31);
454             seg[i] = u;
455             nav[i] = u;
456         }
457         foreach (i; 0..100) {
458             foreach (j; i..101) {
459                 foreach (x; 0..32) {
460                     assert(
461                         nav.binSearchLeft!((a) => a & x)(i, j) ==
462                         seg.binSearchLeft!((a) => a & x)(i, j));
463                     assert(seg.binSearchLeft!((a) => true)(i, j) == i-1);
464                     assert(
465                         nav.binSearchRight!((a) => a & x)(i, j) ==
466                         seg.binSearchRight!((a) => a & x)(i, j));
467                     assert(seg.binSearchRight!((a) => true)(i, j) == j);
468                 }
469             }
470         }
471     }
472     foreach (E; LazyEngines) {
473         f!E();
474     }
475     foreach (E; SimpleEngines) {
476         g!E();
477     }
478 }
479 
480 unittest {
481     //some func test
482     import std.traits : AliasSeq;
483     alias SimpleEngines = AliasSeq!(SimpleSegEngine, SimpleSegNaiveEngine);
484     alias LazyEngines = AliasSeq!(LazySegEngine, LazySegBlockEngine, LazySegNaiveEngine);
485     
486     void checkSimple(alias Seg)() {
487         import std.algorithm : max;
488         
489         alias S = SegTree!(Seg, int, (a, b) => a+b, 0);
490         S seg;
491         seg = S(10);
492         assert(seg.length == 10);
493     }
494     void check(alias Seg)() {
495         import std.algorithm : max;
496 
497         alias S = SegTree!(Seg, int, int,
498             (a, b) => max(a, b), (a, b) => a+b, (a, b) => a+b, 0, 0); 
499         S seg;
500         seg = S([2, 1, 4]);
501         
502         //[2, 1, 4]
503         seg[0] = 2; seg[1] = 1; seg[2] = 4;
504         assert(seg[0..3].sum == 4);
505 
506         //[2, 1, 5]
507         seg[2] = 5;
508         assert(seg[0..2].sum == 2);
509         assert(seg[0..3].sum == 5);
510 
511         //[12, 11, 5]
512         seg[0..2] += 10;
513         assert(seg[0..3].sum == 12);
514 
515         //n=10
516         auto seg2 = SegTree!(Seg, int, int,
517             (a, b) => max(a, b), (a, b) => a+b, (a, b) => a+b, 0, 0)(10);
518         assert(seg2.length == 10);
519     }
520 
521     foreach (E; SimpleEngines) {
522         checkSimple!E();
523     }
524     foreach (E; LazyEngines) {
525         check!E();
526     }
527 }
528 
529 unittest {
530     //stress test
531     import dkh.segtree.naive;
532     import std.traits : AliasSeq;
533     alias SimpleEngines = AliasSeq!(SimpleSegEngine, SimpleSegNaiveEngine);
534     alias LazyEngines = AliasSeq!(LazySegEngine, LazySegBlockEngine, LazySegNaiveEngine);
535 
536     import std.typecons, std.random, std.algorithm;
537     import dkh.modint, dkh.matrix, dkh.numeric.primitive;
538     static immutable uint MD = 10^^9 + 7;
539     alias Mint = ModInt!MD;
540     alias Mat = SMatrix!(Mint, 2, 2);
541 
542     static immutable Mat e = matrix!(2, 2, (i, j) => Mint(i == j ? 1 : 0))();
543 
544     Xorshift128 gen;
545 
546     Mat rndM() {
547         Mat m;
548         while (true) {
549             m = matrix!(2, 2, (i, j) => Mint(uniform(0, MD, gen)))();
550             if (m[0, 0] * m[1, 1] == m[0, 1] * m[1, 0]) continue;
551             break;
552         }
553         return m;
554     }
555 
556     Mat checkSimple(alias Seg)(int N, int M, uint seed) {
557         alias T = Tuple!(Mat, int);
558         gen = Xorshift128(seed);
559         Mat[] a = new Mat[N];
560         a.each!((ref x) => x = rndM());
561         alias Q = Tuple!(int, int, int, Mat);
562         Q[] que = new Q[M];
563         foreach (ref q; que) {
564             q[0] = uniform(0, 2, gen);
565             if (N == 0) q[0] = 0;
566             if (q[0] == 0) {
567                 q[1] = uniform(0, N+1, gen);
568                 q[2] = uniform(0, N+1, gen);
569                 if (q[1] > q[2]) swap(q[1], q[2]);
570             } else {
571                 q[1] = uniform(0, N, gen);
572             }
573             q[3] = rndM();
574         }
575         static auto opTT(Mat a, Mat b) {
576             return a*b;
577         }
578 
579         auto s = SegTree!(Seg, Mat, opTT, e)(a);
580         Mat res;
581         foreach (q; que) {
582             if (q[0] == 0) {
583                 //sum
584                 res += s[q[1]..q[2]].sum();
585             } else if (q[0] == 1) {
586                 //set
587                 s[q[1]] = q[3];
588             }
589         }
590         return res;
591     }    
592     Mat check(alias Seg)(int N, int M, uint seed) {
593         alias T = Tuple!(Mat, int);
594         gen = Xorshift128(seed);
595         T[] a = new T[N];
596         a.each!((ref x) => x = T(rndM(), 1));
597         alias Q = Tuple!(int, int, int, Mat);
598         Q[] que = new Q[M];
599         foreach (ref q; que) {
600             q[0] = uniform(0, 4, gen);
601             if (N == 0) q[0] %= 2;
602             if (q[0] < 2) {
603                 q[1] = uniform(0, N+1, gen);
604                 q[2] = uniform(0, N+1, gen);
605                 if (q[1] > q[2]) swap(q[1], q[2]);
606             } else {
607                 q[1] = uniform(0, N, gen);
608             }
609             q[3] = rndM();
610         }
611         static auto opTT(T a, T b) {
612             return T(a[0]*b[0], a[1]+b[1]);
613         }
614         static auto opTL(T a, Mat b) {
615             if (b == Mat()) return a;
616             return T(pow(b, a[1], e), a[1]);
617         }
618         static auto opLL(Mat a, Mat b) {
619             return b;
620         }
621 
622         auto s = SegTree!(Seg, T, Mat, opTT, opTL, opLL, T(e, 0), Mat())(a);
623         Mat res;
624         foreach (q; que) {
625             if (q[0] == 0) {
626                 //sum
627                 res += s[q[1]..q[2]].sum()[0];
628             } else if (q[0] == 1) {
629                 //set
630                 s[q[1]..q[2]] += q[3];
631             } else if (q[0] == 2) {
632                 //single sum
633                 T w = s[q[1]];
634                 res += w[0];
635             } else if (q[0] == 3) {
636                 //single set
637                 s[q[1]] = T(q[3], 1);
638             }
639         }
640         return res;
641     }
642 
643     import dkh.stopwatch;
644     StopWatch sw; sw.start;
645 
646     int n = 40;
647     Mat[] ansLazy = new Mat[n];
648     foreach (i; 0..n) {
649         ansLazy[i] = check!Naive(i, 500, 114514);
650     }
651     Mat[] ansSimple = new Mat[n];
652     foreach (i; 0..n) {
653         ansSimple[i] = checkSimple!NaiveSimple(i, 500, 114514);
654     }
655     
656     foreach (E; SimpleEngines) {
657         foreach (i; 0..n) {
658             assert(checkSimple!E(i, 500, 114514) == ansSimple[i]);
659         }
660     }
661     foreach (E; LazyEngines) {
662         foreach (i; 0..n) {
663             assert(check!E(i, 500, 114514) == ansLazy[i]);
664         }
665     }
666 
667     import std.stdio;
668     writeln("SegTree Stress: ", sw.peek.toMsecs);
669 }