1 module dkh.matrix; 2 3 import dkh.bitop; 4 5 /// 行列ライブラリ 6 struct SMatrix(T, size_t H, size_t W) { 7 alias DataType = T; 8 T[W][H] data; 9 this(Args...)(Args args) { 10 static assert(args.length == H*W); 11 foreach (i, v; args) { 12 data[i/W][i%W] = v; 13 } 14 } 15 SMatrix dup() const { return this; } 16 17 @property static size_t height() {return H;} 18 @property static size_t width() {return W;} 19 20 ref inout(T) opIndex(size_t i1, size_t i2) inout { 21 return data[i1][i2]; 22 } 23 auto opBinary(string op:"+", R)(in R r) const 24 if(height == R.height && width == R.width) { 25 SMatrix res = this; 26 foreach (y; 0..height) foreach (x; 0..W) res[y, x] += r[y, x]; 27 return res; 28 } 29 auto opBinary(string op:"*", R)(in R r) const 30 if(width == R.height) { 31 auto rBuf = SMatrix!(T, R.width, R.height)(); 32 foreach (y; 0..R.height) { 33 foreach (x; 0..R.width) { 34 rBuf[x, y] = r[y, x]; 35 } 36 } 37 auto res = SMatrix!(T, height, R.width)(); 38 foreach (y; 0..height) { 39 foreach (x; 0..R.width) { 40 T sm; 41 auto lv = this.data[y][]; 42 auto rv = rBuf.data[x][]; 43 foreach (k; 0..width) { 44 // res[y, x] += this[y, k]*r[k, x]; 45 // sm += this[y, k]*rBuf[x, k]; 46 sm += lv[k] * rv[k]; 47 // res[y, x] += this[y, k]*rBuf[x, k]; 48 } 49 res[y, x] = sm; 50 } 51 } 52 return res; 53 } 54 auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");} 55 56 void swapLine(size_t x, size_t y) { 57 import std.algorithm : swap; 58 foreach (i; 0..W) swap(data[x][i], data[y][i]); 59 } 60 } 61 62 import dkh.foundation, dkh.modint; 63 /// 行列ライブラリ(Mod2) 64 struct SMatrixMod2(size_t H, size_t W) { 65 alias DataType = ModInt!2; 66 static immutable B = 64; 67 static immutable L = (W + B-1) / B; 68 ulong[L][H] data; 69 this(Args...)(Args args) { 70 static assert(args.length == H*W); 71 foreach (i, v; args) { 72 this[i/W, i%W] = v; 73 } 74 } 75 SMatrixMod2 dup() const { return this; } 76 77 @property static size_t height() {return H;} 78 @property static size_t width() {return W;} 79 80 const(DataType) opIndex(size_t i1, size_t i2) const { 81 return DataType(((data[i1][i2/B] >> (i2%B)) & 1UL) ? 1 : 0); 82 } 83 void opIndexAssign(DataType d, size_t i1, size_t i2) { 84 size_t r = i2 % 64; 85 if (d.v) data[i1][i2/B] |= (1UL<<r); 86 else data[i1][i2/B] &= ~(1UL<<r); 87 } 88 void opIndexAssign(bool d, size_t i1, size_t i2) { 89 size_t r = i2 % 64; 90 if (d) data[i1][i2/B] |= (1UL<<r); 91 else data[i1][i2/B] &= ~(1UL<<r); 92 } 93 auto opIndexOpAssign(string op)(DataType d, size_t i1, size_t i2) { 94 return mixin("this[i1,i2]=this[i1,i2]"~op~"d"); 95 } 96 auto opBinary(string op:"+", R)(in R r) const 97 if(height == R.height && width == R.width) { 98 auto res = this.dup; 99 foreach (y; 0..height) foreach (x; 0..L) { 100 res.data[y][x] ^= r.data[y][x]; 101 } 102 return res; 103 } 104 auto opBinary(string op:"*", R)(in R r) const 105 if(width == R.height) { 106 auto r2 = SMatrixMod2!(R.width, R.height)(); 107 foreach (y; 0..R.height) foreach (x; 0..R.width) { 108 r2[x, y] = r[y, x]; 109 } 110 auto res = SMatrixMod2!(height, R.width)(); 111 foreach (y; 0..height) { 112 foreach (x; 0..R.width) { 113 ulong sm = 0; 114 foreach (k; 0..L) { 115 sm ^= data[y][k]&r2.data[x][k]; 116 } 117 res[y, x] = poppar(sm); 118 } 119 } 120 return res; 121 } 122 auto opOpAssign(string op, T)(T r) {return mixin ("this=this"~op~"r");} 123 124 void swapLine(size_t x, size_t y) { 125 import std.algorithm : swap; 126 foreach (i; 0..L) swap(data[x][i], data[y][i]); 127 } 128 } 129 130 /// ditto 131 struct DMatrix(T) { 132 size_t h, w; 133 T[] data; 134 this(size_t h, size_t w) { 135 this.h = h; this.w = w; 136 data = new T[h*w]; 137 } 138 this(size_t h, size_t w, in T[] d) { 139 this(h, w); 140 assert(d.length == h*w); 141 data[] = d[]; 142 } 143 DMatrix dup() const { return DMatrix(h, w, data); } 144 145 @property size_t height() const {return h;} 146 @property size_t width() const {return w;} 147 148 ref inout(T) opIndex(size_t i1, size_t i2) inout { 149 return data[i1*width+i2]; 150 } 151 auto opBinary(string op:"+", R)(in R r) const { 152 assert(height == r.height && width == r.width); 153 auto res = this.dup; 154 foreach (y; 0..height) foreach (x; 0..width) res[y, x] += r[y, x]; 155 return res; 156 } 157 auto opBinary(string op:"*", R)(in R r) const { 158 assert(width == r.height); 159 auto rBuf = DMatrix!(T)(R.width, R.height); 160 foreach (y; 0..R.height) { 161 foreach (x; 0..R.width) { 162 rBuf[x, y] = r[y, x]; 163 } 164 } 165 auto res = DMatrix!(T)(height, r.width); 166 foreach (y; 0..height) { 167 foreach (x; 0..r.width) { 168 foreach (k; 0..width) { 169 res[y, x] += this[y, k]*rBuf[x, k]; 170 } 171 } 172 } 173 return res; 174 } 175 auto opOpAssign(string op, T)(in T r) {return mixin ("this=this"~op~"r");} 176 } 177 178 /// 179 unittest { 180 import dkh.numeric.primitive; 181 auto mat = DMatrix!int(2, 2, [0, 1, 1, 1]); 182 assert(pow(mat, 10, DMatrix!int(2, 2, [1, 0, 0, 1]))[0, 0] == 34); //Fib_10 183 } 184 185 unittest { 186 auto mat1 = DMatrix!int(2, 2, [1, 1, 1, 1]); 187 auto mat2 = DMatrix!int(2, 2, [2, 2, 2, 2]); 188 auto mat3 = mat1 + mat2; 189 assert(mat1[0, 0] == 1); 190 assert(mat2[0, 0] == 2); 191 } 192 193 auto matrix(size_t H, size_t W, alias pred)() { 194 import std.traits : ReturnType; 195 SMatrix!(typeof(pred(0, 0)), H, W) res; 196 foreach (y; 0..H) { 197 foreach (x; 0..W) { 198 res[y, x] = pred(y, x); 199 } 200 } 201 return res; 202 } 203 auto matrixMod2(size_t H, size_t W, alias pred)() { 204 import std.traits : ReturnType; 205 SMatrixMod2!(H, W) res; 206 foreach (y; 0..H) { 207 foreach (x; 0..W) { 208 res[y, x] = pred(y, x); 209 } 210 } 211 return res; 212 } 213 214 auto determinent(Mat)(in Mat _m) { 215 auto m = _m.dup; 216 assert(m.height == m.width); 217 import std.conv, std.algorithm; 218 alias M = Mat.DataType; 219 size_t N = m.height; 220 M base = 1; 221 foreach (i; 0..N) { 222 if (m[i, i] == M(0)) { 223 foreach (j; i+1..N) { 224 if (m[j, i] != M(0)) { 225 foreach (k; 0..N) m.swapLine(i, j); 226 base *= M(-1); 227 break; 228 } 229 } 230 if (m[i, i] == M(0)) return M(0); 231 } 232 base *= m[i, i]; 233 M im = M(1)/m[i, i]; 234 foreach (j; 0..N) { 235 m[i, j] *= im; 236 } 237 foreach (j; i+1..N) { 238 M x = m[j, i]; 239 foreach (k; 0..N) { 240 m[j, k] -= m[i, k] * x; 241 } 242 } 243 } 244 return base; 245 } 246 247 unittest { 248 import std.random, std.stdio, std.algorithm; 249 import dkh.modint; 250 void f(uint Mod)() { 251 alias Mint = ModInt!Mod; 252 alias Mat = SMatrix!(Mint, 3, 3); 253 alias Vec = SMatrix!(Mint, 3, 1); 254 static Mint rndM() { 255 return Mint(uniform(0, Mod)); 256 } 257 Mat m = matrix!(3, 3, (i, j) => rndM())(); 258 Mint sm = 0; 259 auto idx = [0, 1, 2]; 260 do { 261 Mint buf = 1; 262 foreach (i; 0..3) { 263 buf *= m[i, idx[i]]; 264 } 265 sm += buf; 266 } while (idx.nextEvenPermutation); 267 idx = [0, 2, 1]; 268 do { 269 Mint buf = 1; 270 foreach (i; 0..3) { 271 buf *= m[i, idx[i]]; 272 } 273 sm -= buf; 274 } while (idx.nextEvenPermutation); 275 auto _m = m.dup; 276 auto u = m.determinent; 277 assert(sm == m.determinent); 278 assert(_m == m); 279 } 280 void fMod2() { 281 alias Mint = ModInt!2; 282 alias Mat = SMatrixMod2!(3, 3); 283 alias Vec = SMatrixMod2!(3, 1); 284 static Mint rndM() { 285 return Mint(uniform(0, 2)); 286 } 287 Mat m = matrixMod2!(3, 3, (i, j) => rndM())(); 288 Mint sm = 0; 289 auto idx = [0, 1, 2]; 290 do { 291 Mint buf = 1; 292 foreach (i; 0..3) { 293 buf *= m[i, idx[i]]; 294 } 295 sm += buf; 296 } while (idx.nextEvenPermutation); 297 idx = [0, 2, 1]; 298 do { 299 Mint buf = 1; 300 foreach (i; 0..3) { 301 buf *= m[i, idx[i]]; 302 } 303 sm -= buf; 304 } while (idx.nextEvenPermutation); 305 auto _m = m.dup; 306 auto u = m.determinent; 307 if (sm != m.determinent) { 308 writeln(sm, " ", m.determinent); 309 foreach (i; 0..3) { 310 foreach (j; 0..3) { 311 write(m[i, j], " "); 312 } 313 writeln; 314 } 315 writeln(m); 316 } 317 assert(sm == m.determinent); 318 assert(_m == m); 319 } 320 import dkh.stopwatch; 321 writeln("Det: ", benchmark!(f!2, f!3, f!11, fMod2)(10000)[].map!(a => a.toMsecs)); 322 } 323 324 325 // m * v = r 326 Vec solveLinear(Mat, Vec)(Mat m, Vec r) { 327 import std.conv, std.algorithm; 328 size_t N = m.height, M = m.width; 329 int c = 0; 330 foreach (x; 0..M) { 331 ptrdiff_t my = -1; 332 foreach (y; c..N) { 333 if (m[y, x].v) { 334 my = y; 335 break; 336 } 337 } 338 if (my == -1) continue; 339 m.swapLine(c, my); 340 r.swapLine(c, my); 341 foreach (y; 0..N) { 342 if (c == y) continue; 343 if (m[y, x].v == 0) continue; 344 auto freq = m[y, x] / m[c, x]; 345 foreach (k; 0..M) { 346 m[y, k] -= freq * m[c, k]; 347 } 348 r[y, 0] -= freq * r[c, 0]; 349 } 350 c++; 351 if (c == N) break; 352 } 353 Vec v; 354 foreach_reverse (y; 0..c) { 355 ptrdiff_t f = -1; 356 Mat.DataType sm; 357 foreach (x; 0..M) { 358 if (m[y, x].v && f == -1) { 359 f = x; 360 } 361 sm += m[y, x] * v[x, 0]; 362 } 363 v[f, 0] += (r[y, 0] - sm) / m[y, f]; 364 } 365 return v; 366 } 367 368 unittest { 369 import std.random, std.stdio; 370 import dkh.modint; 371 alias Mint = ModInt!(10^^9 + 7); 372 alias Mat = SMatrix!(Mint, 3, 3); 373 alias Vec = SMatrix!(Mint, 3, 1); 374 static Mint rndM() { 375 return Mint(uniform(0, 10^^9 + 7)); 376 } 377 Mat m = matrix!(3, 3, (i, j) => rndM())(); 378 Vec x = matrix!(3, 1, (i, j) => rndM())(); 379 Vec r = m * x; 380 Vec x2 = solveLinear(m, r); 381 assert(m * x2 == r); 382 } 383 384 unittest { 385 import std.random, std.stdio, std.algorithm; 386 import dkh.modint; 387 void f(uint Mod)() { 388 alias Mint = ModInt!Mod; 389 alias Mat = SMatrix!(Mint, 3, 3); 390 alias Vec = SMatrix!(Mint, 3, 1); 391 static Mint rndM() { 392 return Mint(uniform(0, Mod)); 393 } 394 Mat m = matrix!(3, 3, (i, j) => rndM())(); 395 Vec x = matrix!(3, 1, (i, j) => rndM())(); 396 Vec r = m * x; 397 Mat _m = m.dup; 398 Vec x2 = solveLinear(m, r); 399 assert(m == _m); 400 assert(m * x2 == r); 401 } 402 void fMod2() { 403 alias Mint = ModInt!2; 404 alias Mat = SMatrixMod2!(3, 3); 405 alias Vec = SMatrixMod2!(3, 1); 406 static Mint rndM() { 407 return Mint(uniform(0, 2)); 408 } 409 Mat m = matrixMod2!(3, 3, (i, j) => rndM())(); 410 Vec x = matrixMod2!(3, 1, (i, j) => rndM())(); 411 Vec r = m * x; 412 Mat _m = m.dup; 413 Vec x2 = solveLinear(m, r); 414 assert(m == _m); 415 assert(m * x2 == r); 416 } 417 import dkh.stopwatch; 418 writeln("SolveLinear: ", benchmark!(f!2, f!3, f!11, fMod2)(10000)[].map!(a => a.toMsecs)); 419 }