1 module dkh.matrix;
2 
3 import dkh.bitop;
4 
5 /// 行列ライブラリ
6 struct SMatrix(T, size_t H, size_t W) {
7     alias DataType = T;
8     T[W][H] data;
9     this(Args...)(Args args) {
10         static assert(args.length == H*W);
11         foreach (i, v; args) {
12             data[i/W][i%W] = v;
13         }
14     }
15     SMatrix dup() const { return this; }
16 
17     @property static size_t height() {return H;}
18     @property static size_t width() {return W;}
19 
20     ref inout(T) opIndex(size_t i1, size_t i2) inout {
21         return data[i1][i2];
22     }
23     auto opBinary(string op:"+", R)(in R r) const
24     if(height == R.height && width == R.width) {
25         SMatrix res = this;
26         foreach (y; 0..height) foreach (x; 0..W) res[y, x] += r[y, x];
27         return res;
28     }
29     auto opBinary(string op:"*", R)(in R r) const
30     if(width == R.height) {
31         auto rBuf = SMatrix!(T, R.width, R.height)();
32         foreach (y; 0..R.height) {
33             foreach (x; 0..R.width) {
34                 rBuf[x, y] = r[y, x];
35             }
36         }
37         auto res = SMatrix!(T, height, R.width)();
38         foreach (y; 0..height) {
39             foreach (x; 0..R.width) {
40                 T sm;
41                 auto lv = this.data[y][];
42                 auto rv = rBuf.data[x][];
43                 foreach (k; 0..width) {
44                     sm += lv[k] * rv[k];
45                 }
46                 res[y, x] = sm;
47             }
48         }
49         return res;
50     }
51     auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");}
52 
53     void swapLine(size_t x, size_t y) {
54         import std.algorithm : swap;
55         foreach (i; 0..W) swap(data[x][i], data[y][i]);
56     }
57 }
58 
59 import dkh.foundation, dkh.modint;
60 /// 行列ライブラリ(Mod2)
61 struct SMatrixMod2(size_t H, size_t W) {
62     alias DataType = ModInt!2;
63     static immutable B = 64;
64     static immutable L = (W + B-1) / B;
65     ulong[L][H] data;
66     this(Args...)(Args args) {
67         static assert(args.length == H*W);
68         foreach (i, v; args) {
69             this[i/W, i%W] = v;
70         }
71     }
72     SMatrixMod2 dup() const { return this; }
73 
74     @property static size_t height() {return H;}
75     @property static size_t width() {return W;}
76 
77     const(DataType) opIndex(size_t i1, size_t i2) const {
78         assert(i1 < H && i2 < W);
79         return DataType(((data[i1][i2/B] >> (i2%B)) & 1UL) ? 1 : 0);
80     }
81     void opIndexAssign(DataType d, size_t i1, size_t i2) {
82         assert(i1 < H && i2 < W);
83         size_t r = i2 % 64;
84         if (d.v) data[i1][i2/B] |= (1UL<<r);
85         else data[i1][i2/B] &= ~(1UL<<r);
86     }
87     void opIndexAssign(bool d, size_t i1, size_t i2) {
88         assert(i1 < H && i2 < W);
89         size_t r = i2 % 64;
90         if (d) data[i1][i2/B] |= (1UL<<r);
91         else data[i1][i2/B] &= ~(1UL<<r);
92     }
93     auto opIndexOpAssign(string op)(DataType d, size_t i1, size_t i2) {
94         return mixin("this[i1,i2]=this[i1,i2]"~op~"d");
95     }
96     auto opBinary(string op:"+", R)(in R r) const
97     if(height == R.height && width == R.width) {
98         auto res = this.dup;
99         foreach (y; 0..height) foreach (x; 0..L) {
100             res.data[y][x] ^= r.data[y][x];
101         }
102         return res;
103     }
104     auto opBinary(string op:"*", R)(in R r) const
105     if(width == R.height) {
106         auto r2 = SMatrixMod2!(R.width, R.height)();
107         foreach (y; 0..R.height) foreach (x; 0..R.width) {
108             r2[x, y] = r[y, x];
109         }
110         auto res = SMatrixMod2!(height, R.width)();
111         foreach (y; 0..height) {
112             foreach (x; 0..R.width) {
113                 ulong sm = 0;
114                 foreach (k; 0..L) {
115                     sm ^= data[y][k]&r2.data[x][k];
116                 }
117                 res[y, x] = poppar(sm);
118             }
119         }
120         return res;
121     }
122     auto opOpAssign(string op, T)(T r) {return mixin ("this=this"~op~"r");}
123 
124     void swapLine(size_t x, size_t y) {
125         import std.algorithm : swap;
126         foreach (i; 0..L) swap(data[x][i], data[y][i]);
127     }
128 }
129 
130 /// ditto
131 struct DMatrix(T) {
132     alias DataType = T;
133     size_t h, w;
134     T[] data;
135     this(size_t h, size_t w) {
136         this.h = h; this.w = w;
137         data = new T[h*w];
138     }
139     this(size_t h, size_t w, in T[] d) {
140         this(h, w);
141         assert(d.length == h*w);
142         data[] = d[];
143     }
144     DMatrix dup() const { return DMatrix(h, w, data); }
145 
146     @property size_t height() const {return h;}
147     @property size_t width() const {return w;}
148 
149     ref inout(T) opIndex(size_t i1, size_t i2) inout {
150         return data[i1*width+i2];
151     }
152     auto opBinary(string op:"+", R)(in R r) const {
153         assert(height == r.height && width == r.width);
154         auto res = this.dup;
155         foreach (y; 0..height) foreach (x; 0..width) res[y, x] += r[y, x];
156         return res;
157     }
158     auto opBinary(string op:"*", R)(in R r) const {
159         assert(width == r.height);
160         auto rBuf = DMatrix!(T)(R.width, R.height);
161         foreach (y; 0..R.height) {
162             foreach (x; 0..R.width) {
163                 rBuf[x, y] = r[y, x];
164             }
165         }
166         auto res = DMatrix!(T)(height, r.width);
167         foreach (y; 0..height) {
168             foreach (x; 0..r.width) {
169                 foreach (k; 0..width) {
170                     res[y, x] += this[y, k]*rBuf[x, k];
171                 }
172             }
173         }
174         return res;
175     }
176     auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");}
177     void swapLine(size_t x, size_t y) {
178         import std.algorithm : swap;
179         foreach (i; 0..w) swap(this[x, i], this[y, i]);
180     }    
181 }
182 
183 ///
184 unittest {
185     import dkh.numeric.primitive;
186     auto mat = DMatrix!int(2, 2, [0, 1, 1, 1]);
187     assert(pow(mat, 10, DMatrix!int(2, 2, [1, 0, 0, 1]))[0, 0] == 34); //Fib_10
188 }
189 
190 unittest {
191     auto mat1 = DMatrix!int(2, 2, [1, 1, 1, 1]);
192     auto mat2 = DMatrix!int(2, 2, [2, 2, 2, 2]);
193     auto mat3 = mat1 + mat2;
194     assert(mat1[0, 0] == 1);
195     assert(mat2[0, 0] == 2);
196 }
197 
198 auto matrix(size_t H, size_t W, alias pred)() {
199     import std.traits : ReturnType;
200     SMatrix!(typeof(pred(0, 0)), H, W) res;
201     foreach (y; 0..H) {
202         foreach (x; 0..W) {
203             res[y, x] = pred(y, x);
204         }
205     }
206     return res;
207 }
208 
209 auto matrix(alias pred)(size_t H, size_t W) {
210     import std.traits : ReturnType;
211     auto res = DMatrix!(typeof(pred(0, 0)))(H, W);
212     foreach (y; 0..H) {
213         foreach (x; 0..W) {
214             res[y, x] = pred(y, x);
215         }
216     }
217     return res;
218 }
219 
220 auto matrixMod2(size_t H, size_t W, alias pred)() {
221     import std.traits : ReturnType;
222     SMatrixMod2!(H, W) res;
223     foreach (y; 0..H) {
224         foreach (x; 0..W) {
225             res[y, x] = pred(y, x);
226         }
227     }
228     return res;
229 }
230 
231 auto determinent(Mat)(in Mat _m) {
232     auto m = _m.dup;
233     assert(m.height == m.width);
234     import std.conv, std.algorithm;
235     alias M = Mat.DataType;
236     size_t N = m.height;
237     M base = 1;
238     foreach (i; 0..N) {
239         if (m[i, i] == M(0)) {
240             foreach (j; i+1..N) {
241                 if (m[j, i] != M(0)) {
242                     m.swapLine(i, j);
243                     base *= M(-1);
244                     break;
245                 }
246             }
247             if (m[i, i] == M(0)) return M(0);
248         }
249         base *= m[i, i];
250         M im = M(1)/m[i, i];
251         foreach (j; 0..N) {
252             m[i, j] *= im;
253         }
254         foreach (j; i+1..N) {
255             M x = m[j, i];
256             foreach (k; 0..N) {
257                 m[j, k] -= m[i, k] * x;
258             }
259         }
260     }
261     return base;
262 }
263 
264 unittest {
265     import std.random, std.stdio, std.algorithm;
266     import dkh.modint;
267     void f(uint Mod)() {
268         alias Mint = ModInt!Mod;
269         alias Mat = SMatrix!(Mint, 3, 3);
270         alias Vec = SMatrix!(Mint, 3, 1);
271         static Mint rndM() {
272             return Mint(uniform(0, Mod));
273         }
274         Mat m = matrix!(3, 3, (i, j) => rndM())();
275         Mint sm = 0;
276         auto idx = [0, 1, 2];
277         do {
278             Mint buf = 1;
279             foreach (i; 0..3) {
280                 buf *= m[i, idx[i]];
281             }
282             sm += buf;
283         } while (idx.nextEvenPermutation);
284         idx = [0, 2, 1];
285         do {
286             Mint buf = 1;
287             foreach (i; 0..3) {
288                 buf *= m[i, idx[i]];
289             }
290             sm -= buf;
291         } while (idx.nextEvenPermutation);
292         auto _m = m.dup;
293         auto u = m.determinent;
294         assert(sm == m.determinent);
295         assert(_m == m);
296     }
297     void fD(uint Mod)() {
298         alias Mint = ModInt!Mod;
299         alias Mat = DMatrix!Mint;
300         alias Vec = DMatrix!Mint;
301         static Mint rndM() {
302             return Mint(uniform(0, Mod));
303         }
304         Mat m = matrix!((i, j) => rndM())(3, 3);
305         Mint sm = 0;
306         auto idx = [0, 1, 2];
307         do {
308             Mint buf = 1;
309             foreach (i; 0..3) {
310                 buf *= m[i, idx[i]];
311             }
312             sm += buf;
313         } while (idx.nextEvenPermutation);
314         idx = [0, 2, 1];
315         do {
316             Mint buf = 1;
317             foreach (i; 0..3) {
318                 buf *= m[i, idx[i]];
319             }
320             sm -= buf;
321         } while (idx.nextEvenPermutation);
322         auto _m = m.dup;
323         auto u = m.determinent;
324         assert(sm == m.determinent);
325         assert(_m == m);
326     }    
327     void fMod2() {
328         alias Mint = ModInt!2;
329         alias Mat = SMatrixMod2!(3, 3);
330         alias Vec = SMatrixMod2!(3, 1);
331         static Mint rndM() {
332             return Mint(uniform(0, 2));
333         }
334         Mat m = matrixMod2!(3, 3, (i, j) => rndM())();
335         Mint sm = 0;
336         auto idx = [0, 1, 2];
337         do {
338             Mint buf = 1;
339             foreach (i; 0..3) {
340                 buf *= m[i, idx[i]];
341             }
342             sm += buf;
343         } while (idx.nextEvenPermutation);
344         idx = [0, 2, 1];
345         do {
346             Mint buf = 1;
347             foreach (i; 0..3) {
348                 buf *= m[i, idx[i]];
349             }
350             sm -= buf;
351         } while (idx.nextEvenPermutation);
352         auto _m = m.dup;
353         auto u = m.determinent;
354         if (sm != m.determinent) {
355             writeln(sm, " ", m.determinent);
356             foreach (i; 0..3) {
357                 foreach (j; 0..3) {
358                     write(m[i, j], " ");
359                 }
360                 writeln;
361             }
362             writeln(m);
363         }
364         assert(sm == m.determinent);
365         assert(_m == m);
366     }    
367     import dkh.stopwatch;
368     writeln("Det: ", benchmark!(f!2, f!3, f!11, fD!2, fD!3, fD!11, fMod2)(10000)[].map!(a => a.toMsecs));
369 }
370 
371 
372 // m * v = r
373 Vec solveLinear(Mat, Vec)(Mat m, Vec r) {
374     import std.conv, std.algorithm;
375     size_t N = m.height, M = m.width;
376     int c = 0;
377     foreach (x; 0..M) {
378         ptrdiff_t my = -1;
379         foreach (y; c..N) {
380             if (m[y, x].v) {
381                 my = y;
382                 break;
383             }
384         }
385         if (my == -1) continue;
386         m.swapLine(c, my);
387         r.swapLine(c, my);
388         foreach (y; 0..N) {
389             if (c == y) continue;
390             if (m[y, x].v == 0) continue;
391             auto freq = m[y, x] / m[c, x];
392             foreach (k; 0..M) {
393                 m[y, k] -= freq * m[c, k];
394             }
395             r[y, 0] -= freq * r[c, 0];
396         }
397         c++;
398         if (c == N) break;
399     }
400     Vec v;
401     foreach_reverse (y; 0..c) {
402         ptrdiff_t f = -1;
403         Mat.DataType sm;
404         foreach (x; 0..M) {
405             if (m[y, x].v && f == -1) {
406                 f = x;
407             }
408             sm += m[y, x] * v[x, 0];
409         }
410         v[f, 0] += (r[y, 0] - sm) / m[y, f];
411     }
412     return v;
413 }
414 
415 unittest {
416     import std.random, std.stdio;
417     import dkh.modint;
418     alias Mint = ModInt!(10^^9 + 7);
419     alias Mat = SMatrix!(Mint, 3, 3);
420     alias Vec = SMatrix!(Mint, 3, 1);
421     static Mint rndM() {
422         return Mint(uniform(0, 10^^9 + 7));
423     }
424     Mat m = matrix!(3, 3, (i, j) => rndM())();
425     Vec x = matrix!(3, 1, (i, j) => rndM())();
426     Vec r = m * x;
427     Vec x2 = solveLinear(m, r);
428     assert(m * x2 == r);
429 }
430 
431 unittest {
432     import std.random, std.stdio, std.algorithm;
433     import dkh.modint;
434     void f(uint Mod)() {
435         alias Mint = ModInt!Mod;
436         alias Mat = SMatrix!(Mint, 3, 3);
437         alias Vec = SMatrix!(Mint, 3, 1);
438         static Mint rndM() {
439             return Mint(uniform(0, Mod));
440         }
441         Mat m = matrix!(3, 3, (i, j) => rndM())();
442         Vec x = matrix!(3, 1, (i, j) => rndM())();
443         Vec r = m * x;
444         Mat _m = m.dup;
445         Vec x2 = solveLinear(m, r);
446         assert(m == _m);
447         assert(m * x2 == r);
448     }
449     void fMod2() {
450         alias Mint = ModInt!2;
451         alias Mat = SMatrixMod2!(3, 3);
452         alias Vec = SMatrixMod2!(3, 1);
453         static Mint rndM() {
454             return Mint(uniform(0, 2));
455         }
456         Mat m = matrixMod2!(3, 3, (i, j) => rndM())();
457         Vec x = matrixMod2!(3, 1, (i, j) => rndM())();
458         Vec r = m * x;
459         Mat _m = m.dup;
460         Vec x2 = solveLinear(m, r);
461         assert(m == _m);
462         assert(m * x2 == r);
463     }
464     import dkh.stopwatch;
465     writeln("SolveLinear: ", benchmark!(f!2, f!3, f!11, fMod2)(10000)[].map!(a => a.toMsecs));
466 }