1 module dkh.datastructure.unionfind; 2 3 /// UnionFind (Disjoint Set Union) 4 struct UnionFind { 5 private int[] id; /// group id 6 private int[][] groups; /// group list 7 size_t count; /// group count 8 /** 9 Params: 10 n = # of element 11 */ 12 this(size_t n) { 13 import std.algorithm : map; 14 import std.range : iota, array; 15 import std.conv : to; 16 int _n = n.to!int; 17 id = _n.iota.array; 18 groups = _n.iota.map!"[a]".array; 19 count = n; 20 } 21 /// merge a, b 22 void merge(size_t a, size_t b) { 23 import std.algorithm : swap, each; 24 if (same(a, b)) return; 25 count--; 26 uint x = id[a], y = id[b]; 27 if (groups[x].length < groups[y].length) swap(x, y); 28 groups[y].each!(a => id[a] = x); 29 groups[x] ~= groups[y]; 30 groups[y] = []; 31 } 32 /// elements that are same group with i 33 const(int[]) group(size_t i) const { 34 return groups[id[i]]; 35 } 36 /** 37 i がグループのリーダーか返す. 38 各グループにはただ1つのみリーダーが存在する 39 */ 40 bool isLeader(size_t i) const { 41 return i == id[i]; 42 } 43 /// Are a and b same group? 44 bool same(size_t a, size_t b) const { 45 return id[a] == id[b]; 46 } 47 } 48 49 /// 50 unittest { 51 import std.algorithm : equal, sort; 52 auto uf = UnionFind(5); 53 assert(!uf.same(1, 3)); 54 assert(uf.same(0, 0)); 55 56 uf.merge(3, 2); 57 uf.merge(1, 1); 58 uf.merge(4, 2); 59 uf.merge(4, 3); 60 61 assert(uf.count == 3); 62 assert(uf.id[2] == uf.id[3]); 63 assert(uf.id[2] == uf.id[4]); 64 assert(equal(uf.group(0), [0])); 65 assert(equal(uf.group(1), [1])); 66 assert(equal(sort(uf.group(2).dup), [2, 3, 4])); 67 68 auto cnt = 0; 69 foreach (i; 0..5) { 70 if (!uf.isLeader(i)) continue; 71 cnt += uf.group(i).length; 72 } 73 assert(cnt == 5); // view all element exactly once 74 } 75 76 unittest { 77 import std.stdio, std.range; 78 import dkh.stopwatch; 79 // speed check 80 StopWatch sw; sw.start; 81 UnionFind uf; 82 // line 83 uf = UnionFind(100_000); 84 foreach (i; 1..100_000) { 85 uf.merge(i-1, i); 86 } 87 // line(reverse) 88 uf = UnionFind(100_000); 89 foreach_reverse (i; 1..100_000) { 90 uf.merge(i-1, i); 91 } 92 // binary tree 93 uf = UnionFind(100_000); 94 foreach (lg; 1..17) { 95 int len = 1<<lg; 96 foreach (i; iota(0, 100_000-len/2, len)) { 97 uf.merge(i, i+len/2); 98 } 99 } 100 writeln("UnionFind Speed Test: ", sw.peek.toMsecs()); 101 }