1 module dkh.numeric.convolution; 2 3 import std.complex; 4 5 private double[] fftSinList(size_t S) { 6 import std.math : PI, sin; 7 assert(2 <= S); 8 size_t N = 1<<S; 9 static double[][30] buf; 10 if (!buf[S].length) { 11 buf[S] = new double[3*N/4+1]; 12 foreach (i; 0..N/4+1) { 13 buf[S][i] = sin(i*2*double(PI)/N); 14 buf[S][N/2-i] = buf[S][i]; 15 buf[S][N/2+i] = -buf[S][i]; 16 } 17 } 18 return buf[S]; 19 } 20 21 /// fft 22 void fft(bool type)(Complex!double[] c) { 23 import std.algorithm : swap; 24 import core.bitop : bsr; 25 alias P = Complex!double; 26 size_t N = c.length; 27 assert(N); 28 size_t S = bsr(N); 29 assert(1<<S == N); 30 if (S == 1) { 31 auto x = c[0], y = c[1]; 32 c[0] = x+y; 33 c[1] = x-y; 34 return; 35 } 36 auto rot = fftSinList(S); 37 P[] a = c.dup, b = new P[c.length]; 38 foreach (i; 1..S+1) { 39 size_t W = 1<<(S-i); 40 for (size_t y = 0; y < N/2; y += W) { 41 P now = P(rot[y + N/4], rot[y]); 42 if (type) now = conj(now); 43 foreach (x; 0..W) { 44 auto l = a[y<<1 | x]; 45 auto r = now * a[y<<1 | x | W]; 46 b[y | x] = l+r; 47 b[y | x | N>>1] = l-r; 48 } 49 } 50 swap(a, b); 51 } 52 c[] = a[]; 53 } 54 55 /// multiply two double[] 56 double[] multiply(in double[] a, in double[] b) { 57 alias P = Complex!double; 58 size_t A = a.length, B = b.length; 59 if (!A || !B) return []; 60 size_t lg = 1; 61 while ((1<<lg) < A+B-1) lg++; 62 size_t N = 1<<lg; 63 P[] d = new P[N]; 64 d[] = P(0, 0); 65 foreach (i; 0..A) d[i].re = a[i]; 66 foreach (i; 0..B) d[i].im = b[i]; 67 fft!false(d); 68 foreach (i; 0..N/2+1) { 69 auto j = i ? (N-i) : 0; 70 P x = P(d[i].re+d[j].re, d[i].im-d[j].im); 71 P y = P(d[i].im+d[j].im, -d[i].re+d[j].re); 72 d[i] = x * y / 4; 73 if (i != j) d[j] = conj(d[i]); 74 } 75 fft!true(d); 76 double[] c = new double[A+B-1]; 77 foreach (i; 0..A+B-1) { 78 c[i] = d[i].re / N; 79 } 80 return c; 81 } 82 83 unittest { 84 import std.algorithm, std.stdio, std.random, std.math; 85 import dkh.stopwatch; 86 StopWatch sw; sw.start; 87 foreach (L; 1..20) { 88 foreach (R; 1..20) { 89 foreach (ph; 0..10) { 90 double[] a = new double[L]; 91 double[] b = new double[R]; 92 foreach (ref x; a) x = 100 * uniform01; 93 foreach (ref x; b) x = 100 * uniform01; 94 double[] c1 = multiply(a, b); 95 double[] c2 = new double[L+R-1]; c2[] = 0.0; 96 foreach (i; 0..L) { 97 foreach (j; 0..R) { 98 c2[i+j] += a[i] * b[j]; 99 } 100 } 101 assert(c1.length == c2.length); 102 foreach (i; 0..L+R-1) { 103 assert(approxEqual(c1[i], c2[i])); 104 } 105 } 106 } 107 } 108 writeln("FFT Stress: ", sw.peek.toMsecs); 109 } 110 111 import dkh.modint, dkh.numeric.primitive; 112 113 /// nft(G must be primitive root) 114 void nft(uint G, bool type, Mint)(Mint[] c) { 115 import std.algorithm : swap; 116 import core.bitop : bsr; 117 size_t N = c.length; 118 assert(N); 119 size_t S = bsr(N); 120 assert(1<<S == N); 121 122 Mint[] a = c.dup, b = new Mint[N]; 123 foreach (i; 1..S+1) { 124 size_t W = 1<<(S-i); 125 Mint base = pow(Mint(G), Mint(-1).v/(1<<i)); 126 if (type) base = Mint(1)/base; 127 Mint now = 1; 128 for (size_t y = 0; y < N/2; y += W) { 129 foreach (x; 0..W) { 130 auto l = a[y<<1 | x]; 131 auto r = now * a[y<<1 | x | W]; 132 b[y | x] = l+r; 133 b[y | x | N>>1] = l-r; 134 } 135 now *= base; 136 } 137 swap(a, b); 138 } 139 c[] = a[]; 140 } 141 142 /// multiply 2 Mint[](G must be primitive root) 143 Mint[] multiply(uint G, Mint)(in Mint[] a, in Mint[] b) { 144 size_t A = a.length, B = b.length; 145 if (!A || !B) return []; 146 size_t lg = 1; 147 while ((1<<lg) < A+B-1) lg++; 148 size_t N = 1<<lg; 149 Mint[] _a = new Mint[N]; 150 Mint[] _b = new Mint[N]; 151 foreach (i; 0..A) _a[i] = a[i]; 152 foreach (i; 0..B) _b[i] = b[i]; 153 nft!(G, false)(_a); 154 nft!(G, false)(_b); 155 foreach (i; 0..N) _a[i] *= _b[i]; 156 nft!(G, true)(_a); 157 Mint[] c = new Mint[A+B-1]; 158 Mint iN = Mint(1) / Mint(N); 159 foreach (i; 0..A+B-1) { 160 c[i] = _a[i] * iN; 161 } 162 return c; 163 } 164 165 /// multiply 2 Mint[](abiritialy mod) 166 Mint[] multiply(Mint, size_t M = 3, size_t W = 10)(in Mint[] a, in Mint[] b) 167 if (isModInt!Mint) { 168 import std.math : round; 169 alias P = Complex!double; 170 171 size_t A = a.length, B = b.length; 172 if (!A || !B) return []; 173 auto N = A + B - 1; 174 size_t lg = 1; 175 while ((1<<lg) < N) lg++; 176 N = 1<<lg; 177 178 P[][M] x, y; 179 P[] w = new P[N]; 180 foreach (ph; 0..M) { 181 x[ph] = new P[N]; 182 y[ph] = new P[N]; 183 w[] = P(0, 0); 184 foreach (i; 0..A) w[i].re = (a[i].v >> (ph*W)) % (1<<W); 185 foreach (i; 0..B) w[i].im = (b[i].v >> (ph*W)) % (1<<W); 186 fft!false(w); 187 foreach (i; 0..N) w[i] *= 0.5; 188 foreach (i; 0..N) { 189 auto j = i ? N-i : 0; 190 x[ph][i] = P(w[i].re+w[j].re, w[i].im-w[j].im); 191 y[ph][i] = P(w[i].im+w[j].im, -w[i].re+w[j].re); 192 } 193 } 194 195 Mint[] c = new Mint[A+B-1]; 196 Mint basel = 1, baser = pow(Mint(1<<W), M); 197 P[] z = new P[N]; 198 foreach (ph; 0..M) { 199 z[] = P(0, 0); 200 foreach (af; 0..ph+1) { 201 auto bf = ph - af; 202 foreach (i; 0..N) { 203 z[i] += x[af][i] * y[bf][i]; 204 } 205 } 206 foreach (af; ph+1..M) { 207 auto bf = ph + M - af; 208 foreach (i; 0..N) { 209 z[i] += x[af][i] * y[bf][i] * P(0, 1); 210 } 211 } 212 fft!true(z); 213 foreach (i; 0..A+B-1) { 214 z[i] *= 1.0/N; 215 c[i] += Mint(cast(long)(round(z[i].re)))*basel; 216 c[i] += Mint(cast(long)(round(z[i].im)))*baser; 217 } 218 basel *= Mint(1<<W); 219 baser *= Mint(1<<W); 220 } 221 return c; 222 } 223 224 unittest { 225 alias Mint = ModInt!924844033; 226 import std.algorithm, std.stdio, std.random, std.math; 227 import dkh.stopwatch; 228 StopWatch sw; sw.start; 229 Mint rndM() { return Mint(uniform(0, 924844033)); } 230 foreach (L; 1..20) { 231 foreach (R; 1..20) { 232 foreach (ph; 0..10) { 233 Mint[] a = new Mint[L]; 234 Mint[] b = new Mint[R]; 235 foreach (ref x; a) x = rndM(); 236 foreach (ref x; b) x = rndM(); 237 Mint[] c1 = multiply!5(a, b); 238 Mint[] c2 = new Mint[L+R-1]; 239 foreach (i; 0..L) { 240 foreach (j; 0..R) { 241 c2[i+j] += a[i] * b[j]; 242 } 243 } 244 assert(c1.length == c2.length); 245 foreach (i; 0..L+R-1) { 246 if (c1[i] != c2[i]) { 247 writeln(a); 248 writeln(b); 249 writeln(c1); 250 writeln(c2); 251 } 252 assert(c1[i] == c2[i]); 253 } 254 } 255 } 256 } 257 writeln("NFT Stress: ", sw.peek.toMsecs); 258 } 259 260 261 unittest { 262 alias Mint = ModInt!(10^^9 + 7); 263 import std.algorithm, std.stdio, std.random, std.math; 264 import dkh.stopwatch; 265 StopWatch sw; sw.start; 266 Mint rndM() { return Mint(uniform(0, 10^^9 + 7)); } 267 foreach (L; 1..20) { 268 foreach (R; 1..20) { 269 foreach (ph; 0..10) { 270 Mint[] a = new Mint[L]; 271 Mint[] b = new Mint[R]; 272 foreach (ref x; a) x = rndM(); 273 foreach (ref x; b) x = rndM(); 274 Mint[] c1 = multiply(a, b); 275 Mint[] c2 = new Mint[L+R-1]; 276 foreach (i; 0..L) { 277 foreach (j; 0..R) { 278 c2[i+j] += a[i] * b[j]; 279 } 280 } 281 assert(c1.length == c2.length); 282 foreach (i; 0..L+R-1) { 283 if (c1[i] != c2[i]) { 284 writeln(a); 285 writeln(b); 286 writeln(c1); 287 writeln(c2); 288 } 289 assert(c1[i] == c2[i]); 290 } 291 } 292 } 293 } 294 writeln("FFT(ModInt) Stress: ", sw.peek.toMsecs); 295 }