1 module dkh.container.sortedtree;
2 
3 struct SortedTreePayload(T, alias less, bool allowDuplicates = false) {
4     alias NP = Node*;
5     static struct Node {
6         NP[2] ch; NP par;
7         uint len;
8         immutable T v;
9         this(T v) {
10             len = 1;
11             this.v = v;
12         }
13         uint chLength(uint ty) const {
14             return ch[ty] ? ch[ty].len : 0;
15         }
16         uint chWeight(uint ty) const {
17             return chLength(ty) + 1;
18         }
19         void update() {
20             len = 1 + chLength(0) + chLength(1);
21         }
22         NP rot(uint ty) {
23             NP n = ch[ty]; n.par = par;
24             ch[ty] = n.ch[1-ty];
25             if (ch[ty]) ch[ty].par = &this;
26             n.ch[1-ty] = &this; par = n;
27             update(); n.update();
28             return n;
29         }
30         NP balanced() {
31             foreach (f; 0..2) {
32                 if (chWeight(f) * 2 <= chWeight(1-f) * 5) continue;
33                 if (ch[f].chWeight(1-f) * 2 > chWeight(1-f) * 5 ||
34                     ch[f].chWeight(f) * 5 < (ch[f].chWeight(1-f) + chWeight(1-f)) * 2) {
35                     ch[f] = ch[f].rot(1-f);
36                     ch[f].par = &this;
37                     update();
38                 }
39                 return rot(f);
40             }
41             return &this;
42         }
43         NP insert(in T x) {
44             if (less(x, v)) {
45                 if (!ch[0]) ch[0] = new Node(x);
46                 else ch[0] = ch[0].insert(x);
47                 ch[0].par = &this;
48             } else if (allowDuplicates || less(v, x)) {
49                 if (!ch[1]) ch[1] = new Node(x);
50                 else ch[1] = ch[1].insert(x);
51                 ch[1].par = &this;
52             } else {
53                 return &this;
54             }
55             update();
56             return balanced();
57         }
58         T at(uint i) const {
59             if (i < chLength(0)) return ch[0].at(i);
60             else if (i == chLength(0)) return v;
61             else return ch[1].at(i - chLength(0) - 1);
62         }
63         NP[2] removeBegin() { //[new child, removed node]
64             if (!ch[0]) return [ch[1], &this];
65             auto u = ch[0].removeBegin;
66             ch[0] = u[0];
67             if (ch[0]) ch[0].par = &this;
68             update();
69             return [balanced(), u[1]];
70         }
71         NP removeAt(uint i) {
72             if (i < chLength(0)) {
73                 ch[0] = ch[0].removeAt(i);
74                 if (ch[0]) ch[0].par = &this;
75             } else if (i > chLength(0)) {
76                 ch[1] = ch[1].removeAt(i - chLength(0) - 1);
77                 if (ch[1]) ch[1].par = &this;
78             } else {
79                 if (!ch[1]) return ch[0];
80                 auto u = ch[1].removeBegin;
81                 auto n = u[1];
82                 n.ch[0] = ch[0];
83                 n.ch[1] = u[0];
84                 if (n.ch[0]) n.ch[0].par = n;
85                 if (n.ch[1]) n.ch[1].par = n;
86                 n.update();
87                 return n.balanced();
88             }
89             update();
90             return balanced();
91         }
92         NP removeKey(in T x) {
93             if (less(x, v)) {
94                 if (!ch[0]) return &this;
95                 ch[0] = ch[0].removeKey(x);
96                 if (ch[0]) ch[0].par = &this;
97             } else if (less(v, x)) {
98                 if (!ch[1]) return &this;
99                 ch[1] = ch[1].removeKey(x);
100                 if (ch[1]) ch[1].par = &this;
101             } else {
102                 if (!ch[1]) return ch[0];
103                 auto u = ch[1].removeBegin;
104                 auto n = u[1];
105                 n.ch[0] = ch[0];
106                 n.ch[1] = u[0];
107                 if (n.ch[0]) n.ch[0].par = n;
108                 if (n.ch[1]) n.ch[1].par = n;
109                 n.update();
110                 return n.balanced();
111             }
112             update();
113             return balanced();
114         }
115         uint lowerCount(in T x) {
116             if (less(v, x)) {
117                 return chLength(0) + 1 + (!ch[1] ? 0 : ch[1].lowerCount(x));
118             } else {
119                 return !ch[0] ? 0 : ch[0].lowerCount(x);
120             }
121         }
122         void validCheck() {
123             assert(len == chLength(0) + chLength(1) + 1);
124             if (ch[0]) {
125                 assert(ch[0].par == &this);
126                 assert(!less(v, ch[0].v));
127             }
128             if (ch[1]) {
129                 assert(ch[1].par == &this);
130                 assert(!less(ch[1].v, v));
131             }
132             assert(chWeight(0) * 2 <= chWeight(1) * 5);
133             assert(chWeight(1) * 2 <= chWeight(0) * 5);
134             if (ch[0]) ch[0].validCheck();
135             if (ch[1]) ch[1].validCheck();
136         }
137     }
138     NP n;
139     @property size_t length() const { return !n ? 0 : n.len; }
140     void insert(in T x) {
141         if (!n) n = new Node(x);
142         else {
143             n = n.insert(x);        
144         }
145         n.par = null;
146     }
147     T opIndex(size_t i) const {
148         assert(i < length);
149         return n.at(cast(uint)(i));
150     }
151     void removeAt(uint i) {
152         assert(i < length);
153         n = n.removeAt(i);
154         if (n) n.par = null;
155     }
156     void removeKey(in T x) {
157         if (n) n = n.removeKey(x);
158         if (n) n.par = null;
159     }
160     size_t lowerCount(in T x) {
161         return !n ? 0 : n.lowerCount(x);
162     }
163     void validCheck() {
164         //for debug
165         if (n) {
166             assert(!n.par);
167             n.validCheck();
168         }
169     }
170 }
171 
172 
173 /**
174 std.container.rbtree on weighted-balanced tree
175  */
176 struct SortedTree(T, alias less, bool allowDuplicates = false) {
177     alias Payload = SortedTreePayload!(T, less, allowDuplicates);
178     Payload* _p;
179     @property size_t empty() const { return !_p || _p.length == 0; }
180     @property size_t length() const { return !_p ? 0 : _p.length; }
181     alias opDollar = length;
182     void insert(in T x) {
183         if (!_p) _p = new Payload();
184         _p.insert(x);
185     }
186     T opIndex(size_t i) const {
187         assert(i < length);
188         return (*_p)[i];
189     }
190     void removeAt(uint i) {
191         assert(i < length);
192         _p.removeAt(i);
193     }
194     void removeKey(in T x) {
195         _p.removeKey(x);
196     }
197     size_t lowerCount(in T x) {
198         return !_p ? 0 : _p.lowerCount(x);
199     }
200     void validCheck() {
201         //for debug
202         if (_p) _p.validCheck();
203     }    
204 }
205 
206 unittest {
207     import std.random;
208     import std.algorithm;
209     import std.conv;
210     import std.container.rbtree;
211     import std.stdio;
212     import std.range;
213 
214     void check(bool allowDup)() {
215         auto nv = redBlackTree!(allowDup, int)([]);
216         auto tr = SortedTreePayload!(int, (a, b) => a<b, allowDup)();
217         foreach (ph; 0..10000) {
218             int ty = uniform(0, 3);
219             if (ty == 0) {
220                 int x = uniform(0, 100);
221                 nv.insert(x);
222                 tr.insert(x);
223             } else if (ty == 1) {
224                 if (!nv.length) continue;
225                 int i = uniform(0, nv.length.to!int);
226                 auto u = nv[];
227                 foreach (_; 0..i) u.popFront();
228                 assert(u.front == tr[i]);
229                 int x = tr[i];
230                 nv.removeKey(x);
231                 if (uniform(0, 2) == 0) {
232                     tr.removeAt(i);
233                 } else {
234                     tr.removeKey(x);
235                 }
236             } else {
237                 int x = uniform(0, 101);
238                 assert(nv.lowerBound(x).array.length == tr.lowerCount(x));
239             }
240             tr.validCheck();
241             assert(nv.length == tr.length);
242         }
243     }
244     import dkh.stopwatch;
245     StopWatch sw; sw.start;
246     check!true();
247     check!false();
248     writeln("Set TEST: ", sw.peek.toMsecs);
249 }
250 
251 
252 unittest {
253     import std.random;
254     import std.algorithm;
255     import std.conv;
256     import std.container.rbtree;
257     import std.stdio;
258     import std.range;
259 
260     void check(bool allowDup)() {
261         auto nv = redBlackTree!(allowDup, int)([]);
262         auto tr = SortedTree!(int, (a, b) => a<b, allowDup)();
263         foreach (ph; 0..10000) {
264             int ty = uniform(0, 3);
265             if (ty == 0) {
266                 int x = uniform(0, 100);
267                 nv.insert(x);
268                 tr.insert(x);
269             } else if (ty == 1) {
270                 if (!nv.length) continue;
271                 int i = uniform(0, nv.length.to!int);
272                 auto u = nv[];
273                 foreach (_; 0..i) u.popFront();
274                 assert(u.front == tr[i]);
275                 int x = tr[i];
276                 nv.removeKey(x);
277                 if (uniform(0, 2) == 0) {
278                     tr.removeAt(i);
279                 } else {
280                     tr.removeKey(x);
281                 }
282             } else {
283                 int x = uniform(0, 101);
284                 assert(nv.lowerBound(x).array.length == tr.lowerCount(x));
285             }
286             tr.validCheck();
287             assert(nv.length == tr.length);
288         }
289     }
290     import dkh.stopwatch;
291     StopWatch sw; sw.start;
292     check!true();
293     check!false();
294     writeln("Set TEST: ", sw.peek.toMsecs);
295 }