1 module dkh.numeric.convolution;
2 
3 import std.complex;
4 
5 private double[] fftSinList(size_t S) {
6     import std.math : PI, sin;
7     assert(2 <= S);
8     size_t N = 1<<S;
9     static double[][30] buf;
10     if (!buf[S].length) {
11         buf[S] = new double[3*N/4+1];
12         foreach (i; 0..N/4+1) {
13             buf[S][i] = sin(i*2*double(PI)/N);
14             buf[S][N/2-i] = buf[S][i];
15             buf[S][N/2+i] = -buf[S][i];
16         }
17     }
18     return buf[S];
19 }
20 
21 /// fft
22 void fft(bool type)(Complex!double[] c) {
23     import std.algorithm : swap;
24     import core.bitop : bsr;
25     alias P = Complex!double;
26     size_t N = c.length;
27     assert(N);
28     size_t S = bsr(N);
29     assert(1<<S == N);
30     if (S == 1) {
31         auto x = c[0], y = c[1];
32         c[0] = x+y;
33         c[1] = x-y;
34         return;
35     }
36     auto rot = fftSinList(S);
37     P[] a = c.dup, b = new P[c.length];
38     foreach (i; 1..S+1) {
39         size_t W = 1<<(S-i);
40         for (size_t y = 0; y < N/2; y += W) {
41             P now = P(rot[y + N/4], rot[y]);
42             if (type) now = conj(now);
43             foreach (x; 0..W) {
44                 auto l =       a[y<<1 | x];
45                 auto r = now * a[y<<1 | x | W];
46                 b[y | x]        = l+r;
47                 b[y | x | N>>1] = l-r;
48             }
49         }
50         swap(a, b);
51     }
52     c[] = a[];
53 }
54 
55 /// multiply two double[]
56 double[] multiply(in double[] a, in double[] b) {
57     alias P = Complex!double;
58     size_t A = a.length, B = b.length;
59     if (!A || !B) return [];
60     size_t lg = 1;
61     while ((1<<lg) < A+B-1) lg++;
62     size_t N = 1<<lg;
63     P[] d = new P[N];
64     d[] = P(0, 0);
65     foreach (i; 0..A) d[i].re = a[i];
66     foreach (i; 0..B) d[i].im = b[i];
67     fft!false(d);
68     foreach (i; 0..N/2+1) {
69         auto j = i ? (N-i) : 0;
70         P x = P(d[i].re+d[j].re, d[i].im-d[j].im);
71         P y = P(d[i].im+d[j].im, -d[i].re+d[j].re);
72         d[i] = x * y / 4;
73         if (i != j) d[j] = conj(d[i]);
74     }
75     fft!true(d);
76     double[] c = new double[A+B-1];
77     foreach (i; 0..A+B-1) {
78         c[i] = d[i].re / N;
79     }
80     return c;
81 }
82 
83 unittest {
84     import std.algorithm, std.stdio, std.random, std.math;
85     import dkh.stopwatch;
86     StopWatch sw; sw.start;
87     foreach (L; 1..20) {
88         foreach (R; 1..20) {
89             foreach (ph; 0..10) {
90                 double[] a = new double[L];
91                 double[] b = new double[R];
92                 foreach (ref x; a) x = 100 * uniform01;
93                 foreach (ref x; b) x = 100 * uniform01;
94                 double[] c1 = multiply(a, b);
95                 double[] c2 = new double[L+R-1]; c2[] = 0.0;
96                 foreach (i; 0..L) {
97                     foreach (j; 0..R) {
98                         c2[i+j] += a[i] * b[j];
99                     }
100                 }
101                 assert(c1.length == c2.length);
102                 foreach (i; 0..L+R-1) {
103                     assert(approxEqual(c1[i], c2[i]));
104                 }
105             }
106         }
107     }
108     writeln("FFT Stress: ", sw.peek.toMsecs);
109 }
110 
111 import dkh.modint, dkh.numeric.primitive;
112 
113 /// nft(G must be primitive root)
114 void nft(uint G, bool type, Mint)(Mint[] c) {
115     import std.algorithm : swap;
116     import core.bitop : bsr;    
117     size_t N = c.length;
118     assert(N);
119     size_t S = bsr(N);
120     assert(1<<S == N);
121 
122     Mint[] a = c.dup, b = new Mint[N];
123     foreach (i; 1..S+1) {
124         size_t W = 1<<(S-i);
125         Mint base = pow(Mint(G), Mint(-1).v/(1<<i));
126         if (type) base = Mint(1)/base;
127         Mint now = 1;
128         for (size_t y = 0; y < N/2; y += W) {
129             foreach (x; 0..W) {
130                 auto l =       a[y<<1 | x];
131                 auto r = now * a[y<<1 | x | W];
132                 b[y | x]        = l+r;
133                 b[y | x | N>>1] = l-r;
134             }
135             now *= base;
136         }
137         swap(a, b);
138     }
139     c[] = a[];    
140 }
141 
142 /// multiply 2 Mint[](G must be primitive root)
143 Mint[] multiply(uint G, Mint)(in Mint[] a, in Mint[] b) {
144     size_t A = a.length, B = b.length;
145     if (!A || !B) return [];
146     size_t lg = 1;
147     while ((1<<lg) < A+B-1) lg++;
148     size_t N = 1<<lg;
149     Mint[] _a = new Mint[N];
150     Mint[] _b = new Mint[N];
151     foreach (i; 0..A) _a[i] = a[i];
152     foreach (i; 0..B) _b[i] = b[i];
153     nft!(G, false)(_a);
154     nft!(G, false)(_b);
155     foreach (i; 0..N) _a[i] *= _b[i];
156     nft!(G, true)(_a);
157     Mint[] c = new Mint[A+B-1];
158     Mint iN = Mint(1) / Mint(N);
159     foreach (i; 0..A+B-1) {
160         c[i] = _a[i] * iN;
161     }
162     return c;
163 }
164 
165 /// multiply 2 Mint[](abiritialy mod)
166 Mint[] multiply(Mint, size_t M = 3, size_t W = 10)(in Mint[] a, in Mint[] b)
167 if (isModInt!Mint) {
168     import std.math : round;
169     alias P = Complex!double;
170 
171     size_t A = a.length, B = b.length;
172     if (!A || !B) return [];
173     auto N = A + B - 1;
174     size_t lg = 1;
175     while ((1<<lg) < N) lg++;
176     N = 1<<lg;
177 
178     P[][M] x, y;
179     P[] w = new P[N];
180     foreach (ph; 0..M) {
181         x[ph] = new P[N];
182         y[ph] = new P[N];
183         w[] = P(0, 0);
184         foreach (i; 0..A) w[i].re = (a[i].v >> (ph*W)) % (1<<W);
185         foreach (i; 0..B) w[i].im = (b[i].v >> (ph*W)) % (1<<W);
186         fft!false(w);
187         foreach (i; 0..N) w[i] *= 0.5;
188         foreach (i; 0..N) {
189             auto j = i ? N-i : 0;
190             x[ph][i] = P(w[i].re+w[j].re,  w[i].im-w[j].im);
191             y[ph][i] = P(w[i].im+w[j].im, -w[i].re+w[j].re);
192         }
193     }
194 
195     Mint[] c = new Mint[A+B-1];
196     Mint basel = 1, baser = pow(Mint(1<<W), M);
197     P[] z = new P[N];
198     foreach (ph; 0..M) {
199         z[] = P(0, 0);
200         foreach (af; 0..ph+1) {
201             auto bf = ph - af;
202             foreach (i; 0..N) {
203                 z[i] += x[af][i] * y[bf][i];
204             }
205         }
206         foreach (af; ph+1..M) {
207             auto bf = ph + M - af;
208             foreach (i; 0..N) {
209                 z[i] += x[af][i] * y[bf][i] * P(0, 1);
210             }
211         }
212         fft!true(z);
213         foreach (i; 0..A+B-1) {
214             z[i] *= 1.0/N;
215             c[i] += Mint(cast(long)(round(z[i].re)))*basel;
216             c[i] += Mint(cast(long)(round(z[i].im)))*baser;
217         }
218         basel *= Mint(1<<W);
219         baser *= Mint(1<<W);
220     }
221     return c;
222 }
223 
224 unittest {
225     alias Mint = ModInt!924844033;
226     import std.algorithm, std.stdio, std.random, std.math;
227     import dkh.stopwatch;
228     StopWatch sw; sw.start;
229     Mint rndM() { return Mint(uniform(0, 924844033)); }
230     foreach (L; 1..20) {
231         foreach (R; 1..20) {
232             foreach (ph; 0..10) {
233                 Mint[] a = new Mint[L];
234                 Mint[] b = new Mint[R];
235                 foreach (ref x; a) x = rndM();
236                 foreach (ref x; b) x = rndM();
237                 Mint[] c1 = multiply!5(a, b);
238                 Mint[] c2 = new Mint[L+R-1];
239                 foreach (i; 0..L) {
240                     foreach (j; 0..R) {
241                         c2[i+j] += a[i] * b[j];
242                     }
243                 }
244                 assert(c1.length == c2.length);
245                 foreach (i; 0..L+R-1) {
246                     if (c1[i] != c2[i]) {
247                         writeln(a);
248                         writeln(b);
249                         writeln(c1);
250                         writeln(c2);
251                     }
252                     assert(c1[i] == c2[i]);
253                 }
254             }
255         }
256     }
257     writeln("NFT Stress: ", sw.peek.toMsecs);    
258 }
259 
260 
261 unittest {
262     alias Mint = ModInt!(10^^9 + 7);
263     import std.algorithm, std.stdio, std.random, std.math;
264     import dkh.stopwatch;
265     StopWatch sw; sw.start;
266     Mint rndM() { return Mint(uniform(0, 10^^9 + 7)); }
267     foreach (L; 1..20) {
268         foreach (R; 1..20) {
269             foreach (ph; 0..10) {
270                 Mint[] a = new Mint[L];
271                 Mint[] b = new Mint[R];
272                 foreach (ref x; a) x = rndM();
273                 foreach (ref x; b) x = rndM();
274                 Mint[] c1 = multiply(a, b);
275                 Mint[] c2 = new Mint[L+R-1];
276                 foreach (i; 0..L) {
277                     foreach (j; 0..R) {
278                         c2[i+j] += a[i] * b[j];
279                     }
280                 }
281                 assert(c1.length == c2.length);
282                 foreach (i; 0..L+R-1) {
283                     if (c1[i] != c2[i]) {
284                         writeln(a);
285                         writeln(b);
286                         writeln(c1);
287                         writeln(c2);
288                     }
289                     assert(c1[i] == c2[i]);
290                 }
291             }
292         }
293     }
294     writeln("FFT(ModInt) Stress: ", sw.peek.toMsecs);    
295 }