1 module dkh.numeric.convolution;
2 
3 /// Zeta変換
4 T[] zeta(T)(T[] v, bool rev) {
5     import core.bitop : bsr;
6     int n = bsr(v.length);
7     assert(1<<n == v.length);
8     foreach (fe; 0..n) {
9         foreach (i, _; v) {
10             if (i & (1<<fe)) continue;
11             if (!rev) {
12                 v[i] += v[i|(1<<fe)];
13             } else {
14                 v[i] -= v[i|(1<<fe)];
15             }
16         }
17     }
18     return v;
19 }
20 
21 /// hadamard変換
22 T[] hadamard(T)(T[] v, bool rev) {
23     import core.bitop : bsr;
24     int n = bsr(v.length);
25     assert(1<<n == v.length);
26     foreach (fe; 0..n) {
27         foreach (i, _; v) {
28             if (i & (1<<fe)) continue;
29             auto l = v[i], r = v[i|(1<<fe)];
30             if (!rev) {
31                 v[i] = l+r;
32                 v[i|(1<<fe)] = l-r;
33             } else {
34                 v[i] = (l+r)/2;
35                 v[i|(1<<fe)] = (l-r)/2;
36             }
37         }
38     }
39     return v;
40 }
41 
42 import std.complex;
43 
44 double[] fftSinList(size_t S) {
45     import std.math : PI, sin;
46     assert(2 <= S);
47     size_t N = 1<<S;
48     static double[][30] buf;
49     if (!buf[S].length) {
50         buf[S] = new double[3*N/4+1];
51         foreach (i; 0..N/4+1) {
52             buf[S][i] = sin(i*2*double(PI)/N);
53             buf[S][N/2-i] = buf[S][i];
54             buf[S][N/2+i] = -buf[S][i];
55         }
56     }
57     return buf[S];
58 }
59 
60 /// fft
61 void fft(bool type)(Complex!double[] c) {
62     import std.algorithm : swap;
63     import core.bitop : bsr;
64     alias P = Complex!double;
65     size_t N = c.length;
66     assert(N);
67     size_t S = bsr(N);
68     assert(1<<S == N);
69     if (S == 1) {
70         auto x = c[0], y = c[1];
71         c[0] = x+y;
72         c[1] = x-y;
73         return;
74     }
75     auto rot = fftSinList(S);
76     P[] a = c.dup, b = new P[c.length];
77     foreach (i; 1..S+1) {
78         size_t W = 1<<(S-i);
79         for (size_t y = 0; y < N/2; y += W) {
80             P now = P(rot[y + N/4], rot[y]);
81             if (type) now = conj(now);
82             foreach (x; 0..W) {
83                 auto l =       a[y<<1 | x];
84                 auto r = now * a[y<<1 | x | W];
85                 b[y | x]        = l+r;
86                 b[y | x | N>>1] = l-r;
87             }
88         }
89         swap(a, b);
90     }
91     c[] = a[];
92 }
93 
94 /// multiply two double[]
95 double[] multiply(in double[] a, in double[] b) {
96     alias P = Complex!double;
97     size_t A = a.length, B = b.length;
98     if (!A || !B) return [];
99     size_t lg = 1;
100     while ((1<<lg) < A+B-1) lg++;
101     size_t N = 1<<lg;
102     P[] d = new P[N];
103     d[] = P(0, 0);
104     foreach (i; 0..A) d[i].re = a[i];
105     foreach (i; 0..B) d[i].im = b[i];
106     fft!false(d);
107     foreach (i; 0..N/2+1) {
108         auto j = i ? (N-i) : 0;
109         P x = P(d[i].re+d[j].re, d[i].im-d[j].im);
110         P y = P(d[i].im+d[j].im, -d[i].re+d[j].re);
111         d[i] = x * y / 4;
112         if (i != j) d[j] = conj(d[i]);
113     }
114     fft!true(d);
115     double[] c = new double[A+B-1];
116     foreach (i; 0..A+B-1) {
117         c[i] = d[i].re / N;
118     }
119     return c;
120 }
121 
122 unittest {
123     import std.algorithm, std.stdio, std.random, std.math;
124     import dkh.stopwatch;
125     StopWatch sw; sw.start;
126     foreach (L; 1..20) {
127         foreach (R; 1..20) {
128             foreach (ph; 0..10) {
129                 double[] a = new double[L];
130                 double[] b = new double[R];
131                 foreach (ref x; a) x = 100 * uniform01;
132                 foreach (ref x; b) x = 100 * uniform01;
133                 double[] c1 = multiply(a, b);
134                 double[] c2 = new double[L+R-1]; c2[] = 0.0;
135                 foreach (i; 0..L) {
136                     foreach (j; 0..R) {
137                         c2[i+j] += a[i] * b[j];
138                     }
139                 }
140                 assert(c1.length == c2.length);
141                 foreach (i; 0..L+R-1) {
142                     assert(approxEqual(c1[i], c2[i]));
143                 }
144             }
145         }
146     }
147     writeln("FFT Stress: ", sw.peek.toMsecs);
148 }
149 
150 import dkh.modint, dkh.numeric.primitive;
151 
152 void nft(uint G, bool type, Mint)(Mint[] c) {
153     import std.algorithm : swap;
154     import core.bitop : bsr;    
155     size_t N = c.length;
156     assert(N);
157     size_t S = bsr(N);
158     assert(1<<S == N);
159 
160     Mint[] a = c.dup, b = new Mint[N];
161     foreach (i; 1..S+1) {
162         size_t W = 1<<(S-i);
163         Mint base = pow(Mint(G), Mint(-1).v/(1<<i));
164         if (type) base = Mint(1)/base;
165         Mint now = 1;
166         for (size_t y = 0; y < N/2; y += W) {
167             foreach (x; 0..W) {
168                 auto l =       a[y<<1 | x];
169                 auto r = now * a[y<<1 | x | W];
170                 b[y | x]        = l+r;
171                 b[y | x | N>>1] = l-r;
172             }
173             now *= base;
174         }
175         swap(a, b);
176     }
177     c[] = a[];    
178 }
179 
180 Mint[] multiply(uint G, Mint)(in Mint[] a, in Mint[] b) {
181     size_t A = a.length, B = b.length;
182     if (!A || !B) return [];
183     size_t lg = 1;
184     while ((1<<lg) < A+B-1) lg++;
185     size_t N = 1<<lg;
186     Mint[] _a = new Mint[N];
187     Mint[] _b = new Mint[N];
188     foreach (i; 0..A) _a[i] = a[i];
189     foreach (i; 0..B) _b[i] = b[i];
190     nft!(G, false)(_a);
191     nft!(G, false)(_b);
192     foreach (i; 0..N) _a[i] *= _b[i];
193     nft!(G, true)(_a);
194     Mint[] c = new Mint[A+B-1];
195     Mint iN = Mint(1) / Mint(N);
196     foreach (i; 0..A+B-1) {
197         c[i] = _a[i] * iN;
198     }
199     return c;
200 }
201 
202 Mint[] multiply(Mint, size_t M = 3, size_t W = 10)(in Mint[] a, in Mint[] b)
203 if (isModInt!Mint) {
204     import std.math : round;
205     alias P = Complex!double;
206 
207     size_t A = a.length, B = b.length;
208     if (!A || !B) return [];
209     auto N = A + B - 1;
210     size_t lg = 1;
211     while ((1<<lg) < N) lg++;
212     N = 1<<lg;
213 
214     P[][M] x, y;
215     P[] w = new P[N];
216     foreach (ph; 0..M) {
217         x[ph] = new P[N];
218         y[ph] = new P[N];
219         w[] = P(0, 0);
220         foreach (i; 0..A) w[i].re = (a[i].v >> (ph*W)) % (1<<W);
221         foreach (i; 0..B) w[i].im = (b[i].v >> (ph*W)) % (1<<W);
222         fft!false(w);
223         foreach (i; 0..N) w[i] *= 0.5;
224         foreach (i; 0..N) {
225             auto j = i ? N-i : 0;
226             x[ph][i] = P(w[i].re+w[j].re,  w[i].im-w[j].im);
227             y[ph][i] = P(w[i].im+w[j].im, -w[i].re+w[j].re);
228         }
229     }
230 
231     Mint[] c = new Mint[A+B-1];
232     Mint basel = 1, baser = pow(Mint(1<<W), M);
233     P[] z = new P[N];
234     foreach (ph; 0..M) {
235         z[] = P(0, 0);
236         foreach (af; 0..ph+1) {
237             auto bf = ph - af;
238             foreach (i; 0..N) {
239                 z[i] += x[af][i] * y[bf][i];
240             }
241         }
242         foreach (af; ph+1..M) {
243             auto bf = ph + M - af;
244             foreach (i; 0..N) {
245                 z[i] += x[af][i] * y[bf][i] * P(0, 1);
246             }
247         }
248         fft!true(z);
249         foreach (i; 0..A+B-1) {
250             z[i] *= 1.0/N;
251             c[i] += Mint(cast(long)(round(z[i].re)))*basel;
252             c[i] += Mint(cast(long)(round(z[i].im)))*baser;
253         }
254         basel *= Mint(1<<W);
255         baser *= Mint(1<<W);
256     }
257     return c;
258 }
259 
260 unittest {
261     alias Mint = ModInt!924844033;
262     import std.algorithm, std.stdio, std.random, std.math;
263     import dkh.stopwatch;
264     StopWatch sw; sw.start;
265     Mint rndM() { return Mint(uniform(0, 924844033)); }
266     foreach (L; 1..20) {
267         foreach (R; 1..20) {
268             foreach (ph; 0..10) {
269                 Mint[] a = new Mint[L];
270                 Mint[] b = new Mint[R];
271                 foreach (ref x; a) x = rndM();
272                 foreach (ref x; b) x = rndM();
273                 Mint[] c1 = multiply!5(a, b);
274                 Mint[] c2 = new Mint[L+R-1];
275                 foreach (i; 0..L) {
276                     foreach (j; 0..R) {
277                         c2[i+j] += a[i] * b[j];
278                     }
279                 }
280                 assert(c1.length == c2.length);
281                 foreach (i; 0..L+R-1) {
282                     if (c1[i] != c2[i]) {
283                         writeln(a);
284                         writeln(b);
285                         writeln(c1);
286                         writeln(c2);
287                     }
288                     assert(c1[i] == c2[i]);
289                 }
290             }
291         }
292     }
293     writeln("NFT Stress: ", sw.peek.toMsecs);    
294 }
295 
296 
297 unittest {
298     alias Mint = ModInt!(10^^9 + 7);
299     import std.algorithm, std.stdio, std.random, std.math;
300     import dkh.stopwatch;
301     StopWatch sw; sw.start;
302     Mint rndM() { return Mint(uniform(0, 10^^9 + 7)); }
303     foreach (L; 1..20) {
304         foreach (R; 1..20) {
305             foreach (ph; 0..10) {
306                 Mint[] a = new Mint[L];
307                 Mint[] b = new Mint[R];
308                 foreach (ref x; a) x = rndM();
309                 foreach (ref x; b) x = rndM();
310                 Mint[] c1 = multiply(a, b);
311                 Mint[] c2 = new Mint[L+R-1];
312                 foreach (i; 0..L) {
313                     foreach (j; 0..R) {
314                         c2[i+j] += a[i] * b[j];
315                     }
316                 }
317                 assert(c1.length == c2.length);
318                 foreach (i; 0..L+R-1) {
319                     if (c1[i] != c2[i]) {
320                         writeln(a);
321                         writeln(b);
322                         writeln(c1);
323                         writeln(c2);
324                     }
325                     assert(c1[i] == c2[i]);
326                 }
327             }
328         }
329     }
330     writeln("FFT(ModInt) Stress: ", sw.peek.toMsecs);    
331 }
332