1 module dkh.graph.directedmst;
2 
3 import std.stdio;
4 
5 DirectedMSTInfo!(_E, D) directedMST(T, _E = EdgeType!T, D = typeof(_E.dist))(T _g, size_t r) {
6     import std.algorithm, std.range, std.conv, std.typecons;
7     alias E = Tuple!(int, "from", _E, "edge");
8 
9     static struct PairingHeapAllAdd {
10         alias NP = Node*;
11         static struct Node {        
12             E e;
13             D offset;
14             NP head, next;
15             this(E e) {
16                 this.e = e;
17                 offset = D(0);
18             }
19         }
20         NP n;
21         size_t length;
22         this(E[] e) {
23             length = e.length;            
24             foreach (d; e) {
25                 n = merge(n, new Node(d));
26             }
27         }
28         static NP merge(NP x, NP y) {
29             if (!x) return y;
30             if (!y) return x;
31             if (x.e.edge.dist+x.offset > y.e.edge.dist+y.offset) swap(x, y);
32             y.offset -= x.offset;
33             y.next = x.head;
34             x.head = y;
35             return x;
36         }
37         void C() { assert(n); }
38         E front() {C; return n.e; }
39         void removeFront() {
40             assert(n);
41             assert(length > 0);
42             length--;
43             NP x;
44             NP s = n.head;
45             while (s) {
46                 NP a, b;
47                 a = s; s = s.next; a.next = null; a.offset += n.offset;
48                 if (s) {
49                     b = s; s = s.next; b.next = null; b.offset += n.offset;
50                 }
51                 a = merge(a, b);
52                 assert(a);
53                 if (!x) x = a;
54                 else {
55                     a.next = x.next;
56                     x.next = a;
57                 }
58             }
59             n = null;
60             while (x) {
61                 NP a = x; x = x.next;
62                 n = merge(a, n);
63             }
64         }
65         void meld(PairingHeapAllAdd r) {
66             length += r.length;
67             n = merge(n, r.n);
68         }
69         ref D offset() {C; return n.offset; }
70     }
71     
72     auto n = _g.length;
73     auto heap = new PairingHeapAllAdd[2*n];
74     foreach (i; 0..n) {
75         heap[i] = PairingHeapAllAdd(_g[i].map!(e => E(i.to!int, e)).array);
76     }
77 
78     //union find
79     int[] tr = new int[2*n]; tr[] = -1;
80     int[] uf = new int[2*n]; uf[] = -1;
81     int root(int i) {
82         if (uf[i] == -1) return i;
83         return uf[i] = root(uf[i]);
84     }
85 
86     int[] used = new int[2*n];
87     E[] res = new E[2*n];
88     int c = 1, pc = n.to!int;
89     used[r] = 1;
90     void mark(int p) {
91         c++;
92         while (used[p] == 0 || used[p] == c) {
93             if (used[p] == c) {
94                 //compress
95                 int np = pc++;
96                 int q = p;
97                 do {
98                     heap[q].offset -= res[q].edge.dist + heap[q].offset;
99                     heap[np].meld(heap[q]);
100                     tr[q] = uf[q] = np;
101                     q = root(res[q].edge.to);
102                 } while (q != np);
103                 p = np;
104             }
105             assert(used[p] == 0);
106             used[p] = c;
107 
108             assert(root(p) == p);
109             while (heap[p].length && root(heap[p].front.edge.to) == p) {
110                 heap[p].removeFront;
111             }
112             assert(heap[p].length);
113             E mi = heap[p].front;
114             res[p] = mi;
115             p = root(mi.edge.to);
116         }
117     }
118     foreach (i; 0..n) {
119         if (used[i]) continue;
120         mark(i.to!int);
121     }
122 
123     auto info = DirectedMSTInfo!(_E, D)(n);
124     bool[] vis = new bool[pc];
125     foreach_reverse (i; 0..pc) {
126         if (i == r) continue;
127         if (vis[i]) continue;
128         int f = res[i].from.to!int;
129         while (f != -1 && !vis[f]) {
130             vis[f] = true;
131             f = tr[f];
132         }
133         info.cost += res[i].edge.dist;
134         info.res[res[i].from] = res[i].edge;
135     }
136     return info;
137 }
138 
139 
140 
141 import dkh.algorithm;
142 import dkh.graph.primitive;
143 struct DirectedMSTInfo(E, C) {
144     C cost;
145     E[] res;
146     this(size_t n) {
147         cost = C(0);
148         res = new E[n];
149     }
150 }
151 
152 
153 
154 DirectedMSTInfo!(E, typeof(E.dist)) directedMSTSlow(T, E = EdgeType!T)(T g, size_t r) {
155     import std.algorithm : filter;
156     auto n = g.length;
157     auto info = DirectedMSTInfo!(E, typeof(E.dist))(n);
158     with (info) {
159         foreach (i; 0..n) {
160             if (i == r) continue;
161             assert(g[i].filter!(e => e.to != i).empty == false);
162             res[i] = g[i].filter!(e => e.to != i).minimum!"a.dist < b.dist";
163             cost += res[i].dist;
164         }
165         int[] i2g = new int[n]; i2g[] = -1;
166         i2g[r] = 0;
167  
168         int gc = 1;
169         for (int i = 0; i < n; i++) {
170             if (i2g[i] != -1) continue;
171             int j = i;
172             do {
173                 i2g[j] = gc++;
174                 j = res[j].to;
175             } while (i2g[j] == -1);
176             if (i2g[j] < i2g[i]) continue;
177             //roop
178             int k = j;
179             do {
180                 i2g[k] = i2g[j];
181                 k = res[k].to;
182             } while(k != j);
183             gc = i2g[j]+1;
184         }
185         if (gc == n) return info;
186         E[][] ng = new E[][](gc);
187         for (int i = 0; i < n; i++) {
188             if (i == r) continue;
189             foreach (e; g[i]) {
190                 if (i2g[e.to] == i2g[i]) continue;
191                 e.to = i2g[e.to];
192                 e.dist = e.dist - res[i].dist;
193                 ng[i2g[i]] ~= e;
194             }
195         }
196         auto nme = directedMSTSlow(ng, 0).res;
197         bool[] ok = new bool[gc];
198         for (int i = 0; i < n; i++) {
199             if (i == r || ok[i2g[i]]) continue;
200             foreach (e; g[i]) {
201                 import std.math;
202                 immutable typeof(EdgeType!T.dist) EPS = cast(typeof(EdgeType!T.dist))(1e-9);
203                 if (abs(e.dist - res[i].dist - nme[i2g[i]].dist) <= EPS && i2g[e.to] == nme[i2g[i]].to) {
204                     ok[i2g[i]] = true;
205                     res[i] = e;
206                     cost += nme[i2g[i]].dist;
207                     break;
208                 }
209             }
210         }
211  
212     }
213     return info;
214 }
215 
216 unittest {
217     import std.typecons;
218     alias E = Tuple!(int, "to", int, "dist");
219 
220     E[][] g = new E[][4];
221     g[0] ~= E(1, 10);
222     g[2] ~= E(1, 10);
223     g[3] ~= E(1, 3);
224     g[2] ~= E(3, 4);
225     auto info = directedMSTSlow(g, 1);
226     assert(info.cost == 17);
227 }
228 
229 unittest {
230     import std.range, std.algorithm, std.typecons, std.random, std.conv;
231     alias E = Tuple!(int, "to", int, "dist");
232     auto gen = Random(114514);
233     void test() {
234         size_t n = uniform(1, 20, gen);
235         size_t m = uniform(1, 100, gen);
236         E[][] g = new E[][n];
237         foreach (i; 0..m) {
238             auto a = uniform(0, n, gen);
239             auto b = uniform(0, n, gen);
240             int c = uniform(0, 15, gen);
241             g[a] ~= E(b.to!int, c);
242             g[b] ~= E(a.to!int, c);
243         }
244         size_t r = uniform(0, n, gen);
245         foreach (i; 0..n) {
246             g[i] ~= E(r.to!int, 10^^6);
247         }
248 
249         bool check(I)(I info) {
250             import dkh.datastructure.unionfind;
251             auto uf = UnionFind(n.to!int);
252             int sm = 0;
253             foreach (i; 0..n) {
254                 if (i == r) continue;
255                 sm += info.res[i].dist;
256                 if (!g[i].count(info.res[i])) return false;
257                 if (uf.same(i, info.res[i].to)) return false;
258                 uf.merge(i, info.res[i].to);
259             }
260             if (sm != info.cost) return false;
261             return true;
262         }
263         auto info1 = directedMSTSlow(g, r);
264         auto info2 = directedMST(g, r);
265 
266         if (!check(info1)) {
267             writeln("EEEEE");
268             writeln(r);
269             writeln(g.map!(to!string).join("\n"));
270             writeln(info1);
271             writeln(info2);
272         }
273         assert(check(info1));
274         if (info1.cost != info2.cost || !check(info2)) {
275             writeln("FIND ERROR!");
276             writeln(r);
277             writeln(g.map!(to!string).join("\n"));
278             writeln(info1);
279             writeln(info2);
280         }
281         assert(info1.cost == info2.cost);
282     }
283     import dkh.stopwatch;
284     auto ti = benchmark!(test)(1000);
285     writeln("DirectedMST int Random1000: ", ti[0].toMsecs);
286 }
287 
288 unittest {
289     import std.range, std.algorithm, std.typecons, std.random, std.conv, std.math;
290     alias E = Tuple!(int, "to", double, "dist");
291     auto gen = Random(114514);
292     void test() {
293         size_t n = uniform(1, 20, gen);
294         size_t m = uniform(1, 100, gen);
295         E[][] g = new E[][n];
296         foreach (i; 0..m) {
297             auto a = uniform(0, n, gen);
298             auto b = uniform(0, n, gen);
299             double c = uniform(0.0, 15.0, gen);
300             g[a] ~= E(b.to!int, c);
301             g[b] ~= E(a.to!int, c);
302         }
303         size_t r = uniform(0, n, gen);
304         foreach (i; 0..n) {
305             g[i] ~= E(r.to!int, 10^^6);
306         }
307 
308         bool check(I)(I info) {
309             import dkh.datastructure.unionfind;
310             auto uf = UnionFind(n.to!int);
311             double sm = 0;
312             foreach (i; 0..n) {
313                 if (i == r) continue;
314                 sm += info.res[i].dist;
315                 if (!g[i].count(info.res[i])) return false;
316                 if (uf.same(i, info.res[i].to)) return false;
317                 uf.merge(i, info.res[i].to);
318             }
319             if (abs(sm - info.cost) > 1e-4) return false;
320             return true;
321         }
322         auto info1 = directedMSTSlow(g, r);
323 
324         auto info2 = directedMST(g, r);
325 
326         if (!check(info1)) {
327             writeln("EEEEE");
328             writeln(r);
329             writeln(g.map!(to!string).join("\n"));
330             writeln(info1);
331             writeln(info2);
332         }
333         assert(check(info1));
334         if (abs(info1.cost - info2.cost) > 1e-4 || !check(info2)) {
335             writeln("FIND ERROR!");
336             writeln(r);
337             writeln(g.map!(to!string).join("\n"));
338             writeln(info1);
339             writeln(info2);
340         }
341         assert(abs(info1.cost - info2.cost) <= 1e-4);
342     }
343     import dkh.stopwatch;
344     auto ti = benchmark!(test)(1000);
345     writeln("DirectedMST double Random1000: ", ti[0].toMsecs);
346 }