1 module dkh.modint; 2 3 import dkh.numeric.primitive; 4 5 /** 6 自動mod取り構造体 7 */ 8 struct ModInt(uint MD) if (MD < int.max) { 9 import std.conv : to; 10 uint v; 11 this(int v) {this(long(v));} 12 this(long v) {this.v = (v%MD+MD)%MD;} 13 static auto normS(uint x) {return (x<MD)?x:x-MD;} 14 static auto make(uint x) {ModInt m; m.v = x; return m;} 15 /// 整数型と同じように演算可能 割り算のみ遅い 16 auto opBinary(string op:"+")(ModInt r) const {return make(normS(v+r.v));} 17 /// ditto 18 auto opBinary(string op:"-")(ModInt r) const {return make(normS(v+MD-r.v));} 19 /// ditto 20 auto opBinary(string op:"*")(ModInt r) const {return make((ulong(v)*r.v%MD).to!uint);} 21 /// ditto 22 auto opBinary(string op:"/")(ModInt r) const {return this*inv(r);} 23 /// ditto 24 auto opBinary(string op:"^^", T)(T r) const {return pow(this, r, ModInt(1));} 25 auto opOpAssign(string op)(ModInt r) {return mixin ("this=this"~op~"r");} 26 /// xの逆元を求める 27 static ModInt inv(ModInt x) {return x^^(MD-2);}; 28 string toString() const {return v.to!string;} 29 } 30 31 /// 32 unittest { 33 alias Mint = ModInt!(107); 34 assert((Mint(100) + Mint(10)).v == 3); 35 assert(( Mint(10) * Mint(12)).v == 13); 36 assert(( Mint(1) / Mint(2)).v == 108/2); 37 assert((Mint(2) ^^ 7).v == 21); 38 } 39 40 unittest { 41 static assert( is(ModInt!(uint(1000000000) * 2))); //not overflow 42 static assert(!is(ModInt!(uint(1145141919) * 2))); //overflow! 43 alias Mint = ModInt!(10^^9+7); 44 // negative check 45 assert(Mint(-1).v == 10^^9 + 6); 46 assert(Mint(-1L).v == 10^^9 + 6); 47 48 Mint a = 48; 49 Mint b = Mint.inv(a); 50 assert(b.v == 520833337); 51 52 Mint c = Mint(15); 53 Mint d = Mint(3); 54 assert((c/d).v == 5); 55 } 56 57 /** 58 自動mod取り構造体(実行時mod指定) 59 */ 60 struct DModInt(string name) { 61 import std.conv : to; 62 static uint MD; 63 uint v; 64 this(int v) {this(long(v));} 65 this(long v) {this.v = ((v%MD+MD)%MD).to!uint;} 66 static auto normS(uint x) {return (x<MD)?x:x-MD;} 67 static auto make(uint x) {DModInt m; m.MD = MD; m.v = x; return m;} 68 /// 整数型と同じように演算可能 割り算のみ遅い 69 auto opBinary(string op:"+")(DModInt r) const {return make(normS(v+r.v));} 70 /// ditto 71 auto opBinary(string op:"-")(DModInt r) const {return make(normS(v+MD-r.v));} 72 /// ditto 73 auto opBinary(string op:"*")(DModInt r) const {return make((ulong(v)*r.v%MD).to!uint);} 74 /// ditto 75 auto opBinary(string op:"/")(DModInt r) const {return this*inv(r);} 76 auto opOpAssign(string op)(DModInt r) {return mixin ("this=this"~op~"r");} 77 /// xの逆元を求める 78 static DModInt inv(DModInt x) { 79 return DModInt(extGcd!int(x.v, MD)[0]); 80 } 81 string toString() {return v.to!string;} 82 } 83 84 /// 85 unittest { 86 alias Mint1 = DModInt!"mod1"; 87 alias Mint2 = DModInt!"mod2"; 88 Mint1.MD = 7; 89 Mint2.MD = 9; 90 assert((Mint1(5)+Mint1(5)).v == 3); // (5+5) % 7 91 assert((Mint2(5)+Mint2(5)).v == 1); // (5+5) % 9 92 } 93 94 unittest { 95 alias Mint = DModInt!"default"; 96 Mint.MD = 10^^9 + 7; 97 //negative check 98 assert(Mint(-1).v == 10^^9 + 6); 99 assert(Mint(-1L).v == 10^^9 + 6); 100 const Mint a = Mint(48); 101 const Mint b = Mint.inv(a); 102 assert((a*b).v == 1); 103 assert(b.v == 520833337); 104 Mint c = Mint(15); 105 Mint d = Mint(3); 106 assert((c/d).v == 5); 107 c += d; 108 assert(c.v == 18); 109 } 110 111 template isModInt(T) { 112 const isModInt = 113 is(T : ModInt!MD, uint MD) || is(T : DModInt!S, string S); 114 } 115 116 117 T[] factTable(T)(size_t length) if (isModInt!T) { 118 import std.range : take, recurrence; 119 import std.array : array; 120 return T(1).recurrence!((a, n) => a[n-1]*T(n)).take(length).array; 121 } 122 123 T[] invFactTable(T)(size_t length) if (isModInt!T) { 124 import std.algorithm : map, reduce; 125 import std.range : take, recurrence, iota; 126 import std.array : array; 127 auto res = new T[length]; 128 res[$-1] = T(1) / iota(1, length).map!T.reduce!"a*b"; 129 foreach_reverse (i, v; res[0..$-1]) { 130 res[i] = res[i+1] * T(i+1); 131 } 132 return res; 133 } 134 135 T[] invTable(T)(size_t length) if (isModInt!T) { 136 auto f = factTable!T(length); 137 auto invf = invFactTable!T(length); 138 auto res = new T[length]; 139 foreach (i; 1..length) { 140 res[i] = invf[i] * f[i-1]; 141 } 142 return res; 143 } 144 145 unittest { 146 import std.stdio; 147 alias Mint = ModInt!(10^^9 + 7); 148 auto r = factTable!Mint(20); 149 Mint a = 1; 150 assert(r[0] == Mint(1)); 151 foreach (i; 1..20) { 152 a *= Mint(i); 153 assert(r[i] == a); 154 } 155 auto p = invFactTable!Mint(20); 156 foreach (i; 1..20) { 157 assert((r[i]*p[i]).v == 1); 158 } 159 }