1 module dkh.datastructure.segtree;
2 
3 import std.conv : to;
4 import std.functional : binaryFun;
5 import std.traits : isInstanceOf;
6 
7 struct SegTree(alias E, Args...) {
8     import std.traits : ReturnType;
9     alias Engine = E!Args;
10     alias T = Engine.DataType;
11     alias L = Engine.LazyType;
12 
13     Engine eng;
14 
15     this(size_t n) { eng = Engine(n.to!uint); }
16     this(T[] first) { eng = Engine(first); }
17 
18     @property size_t length() const { return eng.length(); }
19     @property size_t opDollar() const { return eng.length(); }
20     
21     struct Range {
22         Engine* eng;
23         size_t start, end;
24         @property const(T) sum() {
25             return eng.sum(start.to!uint, end.to!uint);
26         }
27     }
28     const(T) opIndex(size_t k) {
29         assert(0 <= k && k < eng.length());
30         return eng.single(k.to!uint);
31     }
32     void opIndexAssign(T x, size_t k) {
33         assert(0 <= k && k < eng.length());
34         eng.singleSet(k.to!uint, x);
35     }
36     size_t[2] opSlice(size_t dim : 0)(size_t start, size_t end) {
37         assert(0 <= start && start <= end && end <= eng.length());
38         return [start, end];
39     }
40     Range opIndex(size_t[2] rng) {
41         return Range(&eng, rng[0].to!uint, rng[1].to!uint);
42     }
43     static if (!is(L == void)) {
44         void opIndexOpAssign(string op : "+")(L x, size_t[2] rng) {
45             eng.add(rng[0].to!uint, rng[1].to!uint, x);
46         }
47     }
48 }
49 
50 ptrdiff_t binSearchLeft(alias pred, TR)(TR t, ptrdiff_t a, ptrdiff_t b) 
51 if (isInstanceOf!(SegTree, TR)) {
52     return TR.Engine.BinSearch!(false, pred)(t.eng, a.to!int, b.to!int);
53 }
54 
55 ptrdiff_t binSearchRight(alias pred, TR)(TR t, ptrdiff_t a, ptrdiff_t b) 
56 if (isInstanceOf!(SegTree, TR)) {
57     return TR.Engine.BinSearch!(true, pred)(t.eng, a.to!int, b.to!int);
58 }
59 
60 /**
61 SegTree
62 
63 T型の配列aについて、opTT(a[l..r])が高速に計算できる。
64 opTTは結合率を満たす2引数関数, eTは単位元。
65  */
66 alias SimpleSeg(T, alias opTT, T eT, alias Engine = SimpleSegEngine) =
67     SegTree!(Engine, T, binaryFun!opTT, eT);
68 
69 ///
70 unittest {
71     import std.algorithm : max;
72     ///int型でmax(...)が計算できる、つまりRMQ
73     auto seg = SimpleSeg!(int, (a, b) => max(a, b), 0)(3);
74 
75     //[2, 1, 4]
76     seg[0] = 2; seg[1] = 1; seg[2] = 4;
77     assert(seg[0..3].sum == 4); //max(2, 1, 4) == 4
78 
79     //[2, 1, 5]
80     seg[2] = 5;
81     assert(seg[0..2].sum == 2); //max(2, 1) == 2
82     assert(seg[0..3].sum == 5); //max(2, 1, 5) == 5
83 
84     //[2, 11, 5]
85     seg[1] = seg[1] + 10;
86     assert(seg[0..3].sum == 11);
87 }
88 
89 struct SimpleSegEngine(T, alias opTT, T eT) {
90     alias DataType = T;
91     alias LazyType = void;
92     alias BinSearch = binSearchSimple;
93     uint n, sz, lg;
94     T[] d;
95     @property uint length() const {return n;}
96     this(uint n) {
97         import std.algorithm : each;
98         this.n = n;
99         if (n == 0) return;
100         while ((2^^lg) < n) lg++;
101         sz = 2^^lg;
102         d = new T[](2*sz);
103         d.each!((ref x) => x = eT);
104     }
105     this(T[] first) {
106         import std.conv : to;
107         import std.algorithm : each;
108         n = first.length.to!uint;
109         if (n == 0) return;
110         while ((2^^lg) < n) lg++;
111         sz = 2^^lg;
112         d = new T[](2*sz);
113         d.each!((ref x) => x = eT);
114         foreach (i; 0..n) {
115             d[sz+i] = first[i];
116         }
117         foreach_reverse (i; 1..sz) {
118             update(i);
119         }
120     }
121     pragma(inline):
122     void update(uint k) {
123         d[k] = opTT(d[2*k], d[2*k+1]);
124     }
125     T single(uint k) {
126         return d[k+sz];
127     }
128     void singleSet(uint k, T x) {
129         k += sz;
130         d[k] = x;
131         foreach (uint i; 1..lg+1) {
132             update(k>>i);
133         }
134     }
135     T sum(uint a, uint b) {
136         assert(0 <= a && a <= b && b <= n);
137         T sml = eT, smr = eT;
138         a += sz; b += sz;
139         while (a < b) {
140             if (a & 1) sml = opTT(sml, d[a++]);
141             if (b & 1) smr = opTT(d[--b], smr);
142             a >>= 1; b >>= 1;
143         }
144         return opTT(sml, smr);
145     }
146 }
147 
148 int binSearchSimple(bool rev, alias pred, TR)(TR t, int a, int b) {
149     import std.traits : TemplateArgsOf;
150     alias args = TemplateArgsOf!TR;
151     alias opTT = args[1];
152     auto x = args[2];
153     with (t) {
154         static if (!rev) {
155             //left
156             if (pred(x)) return a-1;
157             int pos = a;
158             void f(int a, int b, int l, int r, int k) {
159                 if (b <= l || r <= a) return;
160                 if (a <= l && r <= b && !pred(opTT(x, d[k]))) {
161                     x = opTT(x, d[k]);
162                     pos = r;
163                     return;
164                 }
165                 if (l+1 == r) return;
166                 int md = (l+r)/2;
167                 f(a, b, l, md, 2*k);
168                 if (pos >= md) f(a, b, md, r, 2*k+1);
169             }
170             f(a, b, 0, sz, 1);
171             return pos;            
172         } else {
173             //right
174             if (pred(x)) return b;
175             int pos = b-1;
176             void f(int a, int b, int l, int r, int k) {
177                 if (b <= l || r <= a) return;
178                 if (a <= l && r <= b && !pred(opTT(x, d[k]))) {
179                     x = opTT(d[k], x);
180                     pos = l-1;
181                     return;
182                 }
183                 if (l+1 == r) return;
184                 int md = (l+r)/2;
185                 f(a, b, md, r, 2*k+1);
186                 if (pos < md) f(a, b, l, md, 2*k);
187             }
188             f(a, b, 0, sz, 1);
189             return pos;            
190         }
191     }
192 }
193 
194 /**
195 遅延伝搬SegTree
196 
197 T型の配列aに対して、a[l..r] += x(xはL型)、opTT(a[l..r])が高速に計算できる
198 
199 Params:
200     opTT = (T, T)の演算(結果をまとめる)
201     opTL = (T, L)の演算(クエリを適用する)
202     opLL = (L, L)の演算(クエリをまとめる)
203     eT = Tの単位元
204     eL = Lの単位元
205 */
206 alias LazySeg(T, L, alias opTT, alias opTL, alias opLL, T eT, L eL, alias Engine = LazySegEngine) =
207     SegTree!(Engine, T, L , binaryFun!opTT, binaryFun!opTL, binaryFun!opLL, eT, eL);
208 
209 ///
210 unittest {
211     import std.algorithm : max;
212     ///区間max, 区間加算
213     auto seg = LazySeg!(int, int,
214         (a, b) => max(a, b), (a, b) => a+b, (a, b) => a+b, 0, 0)([2, 1, 4]);
215     
216     //[2, 1, 4]
217     seg[0] = 2; seg[1] = 1; seg[2] = 4;
218     assert(seg[0..3].sum == 4);
219 
220     //[2, 1, 5]
221     seg[2] = 5;
222     assert(seg[0..2].sum == 2);
223     assert(seg[0..3].sum == 5);
224 
225     //[12, 11, 5]
226     seg[0..2] += 10;
227     assert(seg[0..3].sum == 12);
228 }
229 
230 
231 struct LazySegEngine(T, L, alias opTT, alias opTL, alias opLL, T eT, L eL) {
232     import std.typecons : Tuple;
233     alias DataType = T;
234     alias LazyType = L;
235     alias BinSearch = binSearchLazy;
236     alias S = Tuple!(T, "d", L, "lz");
237     uint n, sz, lg;
238     S[] s;
239     this(uint n) {
240         import std.conv : to;
241         import std.algorithm : each;
242         this.n = n;
243         uint lg = 0;
244         while ((2^^lg) < n) lg++;
245         this.lg = lg;
246         sz = 2^^lg;
247         s = new S[](2*sz);
248         s.each!((ref x) => x = S(eT, eL));
249     }
250     this(T[] first) {
251         import std.conv : to;
252         import std.algorithm : each;
253         n = first.length.to!uint;
254         uint lg = 0;
255         while ((2^^lg) < n) lg++;
256         this.lg = lg;
257         sz = 2^^lg;
258 
259         s = new S[](2*sz);
260         s.each!((ref x) => x = S(eT, eL));
261         foreach (i; 0..n) {
262             s[sz+i].d = first[i];
263         }
264         foreach_reverse (i; 1..sz) {
265             update(i);
266         }
267     }
268     @property uint length() const { return n; }
269     pragma(inline):
270     private void lzAdd(uint k, in L x) {
271         s[k].lz = opLL(s[k].lz, x);
272         s[k].d = opTL(s[k].d, x);
273     }
274     public void push(uint k) {
275         if (s[k].lz == eL) return;
276         lzAdd(2*k, s[k].lz);
277         lzAdd(2*k+1, s[k].lz);
278         s[k].lz = eL;
279     }
280     private void update(uint k) {
281         s[k].d = opTT(s[2*k].d, s[2*k+1].d);
282     }
283     T single(uint k) {
284         k += sz;
285         foreach_reverse (uint i; 1..lg+1) {
286             push(k>>i);
287         }
288         return s[k].d;
289     }
290     void singleSet(uint k, T x) {
291         k += sz;
292         foreach_reverse (uint i; 1..lg+1) {
293             push(k>>i);
294         }
295         s[k].d = x;
296         foreach (uint i; 1..lg+1) {
297             update(k>>i);
298         }
299     }
300     T sum(uint a, uint b) {
301         assert(0 <= a && a <= b && b <= n);
302         if (a == b) return eT;
303         a += sz; b--; b += sz;
304         uint tlg = lg;
305         while (true) {
306             uint k = a >> tlg;
307             if (a >> tlg != b >> tlg) {
308                 tlg++;
309                 break;
310             }
311             if (((a-1) >> tlg) + 2 == (b+1) >> tlg) return s[k].d;
312             push(k);
313             tlg--;
314         }
315         T sm = eT;
316         foreach_reverse (l; 0..tlg) {
317             uint k = a >> l;
318             if ((a-1)>>l != a>>l) {
319                 sm = opTT(s[k].d, sm);
320                 break;
321             }
322             push(k);
323             if (!((a >> (l-1)) & 1)) sm = opTT(s[2*k+1].d, sm);
324         }
325         foreach_reverse (l; 0..tlg) {
326             uint k = b >> l;
327             if (b>>l != (b+1)>>l) {
328                 sm = opTT(sm, s[k].d);
329                 break;
330             }
331             push(k);
332             if ((b >> (l-1)) & 1) sm = opTT(sm, s[2*k].d);
333         }
334         return sm;
335     }
336     void add(uint a, uint b, L x) {
337         assert(0 <= a && a <= b && b <= n);
338         if (a == b) return;
339         a += sz; b--; b += sz;
340         uint tlg = lg;
341         while (true) {
342             uint k = a >> tlg;
343             if (a >> tlg != b >> tlg) {
344                 tlg++;
345                 break;
346             }
347             if (((a-1) >> tlg) + 2 == (b+1) >> tlg) {
348                 lzAdd(k, x);
349                 foreach (l; tlg+1..lg+1) {
350                     update(a >> l);
351                 }
352                 return;
353             }
354             push(k);
355             tlg--;
356         }
357         foreach_reverse (l; 0..tlg) {
358             uint k = a >> l;
359             if ((a-1)>>l != a>>l) {
360                 lzAdd(k, x);
361                 foreach (h; l+1..tlg) {
362                     update(a >> h);
363                 }
364                 break;
365             }
366             push(k);
367             if (!((a >> (l-1)) & 1)) lzAdd(2*k+1, x);
368         }
369         foreach_reverse (l; 0..tlg) {
370             uint k = b >> l;
371             if (b>>l != (b+1)>>l) {
372                 lzAdd(k, x);
373                 foreach (h; l+1..tlg) {
374                     update(b >> h);
375                 }
376                 break;
377             }
378             push(k);
379             if ((b >> (l-1)) & 1) lzAdd(2*k, x);
380         }
381         foreach (l; tlg..lg+1) {
382             update(a >> l);
383         }
384     }
385 }
386 
387 unittest {
388     //issue 17466
389     import std.stdio;
390     auto seg = LazySeg!(long[2], long[2],
391         (a, b) => a, (a, b) => a, (a, b) => a, [0L, 0L], [0L, 0L])(10);
392 }
393 
394 int binSearchLazy(bool rev, alias pred, TR)(TR t, int a, int b) {
395     import std.traits : TemplateArgsOf;
396     alias args = TemplateArgsOf!TR;
397     alias opTT = args[2];
398     auto x = args[5];
399     with (t) {
400         static if (!rev) {
401             //left
402             if (pred(x)) return a-1;
403             int pos = a;
404             void f(int a, int b, int l, int r, int k) {
405                 if (b <= l || r <= a) return;
406                 if (a <= l && r <= b && !pred(opTT(x, s[k].d))) {
407                     x = opTT(x, s[k].d);
408                     pos = r;
409                     return;
410                 }
411                 if (l+1 == r) return;
412                 push(k);
413                 int md = (l+r)/2;
414                 f(a, b, l, md, 2*k);
415                 if (pos >= md) f(a, b, md, r, 2*k+1);
416             }
417             f(a, b, 0, sz, 1);
418             return pos;
419         } else {
420             //right
421             if (pred(x)) return b;
422             int pos = b-1;
423             void f(int a, int b, int l, int r, int k) {
424                 if (b <= l || r <= a) return;
425                 if (a <= l && r <= b && !pred(opTT(x, s[k].d))) {
426                     x = opTT(s[k].d, x);
427                     pos = l-1;
428                     return;
429                 }
430                 if (l+1 == r) return;
431                 push(k);
432                 int md = (l+r)/2;
433                 f(a, b, md, r, 2*k+1);
434                 if (pos < md) f(a, b, l, md, 2*k);
435             }
436             f(a, b, 0, sz, 1);
437             return pos;            
438         }
439     }
440 }