1 module dkh.numeric.convolution; 2 3 /// Zeta変換 4 T[] zeta(T)(T[] v, bool rev) { 5 import core.bitop : bsr; 6 int n = bsr(v.length); 7 assert(1<<n == v.length); 8 foreach (fe; 0..n) { 9 foreach (i, _; v) { 10 if (i & (1<<fe)) continue; 11 if (!rev) { 12 v[i] += v[i|(1<<fe)]; 13 } else { 14 v[i] -= v[i|(1<<fe)]; 15 } 16 } 17 } 18 return v; 19 } 20 21 /// hadamard変換 22 T[] hadamard(T)(T[] v, bool rev) { 23 import core.bitop : bsr; 24 int n = bsr(v.length); 25 assert(1<<n == v.length); 26 foreach (fe; 0..n) { 27 foreach (i, _; v) { 28 if (i & (1<<fe)) continue; 29 auto l = v[i], r = v[i|(1<<fe)]; 30 if (!rev) { 31 v[i] = l+r; 32 v[i|(1<<fe)] = l-r; 33 } else { 34 v[i] = (l+r)/2; 35 v[i|(1<<fe)] = (l-r)/2; 36 } 37 } 38 } 39 return v; 40 } 41 42 import std.complex; 43 44 double[] fftSinList(size_t S) { 45 import std.math : PI, sin; 46 assert(2 <= S); 47 size_t N = 1<<S; 48 static double[][30] buf; 49 if (!buf[S].length) { 50 buf[S] = new double[3*N/4+1]; 51 foreach (i; 0..N/4+1) { 52 buf[S][i] = sin(i*2*double(PI)/N); 53 buf[S][N/2-i] = buf[S][i]; 54 buf[S][N/2+i] = -buf[S][i]; 55 } 56 } 57 return buf[S]; 58 } 59 60 /// fft 61 void fft(bool type)(Complex!double[] c) { 62 import std.algorithm : swap; 63 import core.bitop : bsr; 64 alias P = Complex!double; 65 size_t N = c.length; 66 assert(N); 67 size_t S = bsr(N); 68 assert(1<<S == N); 69 if (S == 1) { 70 auto x = c[0], y = c[1]; 71 c[0] = x+y; 72 c[1] = x-y; 73 return; 74 } 75 auto rot = fftSinList(S); 76 P[] a = c.dup, b = new P[c.length]; 77 foreach (i; 1..S+1) { 78 size_t W = 1<<(S-i); 79 for (size_t y = 0; y < N/2; y += W) { 80 P now = P(rot[y + N/4], rot[y]); 81 if (type) now = conj(now); 82 foreach (x; 0..W) { 83 auto l = a[y<<1 | x]; 84 auto r = now * a[y<<1 | x | W]; 85 b[y | x] = l+r; 86 b[y | x | N>>1] = l-r; 87 } 88 } 89 swap(a, b); 90 } 91 c[] = a[]; 92 } 93 94 /// multiply two double[] 95 double[] multiply(in double[] a, in double[] b) { 96 alias P = Complex!double; 97 size_t A = a.length, B = b.length; 98 if (!A || !B) return []; 99 size_t lg = 1; 100 while ((1<<lg) < A+B-1) lg++; 101 size_t N = 1<<lg; 102 P[] d = new P[N]; 103 d[] = P(0, 0); 104 foreach (i; 0..A) d[i].re = a[i]; 105 foreach (i; 0..B) d[i].im = b[i]; 106 fft!false(d); 107 foreach (i; 0..N/2+1) { 108 auto j = i ? (N-i) : 0; 109 P x = P(d[i].re+d[j].re, d[i].im-d[j].im); 110 P y = P(d[i].im+d[j].im, -d[i].re+d[j].re); 111 d[i] = x * y / 4; 112 if (i != j) d[j] = conj(d[i]); 113 } 114 fft!true(d); 115 double[] c = new double[A+B-1]; 116 foreach (i; 0..A+B-1) { 117 c[i] = d[i].re / N; 118 } 119 return c; 120 } 121 122 unittest { 123 import std.algorithm, std.stdio, std.random, std.math; 124 import dkh.stopwatch; 125 StopWatch sw; sw.start; 126 foreach (L; 1..20) { 127 foreach (R; 1..20) { 128 foreach (ph; 0..10) { 129 double[] a = new double[L]; 130 double[] b = new double[R]; 131 foreach (ref x; a) x = 100 * uniform01; 132 foreach (ref x; b) x = 100 * uniform01; 133 double[] c1 = multiply(a, b); 134 double[] c2 = new double[L+R-1]; c2[] = 0.0; 135 foreach (i; 0..L) { 136 foreach (j; 0..R) { 137 c2[i+j] += a[i] * b[j]; 138 } 139 } 140 assert(c1.length == c2.length); 141 foreach (i; 0..L+R-1) { 142 assert(approxEqual(c1[i], c2[i])); 143 } 144 } 145 } 146 } 147 writeln("FFT Stress: ", sw.peek.toMsecs); 148 } 149 150 import dkh.modint, dkh.numeric.primitive; 151 152 void nft(uint G, bool type, Mint)(Mint[] c) { 153 import std.algorithm : swap; 154 import core.bitop : bsr; 155 size_t N = c.length; 156 assert(N); 157 size_t S = bsr(N); 158 assert(1<<S == N); 159 160 Mint[] a = c.dup, b = new Mint[N]; 161 foreach (i; 1..S+1) { 162 size_t W = 1<<(S-i); 163 Mint base = pow(Mint(G), Mint(-1).v/(1<<i)); 164 if (type) base = Mint(1)/base; 165 Mint now = 1; 166 for (size_t y = 0; y < N/2; y += W) { 167 foreach (x; 0..W) { 168 auto l = a[y<<1 | x]; 169 auto r = now * a[y<<1 | x | W]; 170 b[y | x] = l+r; 171 b[y | x | N>>1] = l-r; 172 } 173 now *= base; 174 } 175 swap(a, b); 176 } 177 c[] = a[]; 178 } 179 180 Mint[] multiply(uint G, Mint)(in Mint[] a, in Mint[] b) { 181 size_t A = a.length, B = b.length; 182 if (!A || !B) return []; 183 size_t lg = 1; 184 while ((1<<lg) < A+B-1) lg++; 185 size_t N = 1<<lg; 186 Mint[] _a = new Mint[N]; 187 Mint[] _b = new Mint[N]; 188 foreach (i; 0..A) _a[i] = a[i]; 189 foreach (i; 0..B) _b[i] = b[i]; 190 nft!(G, false)(_a); 191 nft!(G, false)(_b); 192 foreach (i; 0..N) _a[i] *= _b[i]; 193 nft!(G, true)(_a); 194 Mint[] c = new Mint[A+B-1]; 195 Mint iN = Mint(1) / Mint(N); 196 foreach (i; 0..A+B-1) { 197 c[i] = _a[i] * iN; 198 } 199 return c; 200 } 201 202 Mint[] multiply(Mint, size_t M = 3, size_t W = 10)(in Mint[] a, in Mint[] b) 203 if (isModInt!Mint) { 204 import std.math : round; 205 alias P = Complex!double; 206 207 size_t A = a.length, B = b.length; 208 if (!A || !B) return []; 209 auto N = A + B - 1; 210 size_t lg = 1; 211 while ((1<<lg) < N) lg++; 212 N = 1<<lg; 213 214 P[][M] x, y; 215 P[] w = new P[N]; 216 foreach (ph; 0..M) { 217 x[ph] = new P[N]; 218 y[ph] = new P[N]; 219 w[] = P(0, 0); 220 foreach (i; 0..A) w[i].re = (a[i].v >> (ph*W)) % (1<<W); 221 foreach (i; 0..B) w[i].im = (b[i].v >> (ph*W)) % (1<<W); 222 fft!false(w); 223 foreach (i; 0..N) w[i] *= 0.5; 224 foreach (i; 0..N) { 225 auto j = i ? N-i : 0; 226 x[ph][i] = P(w[i].re+w[j].re, w[i].im-w[j].im); 227 y[ph][i] = P(w[i].im+w[j].im, -w[i].re+w[j].re); 228 } 229 } 230 231 Mint[] c = new Mint[A+B-1]; 232 Mint basel = 1, baser = pow(Mint(1<<W), M); 233 P[] z = new P[N]; 234 foreach (ph; 0..M) { 235 z[] = P(0, 0); 236 foreach (af; 0..ph+1) { 237 auto bf = ph - af; 238 foreach (i; 0..N) { 239 z[i] += x[af][i] * y[bf][i]; 240 } 241 } 242 foreach (af; ph+1..M) { 243 auto bf = ph + M - af; 244 foreach (i; 0..N) { 245 z[i] += x[af][i] * y[bf][i] * P(0, 1); 246 } 247 } 248 fft!true(z); 249 foreach (i; 0..A+B-1) { 250 z[i] *= 1.0/N; 251 c[i] += Mint(cast(long)(round(z[i].re)))*basel; 252 c[i] += Mint(cast(long)(round(z[i].im)))*baser; 253 } 254 basel *= Mint(1<<W); 255 baser *= Mint(1<<W); 256 } 257 return c; 258 } 259 260 unittest { 261 alias Mint = ModInt!924844033; 262 import std.algorithm, std.stdio, std.random, std.math; 263 import dkh.stopwatch; 264 StopWatch sw; sw.start; 265 Mint rndM() { return Mint(uniform(0, 924844033)); } 266 foreach (L; 1..20) { 267 foreach (R; 1..20) { 268 foreach (ph; 0..10) { 269 Mint[] a = new Mint[L]; 270 Mint[] b = new Mint[R]; 271 foreach (ref x; a) x = rndM(); 272 foreach (ref x; b) x = rndM(); 273 Mint[] c1 = multiply!5(a, b); 274 Mint[] c2 = new Mint[L+R-1]; 275 foreach (i; 0..L) { 276 foreach (j; 0..R) { 277 c2[i+j] += a[i] * b[j]; 278 } 279 } 280 assert(c1.length == c2.length); 281 foreach (i; 0..L+R-1) { 282 if (c1[i] != c2[i]) { 283 writeln(a); 284 writeln(b); 285 writeln(c1); 286 writeln(c2); 287 } 288 assert(c1[i] == c2[i]); 289 } 290 } 291 } 292 } 293 writeln("NFT Stress: ", sw.peek.toMsecs); 294 } 295 296 297 unittest { 298 alias Mint = ModInt!(10^^9 + 7); 299 import std.algorithm, std.stdio, std.random, std.math; 300 import dkh.stopwatch; 301 StopWatch sw; sw.start; 302 Mint rndM() { return Mint(uniform(0, 10^^9 + 7)); } 303 foreach (L; 1..20) { 304 foreach (R; 1..20) { 305 foreach (ph; 0..10) { 306 Mint[] a = new Mint[L]; 307 Mint[] b = new Mint[R]; 308 foreach (ref x; a) x = rndM(); 309 foreach (ref x; b) x = rndM(); 310 Mint[] c1 = multiply(a, b); 311 Mint[] c2 = new Mint[L+R-1]; 312 foreach (i; 0..L) { 313 foreach (j; 0..R) { 314 c2[i+j] += a[i] * b[j]; 315 } 316 } 317 assert(c1.length == c2.length); 318 foreach (i; 0..L+R-1) { 319 if (c1[i] != c2[i]) { 320 writeln(a); 321 writeln(b); 322 writeln(c1); 323 writeln(c2); 324 } 325 assert(c1[i] == c2[i]); 326 } 327 } 328 } 329 } 330 writeln("FFT(ModInt) Stress: ", sw.peek.toMsecs); 331 } 332