1 module dkh.datastructure.unionfind;
2 
3 /// UnionFind (Disjoint Set Union)
4 struct UnionFind {
5     private int[] id; /// group id
6     private int[][] groups; /// group list
7     size_t count; /// group count
8     /**
9     Params:
10         n = # of element
11      */
12     this(size_t n) {
13         import std.algorithm : map;
14         import std.range : iota, array;
15         import std.conv : to;
16         int _n = n.to!int;
17         id = _n.iota.array;
18         groups = _n.iota.map!"[a]".array;
19         count = n;
20     }
21     /// merge a, b
22     void merge(size_t a, size_t b) {
23         import std.algorithm : swap, each;
24         if (same(a, b)) return;
25         count--;
26         uint x = id[a], y = id[b];
27         if (groups[x].length < groups[y].length) swap(x, y);
28         groups[y].each!(a => id[a] = x);
29         groups[x] ~= groups[y];
30         groups[y] = [];
31     }
32     /// elements that are same group with i
33     const(int[]) group(size_t i) const {
34         return groups[id[i]];
35     }
36     /**
37     i がグループのリーダーか返す.
38     各グループにはただ1つのみリーダーが存在する
39      */
40     bool isLeader(size_t i) const {
41         return i == id[i];
42     }
43     /// Are a and b same group?
44     bool same(size_t a, size_t b) const {
45         return id[a] == id[b];
46     }
47 }
48 
49 ///
50 unittest {
51     import std.algorithm : equal, sort;
52     auto uf = UnionFind(5);
53     assert(!uf.same(1, 3));
54     assert(uf.same(0, 0));
55 
56     uf.merge(3, 2);
57     uf.merge(1, 1);
58     uf.merge(4, 2);
59     uf.merge(4, 3);
60 
61     assert(uf.count == 3);
62     assert(uf.id[2] == uf.id[3]);
63     assert(uf.id[2] == uf.id[4]);
64     assert(equal(uf.group(0), [0]));
65     assert(equal(uf.group(1), [1]));
66     assert(equal(sort(uf.group(2).dup), [2, 3, 4]));
67 
68     auto cnt = 0;
69     foreach (i; 0..5) {
70         if (!uf.isLeader(i)) continue;
71         cnt += uf.group(i).length;
72     }
73     assert(cnt == 5); // view all element exactly once
74 }
75 
76 unittest {
77     import std.stdio, std.range;
78     import dkh.stopwatch;
79     // speed check
80     StopWatch sw; sw.start;
81     UnionFind uf;
82     // line
83     uf = UnionFind(100_000);
84     foreach (i; 1..100_000) {
85         uf.merge(i-1, i);
86     }
87     // line(reverse)
88     uf = UnionFind(100_000);
89     foreach_reverse (i; 1..100_000) {
90         uf.merge(i-1, i);
91     }
92     // binary tree
93     uf = UnionFind(100_000);
94     foreach (lg; 1..17) {
95         int len = 1<<lg;
96         foreach (i; iota(0, 100_000-len/2, len)) {
97             uf.merge(i, i+len/2);
98         }
99     }
100     writeln("UnionFind Speed Test: ", sw.peek.toMsecs());
101 }