1 /** 2 Unstable 3 */ 4 5 module dkh.tree; 6 7 import std.traits; 8 import dkh.array; 9 10 struct Tree(T, alias _opTT, T _eT, bool hasLazy = false, 11 L = bool, alias _opTL = "a", alias _opLL = "a", L _eL = false) { 12 import std.functional : binaryFun; 13 alias opTT = binaryFun!_opTT; 14 static immutable T eT = _eT; 15 static if (hasLazy) { 16 alias opTL = binaryFun!_opTL; 17 alias opLL = binaryFun!_opLL; 18 static immutable L eL = _eL; 19 } 20 21 alias NP = Node*; 22 /// Weighted balanced tree 23 static struct Node { 24 NP[2] ch; 25 uint length; 26 T v; 27 static if (hasLazy) L lz = eL; 28 this(in T v) { 29 length = 1; 30 this.v = v; 31 } 32 this(NP l, NP r) { 33 ch = [l, r]; 34 update(); 35 } 36 void update() { 37 static if (hasLazy) assert(lz == eL); 38 length = ch[0].length + ch[1].length; 39 v = opTT(ch[0].v, ch[1].v); 40 } 41 static if (hasLazy) { 42 void lzAdd(in L x) { 43 v = opTL(v, x); 44 lz = opLL(lz, x); 45 } 46 } 47 void push() { 48 static if (hasLazy) { 49 if (lz == eL) return; 50 ch[0].lzAdd(lz); 51 ch[1].lzAdd(lz); 52 lz = eL; 53 } 54 } 55 NP rot(uint type) { 56 // ty = 0: ((a, b), c) -> (a, (b, c)) 57 push(); 58 auto m = ch[type]; 59 m.push(); 60 ch[type] = m.ch[1-type]; 61 m.ch[1-type] = &this; 62 update(); m.update(); 63 return m; 64 } 65 NP bal() { 66 push(); 67 foreach (f; 0..2) { 68 if (ch[f].length*2 > ch[1-f].length*5) { 69 ch[f].push(); 70 if (ch[f].ch[1-f].length*2 > ch[1-f].length*5 || 71 ch[f].ch[f].length*5 < (ch[f].ch[1-f].length+ch[1-f].length)*2) { 72 ch[f] = ch[f].rot(1-f); 73 update(); 74 } 75 return rot(f); 76 } 77 } 78 return &this; 79 } 80 NP insert(uint k, in T v) { 81 assert(0 <= k && k <= length); 82 if (length == 1) { 83 if (k == 0) { 84 return new Node(new Node(v), &this); 85 } else { 86 return new Node(&this, new Node(v)); 87 } 88 } 89 push(); 90 if (k < ch[0].length) { 91 ch[0] = ch[0].insert(k, v); 92 } else { 93 ch[1] = ch[1].insert(k-ch[0].length, v); 94 } 95 update(); 96 return bal(); 97 } 98 NP removeAt(uint k) { 99 assert(0 <= k && k < length); 100 if (length == 1) { 101 return null; 102 } 103 push(); 104 if (k < ch[0].length) { 105 ch[0] = ch[0].removeAt(k); 106 if (ch[0] is null) return ch[1]; 107 } else { 108 ch[1] = ch[1].removeAt(k-ch[0].length); 109 if (ch[1] is null) return ch[0]; 110 } 111 update(); 112 return bal(); 113 } 114 const(T) at(uint k) { 115 assert(0 <= k && k < length); 116 if (length == 1) return v; 117 push(); 118 if (k < ch[0].length) return ch[0].at(k); 119 return ch[1].at(k-ch[0].length); 120 } 121 void atAssign(uint k, in T x) { 122 assert(0 <= k && k < length); 123 if (length == 1) { 124 v = x; 125 return; 126 } 127 push(); 128 if (k < ch[0].length) ch[0].atAssign(k, x); 129 else ch[1].atAssign(k-ch[0].length, x); 130 update(); 131 } 132 const(T) sum(int a, int b) { 133 if (b <= 0 || length.to!int <= a) return eT; 134 if (a <= 0 && length.to!int <= b) return v; 135 push(); 136 return opTT(ch[0].sum(a, b), ch[1].sum(a - ch[0].length, b - ch[0].length)); 137 } 138 static if (hasLazy) { 139 void add(int a, int b, L x) { 140 if (b <= 0 || length.to!int <= a) return; 141 if (a <= 0 && length.to!int <= b) { 142 lzAdd(x); 143 return; 144 } 145 push(); 146 ch[0].add(a, b, x); 147 ch[1].add(a - ch[0].length.to!int, b - ch[0].length.to!int, x); 148 update(); 149 } 150 } 151 void check() { 152 if (length == 1) return; 153 assert(length == ch[0].length + ch[1].length); 154 ch[0].check(); 155 ch[1].check(); 156 assert(ch[0].length*5 >= ch[1].length*2); 157 assert(ch[1].length*5 >= ch[0].length*2); 158 } 159 void pr() { 160 import std.stdio; 161 if (length == 1) { 162 writef("(%d)", v); 163 return; 164 } 165 write("("); 166 ch[0].pr(); 167 write(v); 168 ch[1].pr(); 169 write(")"); 170 } 171 } 172 static NP merge(NP l, NP r, NP buf = null) { 173 if (!l) return r; 174 if (!r) return l; 175 if (l.length*2 > r.length*5) { 176 l.push(); 177 l.ch[1] = merge(l.ch[1], r, buf); 178 l.update(); 179 return l.bal(); 180 } else if (l.length*5 < r.length*2) { 181 r.push(); 182 r.ch[0] = merge(l, r.ch[0], buf); 183 r.update(); 184 return r.bal(); 185 } 186 if (buf == null) buf = new Node(); 187 buf.ch = [l, r]; 188 buf.update(); 189 return buf; 190 } 191 static NP[2] split(NP n, uint k) { 192 if (!n) return [null, null]; 193 if (n.length == 1) { 194 if (k == 0) return [null, n]; 195 else return [n, null]; 196 } 197 NP[2] p; 198 n.push(); 199 if (k < n.ch[0].length) { 200 p = split(n.ch[0], k); 201 p[1] = merge(p[1], n.ch[1], n); 202 } else { 203 p = split(n.ch[1], k - n.ch[0].length); 204 p[0] = merge(n.ch[0], p[0], n); 205 } 206 return p; 207 } 208 import std.conv : to; 209 Node* tr; 210 this(T v) { tr = new Node(v); } 211 this(Node* tr) { this.tr = tr; } 212 this(in T[] v) { 213 if (v.length == 0) return; 214 if (v.length == 1) { 215 tr = new Node(v[0]); 216 return; 217 } 218 auto ltr = Tree(v[0..$/2]); 219 auto rtr = Tree(v[$/2..$]); 220 this = ltr.merge(rtr); 221 } 222 @property size_t length() const { return (!tr ? 0 : tr.length); } 223 alias opDollar = length; 224 225 void insert(size_t k, in T v) { 226 assert(0 <= k && k <= length); 227 if (tr is null) { 228 tr = new Node(v); 229 return; 230 } 231 tr = tr.insert(k.to!int, v); 232 } 233 void removeAt(size_t k) { 234 assert(0 <= k && k < length); 235 tr = tr.removeAt(k.to!int); 236 } 237 Tree trim(size_t a, size_t b) { 238 auto v = split(tr, b.to!uint); 239 auto u = split(v[0], a.to!uint); 240 tr = merge(u[0], v[1]); 241 return Tree(u[1]); 242 } 243 Tree split(size_t k) { 244 auto u = split(tr, k.to!uint); 245 tr = u[0]; 246 return Tree(u[1]); 247 } 248 ref Tree merge(Tree r) { 249 tr = merge(tr, r.tr); 250 return this; 251 } 252 static if (hasLazy) { 253 void opIndexOpAssign(string op : "+")(in L x, size_t[2] rng) { 254 if (!tr) return; 255 tr.add(rng[0].to!uint, rng[1].to!uint, x); 256 } 257 } 258 const(T) opIndex(size_t k) { 259 assert(0 <= k && k < length); 260 return tr.at(k.to!int); 261 } 262 void opIndexAssign(in T x, size_t k) { 263 return tr.atAssign(k.to!int, x); 264 } 265 struct Range { 266 Tree* eng; 267 size_t start, end; 268 @property T sum() { 269 if (!eng.tr) return eT; 270 return eng.tr.sum(start.to!uint, end.to!uint); 271 } 272 } 273 size_t[2] opSlice(size_t dim)(size_t start, size_t end) { 274 assert(0 <= start && start <= end && end <= length()); 275 return [start, end]; 276 } 277 Range opIndex(size_t[2] rng) { 278 return Range(&this, rng[0].to!uint, rng[1].to!uint); 279 } 280 string toString() { 281 //todo: more optimize 282 import std.range : iota; 283 import std.algorithm : map; 284 import std.conv : to; 285 string s; 286 s ~= "Tree("; 287 s ~= iota(length).map!(i => this[i]).to!string; 288 s ~= ")"; 289 return s; 290 } 291 void check() { 292 if (tr) tr.check(); 293 } 294 void pr() { 295 if (tr) tr.pr(); 296 } 297 } 298 299 alias SimpleTree(T, alias op, T e) = Tree!(T, op, e); 300 alias LazyTree(T, L, alias opTT, alias opTL, alias opLL, T eT, L eL) = 301 Tree!(T, opTT, eT, true, L, opTL, opLL, eL); 302 303 import std.traits : isInstanceOf; 304 305 ptrdiff_t binSearchLeft(alias pred, T)(T t, ptrdiff_t _a, ptrdiff_t _b) 306 if(isInstanceOf!(Tree, T)) { 307 import std.conv : to; 308 import std.traits : Unqual; 309 int a = _a.to!int, b = _b.to!int; 310 Unqual!(typeof(T.eT)) x = T.eT; 311 if (pred(x)) return a-1; 312 if (t.tr is null) return 0; 313 314 alias opTT = T.opTT; 315 int pos = a; 316 void f(T.Node* n, int a, int b, int offset) { 317 if (b <= offset || offset + n.length <= a) return; 318 if (a <= offset && offset + n.length <= b && !pred(opTT(x, n.v))) { 319 x = opTT(x, n.v); 320 pos = offset + n.length; 321 return; 322 } 323 if (n.length == 1) return; 324 f(n.ch[0], a, b, offset); 325 if (pos >= offset + n.ch[0].length) { 326 f(n.ch[1], a, b, offset + n.ch[0].length); 327 } 328 } 329 f(t.tr, a, b, 0); 330 return pos; 331 } 332 333 ptrdiff_t binSearchRight(alias pred, T)(T t, ptrdiff_t a, ptrdiff_t b) 334 if(isInstanceOf!(Tree, T)) { 335 import std.conv : to; 336 import std.traits : Unqual; 337 int a = _a.to!int, b = _b.to!int; 338 Unqual!(typeof(T.e)) x = T.e; 339 if (pred(x)) return b; 340 if (t.tr is null) return 0; 341 342 alias op = T.op; 343 int pos = b-1; 344 void f(T.Node* n, int a, int b, int offset) { 345 if (b <= offset || offset + n.length <= a) return; 346 if (a <= offset && offset + n.length <= b && !pred(opTT(n.v, x))) { 347 x = opTT(n.v, x); 348 pos = offset - 1; 349 return; 350 } 351 if (n.length == 1) return; 352 f(n.ch[1], a, b, offset + n.ch[0].length); 353 if (pos < offset + n.ch[0].length) { 354 f(n.ch[0], a, b, offset); 355 } 356 } 357 f(t.tr, a, b, 0); 358 return pos; 359 } 360 361 unittest { 362 import std.meta : AliasSeq; 363 import std.random; 364 import dkh.modint; 365 alias Mint = ModInt!(10^^9 + 7); 366 auto rndM = (){ return Mint(uniform(0, 10^^9 + 7)); }; 367 void check() { 368 alias T = SimpleTree!(Mint, "a+b", Mint(0)); 369 T t; 370 Mint sm = 0; 371 foreach (i; 0..100) { 372 auto x = rndM(); 373 sm += x; 374 t.insert(0, x); 375 } 376 assert(sm == t[0..$].sum); 377 } 378 check(); 379 } 380 381 382 unittest { 383 import std.conv : to; 384 import std.meta : AliasSeq; 385 import std.algorithm : swap; 386 import std.random; 387 import dkh.modint, dkh.foundation; 388 alias Mint = ModInt!(10^^9 + 7); 389 auto rndM = (){ return Mint(uniform(0, 10^^9 + 7)); }; 390 void check() { 391 alias T1 = SimpleTree!(Mint[2], (a, b) => [a[0]+b[0], a[1]+b[1]].fixed, [Mint(0), Mint(0)].fixed); 392 alias T2 = LazyTree!(Mint[2], Mint, 393 (a, b) => [a[0]+b[0], a[1]+b[1]].fixed, 394 (a, b) => [a[0]+a[1]*b, a[1]].fixed, 395 (a, b) => a+b, 396 [Mint(0), Mint(0)].fixed, Mint(0)); 397 T1 t1; 398 T2 t2; 399 foreach (ph; 0..1000) { 400 assert(t1.length == t2.length); 401 int L = t1.length.to!int; 402 int ty = uniform(0, 3); 403 if (ty == 0) { 404 auto x = rndM(); 405 auto idx = uniform(0, L+1); 406 t1.insert(idx, [x, Mint(1)]); 407 t2.insert(idx, [x, Mint(1)]); 408 } else if (ty == 1) { 409 int l = uniform(0, L+1); 410 int r = uniform(0, L+1); 411 if (l > r) swap(l, r); 412 assert(t1[0..$].sum == t2[0..$].sum); 413 } else { 414 int l = uniform(0, L+1); 415 int r = uniform(0, L+1); 416 if (l > r) swap(l, r); 417 auto x = rndM(); 418 foreach (i; l..r) { 419 auto u = t1[i]; 420 t1[i] = [t1[i][0] + x, t1[i][1]]; 421 } 422 t2[l..r] += x; 423 } 424 t1.check(); 425 t2.check(); 426 } 427 } 428 check(); 429 } 430 431 unittest { 432 import std.random; 433 import std.algorithm; 434 import std.conv; 435 import std.container.rbtree; 436 import std.stdio; 437 438 import dkh.stopwatch; 439 StopWatch sw; sw.start; 440 auto nv = redBlackTree!(true, int)([]); 441 alias T = SimpleTree!(int, max, int.min); 442 auto tr = T(); 443 foreach (ph; 0..10000) { 444 int ty = uniform(0, 2); 445 if (ty == 0) { 446 int x = uniform(0, 100); 447 nv.insert(x); 448 auto idx = tr.binSearchLeft!(y => x <= y)(0, tr.length); 449 if (uniform(0, 2)) { 450 auto tr2 = tr.trim(idx, tr.length); 451 tr.merge(T(x)).merge(tr2); 452 } else { 453 tr.insert(idx, x); 454 } 455 } else { 456 if (!nv.length) continue; 457 int i = uniform(0, nv.length.to!int); 458 auto u = nv[]; 459 foreach (_; 0..i) u.popFront(); 460 assert(u.front == tr[i]); 461 int x = tr[i]; 462 nv.removeKey(x); 463 if (uniform(0, 2)) { 464 tr.removeAt(i); 465 } else { 466 tr.trim(i, i+1); 467 } 468 } 469 tr.check(); 470 assert(nv.length == tr.length); 471 } 472 writeln("Set TEST: ", sw.peek.toMsecs); 473 }