1 module dkh.matrix;
2 
3 import dkh.bitop;
4 
5 /// 行列ライブラリ
6 struct SMatrix(T, size_t H, size_t W) {
7     alias DataType = T;
8     T[W][H] data;
9     this(Args...)(Args args) {
10         static assert(args.length == H*W);
11         foreach (i, v; args) {
12             data[i/W][i%W] = v;
13         }
14     }
15     SMatrix dup() const { return this; }
16 
17     @property static size_t height() {return H;}
18     @property static size_t width() {return W;}
19 
20     ref inout(T) opIndex(size_t i1, size_t i2) inout {
21         return data[i1][i2];
22     }
23     auto opBinary(string op:"+", R)(in R r) const
24     if(height == R.height && width == R.width) {
25         SMatrix res = this;
26         foreach (y; 0..height) foreach (x; 0..W) res[y, x] += r[y, x];
27         return res;
28     }
29     auto opBinary(string op:"*", R)(in R r) const
30     if(width == R.height) {
31         auto rBuf = SMatrix!(T, R.width, R.height)();
32         foreach (y; 0..R.height) {
33             foreach (x; 0..R.width) {
34                 rBuf[x, y] = r[y, x];
35             }
36         }
37         auto res = SMatrix!(T, height, R.width)();
38         foreach (y; 0..height) {
39             foreach (x; 0..R.width) {
40                 T sm;
41                 auto lv = this.data[y][];
42                 auto rv = rBuf.data[x][];
43                 foreach (k; 0..width) {
44                     sm += lv[k] * rv[k];
45                 }
46                 res[y, x] = sm;
47             }
48         }
49         return res;
50     }
51     auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");}
52 
53     void swapLine(size_t x, size_t y) {
54         import std.algorithm : swap;
55         foreach (i; 0..W) swap(data[x][i], data[y][i]);
56     }
57 }
58 
59 import dkh.foundation, dkh.modint;
60 /// 行列ライブラリ(Mod2)
61 struct SMatrixMod2(size_t H, size_t W) {
62     alias DataType = ModInt!2;
63     static immutable B = 64;
64     static immutable L = (W + B-1) / B;
65     ulong[L][H] data;
66     this(Args...)(Args args) {
67         static assert(args.length == H*W);
68         foreach (i, v; args) {
69             this[i/W, i%W] = v;
70         }
71     }
72     SMatrixMod2 dup() const { return this; }
73 
74     @property static size_t height() {return H;}
75     @property static size_t width() {return W;}
76 
77     const(DataType) opIndex(size_t i1, size_t i2) const {
78         assert(i1 < H && i2 < W);
79         return DataType(((data[i1][i2/B] >> (i2%B)) & 1UL) ? 1 : 0);
80     }
81     void opIndexAssign(DataType d, size_t i1, size_t i2) {
82         assert(i1 < H && i2 < W);
83         size_t r = i2 % 64;
84         if (d.v) data[i1][i2/B] |= (1UL<<r);
85         else data[i1][i2/B] &= ~(1UL<<r);
86     }
87     void opIndexAssign(bool d, size_t i1, size_t i2) {
88         assert(i1 < H && i2 < W);
89         size_t r = i2 % 64;
90         if (d) data[i1][i2/B] |= (1UL<<r);
91         else data[i1][i2/B] &= ~(1UL<<r);
92     }
93     auto opIndexOpAssign(string op)(DataType d, size_t i1, size_t i2) {
94         return mixin("this[i1,i2]=this[i1,i2]"~op~"d");
95     }
96     auto opBinary(string op:"+", R)(in R r) const
97     if(height == R.height && width == R.width) {
98         auto res = this.dup;
99         foreach (y; 0..height) foreach (x; 0..L) {
100             res.data[y][x] ^= r.data[y][x];
101         }
102         return res;
103     }
104     auto opBinary(string op:"*", R)(in R r) const
105     if(width == R.height) {
106         auto r2 = SMatrixMod2!(R.width, R.height)();
107         foreach (y; 0..R.height) foreach (x; 0..R.width) {
108             r2[x, y] = r[y, x];
109         }
110         auto res = SMatrixMod2!(height, R.width)();
111         foreach (y; 0..height) {
112             foreach (x; 0..R.width) {
113                 ulong sm = 0;
114                 foreach (k; 0..L) {
115                     sm ^= data[y][k]&r2.data[x][k];
116                 }
117                 res[y, x] = poppar(sm);
118             }
119         }
120         return res;
121     }
122     auto opOpAssign(string op, T)(T r) {return mixin ("this=this"~op~"r");}
123 
124     void swapLine(size_t x, size_t y) {
125         import std.algorithm : swap;
126         foreach (i; 0..L) swap(data[x][i], data[y][i]);
127     }
128 }
129 
130 /// ditto
131 struct DMatrix(T) {
132     alias DataType = T;
133     size_t h, w;
134     T[] data;
135     this(size_t h, size_t w) {
136         this.h = h; this.w = w;
137         data = new T[h*w];
138     }
139     this(size_t h, size_t w, in T[] d) {
140         this(h, w);
141         assert(d.length == h*w);
142         data[] = d[];
143     }
144     DMatrix dup() const { return DMatrix(h, w, data); }
145 
146     @property size_t height() const {return h;}
147     @property size_t width() const {return w;}
148 
149     ref inout(T) opIndex(size_t i1, size_t i2) inout {
150         return data[i1*width+i2];
151     }
152     auto opBinary(string op:"+", R)(in R r) const {
153         assert(height == r.height && width == r.width);
154         auto res = this.dup;
155         foreach (y; 0..height) foreach (x; 0..width) res[y, x] += r[y, x];
156         return res;
157     }
158     auto opBinary(string op:"*", R)(in R r) const {
159         assert(width == r.height);
160         auto rBuf = DMatrix!(T)(R.width, R.height);
161         foreach (y; 0..R.height) {
162             foreach (x; 0..R.width) {
163                 rBuf[x, y] = r[y, x];
164             }
165         }
166         auto res = DMatrix!(T)(height, r.width);
167         foreach (y; 0..height) {
168             foreach (x; 0..r.width) {
169                 foreach (k; 0..width) {
170                     res[y, x] += this[y, k]*rBuf[x, k];
171                 }
172             }
173         }
174         return res;
175     }
176     auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");}
177     void swapLine(size_t x, size_t y) {
178         import std.algorithm : swap;
179         foreach (i; 0..w) swap(this[x, i], this[y, i]);
180     }    
181 }
182 
183 ///
184 unittest {
185     import dkh.numeric.primitive;
186     auto mat = DMatrix!int(2, 2, [0, 1, 1, 1]);
187     assert(pow(mat, 10, DMatrix!int(2, 2, [1, 0, 0, 1]))[0, 0] == 34); //Fib_10
188 }
189 
190 unittest {
191     auto mat1 = DMatrix!int(2, 2, [1, 1, 1, 1]);
192     auto mat2 = DMatrix!int(2, 2, [2, 2, 2, 2]);
193     auto mat3 = mat1 + mat2;
194     assert(mat1[0, 0] == 1);
195     assert(mat2[0, 0] == 2);
196 }
197 
198 auto matrix(size_t H, size_t W, alias pred)() {
199     import std.traits : ReturnType;
200     SMatrix!(typeof(pred(0, 0)), H, W) res;
201     foreach (y; 0..H) {
202         foreach (x; 0..W) {
203             res[y, x] = pred(y, x);
204         }
205     }
206     return res;
207 }
208 
209 auto matrix(alias pred)(size_t H, size_t W) {
210     import std.traits : ReturnType;
211     auto res = DMatrix!(typeof(pred(0, 0)))(H, W);
212     foreach (y; 0..H) {
213         foreach (x; 0..W) {
214             res[y, x] = pred(y, x);
215         }
216     }
217     return res;
218 }
219 
220 auto matrixMod2(size_t H, size_t W, alias pred)() {
221     import std.traits : ReturnType;
222     SMatrixMod2!(H, W) res;
223     foreach (y; 0..H) {
224         foreach (x; 0..W) {
225             res[y, x] = pred(y, x);
226         }
227     }
228     return res;
229 }
230 
231 auto determinent(Mat)(in Mat _m) {
232     auto m = _m.dup;
233     assert(m.height == m.width);
234     import std.conv, std.algorithm;
235     alias M = Mat.DataType;
236     size_t N = m.height;
237     M base = 1;
238     foreach (i; 0..N) {
239         if (m[i, i] == M(0)) {
240             foreach (j; i+1..N) {
241                 if (m[j, i] != M(0)) {
242                     m.swapLine(i, j);
243                     base *= M(-1);
244                     break;
245                 }
246             }
247             if (m[i, i] == M(0)) return M(0);
248         }
249         base *= m[i, i];
250         M im = M(1)/m[i, i];
251         foreach (j; 0..N) {
252             m[i, j] *= im;
253         }
254         foreach (j; i+1..N) {
255             M x = m[j, i];
256             foreach (k; 0..N) {
257                 m[j, k] -= m[i, k] * x;
258             }
259         }
260     }
261     return base;
262 }
263 
264 
265 size_t rank(Mat)(in Mat _m) {
266     alias M = Mat.DataType;
267     auto m = _m.dup;
268     immutable size_t h = m.height, w = m.width;
269     size_t rnk;
270     foreach (c; 0..w) {
271         size_t mr = h;
272         foreach (r; rnk..h) {
273             if (m[r, c]) {
274                 mr = r;
275                 break;
276             }
277         }
278         if (mr == h) continue;
279         m.swapLine(rnk, mr);
280         foreach (r; rnk + 1..h) {
281             if (!a[r, c]) continue;
282             auto freq = a[r, c] / a[rnk, c];
283             foreach (i; c..w) m[r, i] -= freq * m[rnk, i];
284         }
285         rnk++;
286         if (rnk == h) break;
287     }
288     return rnk;
289 }
290 
291 unittest {
292     import std.random, std.stdio, std.algorithm;
293     import dkh.modint;
294     void f(uint Mod)() {
295         alias Mint = ModInt!Mod;
296         alias Mat = SMatrix!(Mint, 3, 3);
297         alias Vec = SMatrix!(Mint, 3, 1);
298         static Mint rndM() {
299             return Mint(uniform(0, Mod));
300         }
301         Mat m = matrix!(3, 3, (i, j) => rndM())();
302         Mint sm = 0;
303         auto idx = [0, 1, 2];
304         do {
305             Mint buf = 1;
306             foreach (i; 0..3) {
307                 buf *= m[i, idx[i]];
308             }
309             sm += buf;
310         } while (idx.nextEvenPermutation);
311         idx = [0, 2, 1];
312         do {
313             Mint buf = 1;
314             foreach (i; 0..3) {
315                 buf *= m[i, idx[i]];
316             }
317             sm -= buf;
318         } while (idx.nextEvenPermutation);
319         auto _m = m.dup;
320         auto u = m.determinent;
321         assert(sm == m.determinent);
322         assert(_m == m);
323     }
324     void fD(uint Mod)() {
325         alias Mint = ModInt!Mod;
326         alias Mat = DMatrix!Mint;
327         alias Vec = DMatrix!Mint;
328         static Mint rndM() {
329             return Mint(uniform(0, Mod));
330         }
331         Mat m = matrix!((i, j) => rndM())(3, 3);
332         Mint sm = 0;
333         auto idx = [0, 1, 2];
334         do {
335             Mint buf = 1;
336             foreach (i; 0..3) {
337                 buf *= m[i, idx[i]];
338             }
339             sm += buf;
340         } while (idx.nextEvenPermutation);
341         idx = [0, 2, 1];
342         do {
343             Mint buf = 1;
344             foreach (i; 0..3) {
345                 buf *= m[i, idx[i]];
346             }
347             sm -= buf;
348         } while (idx.nextEvenPermutation);
349         auto _m = m.dup;
350         auto u = m.determinent;
351         assert(sm == m.determinent);
352         assert(_m == m);
353     }    
354     void fMod2() {
355         alias Mint = ModInt!2;
356         alias Mat = SMatrixMod2!(3, 3);
357         alias Vec = SMatrixMod2!(3, 1);
358         static Mint rndM() {
359             return Mint(uniform(0, 2));
360         }
361         Mat m = matrixMod2!(3, 3, (i, j) => rndM())();
362         Mint sm = 0;
363         auto idx = [0, 1, 2];
364         do {
365             Mint buf = 1;
366             foreach (i; 0..3) {
367                 buf *= m[i, idx[i]];
368             }
369             sm += buf;
370         } while (idx.nextEvenPermutation);
371         idx = [0, 2, 1];
372         do {
373             Mint buf = 1;
374             foreach (i; 0..3) {
375                 buf *= m[i, idx[i]];
376             }
377             sm -= buf;
378         } while (idx.nextEvenPermutation);
379         auto _m = m.dup;
380         auto u = m.determinent;
381         if (sm != m.determinent) {
382             writeln(sm, " ", m.determinent);
383             foreach (i; 0..3) {
384                 foreach (j; 0..3) {
385                     write(m[i, j], " ");
386                 }
387                 writeln;
388             }
389             writeln(m);
390         }
391         assert(sm == m.determinent);
392         assert(_m == m);
393     }    
394     import dkh.stopwatch;
395     writeln("Det: ", benchmark!(f!2, f!3, f!11, fD!2, fD!3, fD!11, fMod2)(10000)[].map!(a => a.toMsecs));
396 }
397 
398 
399 // m * v = r
400 Vec solveLinear(Mat, Vec)(Mat m, Vec r) {
401     import std.conv, std.algorithm;
402     size_t N = m.height, M = m.width;
403     int c = 0;
404     foreach (x; 0..M) {
405         ptrdiff_t my = -1;
406         foreach (y; c..N) {
407             if (m[y, x].v) {
408                 my = y;
409                 break;
410             }
411         }
412         if (my == -1) continue;
413         m.swapLine(c, my);
414         r.swapLine(c, my);
415         foreach (y; 0..N) {
416             if (c == y) continue;
417             if (m[y, x].v == 0) continue;
418             auto freq = m[y, x] / m[c, x];
419             foreach (k; 0..M) {
420                 m[y, k] -= freq * m[c, k];
421             }
422             r[y, 0] -= freq * r[c, 0];
423         }
424         c++;
425         if (c == N) break;
426     }
427     Vec v;
428     foreach_reverse (y; 0..c) {
429         ptrdiff_t f = -1;
430         Mat.DataType sm;
431         foreach (x; 0..M) {
432             if (m[y, x].v && f == -1) {
433                 f = x;
434             }
435             sm += m[y, x] * v[x, 0];
436         }
437         v[f, 0] += (r[y, 0] - sm) / m[y, f];
438     }
439     return v;
440 }
441 
442 unittest {
443     import std.random, std.stdio;
444     import dkh.modint;
445     alias Mint = ModInt!(10^^9 + 7);
446     alias Mat = SMatrix!(Mint, 3, 3);
447     alias Vec = SMatrix!(Mint, 3, 1);
448     static Mint rndM() {
449         return Mint(uniform(0, 10^^9 + 7));
450     }
451     Mat m = matrix!(3, 3, (i, j) => rndM())();
452     Vec x = matrix!(3, 1, (i, j) => rndM())();
453     Vec r = m * x;
454     Vec x2 = solveLinear(m, r);
455     assert(m * x2 == r);
456 }
457 
458 unittest {
459     import std.random, std.stdio, std.algorithm;
460     import dkh.modint;
461     void f(uint Mod)() {
462         alias Mint = ModInt!Mod;
463         alias Mat = SMatrix!(Mint, 3, 3);
464         alias Vec = SMatrix!(Mint, 3, 1);
465         static Mint rndM() {
466             return Mint(uniform(0, Mod));
467         }
468         Mat m = matrix!(3, 3, (i, j) => rndM())();
469         Vec x = matrix!(3, 1, (i, j) => rndM())();
470         Vec r = m * x;
471         Mat _m = m.dup;
472         Vec x2 = solveLinear(m, r);
473         assert(m == _m);
474         assert(m * x2 == r);
475     }
476     void fMod2() {
477         alias Mint = ModInt!2;
478         alias Mat = SMatrixMod2!(3, 3);
479         alias Vec = SMatrixMod2!(3, 1);
480         static Mint rndM() {
481             return Mint(uniform(0, 2));
482         }
483         Mat m = matrixMod2!(3, 3, (i, j) => rndM())();
484         Vec x = matrixMod2!(3, 1, (i, j) => rndM())();
485         Vec r = m * x;
486         Mat _m = m.dup;
487         Vec x2 = solveLinear(m, r);
488         assert(m == _m);
489         assert(m * x2 == r);
490     }
491     import dkh.stopwatch;
492     writeln("SolveLinear: ", benchmark!(f!2, f!3, f!11, fMod2)(10000)[].map!(a => a.toMsecs));
493 }