1 module dkh.datastructure.fastset;
2 
3 /**
4 almost same bool[int], except key range is stricted
5  */
6 struct FastSet {
7     import std.range : back;
8     import core.bitop : bsr, bsf;
9     private size_t n, len;
10     ulong[][] seg;
11     /// make set for [0, 1, ..., n-1]
12     this(size_t n) {
13         if (n == 0) n = 1;
14         this.n = n;
15         while (true) {
16             seg ~= new ulong[(n+63)/64];
17             if (n == 1) break;
18             n = (n+63)/64;
19         }
20     }
21     bool empty() const { return seg.back[0] != 0; } ///
22     size_t length() const { return len; } /// count true
23 
24     bool opBinaryRight(string op : "in")(size_t i) {
25         assert(i < n);
26         //todo: consider bitop.bt
27         size_t D = i/64, R = i%64;
28         return (seg[0][D] & (1UL << R)) != 0;
29     } ///
30     void insert(size_t x) {
31         if (x in this) return;
32         len++;
33         foreach (i; 0..seg.length) {
34             size_t D = x/64, R = x%64;
35             seg[i][D] |= (1UL << R);
36             x /= 64;
37         }
38     } ///
39     void remove(size_t x) {
40         if (x !in this) return;
41         len--;
42         size_t D = x/64, R = x%64;
43         seg[0][D] &= ~(1UL << R);
44         foreach (i; 1..seg.length) {
45             x /= 64;
46             if (seg[i-1][x]) break;
47             D = x/64; R = x%64;
48             seg[i][D] &= ~(1UL << R);
49         }
50     } ///
51     /// return minimum element that isn't less than x
52     ptrdiff_t next(ptrdiff_t x) const {
53         if (x < 0) x = 0;
54         if (n <= x) return n;
55         foreach (i; 0..seg.length) {
56             if (x == seg[i].length * 64) break;
57             size_t D = x/64, R = x%64;
58             ulong B = seg[i][D]>>R;
59             if (!B) {
60                 x = x/64+1;
61                 continue;
62             }
63             //find
64             x += bsf(B);
65             foreach_reverse (j; 0..i) {
66                 x *= 64;
67                 x += bsf(seg[j][x/64]);
68             }
69             return x;
70         }
71         return n;
72     }
73     /// return maximum element that isn't greater than x
74     ptrdiff_t prev(ptrdiff_t x) const {
75         if (n <= x) x = n-1;
76         if (x < 0) return -1;
77         foreach (i; 0..seg.length) {
78             if (x == -1) break;
79             size_t D = x/64, R = x%64;
80             ulong B = seg[i][D]<<(63-R);
81             if (!B) {
82                 x = x/64-1;
83                 continue;
84             }
85             //find
86             x += bsr(B)-63;
87             foreach_reverse (j; 0..i) {
88                 x *= 64;
89                 x += bsr(seg[j][x/64]);
90             }
91             return x;
92         }
93         return -1;
94     }
95     
96     /// return range that contain less than x
97     Range lowerBound(ptrdiff_t x) {
98         return Range(&this, next(0), prev(x-1));
99     }
100     /// return range that contain greater than x
101     Range upperBound(ptrdiff_t x) {
102         return Range(&this, next(x+1), prev(n-1));
103     }
104     /// 
105     Range opIndex() {
106         return Range(&this, next(0), prev(n-1));
107     }
108     /// bidirectional range
109     static struct Range {
110         FastSet* fs;
111         ptrdiff_t lower, upper;
112 
113         @property bool empty() const { return upper < lower; }
114 
115         size_t front() const { return lower; }
116         size_t back() const { return upper; }
117         void popFront() {
118             assert(!empty);
119             lower = fs.next(lower+1);
120         }
121         void popBack() {
122             assert(!empty);
123             upper = fs.prev(upper-1);
124         }
125     }
126 }
127 
128 ///
129 unittest {
130     import std.algorithm : equal, map;
131     import std.range : iota;
132     auto fs = FastSet(10);
133     fs.insert(1);
134     fs.insert(5);
135     fs.insert(6);
136     fs.remove(5);
137     fs.insert(4);
138     // [1, 4, 6]
139     assert(1 in fs);
140     assert(2 !in fs);
141     assert(5 !in fs);
142     assert(6 in fs);
143     assert(equal([1, 4, 6], fs[]));
144     assert(equal(
145         iota(8).map!(i => fs.next(i)),
146         [1, 1, 4, 4, 4, 6, 6, 10]
147     ));
148     assert(equal(
149         iota(8).map!(i => fs.prev(i)),
150         [-1, 1, 1, 1, 4, 4, 6, 6]
151     ));
152     assert(equal([1], fs.lowerBound(4)));
153     assert(equal([1, 4], fs.lowerBound(5)));
154     assert(equal([1, 4, 6], fs.upperBound(0)));
155     assert(equal([4, 6], fs.upperBound(1)));
156 }