1 module dkh.graph.directedmst; 2 3 import std.stdio; 4 5 DirectedMSTInfo!(_E, D) directedMST(T, _E = EdgeType!T, D = typeof(_E.dist))(T _g, size_t r) { 6 import std.algorithm, std.range, std.conv, std.typecons; 7 alias E = Tuple!(int, "from", _E, "edge"); 8 9 static struct PairingHeapAllAdd { 10 alias NP = Node*; 11 static struct Node { 12 E e; 13 D offset; 14 NP head, next; 15 this(E e) { 16 this.e = e; 17 offset = D(0); 18 } 19 } 20 NP n; 21 size_t length; 22 this(E[] e) { 23 length = e.length; 24 foreach (d; e) { 25 n = merge(n, new Node(d)); 26 } 27 } 28 static NP merge(NP x, NP y) { 29 if (!x) return y; 30 if (!y) return x; 31 if (x.e.edge.dist+x.offset > y.e.edge.dist+y.offset) swap(x, y); 32 y.offset -= x.offset; 33 y.next = x.head; 34 x.head = y; 35 return x; 36 } 37 void C() { assert(n); } 38 E front() {C; return n.e; } 39 void removeFront() { 40 assert(n); 41 assert(length > 0); 42 length--; 43 NP x; 44 NP s = n.head; 45 while (s) { 46 NP a, b; 47 a = s; s = s.next; a.next = null; a.offset += n.offset; 48 if (s) { 49 b = s; s = s.next; b.next = null; b.offset += n.offset; 50 } 51 a = merge(a, b); 52 assert(a); 53 if (!x) x = a; 54 else { 55 a.next = x.next; 56 x.next = a; 57 } 58 } 59 n = null; 60 while (x) { 61 NP a = x; x = x.next; 62 n = merge(a, n); 63 } 64 } 65 void meld(PairingHeapAllAdd r) { 66 length += r.length; 67 n = merge(n, r.n); 68 } 69 ref D offset() {C; return n.offset; } 70 } 71 72 auto n = _g.length; 73 auto heap = new PairingHeapAllAdd[2*n]; 74 foreach (i; 0..n) { 75 heap[i] = PairingHeapAllAdd(_g[i].map!(e => E(i.to!int, e)).array); 76 } 77 78 //union find 79 int[] tr = new int[2*n]; tr[] = -1; 80 int[] uf = new int[2*n]; uf[] = -1; 81 int root(int i) { 82 if (uf[i] == -1) return i; 83 return uf[i] = root(uf[i]); 84 } 85 86 int[] used = new int[2*n]; 87 E[] res = new E[2*n]; 88 int c = 1, pc = n.to!int; 89 used[r] = 1; 90 void mark(int p) { 91 c++; 92 while (used[p] == 0 || used[p] == c) { 93 if (used[p] == c) { 94 //compress 95 int np = pc++; 96 int q = p; 97 do { 98 heap[q].offset -= res[q].edge.dist + heap[q].offset; 99 heap[np].meld(heap[q]); 100 tr[q] = uf[q] = np; 101 q = root(res[q].edge.to); 102 } while (q != np); 103 p = np; 104 } 105 assert(used[p] == 0); 106 used[p] = c; 107 108 assert(root(p) == p); 109 while (heap[p].length && root(heap[p].front.edge.to) == p) { 110 heap[p].removeFront; 111 } 112 assert(heap[p].length); 113 E mi = heap[p].front; 114 res[p] = mi; 115 p = root(mi.edge.to); 116 } 117 } 118 foreach (i; 0..n) { 119 if (used[i]) continue; 120 mark(i.to!int); 121 } 122 123 auto info = DirectedMSTInfo!(_E, D)(n); 124 bool[] vis = new bool[pc]; 125 foreach_reverse (i; 0..pc) { 126 if (i == r) continue; 127 if (vis[i]) continue; 128 int f = res[i].from.to!int; 129 while (f != -1 && !vis[f]) { 130 vis[f] = true; 131 f = tr[f]; 132 } 133 info.cost += res[i].edge.dist; 134 info.res[res[i].from] = res[i].edge; 135 } 136 return info; 137 } 138 139 140 141 import dkh.algorithm; 142 import dkh.graph.primitive; 143 struct DirectedMSTInfo(E, C) { 144 C cost; 145 E[] res; 146 this(size_t n) { 147 cost = C(0); 148 res = new E[n]; 149 } 150 } 151 152 153 154 DirectedMSTInfo!(E, typeof(E.dist)) directedMSTSlow(T, E = EdgeType!T)(T g, size_t r) { 155 import std.algorithm : filter; 156 auto n = g.length; 157 auto info = DirectedMSTInfo!(E, typeof(E.dist))(n); 158 with (info) { 159 foreach (i; 0..n) { 160 if (i == r) continue; 161 assert(g[i].filter!(e => e.to != i).empty == false); 162 res[i] = g[i].filter!(e => e.to != i).minimum!"a.dist < b.dist"; 163 cost += res[i].dist; 164 } 165 int[] i2g = new int[n]; i2g[] = -1; 166 i2g[r] = 0; 167 168 int gc = 1; 169 for (int i = 0; i < n; i++) { 170 if (i2g[i] != -1) continue; 171 int j = i; 172 do { 173 i2g[j] = gc++; 174 j = res[j].to; 175 } while (i2g[j] == -1); 176 if (i2g[j] < i2g[i]) continue; 177 //roop 178 int k = j; 179 do { 180 i2g[k] = i2g[j]; 181 k = res[k].to; 182 } while(k != j); 183 gc = i2g[j]+1; 184 } 185 if (gc == n) return info; 186 E[][] ng = new E[][](gc); 187 for (int i = 0; i < n; i++) { 188 if (i == r) continue; 189 foreach (e; g[i]) { 190 if (i2g[e.to] == i2g[i]) continue; 191 e.to = i2g[e.to]; 192 e.dist = e.dist - res[i].dist; 193 ng[i2g[i]] ~= e; 194 } 195 } 196 auto nme = directedMSTSlow(ng, 0).res; 197 bool[] ok = new bool[gc]; 198 for (int i = 0; i < n; i++) { 199 if (i == r || ok[i2g[i]]) continue; 200 foreach (e; g[i]) { 201 import std.math; 202 immutable typeof(EdgeType!T.dist) EPS = cast(typeof(EdgeType!T.dist))(1e-9); 203 if (abs(e.dist - res[i].dist - nme[i2g[i]].dist) <= EPS && i2g[e.to] == nme[i2g[i]].to) { 204 ok[i2g[i]] = true; 205 res[i] = e; 206 cost += nme[i2g[i]].dist; 207 break; 208 } 209 } 210 } 211 212 } 213 return info; 214 } 215 216 unittest { 217 import std.typecons; 218 alias E = Tuple!(int, "to", int, "dist"); 219 220 E[][] g = new E[][4]; 221 g[0] ~= E(1, 10); 222 g[2] ~= E(1, 10); 223 g[3] ~= E(1, 3); 224 g[2] ~= E(3, 4); 225 auto info = directedMSTSlow(g, 1); 226 assert(info.cost == 17); 227 } 228 229 unittest { 230 import std.range, std.algorithm, std.typecons, std.random, std.conv; 231 alias E = Tuple!(int, "to", int, "dist"); 232 auto gen = Random(114514); 233 void test() { 234 size_t n = uniform(1, 20, gen); 235 size_t m = uniform(1, 100, gen); 236 E[][] g = new E[][n]; 237 foreach (i; 0..m) { 238 auto a = uniform(0, n, gen); 239 auto b = uniform(0, n, gen); 240 int c = uniform(0, 15, gen); 241 g[a] ~= E(b.to!int, c); 242 g[b] ~= E(a.to!int, c); 243 } 244 size_t r = uniform(0, n, gen); 245 foreach (i; 0..n) { 246 g[i] ~= E(r.to!int, 10^^6); 247 } 248 249 bool check(I)(I info) { 250 import dkh.datastructure.unionfind; 251 auto uf = UnionFind(n.to!int); 252 int sm = 0; 253 foreach (i; 0..n) { 254 if (i == r) continue; 255 sm += info.res[i].dist; 256 if (!g[i].count(info.res[i])) return false; 257 if (uf.same(i, info.res[i].to)) return false; 258 uf.merge(i, info.res[i].to); 259 } 260 if (sm != info.cost) return false; 261 return true; 262 } 263 auto info1 = directedMSTSlow(g, r); 264 auto info2 = directedMST(g, r); 265 266 if (!check(info1)) { 267 writeln("EEEEE"); 268 writeln(r); 269 writeln(g.map!(to!string).join("\n")); 270 writeln(info1); 271 writeln(info2); 272 } 273 assert(check(info1)); 274 if (info1.cost != info2.cost || !check(info2)) { 275 writeln("FIND ERROR!"); 276 writeln(r); 277 writeln(g.map!(to!string).join("\n")); 278 writeln(info1); 279 writeln(info2); 280 } 281 assert(info1.cost == info2.cost); 282 } 283 import dkh.stopwatch; 284 auto ti = benchmark!(test)(1000); 285 writeln("DirectedMST int Random1000: ", ti[0].toMsecs); 286 } 287 288 unittest { 289 import std.range, std.algorithm, std.typecons, std.random, std.conv, std.math; 290 alias E = Tuple!(int, "to", double, "dist"); 291 auto gen = Random(114514); 292 void test() { 293 size_t n = uniform(1, 20, gen); 294 size_t m = uniform(1, 100, gen); 295 E[][] g = new E[][n]; 296 foreach (i; 0..m) { 297 auto a = uniform(0, n, gen); 298 auto b = uniform(0, n, gen); 299 double c = uniform(0.0, 15.0, gen); 300 g[a] ~= E(b.to!int, c); 301 g[b] ~= E(a.to!int, c); 302 } 303 size_t r = uniform(0, n, gen); 304 foreach (i; 0..n) { 305 g[i] ~= E(r.to!int, 10^^6); 306 } 307 308 bool check(I)(I info) { 309 import dkh.datastructure.unionfind; 310 auto uf = UnionFind(n.to!int); 311 double sm = 0; 312 foreach (i; 0..n) { 313 if (i == r) continue; 314 sm += info.res[i].dist; 315 if (!g[i].count(info.res[i])) return false; 316 if (uf.same(i, info.res[i].to)) return false; 317 uf.merge(i, info.res[i].to); 318 } 319 if (abs(sm - info.cost) > 1e-4) return false; 320 return true; 321 } 322 auto info1 = directedMSTSlow(g, r); 323 324 auto info2 = directedMST(g, r); 325 326 if (!check(info1)) { 327 writeln("EEEEE"); 328 writeln(r); 329 writeln(g.map!(to!string).join("\n")); 330 writeln(info1); 331 writeln(info2); 332 } 333 assert(check(info1)); 334 if (abs(info1.cost - info2.cost) > 1e-4 || !check(info2)) { 335 writeln("FIND ERROR!"); 336 writeln(r); 337 writeln(g.map!(to!string).join("\n")); 338 writeln(info1); 339 writeln(info2); 340 } 341 assert(abs(info1.cost - info2.cost) <= 1e-4); 342 } 343 import dkh.stopwatch; 344 auto ti = benchmark!(test)(1000); 345 writeln("DirectedMST double Random1000: ", ti[0].toMsecs); 346 }