1 module dkh.container.sortedtree; 2 3 struct SortedTreePayload(T, alias less, bool allowDuplicates = false) { 4 alias NP = Node*; 5 static struct Node { 6 NP[2] ch; NP par; 7 uint len; 8 immutable T v; 9 this(T v) { 10 len = 1; 11 this.v = v; 12 } 13 uint chLength(uint ty) const { 14 return ch[ty] ? ch[ty].len : 0; 15 } 16 uint chWeight(uint ty) const { 17 return chLength(ty) + 1; 18 } 19 void update() { 20 len = 1 + chLength(0) + chLength(1); 21 } 22 NP rot(uint ty) { 23 NP n = ch[ty]; n.par = par; 24 ch[ty] = n.ch[1-ty]; 25 if (ch[ty]) ch[ty].par = &this; 26 n.ch[1-ty] = &this; par = n; 27 update(); n.update(); 28 return n; 29 } 30 NP balanced() { 31 foreach (f; 0..2) { 32 if (chWeight(f) * 2 <= chWeight(1-f) * 5) continue; 33 if (ch[f].chWeight(1-f) * 2 > chWeight(1-f) * 5 || 34 ch[f].chWeight(f) * 5 < (ch[f].chWeight(1-f) + chWeight(1-f)) * 2) { 35 ch[f] = ch[f].rot(1-f); 36 ch[f].par = &this; 37 update(); 38 } 39 return rot(f); 40 } 41 return &this; 42 } 43 NP insert(in T x) { 44 if (less(x, v)) { 45 if (!ch[0]) ch[0] = new Node(x); 46 else ch[0] = ch[0].insert(x); 47 ch[0].par = &this; 48 } else if (allowDuplicates || less(v, x)) { 49 if (!ch[1]) ch[1] = new Node(x); 50 else ch[1] = ch[1].insert(x); 51 ch[1].par = &this; 52 } else { 53 return &this; 54 } 55 update(); 56 return balanced(); 57 } 58 T at(uint i) const { 59 if (i < chLength(0)) return ch[0].at(i); 60 else if (i == chLength(0)) return v; 61 else return ch[1].at(i - chLength(0) - 1); 62 } 63 NP[2] removeBegin() { //[new child, removed node] 64 if (!ch[0]) return [ch[1], &this]; 65 auto u = ch[0].removeBegin; 66 ch[0] = u[0]; 67 if (ch[0]) ch[0].par = &this; 68 update(); 69 return [balanced(), u[1]]; 70 } 71 NP removeAt(uint i) { 72 if (i < chLength(0)) { 73 ch[0] = ch[0].removeAt(i); 74 if (ch[0]) ch[0].par = &this; 75 } else if (i > chLength(0)) { 76 ch[1] = ch[1].removeAt(i - chLength(0) - 1); 77 if (ch[1]) ch[1].par = &this; 78 } else { 79 if (!ch[1]) return ch[0]; 80 auto u = ch[1].removeBegin; 81 auto n = u[1]; 82 n.ch[0] = ch[0]; 83 n.ch[1] = u[0]; 84 if (n.ch[0]) n.ch[0].par = n; 85 if (n.ch[1]) n.ch[1].par = n; 86 n.update(); 87 return n.balanced(); 88 } 89 update(); 90 return balanced(); 91 } 92 NP removeKey(in T x) { 93 if (less(x, v)) { 94 if (!ch[0]) return &this; 95 ch[0] = ch[0].removeKey(x); 96 if (ch[0]) ch[0].par = &this; 97 } else if (less(v, x)) { 98 if (!ch[1]) return &this; 99 ch[1] = ch[1].removeKey(x); 100 if (ch[1]) ch[1].par = &this; 101 } else { 102 if (!ch[1]) return ch[0]; 103 auto u = ch[1].removeBegin; 104 auto n = u[1]; 105 n.ch[0] = ch[0]; 106 n.ch[1] = u[0]; 107 if (n.ch[0]) n.ch[0].par = n; 108 if (n.ch[1]) n.ch[1].par = n; 109 n.update(); 110 return n.balanced(); 111 } 112 update(); 113 return balanced(); 114 } 115 uint lowerCount(in T x) { 116 if (less(v, x)) { 117 return chLength(0) + 1 + (!ch[1] ? 0 : ch[1].lowerCount(x)); 118 } else { 119 return !ch[0] ? 0 : ch[0].lowerCount(x); 120 } 121 } 122 void validCheck() { 123 assert(len == chLength(0) + chLength(1) + 1); 124 if (ch[0]) { 125 assert(ch[0].par == &this); 126 assert(!less(v, ch[0].v)); 127 } 128 if (ch[1]) { 129 assert(ch[1].par == &this); 130 assert(!less(ch[1].v, v)); 131 } 132 assert(chWeight(0) * 2 <= chWeight(1) * 5); 133 assert(chWeight(1) * 2 <= chWeight(0) * 5); 134 if (ch[0]) ch[0].validCheck(); 135 if (ch[1]) ch[1].validCheck(); 136 } 137 } 138 NP n; 139 @property size_t length() const { return !n ? 0 : n.len; } 140 void insert(in T x) { 141 if (!n) n = new Node(x); 142 else { 143 n = n.insert(x); 144 } 145 n.par = null; 146 } 147 T opIndex(size_t i) const { 148 assert(i < length); 149 return n.at(cast(uint)(i)); 150 } 151 void removeAt(uint i) { 152 assert(i < length); 153 n = n.removeAt(i); 154 if (n) n.par = null; 155 } 156 void removeKey(in T x) { 157 if (n) n = n.removeKey(x); 158 if (n) n.par = null; 159 } 160 size_t lowerCount(in T x) { 161 return !n ? 0 : n.lowerCount(x); 162 } 163 void validCheck() { 164 //for debug 165 if (n) { 166 assert(!n.par); 167 n.validCheck(); 168 } 169 } 170 } 171 172 173 /** 174 std.container.rbtree on weighted-balanced tree 175 */ 176 struct SortedTree(T, alias less, bool allowDuplicates = false) { 177 alias Payload = SortedTreePayload!(T, less, allowDuplicates); 178 Payload* _p; 179 @property size_t empty() const { return !_p || _p.length == 0; } 180 @property size_t length() const { return !_p ? 0 : _p.length; } 181 alias opDollar = length; 182 void insert(in T x) { 183 if (!_p) _p = new Payload(); 184 _p.insert(x); 185 } 186 T opIndex(size_t i) const { 187 assert(i < length); 188 return (*_p)[i]; 189 } 190 void removeAt(uint i) { 191 assert(i < length); 192 _p.removeAt(i); 193 } 194 void removeKey(in T x) { 195 _p.removeKey(x); 196 } 197 size_t lowerCount(in T x) { 198 return !_p ? 0 : _p.lowerCount(x); 199 } 200 void validCheck() { 201 //for debug 202 if (_p) _p.validCheck(); 203 } 204 } 205 206 unittest { 207 import std.random; 208 import std.algorithm; 209 import std.conv; 210 import std.container.rbtree; 211 import std.stdio; 212 import std.range; 213 214 void check(bool allowDup)() { 215 auto nv = redBlackTree!(allowDup, int)([]); 216 auto tr = SortedTreePayload!(int, (a, b) => a<b, allowDup)(); 217 foreach (ph; 0..10000) { 218 int ty = uniform(0, 3); 219 if (ty == 0) { 220 int x = uniform(0, 100); 221 nv.insert(x); 222 tr.insert(x); 223 } else if (ty == 1) { 224 if (!nv.length) continue; 225 int i = uniform(0, nv.length.to!int); 226 auto u = nv[]; 227 foreach (_; 0..i) u.popFront(); 228 assert(u.front == tr[i]); 229 int x = tr[i]; 230 nv.removeKey(x); 231 if (uniform(0, 2) == 0) { 232 tr.removeAt(i); 233 } else { 234 tr.removeKey(x); 235 } 236 } else { 237 int x = uniform(0, 101); 238 assert(nv.lowerBound(x).array.length == tr.lowerCount(x)); 239 } 240 tr.validCheck(); 241 assert(nv.length == tr.length); 242 } 243 } 244 import dkh.stopwatch; 245 StopWatch sw; sw.start; 246 check!true(); 247 check!false(); 248 writeln("Set TEST: ", sw.peek.toMsecs); 249 } 250 251 252 unittest { 253 import std.random; 254 import std.algorithm; 255 import std.conv; 256 import std.container.rbtree; 257 import std.stdio; 258 import std.range; 259 260 void check(bool allowDup)() { 261 auto nv = redBlackTree!(allowDup, int)([]); 262 auto tr = SortedTree!(int, (a, b) => a<b, allowDup)(); 263 foreach (ph; 0..10000) { 264 int ty = uniform(0, 3); 265 if (ty == 0) { 266 int x = uniform(0, 100); 267 nv.insert(x); 268 tr.insert(x); 269 } else if (ty == 1) { 270 if (!nv.length) continue; 271 int i = uniform(0, nv.length.to!int); 272 auto u = nv[]; 273 foreach (_; 0..i) u.popFront(); 274 assert(u.front == tr[i]); 275 int x = tr[i]; 276 nv.removeKey(x); 277 if (uniform(0, 2) == 0) { 278 tr.removeAt(i); 279 } else { 280 tr.removeKey(x); 281 } 282 } else { 283 int x = uniform(0, 101); 284 assert(nv.lowerBound(x).array.length == tr.lowerCount(x)); 285 } 286 tr.validCheck(); 287 assert(nv.length == tr.length); 288 } 289 } 290 import dkh.stopwatch; 291 StopWatch sw; sw.start; 292 check!true(); 293 check!false(); 294 writeln("Set TEST: ", sw.peek.toMsecs); 295 }