1 module dkh.modint; 2 3 import dkh.numeric.primitive; 4 5 /** 6 自動mod取り構造体 7 */ 8 struct ModInt(uint MD) if (MD < int.max) { 9 import std.conv : to; 10 uint v; 11 this(int v) {this(long(v));} 12 this(long v) {this.v = (v%MD+MD)%MD;} 13 static auto normS(uint x) {return (x<MD)?x:x-MD;} 14 static auto make(uint x) {ModInt m; m.v = x; return m;} 15 /// 整数型と同じように演算可能 割り算のみ遅い 16 auto opBinary(string op:"+")(ModInt r) const {return make(normS(v+r.v));} 17 /// ditto 18 auto opBinary(string op:"-")(ModInt r) const {return make(normS(v+MD-r.v));} 19 /// ditto 20 auto opBinary(string op:"*")(ModInt r) const {return make((ulong(v)*r.v%MD).to!uint);} 21 /// ditto 22 auto opBinary(string op:"/")(ModInt r) const {return this*inv(r);} 23 auto opOpAssign(string op)(ModInt r) {return mixin ("this=this"~op~"r");} 24 /// xの逆元を求める 25 static ModInt inv(ModInt x) {return ModInt(extGcd!int(x.v, MD)[0]);} 26 string toString() const {return v.to!string;} 27 } 28 29 /// 30 unittest { 31 alias Mint = ModInt!(107); 32 assert((Mint(100) + Mint(10)).v == 3); 33 assert(( Mint(10) * Mint(12)).v == 13); 34 assert(( Mint(1) / Mint(2)).v == 108/2); 35 } 36 37 unittest { 38 static assert( is(ModInt!(uint(1000000000) * 2))); //not overflow 39 static assert(!is(ModInt!(uint(1145141919) * 2))); //overflow! 40 alias Mint = ModInt!(10^^9+7); 41 // negative check 42 assert(Mint(-1).v == 10^^9 + 6); 43 assert(Mint(-1L).v == 10^^9 + 6); 44 45 Mint a = 48; 46 Mint b = Mint.inv(a); 47 assert(b.v == 520833337); 48 49 Mint c = Mint(15); 50 Mint d = Mint(3); 51 assert((c/d).v == 5); 52 } 53 54 /** 55 自動mod取り構造体(実行時mod指定) 56 */ 57 struct DModInt(string name) { 58 import std.conv : to; 59 static uint MD; 60 uint v; 61 this(int v) {this(long(v));} 62 this(long v) {this.v = ((v%MD+MD)%MD).to!uint;} 63 static auto normS(uint x) {return (x<MD)?x:x-MD;} 64 static auto make(uint x) {DModInt m; m.MD = MD; m.v = x; return m;} 65 /// 整数型と同じように演算可能 割り算のみ遅い 66 auto opBinary(string op:"+")(DModInt r) const {return make(normS(v+r.v));} 67 /// ditto 68 auto opBinary(string op:"-")(DModInt r) const {return make(normS(v+MD-r.v));} 69 /// ditto 70 auto opBinary(string op:"*")(DModInt r) const {return make((ulong(v)*r.v%MD).to!uint);} 71 /// ditto 72 auto opBinary(string op:"/")(DModInt r) const {return this*inv(r);} 73 auto opOpAssign(string op)(DModInt r) {return mixin ("this=this"~op~"r");} 74 /// xの逆元を求める 75 static DModInt inv(DModInt x) { 76 return DModInt(extGcd!int(x.v, MD)[0]); 77 } 78 string toString() {return v.to!string;} 79 } 80 81 /// 82 unittest { 83 alias Mint1 = DModInt!"mod1"; 84 alias Mint2 = DModInt!"mod2"; 85 Mint1.MD = 7; 86 Mint2.MD = 9; 87 assert((Mint1(5)+Mint1(5)).v == 3); // (5+5) % 7 88 assert((Mint2(5)+Mint2(5)).v == 1); // (5+5) % 9 89 } 90 91 unittest { 92 alias Mint = DModInt!"default"; 93 Mint.MD = 10^^9 + 7; 94 //negative check 95 assert(Mint(-1).v == 10^^9 + 6); 96 assert(Mint(-1L).v == 10^^9 + 6); 97 const Mint a = Mint(48); 98 const Mint b = Mint.inv(a); 99 assert((a*b).v == 1); 100 assert(b.v == 520833337); 101 Mint c = Mint(15); 102 Mint d = Mint(3); 103 assert((c/d).v == 5); 104 c += d; 105 assert(c.v == 18); 106 } 107 108 template isModInt(T) { 109 const isModInt = 110 is(T : ModInt!MD, uint MD) || is(S : DModInt!S, string s); 111 } 112 113 114 T[] factTable(T)(size_t length) if (isModInt!T) { 115 import std.range : take, recurrence; 116 import std.array : array; 117 return T(1).recurrence!((a, n) => a[n-1]*T(n)).take(length).array; 118 } 119 120 T[] invFactTable(T)(size_t length) if (isModInt!T) { 121 import std.algorithm : map, reduce; 122 import std.range : take, recurrence, iota; 123 import std.array : array; 124 auto res = new T[length]; 125 res[$-1] = T(1) / iota(1, length).map!T.reduce!"a*b"; 126 foreach_reverse (i, v; res[0..$-1]) { 127 res[i] = res[i+1] * T(i+1); 128 } 129 return res; 130 } 131 132 T[] invTable(T)(size_t length) if (isModInt!T) { 133 auto f = factTable!T(length); 134 auto invf = invFactTable!T(length); 135 auto res = new T[length]; 136 foreach (i; 1..length) { 137 res[i] = invf[i] * f[i-1]; 138 } 139 return res; 140 } 141 142 unittest { 143 import std.stdio; 144 alias Mint = ModInt!(10^^9 + 7); 145 auto r = factTable!Mint(20); 146 Mint a = 1; 147 assert(r[0] == Mint(1)); 148 foreach (i; 1..20) { 149 a *= Mint(i); 150 assert(r[i] == a); 151 } 152 auto p = invFactTable!Mint(20); 153 foreach (i; 1..20) { 154 assert((r[i]*p[i]).v == 1); 155 } 156 }