1 module dkh.modpoly;
2 
3 import dkh.numeric.primitive;
4 import dkh.numeric.convolution;
5 import dkh.modint;
6 import dkh.container.stack;
7 
8 
9 struct ModPoly(uint MD) if (MD < int.max) {
10     alias Mint = ModInt!MD;
11     import std.algorithm : min, max, reverse;
12     
13     Stack!Mint d;
14     void shrink() { while (!d.empty && d.back == Mint(0)) d.removeBack; }
15     @property size_t length() const { return d.length; }
16     @property inout(Mint)[] data() inout { return d.data; }
17     
18     this(in Mint[] v) {
19         d = v.dup;
20         shrink();
21     }
22 
23     const(Mint) opIndex(size_t i) const {
24         if (i < d.length) return d[i];
25         return Mint(0);
26     }
27     void opIndexAssign(Mint x, size_t i) {
28         if (i < d.length) {
29             d[i] = x;
30             shrink();
31             return;
32         }
33         if (x == Mint(0)) return;
34         while (d.length < i) d.insertBack(Mint(0));
35         d.insertBack(x);
36         return;
37     }
38 
39     ModPoly opBinary(string op : "+")(in ModPoly r) const {
40         size_t N = length, M = r.length;
41         Mint[] res = new Mint[max(N, M)];
42         foreach (i; 0..max(N, M)) res[i] = this[i] + r[i];
43         return ModPoly(res);
44     }
45     ModPoly opBinary(string op : "-")(in ModPoly r) const {
46         size_t N = length, M = r.length;
47         Mint[] res = new Mint[max(N, M)];
48         foreach (i; 0..max(N, M)) res[i] = this[i] - r[i];
49         return ModPoly(res);
50     }
51     ModPoly opBinary(string op : "*")(in ModPoly r) const {
52         size_t N = length, M = r.length;
53         if (min(N, M) == 0) return ModPoly();
54         return ModPoly(multiply(data, r.data));
55     }
56     ModPoly opBinary(string op : "*")(in Mint r) const {
57         Mint[] res = new Mint[length];
58         foreach (i; 0..length) res[i] = this[i]*r;
59         return ModPoly(res);
60     }
61     ModPoly opBinary(string op : "/")(in ModPoly r) const {
62         size_t B = max(1, length, r.length);
63         return divWithInv(r.inv(B), B);
64     }
65     ModPoly opBinary(string op : "%")(in ModPoly r) const {
66         return *this - y * div(y);
67     }
68     ModPoly opBinary(string op : "<<")(size_t n) const {
69         Mint[] res = new Mint[n+length];
70         foreach (i; 0..length) res[i+n] = this[i];
71         return ModPoly(res);
72     }
73     ModPoly opBinary(string op : ">>")(size_t n) const {
74         if (length <= n) return ModPoly();
75         Mint[] res = new Mint[length-n];
76         foreach (i; n..length) res[i-n] = this[i];
77         return ModPoly(res);
78     }
79     ModPoly opOpAssign(string op)(in ModPoly r) {
80         return mixin("this=this"~op~"r");
81     }
82 
83     ModPoly strip(size_t n) const {
84         auto res = d.data.dup;
85         res = res[0..min(n, length)];
86         return ModPoly(res);
87     }
88     ModPoly divWithInv(in ModPoly ir, size_t B) const {
89         return (this * ir) >> (B-1);
90     }
91     ModPoly remWithInv(in ModPoly r, in ModPoly ir, size_t B) const {
92         return this - r * divWithInv(ir, B);
93     }
94     ModPoly rev(ptrdiff_t n = -1) const {
95         auto res = d.data.dup;
96         if (n != -1) res = res[0..n];
97         reverse(res);
98         return ModPoly(res);
99     }
100     ModPoly inv(size_t n) const {
101         assert(length >= 1);
102         assert(n >= length-1);
103         ModPoly c = rev();
104         ModPoly d = ModPoly([Mint(1)/c[0]]);
105         for (ptrdiff_t i = 1; i+length-2 < n; i *= 2) {
106             d = (d * (ModPoly([Mint(2)]) - c*d)).strip(2*i);
107         }
108         return d.rev(n+1-length);
109     }
110 
111     string toString() {
112         import std.conv : to;
113         import std.range : join;
114         string[] l = new string[length];
115         foreach (i; 0..length) {
116             l[i] = (this[i]).toString ~ "x^" ~ i.to!string;
117         }
118         return l.join(" + ");
119     }
120 }
121 
122 ModPoly!MD nthMod(uint MD)(in ModPoly!MD mod, ulong n) {
123     import core.bitop : bsr;
124     alias Mint = ModInt!MD;
125     assert(mod.length);
126     size_t B = mod.length * 2 - 1;
127     auto modInv = mod.inv(B);
128     auto p = ModPoly!MD([Mint(1)]);
129     if (n == 0) return p;
130     auto m = bsr(n);
131     foreach_reverse(i; 0..m+1) {
132         if (n & (1L<<i)) {
133             p = (p<<1).remWithInv(mod, modInv, B);
134         }
135         if (i) {
136             p = (p*p).remWithInv(mod, modInv, B);
137         }
138     }
139     return p;
140 }
141 
142 ModPoly!MD berlekampMassey(uint MD)(in ModInt!MD[] s) {
143     alias Mint = ModInt!MD;
144     Mint[] b = [Mint(-1)], c = [Mint(-1)];
145     Mint y = 1;
146     foreach (ed; 1..s.length+1) {
147         auto L = c.length, M = b.length;
148         Mint x = 0;
149         foreach (i; 0..L) {
150             x += c[i] * s[ed-L+i];
151         }
152         b ~= Mint(0); M++;
153         if (x == Mint(0)) {
154             continue;
155         }
156         auto freq = x/y;
157         if (L < M) {
158             auto tmp = c;
159             import std.range : repeat, take, array;
160             c = Mint(0).repeat.take(M-L).array ~ c;
161             foreach (i; 0..M) {
162                 c[M-1-i] -= freq*b[M-1-i];
163             }
164             b = tmp;
165             y = x;
166         } else {
167             foreach (i; 0..M) {
168                 c[L-1-i] -= freq*b[M-1-i];
169             }
170         }
171     }
172     return ModPoly!MD(c);
173 }
174 
175 unittest {
176     import std.stdio;
177     static immutable int MD = 7;
178     alias Mint = ModInt!MD;
179     alias MPol = ModPoly!MD;
180     auto p = MPol(), q = MPol();
181     p[0] = Mint(3); p[1] = Mint(2);
182     q[0] = Mint(3); q[1] = Mint(2);
183     writeln(p+q);
184     writeln(p-q);
185     writeln(p*q);
186 }
187 
188 unittest {
189     import std.stdio;
190     static immutable int MD = 7;
191     alias Mint = ModInt!MD;
192     alias MPol = ModPoly!MD;
193     auto p = MPol();
194     p[10] = Mint(1);
195     assert(p.length == 11);    
196 }