1 /** 2 Unstable 3 */ 4 5 module dkh.modpoly; 6 7 import dkh.numeric.primitive; 8 import dkh.numeric.convolution; 9 import dkh.modint; 10 import dkh.container.stack; 11 12 13 struct ModPoly(uint MD) if (MD < int.max) { 14 alias Mint = ModInt!MD; 15 import std.algorithm : min, max, reverse; 16 17 Stack!Mint d; 18 void shrink() { while (!d.empty && d.back == Mint(0)) d.removeBack; } 19 @property size_t length() const { return d.length; } 20 @property inout(Mint)[] data() inout { return d.data; } 21 22 this(in Mint[] v) { 23 d = v.dup; 24 shrink(); 25 } 26 27 const(Mint) opIndex(size_t i) const { 28 if (i < d.length) return d[i]; 29 return Mint(0); 30 } 31 void opIndexAssign(Mint x, size_t i) { 32 if (i < d.length) { 33 d[i] = x; 34 shrink(); 35 return; 36 } 37 if (x == Mint(0)) return; 38 while (d.length < i) d.insertBack(Mint(0)); 39 d.insertBack(x); 40 return; 41 } 42 43 ModPoly opBinary(string op : "+")(in ModPoly r) const { 44 size_t N = length, M = r.length; 45 Mint[] res = new Mint[max(N, M)]; 46 foreach (i; 0..max(N, M)) res[i] = this[i] + r[i]; 47 return ModPoly(res); 48 } 49 ModPoly opBinary(string op : "-")(in ModPoly r) const { 50 size_t N = length, M = r.length; 51 Mint[] res = new Mint[max(N, M)]; 52 foreach (i; 0..max(N, M)) res[i] = this[i] - r[i]; 53 return ModPoly(res); 54 } 55 ModPoly opBinary(string op : "*")(in ModPoly r) const { 56 size_t N = length, M = r.length; 57 if (min(N, M) == 0) return ModPoly(); 58 return ModPoly(multiply(data, r.data)); 59 } 60 ModPoly opBinary(string op : "*")(in Mint r) const { 61 Mint[] res = new Mint[length]; 62 foreach (i; 0..length) res[i] = this[i]*r; 63 return ModPoly(res); 64 } 65 ModPoly opBinary(string op : "/")(in ModPoly r) const { 66 size_t B = max(1, length, r.length); 67 return divWithInv(r.inv(B), B); 68 } 69 ModPoly opBinary(string op : "%")(in ModPoly r) const { 70 return *this - y * div(y); 71 } 72 ModPoly opBinary(string op : "<<")(size_t n) const { 73 Mint[] res = new Mint[n+length]; 74 foreach (i; 0..length) res[i+n] = this[i]; 75 return ModPoly(res); 76 } 77 ModPoly opBinary(string op : ">>")(size_t n) const { 78 if (length <= n) return ModPoly(); 79 Mint[] res = new Mint[length-n]; 80 foreach (i; n..length) res[i-n] = this[i]; 81 return ModPoly(res); 82 } 83 ModPoly opOpAssign(string op)(in ModPoly r) { 84 return mixin("this=this"~op~"r"); 85 } 86 87 ModPoly strip(size_t n) const { 88 auto res = d.data.dup; 89 res = res[0..min(n, length)]; 90 return ModPoly(res); 91 } 92 ModPoly divWithInv(in ModPoly ir, size_t B) const { 93 return (this * ir) >> (B-1); 94 } 95 ModPoly remWithInv(in ModPoly r, in ModPoly ir, size_t B) const { 96 return this - r * divWithInv(ir, B); 97 } 98 ModPoly rev(ptrdiff_t n = -1) const { 99 auto res = d.data.dup; 100 if (n != -1) res = res[0..n]; 101 reverse(res); 102 return ModPoly(res); 103 } 104 ModPoly inv(size_t n) const { 105 assert(length >= 1); 106 assert(n >= length-1); 107 ModPoly c = rev(); 108 ModPoly d = ModPoly([Mint(1)/c[0]]); 109 for (ptrdiff_t i = 1; i+length-2 < n; i *= 2) { 110 d = (d * (ModPoly([Mint(2)]) - c*d)).strip(2*i); 111 } 112 return d.rev(n+1-length); 113 } 114 115 string toString() { 116 import std.conv : to; 117 import std.range : join; 118 string[] l = new string[length]; 119 foreach (i; 0..length) { 120 l[i] = (this[i]).toString ~ "x^" ~ i.to!string; 121 } 122 return l.join(" + "); 123 } 124 } 125 126 ModPoly!MD nthMod(uint MD)(in ModPoly!MD mod, ulong n) { 127 import core.bitop : bsr; 128 alias Mint = ModInt!MD; 129 assert(mod.length); 130 size_t B = mod.length * 2 - 1; 131 auto modInv = mod.inv(B); 132 auto p = ModPoly!MD([Mint(1)]); 133 if (n == 0) return p; 134 auto m = bsr(n); 135 foreach_reverse(i; 0..m+1) { 136 if (n & (1L<<i)) { 137 p = (p<<1).remWithInv(mod, modInv, B); 138 } 139 if (i) { 140 p = (p*p).remWithInv(mod, modInv, B); 141 } 142 } 143 return p; 144 } 145 146 ModPoly!MD berlekampMassey(uint MD)(in ModInt!MD[] s) { 147 alias Mint = ModInt!MD; 148 Mint[] b = [Mint(-1)], c = [Mint(-1)]; 149 Mint y = 1; 150 foreach (ed; 1..s.length+1) { 151 auto L = c.length, M = b.length; 152 Mint x = 0; 153 foreach (i; 0..L) { 154 x += c[i] * s[ed-L+i]; 155 } 156 b ~= Mint(0); M++; 157 if (x == Mint(0)) { 158 continue; 159 } 160 auto freq = x/y; 161 if (L < M) { 162 auto tmp = c; 163 import std.range : repeat, take, array; 164 c = Mint(0).repeat.take(M-L).array ~ c; 165 foreach (i; 0..M) { 166 c[M-1-i] -= freq*b[M-1-i]; 167 } 168 b = tmp; 169 y = x; 170 } else { 171 foreach (i; 0..M) { 172 c[L-1-i] -= freq*b[M-1-i]; 173 } 174 } 175 } 176 return ModPoly!MD(c); 177 } 178 179 unittest { 180 import std.stdio; 181 static immutable int MD = 7; 182 alias Mint = ModInt!MD; 183 alias MPol = ModPoly!MD; 184 auto p = MPol(), q = MPol(); 185 p[0] = Mint(3); p[1] = Mint(2); 186 q[0] = Mint(3); q[1] = Mint(2); 187 writeln(p+q); 188 writeln(p-q); 189 writeln(p*q); 190 } 191 192 unittest { 193 import std.stdio; 194 static immutable int MD = 7; 195 alias Mint = ModInt!MD; 196 alias MPol = ModPoly!MD; 197 auto p = MPol(); 198 p[10] = Mint(1); 199 assert(p.length == 11); 200 }