1 /** 2 Calculate directed minimam spanning tree 3 */ 4 5 module dkh.graph.directedmst; 6 7 import dkh.algorithm; 8 import dkh.graph.primitive; 9 10 /** 11 Information of directed mst 12 */ 13 struct DirectedMSTInfo(E, C) { 14 C cost; /// mst cost 15 E[] res; /// edge list 16 this(size_t n) { 17 cost = C(0); 18 res = new E[n]; 19 } 20 } 21 22 /// calc directed mst 23 DirectedMSTInfo!(_E, D) directedMST(T, _E = EdgeType!T, D = typeof(_E.dist))(T _g, size_t r) { 24 import std.algorithm, std.range, std.conv, std.typecons; 25 alias E = Tuple!(int, "from", _E, "edge"); 26 27 static struct PairingHeapAllAdd { 28 alias NP = Node*; 29 static struct Node { 30 E e; 31 D offset; 32 NP head, next; 33 this(E e) { 34 this.e = e; 35 offset = D(0); 36 } 37 } 38 NP n; 39 size_t length; 40 this(E[] e) { 41 length = e.length; 42 foreach (d; e) { 43 n = merge(n, new Node(d)); 44 } 45 } 46 static NP merge(NP x, NP y) { 47 if (!x) return y; 48 if (!y) return x; 49 if (x.e.edge.dist+x.offset > y.e.edge.dist+y.offset) swap(x, y); 50 y.offset -= x.offset; 51 y.next = x.head; 52 x.head = y; 53 return x; 54 } 55 void C() { assert(n); } 56 E front() {C; return n.e; } 57 void removeFront() { 58 assert(n); 59 assert(length > 0); 60 length--; 61 NP x; 62 NP s = n.head; 63 while (s) { 64 NP a, b; 65 a = s; s = s.next; a.next = null; a.offset += n.offset; 66 if (s) { 67 b = s; s = s.next; b.next = null; b.offset += n.offset; 68 } 69 a = merge(a, b); 70 assert(a); 71 if (!x) x = a; 72 else { 73 a.next = x.next; 74 x.next = a; 75 } 76 } 77 n = null; 78 while (x) { 79 NP a = x; x = x.next; 80 n = merge(a, n); 81 } 82 } 83 void meld(PairingHeapAllAdd r) { 84 length += r.length; 85 n = merge(n, r.n); 86 } 87 ref D offset() {C; return n.offset; } 88 } 89 90 auto n = _g.length; 91 auto heap = new PairingHeapAllAdd[2*n]; 92 foreach (i; 0..n) { 93 heap[i] = PairingHeapAllAdd(_g[i].map!(e => E(i.to!int, e)).array); 94 } 95 96 //union find 97 int[] tr = new int[2*n]; tr[] = -1; 98 int[] uf = new int[2*n]; uf[] = -1; 99 int root(int i) { 100 if (uf[i] == -1) return i; 101 return uf[i] = root(uf[i]); 102 } 103 104 int[] used = new int[2*n]; 105 E[] res = new E[2*n]; 106 int c = 1, pc = n.to!int; 107 used[r] = 1; 108 void mark(int p) { 109 c++; 110 while (used[p] == 0 || used[p] == c) { 111 if (used[p] == c) { 112 //compress 113 int np = pc++; 114 int q = p; 115 do { 116 heap[q].offset -= res[q].edge.dist + heap[q].offset; 117 heap[np].meld(heap[q]); 118 tr[q] = uf[q] = np; 119 q = root(res[q].edge.to); 120 } while (q != np); 121 p = np; 122 } 123 assert(used[p] == 0); 124 used[p] = c; 125 126 assert(root(p) == p); 127 while (heap[p].length && root(heap[p].front.edge.to) == p) { 128 heap[p].removeFront; 129 } 130 assert(heap[p].length); 131 E mi = heap[p].front; 132 res[p] = mi; 133 p = root(mi.edge.to); 134 } 135 } 136 foreach (i; 0..n) { 137 if (used[i]) continue; 138 mark(i.to!int); 139 } 140 141 auto info = DirectedMSTInfo!(_E, D)(n); 142 bool[] vis = new bool[pc]; 143 foreach_reverse (i; 0..pc) { 144 if (i == r) continue; 145 if (vis[i]) continue; 146 int f = res[i].from.to!int; 147 while (f != -1 && !vis[f]) { 148 vis[f] = true; 149 f = tr[f]; 150 } 151 info.cost += res[i].edge.dist; 152 info.res[res[i].from] = res[i].edge; 153 } 154 return info; 155 } 156 157 /// 158 unittest { 159 import std.typecons; 160 alias E = Tuple!(int, "to", int, "dist"); 161 162 E[][] g = new E[][4]; 163 g[0] ~= E(1, 10); 164 g[2] ~= E(1, 10); 165 g[3] ~= E(1, 3); 166 g[2] ~= E(3, 4); 167 auto info = directedMSTSlow(g, 1); 168 assert(info.cost == 17); 169 } 170 171 DirectedMSTInfo!(E, typeof(E.dist)) directedMSTSlow(T, E = EdgeType!T)(T g, size_t r) { 172 import std.algorithm : filter; 173 auto n = g.length; 174 auto info = DirectedMSTInfo!(E, typeof(E.dist))(n); 175 with (info) { 176 foreach (i; 0..n) { 177 if (i == r) continue; 178 assert(g[i].filter!(e => e.to != i).empty == false); 179 res[i] = g[i].filter!(e => e.to != i).minimum!"a.dist < b.dist"; 180 cost += res[i].dist; 181 } 182 int[] i2g = new int[n]; i2g[] = -1; 183 i2g[r] = 0; 184 185 int gc = 1; 186 for (int i = 0; i < n; i++) { 187 if (i2g[i] != -1) continue; 188 int j = i; 189 do { 190 i2g[j] = gc++; 191 j = res[j].to; 192 } while (i2g[j] == -1); 193 if (i2g[j] < i2g[i]) continue; 194 //roop 195 int k = j; 196 do { 197 i2g[k] = i2g[j]; 198 k = res[k].to; 199 } while(k != j); 200 gc = i2g[j]+1; 201 } 202 if (gc == n) return info; 203 E[][] ng = new E[][](gc); 204 for (int i = 0; i < n; i++) { 205 if (i == r) continue; 206 foreach (e; g[i]) { 207 if (i2g[e.to] == i2g[i]) continue; 208 e.to = i2g[e.to]; 209 e.dist = e.dist - res[i].dist; 210 ng[i2g[i]] ~= e; 211 } 212 } 213 auto nme = directedMSTSlow(ng, 0).res; 214 bool[] ok = new bool[gc]; 215 for (int i = 0; i < n; i++) { 216 if (i == r || ok[i2g[i]]) continue; 217 foreach (e; g[i]) { 218 import std.math; 219 immutable typeof(EdgeType!T.dist) EPS = cast(typeof(EdgeType!T.dist))(1e-9); 220 if (abs(e.dist - res[i].dist - nme[i2g[i]].dist) <= EPS && i2g[e.to] == nme[i2g[i]].to) { 221 ok[i2g[i]] = true; 222 res[i] = e; 223 cost += nme[i2g[i]].dist; 224 break; 225 } 226 } 227 } 228 229 } 230 return info; 231 } 232 233 unittest { 234 import std.range, std.algorithm, std.typecons, std.random, std.conv, std.stdio; 235 alias E = Tuple!(int, "to", int, "dist"); 236 auto gen = Random(114514); 237 void test() { 238 size_t n = uniform(1, 20, gen); 239 size_t m = uniform(1, 100, gen); 240 E[][] g = new E[][n]; 241 foreach (i; 0..m) { 242 auto a = uniform(0, n, gen); 243 auto b = uniform(0, n, gen); 244 int c = uniform(0, 15, gen); 245 g[a] ~= E(b.to!int, c); 246 g[b] ~= E(a.to!int, c); 247 } 248 size_t r = uniform(0, n, gen); 249 foreach (i; 0..n) { 250 g[i] ~= E(r.to!int, 10^^6); 251 } 252 253 bool check(I)(I info) { 254 import dkh.datastructure.unionfind; 255 auto uf = UnionFind(n.to!int); 256 int sm = 0; 257 foreach (i; 0..n) { 258 if (i == r) continue; 259 sm += info.res[i].dist; 260 if (!g[i].count(info.res[i])) return false; 261 if (uf.same(i, info.res[i].to)) return false; 262 uf.merge(i, info.res[i].to); 263 } 264 if (sm != info.cost) return false; 265 return true; 266 } 267 auto info1 = directedMSTSlow(g, r); 268 auto info2 = directedMST(g, r); 269 270 if (!check(info1)) { 271 writeln("EEEEE"); 272 writeln(r); 273 writeln(g.map!(to!string).join("\n")); 274 writeln(info1); 275 writeln(info2); 276 } 277 assert(check(info1)); 278 if (info1.cost != info2.cost || !check(info2)) { 279 writeln("FIND ERROR!"); 280 writeln(r); 281 writeln(g.map!(to!string).join("\n")); 282 writeln(info1); 283 writeln(info2); 284 } 285 assert(info1.cost == info2.cost); 286 } 287 import dkh.stopwatch; 288 auto ti = benchmark!(test)(1000); 289 writeln("DirectedMST int Random1000: ", ti[0].toMsecs); 290 } 291 292 unittest { 293 import std.range, std.algorithm, std.typecons, std.random, std.conv, std.math, std.stdio; 294 alias E = Tuple!(int, "to", double, "dist"); 295 auto gen = Random(114514); 296 void test() { 297 size_t n = uniform(1, 20, gen); 298 size_t m = uniform(1, 100, gen); 299 E[][] g = new E[][n]; 300 foreach (i; 0..m) { 301 auto a = uniform(0, n, gen); 302 auto b = uniform(0, n, gen); 303 double c = uniform(0.0, 15.0, gen); 304 g[a] ~= E(b.to!int, c); 305 g[b] ~= E(a.to!int, c); 306 } 307 size_t r = uniform(0, n, gen); 308 foreach (i; 0..n) { 309 g[i] ~= E(r.to!int, 10^^6); 310 } 311 312 bool check(I)(I info) { 313 import dkh.datastructure.unionfind; 314 auto uf = UnionFind(n.to!int); 315 double sm = 0; 316 foreach (i; 0..n) { 317 if (i == r) continue; 318 sm += info.res[i].dist; 319 if (!g[i].count(info.res[i])) return false; 320 if (uf.same(i, info.res[i].to)) return false; 321 uf.merge(i, info.res[i].to); 322 } 323 if (abs(sm - info.cost) > 1e-4) return false; 324 return true; 325 } 326 auto info1 = directedMSTSlow(g, r); 327 328 auto info2 = directedMST(g, r); 329 330 if (!check(info1)) { 331 writeln("EEEEE"); 332 writeln(r); 333 writeln(g.map!(to!string).join("\n")); 334 writeln(info1); 335 writeln(info2); 336 } 337 assert(check(info1)); 338 if (abs(info1.cost - info2.cost) > 1e-4 || !check(info2)) { 339 writeln("FIND ERROR!"); 340 writeln(r); 341 writeln(g.map!(to!string).join("\n")); 342 writeln(info1); 343 writeln(info2); 344 } 345 assert(abs(info1.cost - info2.cost) <= 1e-4); 346 } 347 import dkh.stopwatch; 348 auto ti = benchmark!(test)(1000); 349 writeln("DirectedMST double Random1000: ", ti[0].toMsecs); 350 }