1 /**
2 Unstable
3 */
4 
5 module dkh.modpoly;
6 
7 import dkh.numeric.primitive;
8 import dkh.numeric.convolution;
9 import dkh.modint;
10 import dkh.container.stack;
11 
12 
13 struct ModPoly(uint MD) if (MD < int.max) {
14     alias Mint = ModInt!MD;
15     import std.algorithm : min, max, reverse;
16     
17     Stack!Mint d;
18     void shrink() { while (!d.empty && d.back == Mint(0)) d.removeBack; }
19     @property size_t length() const { return d.length; }
20     @property inout(Mint)[] data() inout { return d.data; }
21     
22     this(in Mint[] v) {
23         d = v.dup;
24         shrink();
25     }
26 
27     const(Mint) opIndex(size_t i) const {
28         if (i < d.length) return d[i];
29         return Mint(0);
30     }
31     void opIndexAssign(Mint x, size_t i) {
32         if (i < d.length) {
33             d[i] = x;
34             shrink();
35             return;
36         }
37         if (x == Mint(0)) return;
38         while (d.length < i) d.insertBack(Mint(0));
39         d.insertBack(x);
40         return;
41     }
42 
43     ModPoly opBinary(string op : "+")(in ModPoly r) const {
44         size_t N = length, M = r.length;
45         Mint[] res = new Mint[max(N, M)];
46         foreach (i; 0..max(N, M)) res[i] = this[i] + r[i];
47         return ModPoly(res);
48     }
49     ModPoly opBinary(string op : "-")(in ModPoly r) const {
50         size_t N = length, M = r.length;
51         Mint[] res = new Mint[max(N, M)];
52         foreach (i; 0..max(N, M)) res[i] = this[i] - r[i];
53         return ModPoly(res);
54     }
55     ModPoly opBinary(string op : "*")(in ModPoly r) const {
56         size_t N = length, M = r.length;
57         if (min(N, M) == 0) return ModPoly();
58         return ModPoly(multiply(data, r.data));
59     }
60     ModPoly opBinary(string op : "*")(in Mint r) const {
61         Mint[] res = new Mint[length];
62         foreach (i; 0..length) res[i] = this[i]*r;
63         return ModPoly(res);
64     }
65     ModPoly opBinary(string op : "/")(in ModPoly r) const {
66         size_t B = max(1, length, r.length);
67         return divWithInv(r.inv(B), B);
68     }
69     ModPoly opBinary(string op : "%")(in ModPoly r) const {
70         return *this - y * div(y);
71     }
72     ModPoly opBinary(string op : "<<")(size_t n) const {
73         Mint[] res = new Mint[n+length];
74         foreach (i; 0..length) res[i+n] = this[i];
75         return ModPoly(res);
76     }
77     ModPoly opBinary(string op : ">>")(size_t n) const {
78         if (length <= n) return ModPoly();
79         Mint[] res = new Mint[length-n];
80         foreach (i; n..length) res[i-n] = this[i];
81         return ModPoly(res);
82     }
83     ModPoly opOpAssign(string op)(in ModPoly r) {
84         return mixin("this=this"~op~"r");
85     }
86 
87     ModPoly strip(size_t n) const {
88         auto res = d.data.dup;
89         res = res[0..min(n, length)];
90         return ModPoly(res);
91     }
92     ModPoly divWithInv(in ModPoly ir, size_t B) const {
93         return (this * ir) >> (B-1);
94     }
95     ModPoly remWithInv(in ModPoly r, in ModPoly ir, size_t B) const {
96         return this - r * divWithInv(ir, B);
97     }
98     ModPoly rev(ptrdiff_t n = -1) const {
99         auto res = d.data.dup;
100         if (n != -1) res = res[0..n];
101         reverse(res);
102         return ModPoly(res);
103     }
104     ModPoly inv(size_t n) const {
105         assert(length >= 1);
106         assert(n >= length-1);
107         ModPoly c = rev();
108         ModPoly d = ModPoly([Mint(1)/c[0]]);
109         for (ptrdiff_t i = 1; i+length-2 < n; i *= 2) {
110             d = (d * (ModPoly([Mint(2)]) - c*d)).strip(2*i);
111         }
112         return d.rev(n+1-length);
113     }
114 
115     string toString() {
116         import std.conv : to;
117         import std.range : join;
118         string[] l = new string[length];
119         foreach (i; 0..length) {
120             l[i] = (this[i]).toString ~ "x^" ~ i.to!string;
121         }
122         return l.join(" + ");
123     }
124 }
125 
126 ModPoly!MD nthMod(uint MD)(in ModPoly!MD mod, ulong n) {
127     import core.bitop : bsr;
128     alias Mint = ModInt!MD;
129     assert(mod.length);
130     size_t B = mod.length * 2 - 1;
131     auto modInv = mod.inv(B);
132     auto p = ModPoly!MD([Mint(1)]);
133     if (n == 0) return p;
134     auto m = bsr(n);
135     foreach_reverse(i; 0..m+1) {
136         if (n & (1L<<i)) {
137             p = (p<<1).remWithInv(mod, modInv, B);
138         }
139         if (i) {
140             p = (p*p).remWithInv(mod, modInv, B);
141         }
142     }
143     return p;
144 }
145 
146 ModPoly!MD berlekampMassey(uint MD)(in ModInt!MD[] s) {
147     alias Mint = ModInt!MD;
148     Mint[] b = [Mint(-1)], c = [Mint(-1)];
149     Mint y = 1;
150     foreach (ed; 1..s.length+1) {
151         auto L = c.length, M = b.length;
152         Mint x = 0;
153         foreach (i; 0..L) {
154             x += c[i] * s[ed-L+i];
155         }
156         b ~= Mint(0); M++;
157         if (x == Mint(0)) {
158             continue;
159         }
160         auto freq = x/y;
161         if (L < M) {
162             auto tmp = c;
163             import std.range : repeat, take, array;
164             c = Mint(0).repeat.take(M-L).array ~ c;
165             foreach (i; 0..M) {
166                 c[M-1-i] -= freq*b[M-1-i];
167             }
168             b = tmp;
169             y = x;
170         } else {
171             foreach (i; 0..M) {
172                 c[L-1-i] -= freq*b[M-1-i];
173             }
174         }
175     }
176     return ModPoly!MD(c);
177 }
178 
179 unittest {
180     import std.stdio;
181     static immutable int MD = 7;
182     alias Mint = ModInt!MD;
183     alias MPol = ModPoly!MD;
184     auto p = MPol(), q = MPol();
185     p[0] = Mint(3); p[1] = Mint(2);
186     q[0] = Mint(3); q[1] = Mint(2);
187     writeln(p+q);
188     writeln(p-q);
189     writeln(p*q);
190 }
191 
192 unittest {
193     import std.stdio;
194     static immutable int MD = 7;
195     alias Mint = ModInt!MD;
196     alias MPol = ModPoly!MD;
197     auto p = MPol();
198     p[10] = Mint(1);
199     assert(p.length == 11);    
200 }