1 module dkh.tree; 2 3 import std.traits; 4 import dkh.array; 5 6 struct Tree(T, alias _opTT, T _eT, bool hasLazy = false, 7 L = bool, alias _opTL = "a", alias _opLL = "a", L _eL = false) { 8 import std.functional : binaryFun; 9 alias opTT = binaryFun!_opTT; 10 static immutable T eT = _eT; 11 static if (hasLazy) { 12 alias opTL = binaryFun!_opTL; 13 alias opLL = binaryFun!_opLL; 14 static immutable L eL = _eL; 15 } 16 17 alias NP = Node*; 18 /// Weighted balanced tree 19 static struct Node { 20 NP[2] ch; 21 uint length; 22 T v; 23 static if (hasLazy) L lz = eL; 24 this(in T v) { 25 length = 1; 26 this.v = v; 27 } 28 this(NP l, NP r) { 29 ch = [l, r]; 30 update(); 31 } 32 void update() { 33 static if (hasLazy) assert(lz == eL); 34 length = ch[0].length + ch[1].length; 35 v = opTT(ch[0].v, ch[1].v); 36 } 37 static if (hasLazy) { 38 void lzAdd(in L x) { 39 v = opTL(v, x); 40 lz = opLL(lz, x); 41 } 42 } 43 void push() { 44 static if (hasLazy) { 45 if (lz == eL) return; 46 ch[0].lzAdd(lz); 47 ch[1].lzAdd(lz); 48 lz = eL; 49 } 50 } 51 NP rot(uint type) { 52 // ty = 0: ((a, b), c) -> (a, (b, c)) 53 push(); 54 auto m = ch[type]; 55 m.push(); 56 ch[type] = m.ch[1-type]; 57 m.ch[1-type] = &this; 58 update(); m.update(); 59 return m; 60 } 61 NP bal() { 62 push(); 63 foreach (f; 0..2) { 64 if (ch[f].length*2 > ch[1-f].length*5) { 65 ch[f].push(); 66 if (ch[f].ch[1-f].length*2 > ch[1-f].length*5 || 67 ch[f].ch[f].length*5 < (ch[f].ch[1-f].length+ch[1-f].length)*2) { 68 ch[f] = ch[f].rot(1-f); 69 update(); 70 } 71 return rot(f); 72 } 73 } 74 return &this; 75 } 76 NP insert(uint k, in T v) { 77 assert(0 <= k && k <= length); 78 if (length == 1) { 79 if (k == 0) { 80 return new Node(new Node(v), &this); 81 } else { 82 return new Node(&this, new Node(v)); 83 } 84 } 85 push(); 86 if (k < ch[0].length) { 87 ch[0] = ch[0].insert(k, v); 88 } else { 89 ch[1] = ch[1].insert(k-ch[0].length, v); 90 } 91 update(); 92 return bal(); 93 } 94 NP removeAt(uint k) { 95 assert(0 <= k && k < length); 96 if (length == 1) { 97 return null; 98 } 99 push(); 100 if (k < ch[0].length) { 101 ch[0] = ch[0].removeAt(k); 102 if (ch[0] is null) return ch[1]; 103 } else { 104 ch[1] = ch[1].removeAt(k-ch[0].length); 105 if (ch[1] is null) return ch[0]; 106 } 107 update(); 108 return bal(); 109 } 110 const(T) at(uint k) { 111 assert(0 <= k && k < length); 112 if (length == 1) return v; 113 push(); 114 if (k < ch[0].length) return ch[0].at(k); 115 return ch[1].at(k-ch[0].length); 116 } 117 void atAssign(uint k, in T x) { 118 assert(0 <= k && k < length); 119 if (length == 1) { 120 v = x; 121 return; 122 } 123 push(); 124 if (k < ch[0].length) ch[0].atAssign(k, x); 125 else ch[1].atAssign(k-ch[0].length, x); 126 update(); 127 } 128 const(T) sum(int a, int b) { 129 if (b <= 0 || length.to!int <= a) return eT; 130 if (a <= 0 && length.to!int <= b) return v; 131 push(); 132 return opTT(ch[0].sum(a, b), ch[1].sum(a - ch[0].length, b - ch[0].length)); 133 } 134 static if (hasLazy) { 135 void add(int a, int b, L x) { 136 if (b <= 0 || length.to!int <= a) return; 137 if (a <= 0 && length.to!int <= b) { 138 lzAdd(x); 139 return; 140 } 141 push(); 142 ch[0].add(a, b, x); 143 ch[1].add(a - ch[0].length.to!int, b - ch[0].length.to!int, x); 144 update(); 145 } 146 } 147 void check() { 148 if (length == 1) return; 149 assert(length == ch[0].length + ch[1].length); 150 ch[0].check(); 151 ch[1].check(); 152 assert(ch[0].length*5 >= ch[1].length*2); 153 assert(ch[1].length*5 >= ch[0].length*2); 154 } 155 void pr() { 156 import std.stdio; 157 if (length == 1) { 158 writef("(%d)", v); 159 return; 160 } 161 write("("); 162 ch[0].pr(); 163 write(v); 164 ch[1].pr(); 165 write(")"); 166 } 167 } 168 static NP merge(NP l, NP r, NP buf = null) { 169 if (!l) return r; 170 if (!r) return l; 171 if (l.length*2 > r.length*5) { 172 l.push(); 173 l.ch[1] = merge(l.ch[1], r, buf); 174 l.update(); 175 return l.bal(); 176 } else if (l.length*5 < r.length*2) { 177 r.push(); 178 r.ch[0] = merge(l, r.ch[0], buf); 179 r.update(); 180 return r.bal(); 181 } 182 if (buf == null) buf = new Node(); 183 buf.ch = [l, r]; 184 buf.update(); 185 return buf; 186 } 187 static NP[2] split(NP n, uint k) { 188 if (!n) return [null, null]; 189 if (n.length == 1) { 190 if (k == 0) return [null, n]; 191 else return [n, null]; 192 } 193 NP[2] p; 194 n.push(); 195 if (k < n.ch[0].length) { 196 p = split(n.ch[0], k); 197 p[1] = merge(p[1], n.ch[1], n); 198 } else { 199 p = split(n.ch[1], k - n.ch[0].length); 200 p[0] = merge(n.ch[0], p[0], n); 201 } 202 return p; 203 } 204 import std.conv : to; 205 Node* tr; 206 this(T v) { tr = new Node(v); } 207 this(Node* tr) { this.tr = tr; } 208 this(in T[] v) { 209 if (v.length == 0) return; 210 if (v.length == 1) { 211 tr = new Node(v[0]); 212 return; 213 } 214 auto ltr = Tree(v[0..$/2]); 215 auto rtr = Tree(v[$/2..$]); 216 this = ltr.merge(rtr); 217 } 218 @property size_t length() const { return (!tr ? 0 : tr.length); } 219 alias opDollar = length; 220 221 void insert(size_t k, in T v) { 222 assert(0 <= k && k <= length); 223 if (tr is null) { 224 tr = new Node(v); 225 return; 226 } 227 tr = tr.insert(k.to!int, v); 228 } 229 void removeAt(size_t k) { 230 assert(0 <= k && k < length); 231 tr = tr.removeAt(k.to!int); 232 } 233 Tree trim(size_t a, size_t b) { 234 auto v = split(tr, b.to!uint); 235 auto u = split(v[0], a.to!uint); 236 tr = merge(u[0], v[1]); 237 return Tree(u[1]); 238 } 239 Tree split(size_t k) { 240 auto u = split(tr, k.to!uint); 241 tr = u[0]; 242 return Tree(u[1]); 243 } 244 ref Tree merge(Tree r) { 245 tr = merge(tr, r.tr); 246 return this; 247 } 248 static if (hasLazy) { 249 void opIndexOpAssign(string op : "+")(in L x, size_t[2] rng) { 250 if (!tr) return; 251 tr.add(rng[0].to!uint, rng[1].to!uint, x); 252 } 253 } 254 const(T) opIndex(size_t k) { 255 assert(0 <= k && k < length); 256 return tr.at(k.to!int); 257 } 258 void opIndexAssign(in T x, size_t k) { 259 return tr.atAssign(k.to!int, x); 260 } 261 struct Range { 262 Tree* eng; 263 size_t start, end; 264 @property T sum() { 265 if (!eng.tr) return eT; 266 return eng.tr.sum(start.to!uint, end.to!uint); 267 } 268 } 269 size_t[2] opSlice(size_t dim)(size_t start, size_t end) { 270 assert(0 <= start && start <= end && end <= length()); 271 return [start, end]; 272 } 273 Range opIndex(size_t[2] rng) { 274 return Range(&this, rng[0].to!uint, rng[1].to!uint); 275 } 276 string toString() { 277 //todo: more optimize 278 import std.range : iota; 279 import std.algorithm : map; 280 import std.conv : to; 281 string s; 282 s ~= "Tree("; 283 s ~= iota(length).map!(i => this[i]).to!string; 284 s ~= ")"; 285 return s; 286 } 287 void check() { 288 if (tr) tr.check(); 289 } 290 void pr() { 291 if (tr) tr.pr(); 292 } 293 } 294 295 alias SimpleTree(T, alias op, T e) = Tree!(T, op, e); 296 alias LazyTree(T, L, alias opTT, alias opTL, alias opLL, T eT, L eL) = 297 Tree!(T, opTT, eT, true, L, opTL, opLL, eL); 298 299 import std.traits : isInstanceOf; 300 301 ptrdiff_t binSearchLeft(alias pred, T)(T t, ptrdiff_t _a, ptrdiff_t _b) 302 if(isInstanceOf!(Tree, T)) { 303 import std.conv : to; 304 import std.traits : Unqual; 305 int a = _a.to!int, b = _b.to!int; 306 Unqual!(typeof(T.eT)) x = T.eT; 307 if (pred(x)) return a-1; 308 if (t.tr is null) return 0; 309 310 alias opTT = T.opTT; 311 int pos = a; 312 void f(T.Node* n, int a, int b, int offset) { 313 if (b <= offset || offset + n.length <= a) return; 314 if (a <= offset && offset + n.length <= b && !pred(opTT(x, n.v))) { 315 x = opTT(x, n.v); 316 pos = offset + n.length; 317 return; 318 } 319 if (n.length == 1) return; 320 f(n.ch[0], a, b, offset); 321 if (pos >= offset + n.ch[0].length) { 322 f(n.ch[1], a, b, offset + n.ch[0].length); 323 } 324 } 325 f(t.tr, a, b, 0); 326 return pos; 327 } 328 329 ptrdiff_t binSearchRight(alias pred, T)(T t, ptrdiff_t a, ptrdiff_t b) 330 if(isInstanceOf!(Tree, T)) { 331 import std.conv : to; 332 import std.traits : Unqual; 333 int a = _a.to!int, b = _b.to!int; 334 Unqual!(typeof(T.e)) x = T.e; 335 if (pred(x)) return b; 336 if (t.tr is null) return 0; 337 338 alias op = T.op; 339 int pos = b-1; 340 void f(T.Node* n, int a, int b, int offset) { 341 if (b <= offset || offset + n.length <= a) return; 342 if (a <= offset && offset + n.length <= b && !pred(opTT(n.v, x))) { 343 x = opTT(n.v, x); 344 pos = offset - 1; 345 return; 346 } 347 if (n.length == 1) return; 348 f(n.ch[1], a, b, offset + n.ch[0].length); 349 if (pos < offset + n.ch[0].length) { 350 f(n.ch[0], a, b, offset); 351 } 352 } 353 f(t.tr, a, b, 0); 354 return pos; 355 } 356 357 unittest { 358 import std.traits : AliasSeq; 359 import std.random; 360 import dkh.modint; 361 alias Mint = ModInt!(10^^9 + 7); 362 auto rndM = (){ return Mint(uniform(0, 10^^9 + 7)); }; 363 void check() { 364 alias T = SimpleTree!(Mint, "a+b", Mint(0)); 365 T t; 366 Mint sm = 0; 367 foreach (i; 0..100) { 368 auto x = rndM(); 369 sm += x; 370 t.insert(0, x); 371 } 372 assert(sm == t[0..$].sum); 373 } 374 check(); 375 } 376 377 378 unittest { 379 import std.conv : to; 380 import std.traits : AliasSeq; 381 import std.algorithm : swap; 382 import std.random; 383 import dkh.modint, dkh.foundation; 384 alias Mint = ModInt!(10^^9 + 7); 385 auto rndM = (){ return Mint(uniform(0, 10^^9 + 7)); }; 386 void check() { 387 alias T1 = SimpleTree!(Mint[2], (a, b) => [a[0]+b[0], a[1]+b[1]].fixed, [Mint(0), Mint(0)].fixed); 388 alias T2 = LazyTree!(Mint[2], Mint, 389 (a, b) => [a[0]+b[0], a[1]+b[1]].fixed, 390 (a, b) => [a[0]+a[1]*b, a[1]].fixed, 391 (a, b) => a+b, 392 [Mint(0), Mint(0)].fixed, Mint(0)); 393 T1 t1; 394 T2 t2; 395 foreach (ph; 0..1000) { 396 assert(t1.length == t2.length); 397 int L = t1.length.to!int; 398 int ty = uniform(0, 3); 399 if (ty == 0) { 400 auto x = rndM(); 401 auto idx = uniform(0, L+1); 402 t1.insert(idx, [x, Mint(1)]); 403 t2.insert(idx, [x, Mint(1)]); 404 } else if (ty == 1) { 405 int l = uniform(0, L+1); 406 int r = uniform(0, L+1); 407 if (l > r) swap(l, r); 408 assert(t1[0..$].sum == t2[0..$].sum); 409 } else { 410 int l = uniform(0, L+1); 411 int r = uniform(0, L+1); 412 if (l > r) swap(l, r); 413 auto x = rndM(); 414 foreach (i; l..r) { 415 auto u = t1[i]; 416 t1[i] = [t1[i][0] + x, t1[i][1]]; 417 } 418 t2[l..r] += x; 419 } 420 t1.check(); 421 t2.check(); 422 } 423 } 424 check(); 425 } 426 427 unittest { 428 import std.random; 429 import std.algorithm; 430 import std.conv; 431 import std.container.rbtree; 432 import std.stdio; 433 434 import dkh.stopwatch; 435 StopWatch sw; sw.start; 436 auto nv = redBlackTree!(true, int)([]); 437 alias T = SimpleTree!(int, max, int.min); 438 auto tr = T(); 439 foreach (ph; 0..10000) { 440 int ty = uniform(0, 2); 441 if (ty == 0) { 442 int x = uniform(0, 100); 443 nv.insert(x); 444 auto idx = tr.binSearchLeft!(y => x <= y)(0, tr.length); 445 if (uniform(0, 2)) { 446 auto tr2 = tr.trim(idx, tr.length); 447 tr.merge(T(x)).merge(tr2); 448 } else { 449 tr.insert(idx, x); 450 } 451 } else { 452 if (!nv.length) continue; 453 int i = uniform(0, nv.length.to!int); 454 auto u = nv[]; 455 foreach (_; 0..i) u.popFront(); 456 assert(u.front == tr[i]); 457 int x = tr[i]; 458 nv.removeKey(x); 459 if (uniform(0, 2)) { 460 tr.removeAt(i); 461 } else { 462 tr.trim(i, i+1); 463 } 464 } 465 tr.check(); 466 assert(nv.length == tr.length); 467 } 468 writeln("Set TEST: ", sw.peek.toMsecs); 469 }