1 /**
2 Calculate directed minimam spanning tree
3 */
4 
5 module dkh.graph.directedmst;
6 
7 import dkh.algorithm;
8 import dkh.graph.primitive;
9 
10 /**
11 Information of directed mst
12 */
13 struct DirectedMSTInfo(E, C) {
14     C cost; /// mst cost
15     E[] res; /// edge list
16     this(size_t n) {
17         cost = C(0);
18         res = new E[n];
19     }
20 }
21 
22 /// calc directed mst
23 DirectedMSTInfo!(_E, D) directedMST(T, _E = EdgeType!T, D = typeof(_E.dist))(T _g, size_t r) {
24     import std.algorithm, std.range, std.conv, std.typecons;
25     alias E = Tuple!(int, "from", _E, "edge");
26 
27     static struct PairingHeapAllAdd {
28         alias NP = Node*;
29         static struct Node {        
30             E e;
31             D offset;
32             NP head, next;
33             this(E e) {
34                 this.e = e;
35                 offset = D(0);
36             }
37         }
38         NP n;
39         size_t length;
40         this(E[] e) {
41             length = e.length;            
42             foreach (d; e) {
43                 n = merge(n, new Node(d));
44             }
45         }
46         static NP merge(NP x, NP y) {
47             if (!x) return y;
48             if (!y) return x;
49             if (x.e.edge.dist+x.offset > y.e.edge.dist+y.offset) swap(x, y);
50             y.offset -= x.offset;
51             y.next = x.head;
52             x.head = y;
53             return x;
54         }
55         void C() { assert(n); }
56         E front() {C; return n.e; }
57         void removeFront() {
58             assert(n);
59             assert(length > 0);
60             length--;
61             NP x;
62             NP s = n.head;
63             while (s) {
64                 NP a, b;
65                 a = s; s = s.next; a.next = null; a.offset += n.offset;
66                 if (s) {
67                     b = s; s = s.next; b.next = null; b.offset += n.offset;
68                 }
69                 a = merge(a, b);
70                 assert(a);
71                 if (!x) x = a;
72                 else {
73                     a.next = x.next;
74                     x.next = a;
75                 }
76             }
77             n = null;
78             while (x) {
79                 NP a = x; x = x.next;
80                 n = merge(a, n);
81             }
82         }
83         void meld(PairingHeapAllAdd r) {
84             length += r.length;
85             n = merge(n, r.n);
86         }
87         ref D offset() {C; return n.offset; }
88     }
89     
90     auto n = _g.length;
91     auto heap = new PairingHeapAllAdd[2*n];
92     foreach (i; 0..n) {
93         heap[i] = PairingHeapAllAdd(_g[i].map!(e => E(i.to!int, e)).array);
94     }
95 
96     //union find
97     int[] tr = new int[2*n]; tr[] = -1;
98     int[] uf = new int[2*n]; uf[] = -1;
99     int root(int i) {
100         if (uf[i] == -1) return i;
101         return uf[i] = root(uf[i]);
102     }
103 
104     int[] used = new int[2*n];
105     E[] res = new E[2*n];
106     int c = 1, pc = n.to!int;
107     used[r] = 1;
108     void mark(int p) {
109         c++;
110         while (used[p] == 0 || used[p] == c) {
111             if (used[p] == c) {
112                 //compress
113                 int np = pc++;
114                 int q = p;
115                 do {
116                     heap[q].offset -= res[q].edge.dist + heap[q].offset;
117                     heap[np].meld(heap[q]);
118                     tr[q] = uf[q] = np;
119                     q = root(res[q].edge.to);
120                 } while (q != np);
121                 p = np;
122             }
123             assert(used[p] == 0);
124             used[p] = c;
125 
126             assert(root(p) == p);
127             while (heap[p].length && root(heap[p].front.edge.to) == p) {
128                 heap[p].removeFront;
129             }
130             assert(heap[p].length);
131             E mi = heap[p].front;
132             res[p] = mi;
133             p = root(mi.edge.to);
134         }
135     }
136     foreach (i; 0..n) {
137         if (used[i]) continue;
138         mark(i.to!int);
139     }
140 
141     auto info = DirectedMSTInfo!(_E, D)(n);
142     bool[] vis = new bool[pc];
143     foreach_reverse (i; 0..pc) {
144         if (i == r) continue;
145         if (vis[i]) continue;
146         int f = res[i].from.to!int;
147         while (f != -1 && !vis[f]) {
148             vis[f] = true;
149             f = tr[f];
150         }
151         info.cost += res[i].edge.dist;
152         info.res[res[i].from] = res[i].edge;
153     }
154     return info;
155 }
156 
157 ///
158 unittest {
159     import std.typecons;
160     alias E = Tuple!(int, "to", int, "dist");
161 
162     E[][] g = new E[][4];
163     g[0] ~= E(1, 10);
164     g[2] ~= E(1, 10);
165     g[3] ~= E(1, 3);
166     g[2] ~= E(3, 4);
167     auto info = directedMSTSlow(g, 1);
168     assert(info.cost == 17);
169 }
170 
171 DirectedMSTInfo!(E, typeof(E.dist)) directedMSTSlow(T, E = EdgeType!T)(T g, size_t r) {
172     import std.algorithm : filter;
173     auto n = g.length;
174     auto info = DirectedMSTInfo!(E, typeof(E.dist))(n);
175     with (info) {
176         foreach (i; 0..n) {
177             if (i == r) continue;
178             assert(g[i].filter!(e => e.to != i).empty == false);
179             res[i] = g[i].filter!(e => e.to != i).minimum!"a.dist < b.dist";
180             cost += res[i].dist;
181         }
182         int[] i2g = new int[n]; i2g[] = -1;
183         i2g[r] = 0;
184  
185         int gc = 1;
186         for (int i = 0; i < n; i++) {
187             if (i2g[i] != -1) continue;
188             int j = i;
189             do {
190                 i2g[j] = gc++;
191                 j = res[j].to;
192             } while (i2g[j] == -1);
193             if (i2g[j] < i2g[i]) continue;
194             //roop
195             int k = j;
196             do {
197                 i2g[k] = i2g[j];
198                 k = res[k].to;
199             } while(k != j);
200             gc = i2g[j]+1;
201         }
202         if (gc == n) return info;
203         E[][] ng = new E[][](gc);
204         for (int i = 0; i < n; i++) {
205             if (i == r) continue;
206             foreach (e; g[i]) {
207                 if (i2g[e.to] == i2g[i]) continue;
208                 e.to = i2g[e.to];
209                 e.dist = e.dist - res[i].dist;
210                 ng[i2g[i]] ~= e;
211             }
212         }
213         auto nme = directedMSTSlow(ng, 0).res;
214         bool[] ok = new bool[gc];
215         for (int i = 0; i < n; i++) {
216             if (i == r || ok[i2g[i]]) continue;
217             foreach (e; g[i]) {
218                 import std.math;
219                 immutable typeof(EdgeType!T.dist) EPS = cast(typeof(EdgeType!T.dist))(1e-9);
220                 if (abs(e.dist - res[i].dist - nme[i2g[i]].dist) <= EPS && i2g[e.to] == nme[i2g[i]].to) {
221                     ok[i2g[i]] = true;
222                     res[i] = e;
223                     cost += nme[i2g[i]].dist;
224                     break;
225                 }
226             }
227         }
228  
229     }
230     return info;
231 }
232 
233 unittest {
234     import std.range, std.algorithm, std.typecons, std.random, std.conv, std.stdio;
235     alias E = Tuple!(int, "to", int, "dist");
236     auto gen = Random(114514);
237     void test() {
238         size_t n = uniform(1, 20, gen);
239         size_t m = uniform(1, 100, gen);
240         E[][] g = new E[][n];
241         foreach (i; 0..m) {
242             auto a = uniform(0, n, gen);
243             auto b = uniform(0, n, gen);
244             int c = uniform(0, 15, gen);
245             g[a] ~= E(b.to!int, c);
246             g[b] ~= E(a.to!int, c);
247         }
248         size_t r = uniform(0, n, gen);
249         foreach (i; 0..n) {
250             g[i] ~= E(r.to!int, 10^^6);
251         }
252 
253         bool check(I)(I info) {
254             import dkh.datastructure.unionfind;
255             auto uf = UnionFind(n.to!int);
256             int sm = 0;
257             foreach (i; 0..n) {
258                 if (i == r) continue;
259                 sm += info.res[i].dist;
260                 if (!g[i].count(info.res[i])) return false;
261                 if (uf.same(i, info.res[i].to)) return false;
262                 uf.merge(i, info.res[i].to);
263             }
264             if (sm != info.cost) return false;
265             return true;
266         }
267         auto info1 = directedMSTSlow(g, r);
268         auto info2 = directedMST(g, r);
269 
270         if (!check(info1)) {
271             writeln("EEEEE");
272             writeln(r);
273             writeln(g.map!(to!string).join("\n"));
274             writeln(info1);
275             writeln(info2);
276         }
277         assert(check(info1));
278         if (info1.cost != info2.cost || !check(info2)) {
279             writeln("FIND ERROR!");
280             writeln(r);
281             writeln(g.map!(to!string).join("\n"));
282             writeln(info1);
283             writeln(info2);
284         }
285         assert(info1.cost == info2.cost);
286     }
287     import dkh.stopwatch;
288     auto ti = benchmark!(test)(1000);
289     writeln("DirectedMST int Random1000: ", ti[0].toMsecs);
290 }
291 
292 unittest {
293     import std.range, std.algorithm, std.typecons, std.random, std.conv, std.math, std.stdio;
294     alias E = Tuple!(int, "to", double, "dist");
295     auto gen = Random(114514);
296     void test() {
297         size_t n = uniform(1, 20, gen);
298         size_t m = uniform(1, 100, gen);
299         E[][] g = new E[][n];
300         foreach (i; 0..m) {
301             auto a = uniform(0, n, gen);
302             auto b = uniform(0, n, gen);
303             double c = uniform(0.0, 15.0, gen);
304             g[a] ~= E(b.to!int, c);
305             g[b] ~= E(a.to!int, c);
306         }
307         size_t r = uniform(0, n, gen);
308         foreach (i; 0..n) {
309             g[i] ~= E(r.to!int, 10^^6);
310         }
311 
312         bool check(I)(I info) {
313             import dkh.datastructure.unionfind;
314             auto uf = UnionFind(n.to!int);
315             double sm = 0;
316             foreach (i; 0..n) {
317                 if (i == r) continue;
318                 sm += info.res[i].dist;
319                 if (!g[i].count(info.res[i])) return false;
320                 if (uf.same(i, info.res[i].to)) return false;
321                 uf.merge(i, info.res[i].to);
322             }
323             if (abs(sm - info.cost) > 1e-4) return false;
324             return true;
325         }
326         auto info1 = directedMSTSlow(g, r);
327 
328         auto info2 = directedMST(g, r);
329 
330         if (!check(info1)) {
331             writeln("EEEEE");
332             writeln(r);
333             writeln(g.map!(to!string).join("\n"));
334             writeln(info1);
335             writeln(info2);
336         }
337         assert(check(info1));
338         if (abs(info1.cost - info2.cost) > 1e-4 || !check(info2)) {
339             writeln("FIND ERROR!");
340             writeln(r);
341             writeln(g.map!(to!string).join("\n"));
342             writeln(info1);
343             writeln(info2);
344         }
345         assert(abs(info1.cost - info2.cost) <= 1e-4);
346     }
347     import dkh.stopwatch;
348     auto ti = benchmark!(test)(1000);
349     writeln("DirectedMST double Random1000: ", ti[0].toMsecs);
350 }