1 module dkh.matrix;
2 
3 import dkh.bitop;
4 
5 /// 行列ライブラリ
6 struct SMatrix(T, size_t H, size_t W) {
7     alias DataType = T;
8     T[W][H] data;
9     this(Args...)(Args args) {
10         static assert(args.length == H*W);
11         foreach (i, v; args) {
12             data[i/W][i%W] = v;
13         }
14     }
15     SMatrix dup() const { return this; }
16 
17     @property static size_t height() {return H;}
18     @property static size_t width() {return W;}
19 
20     ref inout(T) opIndex(size_t i1, size_t i2) inout {
21         return data[i1][i2];
22     }
23     auto opBinary(string op:"+", R)(in R r) const
24     if(height == R.height && width == R.width) {
25         SMatrix res = this;
26         foreach (y; 0..height) foreach (x; 0..W) res[y, x] += r[y, x];
27         return res;
28     }
29     auto opBinary(string op:"*", R)(in R r) const
30     if(width == R.height) {
31         auto rBuf = SMatrix!(T, R.width, R.height)();
32         foreach (y; 0..R.height) {
33             foreach (x; 0..R.width) {
34                 rBuf[x, y] = r[y, x];
35             }
36         }
37         auto res = SMatrix!(T, height, R.width)();
38         foreach (y; 0..height) {
39             foreach (x; 0..R.width) {
40                 T sm;
41                 auto lv = this.data[y][];
42                 auto rv = rBuf.data[x][];
43                 foreach (k; 0..width) {
44 //                    res[y, x] += this[y, k]*r[k, x];
45 //                    sm += this[y, k]*rBuf[x, k];
46                     sm += lv[k] * rv[k];
47 //                    res[y, x] += this[y, k]*rBuf[x, k];
48                 }
49                 res[y, x] = sm;
50             }
51         }
52         return res;
53     }
54     auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");}
55 
56     void swapLine(size_t x, size_t y) {
57         import std.algorithm : swap;
58         foreach (i; 0..W) swap(data[x][i], data[y][i]);
59     }
60 }
61 
62 import dkh.foundation, dkh.modint;
63 /// 行列ライブラリ(Mod2)
64 struct SMatrixMod2(size_t H, size_t W) {
65     alias DataType = ModInt!2;
66     static immutable B = 64;
67     static immutable L = (W + B-1) / B;
68     ulong[L][H] data;
69     this(Args...)(Args args) {
70         static assert(args.length == H*W);
71         foreach (i, v; args) {
72             this[i/W, i%W] = v;
73         }
74     }
75     SMatrixMod2 dup() const { return this; }
76 
77     @property static size_t height() {return H;}
78     @property static size_t width() {return W;}
79 
80     const(DataType) opIndex(size_t i1, size_t i2) const {
81         return DataType(((data[i1][i2/B] >> (i2%B)) & 1UL) ? 1 : 0);
82     }
83     void opIndexAssign(DataType d, size_t i1, size_t i2) {
84         size_t r = i2 % 64;
85         if (d.v) data[i1][i2/B] |= (1UL<<r);
86         else data[i1][i2/B] &= ~(1UL<<r);
87     }
88     void opIndexAssign(bool d, size_t i1, size_t i2) {
89         size_t r = i2 % 64;
90         if (d) data[i1][i2/B] |= (1UL<<r);
91         else data[i1][i2/B] &= ~(1UL<<r);
92     }
93     auto opIndexOpAssign(string op)(DataType d, size_t i1, size_t i2) {
94         return mixin("this[i1,i2]=this[i1,i2]"~op~"d");
95     }
96     auto opBinary(string op:"+", R)(in R r) const
97     if(height == R.height && width == R.width) {
98         auto res = this.dup;
99         foreach (y; 0..height) foreach (x; 0..L) {
100             res.data[y][x] ^= r.data[y][x];
101         }
102         return res;
103     }
104     auto opBinary(string op:"*", R)(in R r) const
105     if(width == R.height) {
106         auto r2 = SMatrixMod2!(R.width, R.height)();
107         foreach (y; 0..R.height) foreach (x; 0..R.width) {
108             r2[x, y] = r[y, x];
109         }
110         auto res = SMatrixMod2!(height, R.width)();
111         foreach (y; 0..height) {
112             foreach (x; 0..R.width) {
113                 ulong sm = 0;
114                 foreach (k; 0..L) {
115                     sm ^= data[y][k]&r2.data[x][k];
116                 }
117                 res[y, x] = poppar(sm);
118             }
119         }
120         return res;
121     }
122     auto opOpAssign(string op, T)(T r) {return mixin ("this=this"~op~"r");}
123 
124     void swapLine(size_t x, size_t y) {
125         import std.algorithm : swap;
126         foreach (i; 0..L) swap(data[x][i], data[y][i]);
127     }
128 }
129 
130 /// ditto
131 struct DMatrix(T) {
132     size_t h, w;
133     T[] data;
134     this(size_t h, size_t w) {
135         this.h = h; this.w = w;
136         data = new T[h*w];
137     }
138     this(size_t h, size_t w, in T[] d) {
139         this(h, w);
140         assert(d.length == h*w);
141         data[] = d[];
142     }
143     DMatrix dup() const { return DMatrix(h, w, data); }
144 
145     @property size_t height() const {return h;}
146     @property size_t width() const {return w;}
147 
148     ref inout(T) opIndex(size_t i1, size_t i2) inout {
149         return data[i1*width+i2];
150     }
151     auto opBinary(string op:"+", R)(in R r) const {
152         assert(height == r.height && width == r.width);
153         auto res = this.dup;
154         foreach (y; 0..height) foreach (x; 0..width) res[y, x] += r[y, x];
155         return res;
156     }
157     auto opBinary(string op:"*", R)(in R r) const {
158         assert(width == r.height);
159         auto rBuf = DMatrix!(T)(R.width, R.height);
160         foreach (y; 0..R.height) {
161             foreach (x; 0..R.width) {
162                 rBuf[x, y] = r[y, x];
163             }
164         }
165         auto res = DMatrix!(T)(height, r.width);
166         foreach (y; 0..height) {
167             foreach (x; 0..r.width) {
168                 foreach (k; 0..width) {
169                     res[y, x] += this[y, k]*rBuf[x, k];
170                 }
171             }
172         }
173         return res;
174     }
175     auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");}    
176 }
177 
178 ///
179 unittest {
180     import dkh.numeric.primitive;
181     auto mat = DMatrix!int(2, 2, [0, 1, 1, 1]);
182     assert(pow(mat, 10, DMatrix!int(2, 2, [1, 0, 0, 1]))[0, 0] == 34); //Fib_10
183 }
184 
185 unittest {
186     auto mat1 = DMatrix!int(2, 2, [1, 1, 1, 1]);
187     auto mat2 = DMatrix!int(2, 2, [2, 2, 2, 2]);
188     auto mat3 = mat1 + mat2;
189     assert(mat1[0, 0] == 1);
190     assert(mat2[0, 0] == 2);
191 }
192 
193 auto matrix(size_t H, size_t W, alias pred)() {
194     import std.traits : ReturnType;
195     SMatrix!(typeof(pred(0, 0)), H, W) res;
196     foreach (y; 0..H) {
197         foreach (x; 0..W) {
198             res[y, x] = pred(y, x);
199         }
200     }
201     return res;
202 }
203 auto matrixMod2(size_t H, size_t W, alias pred)() {
204     import std.traits : ReturnType;
205     SMatrixMod2!(H, W) res;
206     foreach (y; 0..H) {
207         foreach (x; 0..W) {
208             res[y, x] = pred(y, x);
209         }
210     }
211     return res;
212 }
213 
214 auto determinent(Mat)(in Mat _m) {
215     auto m = _m.dup;
216     assert(m.height == m.width);
217     import std.conv, std.algorithm;
218     alias M = Mat.DataType;
219     size_t N = m.height;
220     M base = 1;
221     foreach (i; 0..N) {
222         if (m[i, i] == M(0)) {
223             foreach (j; i+1..N) {
224                 if (m[j, i] != M(0)) {
225                     foreach (k; 0..N) m.swapLine(i, j);
226                     base *= M(-1);
227                     break;
228                 }
229             }
230             if (m[i, i] == M(0)) return M(0);
231         }
232         base *= m[i, i];
233         M im = M(1)/m[i, i];
234         foreach (j; 0..N) {
235             m[i, j] *= im;
236         }
237         foreach (j; i+1..N) {
238             M x = m[j, i];
239             foreach (k; 0..N) {
240                 m[j, k] -= m[i, k] * x;
241             }
242         }
243     }
244     return base;
245 }
246 
247 unittest {
248     import std.random, std.stdio, std.algorithm;
249     import dkh.modint;
250     void f(uint Mod)() {
251         alias Mint = ModInt!Mod;
252         alias Mat = SMatrix!(Mint, 3, 3);
253         alias Vec = SMatrix!(Mint, 3, 1);
254         static Mint rndM() {
255             return Mint(uniform(0, Mod));
256         }
257         Mat m = matrix!(3, 3, (i, j) => rndM())();
258         Mint sm = 0;
259         auto idx = [0, 1, 2];
260         do {
261             Mint buf = 1;
262             foreach (i; 0..3) {
263                 buf *= m[i, idx[i]];
264             }
265             sm += buf;
266         } while (idx.nextEvenPermutation);
267         idx = [0, 2, 1];
268         do {
269             Mint buf = 1;
270             foreach (i; 0..3) {
271                 buf *= m[i, idx[i]];
272             }
273             sm -= buf;
274         } while (idx.nextEvenPermutation);
275         auto _m = m.dup;
276         auto u = m.determinent;
277         assert(sm == m.determinent);
278         assert(_m == m);
279     }
280     void fMod2() {
281         alias Mint = ModInt!2;
282         alias Mat = SMatrixMod2!(3, 3);
283         alias Vec = SMatrixMod2!(3, 1);
284         static Mint rndM() {
285             return Mint(uniform(0, 2));
286         }
287         Mat m = matrixMod2!(3, 3, (i, j) => rndM())();
288         Mint sm = 0;
289         auto idx = [0, 1, 2];
290         do {
291             Mint buf = 1;
292             foreach (i; 0..3) {
293                 buf *= m[i, idx[i]];
294             }
295             sm += buf;
296         } while (idx.nextEvenPermutation);
297         idx = [0, 2, 1];
298         do {
299             Mint buf = 1;
300             foreach (i; 0..3) {
301                 buf *= m[i, idx[i]];
302             }
303             sm -= buf;
304         } while (idx.nextEvenPermutation);
305         auto _m = m.dup;
306         auto u = m.determinent;
307         if (sm != m.determinent) {
308             writeln(sm, " ", m.determinent);
309             foreach (i; 0..3) {
310                 foreach (j; 0..3) {
311                     write(m[i, j], " ");
312                 }
313                 writeln;
314             }
315             writeln(m);
316         }
317         assert(sm == m.determinent);
318         assert(_m == m);
319     }    
320     import dkh.stopwatch;
321     writeln("Det: ", benchmark!(f!2, f!3, f!11, fMod2)(10000)[].map!(a => a.toMsecs));
322 }
323 
324 
325 // m * v = r
326 Vec solveLinear(Mat, Vec)(Mat m, Vec r) {
327     import std.conv, std.algorithm;
328     size_t N = m.height, M = m.width;
329     int c = 0;
330     foreach (x; 0..M) {
331         ptrdiff_t my = -1;
332         foreach (y; c..N) {
333             if (m[y, x].v) {
334                 my = y;
335                 break;
336             }
337         }
338         if (my == -1) continue;
339         m.swapLine(c, my);
340         r.swapLine(c, my);
341         foreach (y; 0..N) {
342             if (c == y) continue;
343             if (m[y, x].v == 0) continue;
344             auto freq = m[y, x] / m[c, x];
345             foreach (k; 0..M) {
346                 m[y, k] -= freq * m[c, k];
347             }
348             r[y, 0] -= freq * r[c, 0];
349         }
350         c++;
351         if (c == N) break;
352     }
353     Vec v;
354     foreach_reverse (y; 0..c) {
355         ptrdiff_t f = -1;
356         Mat.DataType sm;
357         foreach (x; 0..M) {
358             if (m[y, x].v && f == -1) {
359                 f = x;
360             }
361             sm += m[y, x] * v[x, 0];
362         }
363         v[f, 0] += (r[y, 0] - sm) / m[y, f];
364     }
365     return v;
366 }
367 
368 unittest {
369     import std.random, std.stdio;
370     import dkh.modint;
371     alias Mint = ModInt!(10^^9 + 7);
372     alias Mat = SMatrix!(Mint, 3, 3);
373     alias Vec = SMatrix!(Mint, 3, 1);
374     static Mint rndM() {
375         return Mint(uniform(0, 10^^9 + 7));
376     }
377     Mat m = matrix!(3, 3, (i, j) => rndM())();
378     Vec x = matrix!(3, 1, (i, j) => rndM())();
379     Vec r = m * x;
380     Vec x2 = solveLinear(m, r);
381     assert(m * x2 == r);
382 }
383 
384 unittest {
385     import std.random, std.stdio, std.algorithm;
386     import dkh.modint;
387     void f(uint Mod)() {
388         alias Mint = ModInt!Mod;
389         alias Mat = SMatrix!(Mint, 3, 3);
390         alias Vec = SMatrix!(Mint, 3, 1);
391         static Mint rndM() {
392             return Mint(uniform(0, Mod));
393         }
394         Mat m = matrix!(3, 3, (i, j) => rndM())();
395         Vec x = matrix!(3, 1, (i, j) => rndM())();
396         Vec r = m * x;
397         Mat _m = m.dup;
398         Vec x2 = solveLinear(m, r);
399         assert(m == _m);
400         assert(m * x2 == r);
401     }
402     void fMod2() {
403         alias Mint = ModInt!2;
404         alias Mat = SMatrixMod2!(3, 3);
405         alias Vec = SMatrixMod2!(3, 1);
406         static Mint rndM() {
407             return Mint(uniform(0, 2));
408         }
409         Mat m = matrixMod2!(3, 3, (i, j) => rndM())();
410         Vec x = matrixMod2!(3, 1, (i, j) => rndM())();
411         Vec r = m * x;
412         Mat _m = m.dup;
413         Vec x2 = solveLinear(m, r);
414         assert(m == _m);
415         assert(m * x2 == r);
416     }
417     import dkh.stopwatch;
418     writeln("SolveLinear: ", benchmark!(f!2, f!3, f!11, fMod2)(10000)[].map!(a => a.toMsecs));
419 }