1 module dkh.modpoly; 2 3 import dkh.numeric.primitive; 4 import dkh.numeric.convolution; 5 import dkh.modint; 6 import dkh.container.stack; 7 8 9 struct ModPoly(uint MD) if (MD < int.max) { 10 alias Mint = ModInt!MD; 11 import std.algorithm : min, max, reverse; 12 13 Stack!Mint d; 14 void shrink() { while (!d.empty && d.back == Mint(0)) d.removeBack; } 15 @property size_t length() const { return d.length; } 16 @property inout(Mint)[] data() inout { return d.data; } 17 18 this(in Mint[] v) { 19 d = v.dup; 20 shrink(); 21 } 22 23 const(Mint) opIndex(size_t i) const { 24 if (i < d.length) return d[i]; 25 return Mint(0); 26 } 27 void opIndexAssign(Mint x, size_t i) { 28 if (i < d.length) { 29 d[i] = x; 30 shrink(); 31 return; 32 } 33 if (x == Mint(0)) return; 34 while (d.length < i) d.insertBack(Mint(0)); 35 d.insertBack(x); 36 return; 37 } 38 39 ModPoly opBinary(string op : "+")(in ModPoly r) const { 40 size_t N = length, M = r.length; 41 Mint[] res = new Mint[max(N, M)]; 42 foreach (i; 0..max(N, M)) res[i] = this[i] + r[i]; 43 return ModPoly(res); 44 } 45 ModPoly opBinary(string op : "-")(in ModPoly r) const { 46 size_t N = length, M = r.length; 47 Mint[] res = new Mint[max(N, M)]; 48 foreach (i; 0..max(N, M)) res[i] = this[i] - r[i]; 49 return ModPoly(res); 50 } 51 ModPoly opBinary(string op : "*")(in ModPoly r) const { 52 size_t N = length, M = r.length; 53 if (min(N, M) == 0) return ModPoly(); 54 return ModPoly(multiply(data, r.data)); 55 } 56 ModPoly opBinary(string op : "*")(in Mint r) const { 57 Mint[] res = new Mint[length]; 58 foreach (i; 0..length) res[i] = this[i]*r; 59 return ModPoly(res); 60 } 61 ModPoly opBinary(string op : "/")(in ModPoly r) const { 62 size_t B = max(1, length, r.length); 63 return divWithInv(r.inv(B), B); 64 } 65 ModPoly opBinary(string op : "%")(in ModPoly r) const { 66 return *this - y * div(y); 67 } 68 ModPoly opBinary(string op : "<<")(size_t n) const { 69 Mint[] res = new Mint[n+length]; 70 foreach (i; 0..length) res[i+n] = this[i]; 71 return ModPoly(res); 72 } 73 ModPoly opBinary(string op : ">>")(size_t n) const { 74 if (length <= n) return ModPoly(); 75 Mint[] res = new Mint[length-n]; 76 foreach (i; n..length) res[i-n] = this[i]; 77 return ModPoly(res); 78 } 79 ModPoly opOpAssign(string op)(in ModPoly r) { 80 return mixin("this=this"~op~"r"); 81 } 82 83 ModPoly strip(size_t n) const { 84 auto res = d.data.dup; 85 res = res[0..min(n, length)]; 86 return ModPoly(res); 87 } 88 ModPoly divWithInv(in ModPoly ir, size_t B) const { 89 return (this * ir) >> (B-1); 90 } 91 ModPoly remWithInv(in ModPoly r, in ModPoly ir, size_t B) const { 92 return this - r * divWithInv(ir, B); 93 } 94 ModPoly rev(ptrdiff_t n = -1) const { 95 auto res = d.data.dup; 96 if (n != -1) res = res[0..n]; 97 reverse(res); 98 return ModPoly(res); 99 } 100 ModPoly inv(size_t n) const { 101 assert(length >= 1); 102 assert(n >= length-1); 103 ModPoly c = rev(); 104 ModPoly d = ModPoly([Mint(1)/c[0]]); 105 for (ptrdiff_t i = 1; i+length-2 < n; i *= 2) { 106 d = (d * (ModPoly([Mint(2)]) - c*d)).strip(2*i); 107 } 108 return d.rev(n+1-length); 109 } 110 111 string toString() { 112 import std.conv : to; 113 import std.range : join; 114 string[] l = new string[length]; 115 foreach (i; 0..length) { 116 l[i] = (this[i]).toString ~ "x^" ~ i.to!string; 117 } 118 return l.join(" + "); 119 } 120 } 121 122 ModPoly!MD nthMod(uint MD)(in ModPoly!MD mod, ulong n) { 123 import core.bitop : bsr; 124 alias Mint = ModInt!MD; 125 assert(mod.length); 126 size_t B = mod.length * 2 - 1; 127 auto modInv = mod.inv(B); 128 auto p = ModPoly!MD([Mint(1)]); 129 if (n == 0) return p; 130 auto m = bsr(n); 131 foreach_reverse(i; 0..m+1) { 132 if (n & (1L<<i)) { 133 p = (p<<1).remWithInv(mod, modInv, B); 134 } 135 if (i) { 136 p = (p*p).remWithInv(mod, modInv, B); 137 } 138 } 139 return p; 140 } 141 142 ModPoly!MD berlekampMassey(uint MD)(in ModInt!MD[] s) { 143 alias Mint = ModInt!MD; 144 Mint[] b = [Mint(-1)], c = [Mint(-1)]; 145 Mint y = 1; 146 foreach (ed; 1..s.length+1) { 147 auto L = c.length, M = b.length; 148 Mint x = 0; 149 foreach (i; 0..L) { 150 x += c[i] * s[ed-L+i]; 151 } 152 b ~= Mint(0); M++; 153 if (x == Mint(0)) { 154 continue; 155 } 156 auto freq = x/y; 157 if (L < M) { 158 auto tmp = c; 159 import std.range : repeat, take, array; 160 c = Mint(0).repeat.take(M-L).array ~ c; 161 foreach (i; 0..M) { 162 c[M-1-i] -= freq*b[M-1-i]; 163 } 164 b = tmp; 165 y = x; 166 } else { 167 foreach (i; 0..M) { 168 c[L-1-i] -= freq*b[M-1-i]; 169 } 170 } 171 } 172 return ModPoly!MD(c); 173 } 174 175 unittest { 176 import std.stdio; 177 static immutable int MD = 7; 178 alias Mint = ModInt!MD; 179 alias MPol = ModPoly!MD; 180 auto p = MPol(), q = MPol(); 181 p[0] = Mint(3); p[1] = Mint(2); 182 q[0] = Mint(3); q[1] = Mint(2); 183 writeln(p+q); 184 writeln(p-q); 185 writeln(p*q); 186 } 187 188 unittest { 189 import std.stdio; 190 static immutable int MD = 7; 191 alias Mint = ModInt!MD; 192 alias MPol = ModPoly!MD; 193 auto p = MPol(); 194 p[10] = Mint(1); 195 assert(p.length == 11); 196 }