1 module dkh.modint; 2 3 import dkh.numeric.primitive; 4 5 /** 6 int with mod, mod must be prime 7 */ 8 struct ModInt(uint MD) if (MD < int.max) { 9 import std.conv : to; 10 uint v; 11 this(int v) {this(long(v));} 12 this(long v) {this.v = (v%MD+MD)%MD;} 13 static auto normS(uint x) {return (x<MD)?x:x-MD;} 14 static auto make(uint x) {ModInt m; m.v = x; return m;} 15 /// We can handle it as same as int(but divide is slow) 16 auto opBinary(string op:"+")(ModInt r) const {return make(normS(v+r.v));} 17 /// ditto 18 auto opBinary(string op:"-")(ModInt r) const {return make(normS(v+MD-r.v));} 19 /// ditto 20 auto opBinary(string op:"*")(ModInt r) const {return make((ulong(v)*r.v%MD).to!uint);} 21 /// ditto 22 auto opBinary(string op:"/")(ModInt r) const {return this*inv(r);} 23 /// ditto 24 auto opBinary(string op:"^^", T)(T r) const {return pow(this, r, ModInt(1));} 25 auto opOpAssign(string op)(ModInt r) {return mixin ("this=this"~op~"r");} 26 /// return 1/x 27 static ModInt inv(ModInt x) {return x^^(MD-2);}; 28 string toString() const {return v.to!string;} 29 } 30 31 /// 32 unittest { 33 alias Mint = ModInt!(107); 34 assert((Mint(100) + Mint(10)).v == 3); 35 assert(( Mint(10) * Mint(12)).v == 13); 36 assert(( Mint(1) / Mint(2)).v == 108/2); 37 assert((Mint(2) ^^ 7).v == 21); 38 } 39 40 unittest { 41 static assert( is(ModInt!(uint(1000000000) * 2))); //not overflow 42 static assert(!is(ModInt!(uint(1145141919) * 2))); //overflow! 43 alias Mint = ModInt!(10^^9+7); 44 // negative check 45 assert(Mint(-1).v == 10^^9 + 6); 46 assert(Mint(-1L).v == 10^^9 + 6); 47 48 Mint a = 48; 49 Mint b = Mint.inv(a); 50 assert(b.v == 520833337); 51 52 Mint c = Mint(15); 53 Mint d = Mint(3); 54 assert((c/d).v == 5); 55 } 56 57 /** 58 int with mod, mod can be setted in execute time. mod don't have to be prime. 59 */ 60 struct DModInt(string name) { 61 import std.conv : to; 62 static uint MD; 63 uint v; 64 this(int v) {this(long(v));} 65 this(long v) {this.v = ((v%MD+MD)%MD).to!uint;} 66 static auto normS(uint x) {return (x<MD)?x:x-MD;} 67 static auto make(uint x) {DModInt m; m.MD = MD; m.v = x; return m;} 68 /// 整数型と同じように演算可能 割り算のみ遅い 69 auto opBinary(string op:"+")(DModInt r) const {return make(normS(v+r.v));} 70 /// ditto 71 auto opBinary(string op:"-")(DModInt r) const {return make(normS(v+MD-r.v));} 72 /// ditto 73 auto opBinary(string op:"*")(DModInt r) const {return make((ulong(v)*r.v%MD).to!uint);} 74 /// ditto 75 auto opBinary(string op:"/")(DModInt r) const {return this*inv(r);} 76 auto opOpAssign(string op)(DModInt r) {return mixin ("this=this"~op~"r");} 77 /// xの逆元を求める 78 static DModInt inv(DModInt x) { 79 return DModInt(extGcd!int(x.v, MD)[0]); 80 } 81 string toString() {return v.to!string;} 82 } 83 84 /// 85 unittest { 86 alias Mint1 = DModInt!"mod1"; 87 alias Mint2 = DModInt!"mod2"; 88 Mint1.MD = 7; 89 Mint2.MD = 9; 90 assert((Mint1(5)+Mint1(5)).v == 3); // (5+5) % 7 91 assert((Mint2(5)+Mint2(5)).v == 1); // (5+5) % 9 92 } 93 94 unittest { 95 alias Mint = DModInt!"default"; 96 Mint.MD = 10^^9 + 7; 97 //negative check 98 assert(Mint(-1).v == 10^^9 + 6); 99 assert(Mint(-1L).v == 10^^9 + 6); 100 const Mint a = Mint(48); 101 const Mint b = Mint.inv(a); 102 assert((a*b).v == 1); 103 assert(b.v == 520833337); 104 Mint c = Mint(15); 105 Mint d = Mint(3); 106 assert((c/d).v == 5); 107 c += d; 108 assert(c.v == 18); 109 } 110 111 template isModInt(T) { 112 const isModInt = 113 is(T : ModInt!MD, uint MD) || is(T : DModInt!S, string S); 114 } 115 116 /// return [0!, 1!, 2!, ..., (length-1)!] 117 T[] factTable(T)(size_t length) if (isModInt!T) { 118 import std.range : take, recurrence; 119 import std.array : array; 120 return T(1).recurrence!((a, n) => a[n-1]*T(n)).take(length).array; 121 } 122 123 /// return [1/0!, 1/1!, 1/2!, ..., 1/(length-1)!] 124 T[] invFactTable(T)(size_t length) if (isModInt!T) { 125 import std.algorithm : map, reduce; 126 import std.range : take, recurrence, iota; 127 import std.array : array; 128 auto res = new T[length]; 129 res[$-1] = T(1) / iota(1, length).map!T.reduce!"a*b"; 130 foreach_reverse (i, v; res[0..$-1]) { 131 res[i] = res[i+1] * T(i+1); 132 } 133 return res; 134 } 135 136 /// return [0, 1/1, 1/2, 1/3, ...] 137 T[] invTable(T)(size_t length) if (isModInt!T) { 138 auto f = factTable!T(length); 139 auto invf = invFactTable!T(length); 140 auto res = new T[length]; 141 foreach (i; 1..length) { 142 res[i] = invf[i] * f[i-1]; 143 } 144 return res; 145 } 146 147 unittest { 148 import std.stdio; 149 alias Mint = ModInt!(10^^9 + 7); 150 auto r = factTable!Mint(20); 151 Mint a = 1; 152 assert(r[0] == Mint(1)); 153 foreach (i; 1..20) { 154 a *= Mint(i); 155 assert(r[i] == a); 156 } 157 auto p = invFactTable!Mint(20); 158 foreach (i; 1..20) { 159 assert((r[i]*p[i]).v == 1); 160 } 161 }