1 module dkh.matrix;
2
3 import dkh.bitop;
4
5 /// 行列ライブラリ
6 struct SMatrix(T, size_t H, size_t W) {
7 alias DataType = T;
8 T[W][H] data;
9 this(Args...)(Args args) {
10 static assert(args.length == H*W);
11 foreach (i, v; args) {
12 data[i/W][i%W] = v;
13 }
14 }
15 SMatrix dup() const { return this; }
16
17 @property static size_t height() {return H;}
18 @property static size_t width() {return W;}
19
20 ref inout(T) opIndex(size_t i1, size_t i2) inout {
21 return data[i1][i2];
22 }
23 auto opBinary(string op:"+", R)(in R r) const
24 if(height == R.height && width == R.width) {
25 SMatrix res = this;
26 foreach (y; 0..height) foreach (x; 0..W) res[y, x] += r[y, x];
27 return res;
28 }
29 auto opBinary(string op:"*", R)(in R r) const
30 if(width == R.height) {
31 auto rBuf = SMatrix!(T, R.width, R.height)();
32 foreach (y; 0..R.height) {
33 foreach (x; 0..R.width) {
34 rBuf[x, y] = r[y, x];
35 }
36 }
37 auto res = SMatrix!(T, height, R.width)();
38 foreach (y; 0..height) {
39 foreach (x; 0..R.width) {
40 T sm;
41 auto lv = this.data[y][];
42 auto rv = rBuf.data[x][];
43 foreach (k; 0..width) {
44 sm += lv[k] * rv[k];
45 }
46 res[y, x] = sm;
47 }
48 }
49 return res;
50 }
51 auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");}
52
53 void swapLine(size_t x, size_t y) {
54 import std.algorithm : swap;
55 foreach (i; 0..W) swap(data[x][i], data[y][i]);
56 }
57 }
58
59 import dkh.foundation, dkh.modint;
60 /// 行列ライブラリ(Mod2)
61 struct SMatrixMod2(size_t H, size_t W) {
62 alias DataType = ModInt!2;
63 static immutable B = 64;
64 static immutable L = (W + B-1) / B;
65 ulong[L][H] data;
66 this(Args...)(Args args) {
67 static assert(args.length == H*W);
68 foreach (i, v; args) {
69 this[i/W, i%W] = v;
70 }
71 }
72 SMatrixMod2 dup() const { return this; }
73
74 @property static size_t height() {return H;}
75 @property static size_t width() {return W;}
76
77 const(DataType) opIndex(size_t i1, size_t i2) const {
78 assert(i1 < H && i2 < W);
79 return DataType(((data[i1][i2/B] >> (i2%B)) & 1UL) ? 1 : 0);
80 }
81 void opIndexAssign(DataType d, size_t i1, size_t i2) {
82 assert(i1 < H && i2 < W);
83 size_t r = i2 % 64;
84 if (d.v) data[i1][i2/B] |= (1UL<<r);
85 else data[i1][i2/B] &= ~(1UL<<r);
86 }
87 void opIndexAssign(bool d, size_t i1, size_t i2) {
88 assert(i1 < H && i2 < W);
89 size_t r = i2 % 64;
90 if (d) data[i1][i2/B] |= (1UL<<r);
91 else data[i1][i2/B] &= ~(1UL<<r);
92 }
93 auto opIndexOpAssign(string op)(DataType d, size_t i1, size_t i2) {
94 return mixin("this[i1,i2]=this[i1,i2]"~op~"d");
95 }
96 auto opBinary(string op:"+", R)(in R r) const
97 if(height == R.height && width == R.width) {
98 auto res = this.dup;
99 foreach (y; 0..height) foreach (x; 0..L) {
100 res.data[y][x] ^= r.data[y][x];
101 }
102 return res;
103 }
104 auto opBinary(string op:"*", R)(in R r) const
105 if(width == R.height) {
106 auto r2 = SMatrixMod2!(R.width, R.height)();
107 foreach (y; 0..R.height) foreach (x; 0..R.width) {
108 r2[x, y] = r[y, x];
109 }
110 auto res = SMatrixMod2!(height, R.width)();
111 foreach (y; 0..height) {
112 foreach (x; 0..R.width) {
113 ulong sm = 0;
114 foreach (k; 0..L) {
115 sm ^= data[y][k]&r2.data[x][k];
116 }
117 res[y, x] = poppar(sm);
118 }
119 }
120 return res;
121 }
122 auto opOpAssign(string op, T)(T r) {return mixin ("this=this"~op~"r");}
123
124 void swapLine(size_t x, size_t y) {
125 import std.algorithm : swap;
126 foreach (i; 0..L) swap(data[x][i], data[y][i]);
127 }
128 }
129
130 /// ditto
131 struct DMatrix(T) {
132 alias DataType = T;
133 size_t h, w;
134 T[] data;
135 this(size_t h, size_t w) {
136 this.h = h; this.w = w;
137 data = new T[h*w];
138 }
139 this(size_t h, size_t w, in T[] d) {
140 this(h, w);
141 assert(d.length == h*w);
142 data[] = d[];
143 }
144 DMatrix dup() const { return DMatrix(h, w, data); }
145
146 @property size_t height() const {return h;}
147 @property size_t width() const {return w;}
148
149 ref inout(T) opIndex(size_t i1, size_t i2) inout {
150 return data[i1*width+i2];
151 }
152 auto opBinary(string op:"+", R)(in R r) const {
153 assert(height == r.height && width == r.width);
154 auto res = this.dup;
155 foreach (y; 0..height) foreach (x; 0..width) res[y, x] += r[y, x];
156 return res;
157 }
158 auto opBinary(string op:"*", R)(in R r) const {
159 assert(width == r.height);
160 auto rBuf = DMatrix!(T)(R.width, R.height);
161 foreach (y; 0..R.height) {
162 foreach (x; 0..R.width) {
163 rBuf[x, y] = r[y, x];
164 }
165 }
166 auto res = DMatrix!(T)(height, r.width);
167 foreach (y; 0..height) {
168 foreach (x; 0..r.width) {
169 foreach (k; 0..width) {
170 res[y, x] += this[y, k]*rBuf[x, k];
171 }
172 }
173 }
174 return res;
175 }
176 auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");}
177 void swapLine(size_t x, size_t y) {
178 import std.algorithm : swap;
179 foreach (i; 0..w) swap(this[x, i], this[y, i]);
180 }
181 }
182
183 ///
184 unittest {
185 import dkh.numeric.primitive;
186 auto mat = DMatrix!int(2, 2, [0, 1, 1, 1]);
187 assert(pow(mat, 10, DMatrix!int(2, 2, [1, 0, 0, 1]))[0, 0] == 34); //Fib_10
188 }
189
190 unittest {
191 auto mat1 = DMatrix!int(2, 2, [1, 1, 1, 1]);
192 auto mat2 = DMatrix!int(2, 2, [2, 2, 2, 2]);
193 auto mat3 = mat1 + mat2;
194 assert(mat1[0, 0] == 1);
195 assert(mat2[0, 0] == 2);
196 }
197
198 auto matrix(size_t H, size_t W, alias pred)() {
199 import std.traits : ReturnType;
200 SMatrix!(typeof(pred(0, 0)), H, W) res;
201 foreach (y; 0..H) {
202 foreach (x; 0..W) {
203 res[y, x] = pred(y, x);
204 }
205 }
206 return res;
207 }
208
209 auto matrix(alias pred)(size_t H, size_t W) {
210 import std.traits : ReturnType;
211 auto res = DMatrix!(typeof(pred(0, 0)))(H, W);
212 foreach (y; 0..H) {
213 foreach (x; 0..W) {
214 res[y, x] = pred(y, x);
215 }
216 }
217 return res;
218 }
219
220 auto matrixMod2(size_t H, size_t W, alias pred)() {
221 import std.traits : ReturnType;
222 SMatrixMod2!(H, W) res;
223 foreach (y; 0..H) {
224 foreach (x; 0..W) {
225 res[y, x] = pred(y, x);
226 }
227 }
228 return res;
229 }
230
231 auto determinent(Mat)(in Mat _m) {
232 auto m = _m.dup;
233 assert(m.height == m.width);
234 import std.conv, std.algorithm;
235 alias M = Mat.DataType;
236 size_t N = m.height;
237 M base = 1;
238 foreach (i; 0..N) {
239 if (m[i, i] == M(0)) {
240 foreach (j; i+1..N) {
241 if (m[j, i] != M(0)) {
242 m.swapLine(i, j);
243 base *= M(-1);
244 break;
245 }
246 }
247 if (m[i, i] == M(0)) return M(0);
248 }
249 base *= m[i, i];
250 M im = M(1)/m[i, i];
251 foreach (j; 0..N) {
252 m[i, j] *= im;
253 }
254 foreach (j; i+1..N) {
255 M x = m[j, i];
256 foreach (k; 0..N) {
257 m[j, k] -= m[i, k] * x;
258 }
259 }
260 }
261 return base;
262 }
263
264 unittest {
265 import std.random, std.stdio, std.algorithm;
266 import dkh.modint;
267 void f(uint Mod)() {
268 alias Mint = ModInt!Mod;
269 alias Mat = SMatrix!(Mint, 3, 3);
270 alias Vec = SMatrix!(Mint, 3, 1);
271 static Mint rndM() {
272 return Mint(uniform(0, Mod));
273 }
274 Mat m = matrix!(3, 3, (i, j) => rndM())();
275 Mint sm = 0;
276 auto idx = [0, 1, 2];
277 do {
278 Mint buf = 1;
279 foreach (i; 0..3) {
280 buf *= m[i, idx[i]];
281 }
282 sm += buf;
283 } while (idx.nextEvenPermutation);
284 idx = [0, 2, 1];
285 do {
286 Mint buf = 1;
287 foreach (i; 0..3) {
288 buf *= m[i, idx[i]];
289 }
290 sm -= buf;
291 } while (idx.nextEvenPermutation);
292 auto _m = m.dup;
293 auto u = m.determinent;
294 assert(sm == m.determinent);
295 assert(_m == m);
296 }
297 void fD(uint Mod)() {
298 alias Mint = ModInt!Mod;
299 alias Mat = DMatrix!Mint;
300 alias Vec = DMatrix!Mint;
301 static Mint rndM() {
302 return Mint(uniform(0, Mod));
303 }
304 Mat m = matrix!((i, j) => rndM())(3, 3);
305 Mint sm = 0;
306 auto idx = [0, 1, 2];
307 do {
308 Mint buf = 1;
309 foreach (i; 0..3) {
310 buf *= m[i, idx[i]];
311 }
312 sm += buf;
313 } while (idx.nextEvenPermutation);
314 idx = [0, 2, 1];
315 do {
316 Mint buf = 1;
317 foreach (i; 0..3) {
318 buf *= m[i, idx[i]];
319 }
320 sm -= buf;
321 } while (idx.nextEvenPermutation);
322 auto _m = m.dup;
323 auto u = m.determinent;
324 assert(sm == m.determinent);
325 assert(_m == m);
326 }
327 void fMod2() {
328 alias Mint = ModInt!2;
329 alias Mat = SMatrixMod2!(3, 3);
330 alias Vec = SMatrixMod2!(3, 1);
331 static Mint rndM() {
332 return Mint(uniform(0, 2));
333 }
334 Mat m = matrixMod2!(3, 3, (i, j) => rndM())();
335 Mint sm = 0;
336 auto idx = [0, 1, 2];
337 do {
338 Mint buf = 1;
339 foreach (i; 0..3) {
340 buf *= m[i, idx[i]];
341 }
342 sm += buf;
343 } while (idx.nextEvenPermutation);
344 idx = [0, 2, 1];
345 do {
346 Mint buf = 1;
347 foreach (i; 0..3) {
348 buf *= m[i, idx[i]];
349 }
350 sm -= buf;
351 } while (idx.nextEvenPermutation);
352 auto _m = m.dup;
353 auto u = m.determinent;
354 if (sm != m.determinent) {
355 writeln(sm, " ", m.determinent);
356 foreach (i; 0..3) {
357 foreach (j; 0..3) {
358 write(m[i, j], " ");
359 }
360 writeln;
361 }
362 writeln(m);
363 }
364 assert(sm == m.determinent);
365 assert(_m == m);
366 }
367 import dkh.stopwatch;
368 writeln("Det: ", benchmark!(f!2, f!3, f!11, fD!2, fD!3, fD!11, fMod2)(10000)[].map!(a => a.toMsecs));
369 }
370
371
372 // m * v = r
373 Vec solveLinear(Mat, Vec)(Mat m, Vec r) {
374 import std.conv, std.algorithm;
375 size_t N = m.height, M = m.width;
376 int c = 0;
377 foreach (x; 0..M) {
378 ptrdiff_t my = -1;
379 foreach (y; c..N) {
380 if (m[y, x].v) {
381 my = y;
382 break;
383 }
384 }
385 if (my == -1) continue;
386 m.swapLine(c, my);
387 r.swapLine(c, my);
388 foreach (y; 0..N) {
389 if (c == y) continue;
390 if (m[y, x].v == 0) continue;
391 auto freq = m[y, x] / m[c, x];
392 foreach (k; 0..M) {
393 m[y, k] -= freq * m[c, k];
394 }
395 r[y, 0] -= freq * r[c, 0];
396 }
397 c++;
398 if (c == N) break;
399 }
400 Vec v;
401 foreach_reverse (y; 0..c) {
402 ptrdiff_t f = -1;
403 Mat.DataType sm;
404 foreach (x; 0..M) {
405 if (m[y, x].v && f == -1) {
406 f = x;
407 }
408 sm += m[y, x] * v[x, 0];
409 }
410 v[f, 0] += (r[y, 0] - sm) / m[y, f];
411 }
412 return v;
413 }
414
415 unittest {
416 import std.random, std.stdio;
417 import dkh.modint;
418 alias Mint = ModInt!(10^^9 + 7);
419 alias Mat = SMatrix!(Mint, 3, 3);
420 alias Vec = SMatrix!(Mint, 3, 1);
421 static Mint rndM() {
422 return Mint(uniform(0, 10^^9 + 7));
423 }
424 Mat m = matrix!(3, 3, (i, j) => rndM())();
425 Vec x = matrix!(3, 1, (i, j) => rndM())();
426 Vec r = m * x;
427 Vec x2 = solveLinear(m, r);
428 assert(m * x2 == r);
429 }
430
431 unittest {
432 import std.random, std.stdio, std.algorithm;
433 import dkh.modint;
434 void f(uint Mod)() {
435 alias Mint = ModInt!Mod;
436 alias Mat = SMatrix!(Mint, 3, 3);
437 alias Vec = SMatrix!(Mint, 3, 1);
438 static Mint rndM() {
439 return Mint(uniform(0, Mod));
440 }
441 Mat m = matrix!(3, 3, (i, j) => rndM())();
442 Vec x = matrix!(3, 1, (i, j) => rndM())();
443 Vec r = m * x;
444 Mat _m = m.dup;
445 Vec x2 = solveLinear(m, r);
446 assert(m == _m);
447 assert(m * x2 == r);
448 }
449 void fMod2() {
450 alias Mint = ModInt!2;
451 alias Mat = SMatrixMod2!(3, 3);
452 alias Vec = SMatrixMod2!(3, 1);
453 static Mint rndM() {
454 return Mint(uniform(0, 2));
455 }
456 Mat m = matrixMod2!(3, 3, (i, j) => rndM())();
457 Vec x = matrixMod2!(3, 1, (i, j) => rndM())();
458 Vec r = m * x;
459 Mat _m = m.dup;
460 Vec x2 = solveLinear(m, r);
461 assert(m == _m);
462 assert(m * x2 == r);
463 }
464 import dkh.stopwatch;
465 writeln("SolveLinear: ", benchmark!(f!2, f!3, f!11, fMod2)(10000)[].map!(a => a.toMsecs));
466 }