1 module dkh.modint;
2 
3 import dkh.numeric.primitive;
4 
5 /**
6 int with mod, mod must be prime
7  */
8 struct ModInt(uint MD) if (MD < int.max) {
9     import std.conv : to;
10     uint v;
11     this(int v) {this(long(v));}
12     this(long v) {this.v = (v%MD+MD)%MD;}
13     static auto normS(uint x) {return (x<MD)?x:x-MD;}
14     static auto make(uint x) {ModInt m; m.v = x; return m;}
15     /// We can handle it as same as int(but divide is slow)
16     auto opBinary(string op:"+")(ModInt r) const {return make(normS(v+r.v));}
17     /// ditto
18     auto opBinary(string op:"-")(ModInt r) const {return make(normS(v+MD-r.v));}
19     /// ditto
20     auto opBinary(string op:"*")(ModInt r) const {return make((ulong(v)*r.v%MD).to!uint);}
21     /// ditto
22     auto opBinary(string op:"/")(ModInt r) const {return this*inv(r);}
23     /// ditto
24     auto opBinary(string op:"^^", T)(T r) const {return pow(this, r, ModInt(1));}
25     auto opOpAssign(string op)(ModInt r) {return mixin ("this=this"~op~"r");}
26     /// return 1/x
27     static ModInt inv(ModInt x) {return x^^(MD-2);};
28     string toString() const {return v.to!string;}
29 }
30 
31 ///
32 unittest {
33     alias Mint = ModInt!(107);
34     assert((Mint(100) + Mint(10)).v == 3);
35     assert(( Mint(10) * Mint(12)).v == 13);
36     assert((  Mint(1) /  Mint(2)).v == 108/2);
37     assert((Mint(2) ^^ 7).v == 21);
38 }
39 
40 unittest {
41     static assert( is(ModInt!(uint(1000000000) * 2))); //not overflow
42     static assert(!is(ModInt!(uint(1145141919) * 2))); //overflow!
43     alias Mint = ModInt!(10^^9+7);
44     // negative check
45     assert(Mint(-1).v == 10^^9 + 6);
46     assert(Mint(-1L).v == 10^^9 + 6);
47 
48     Mint a = 48;
49     Mint b = Mint.inv(a);
50     assert(b.v == 520833337);
51 
52     Mint c = Mint(15);
53     Mint d = Mint(3);
54     assert((c/d).v == 5);
55 }
56 
57 /**
58 int with mod, mod can be setted in execute time. mod don't have to be prime.
59  */
60 struct DModInt(string name) {
61     import std.conv : to;
62     static uint MD;
63     uint v;
64     this(int v) {this(long(v));}
65     this(long v) {this.v = ((v%MD+MD)%MD).to!uint;}
66     static auto normS(uint x) {return (x<MD)?x:x-MD;}
67     static auto make(uint x) {DModInt m; m.MD = MD; m.v = x; return m;}
68     /// 整数型と同じように演算可能 割り算のみ遅い
69     auto opBinary(string op:"+")(DModInt r) const {return make(normS(v+r.v));}
70     /// ditto
71     auto opBinary(string op:"-")(DModInt r) const {return make(normS(v+MD-r.v));}
72     /// ditto
73     auto opBinary(string op:"*")(DModInt r) const {return make((ulong(v)*r.v%MD).to!uint);}
74     /// ditto
75     auto opBinary(string op:"/")(DModInt r) const {return this*inv(r);}
76     auto opOpAssign(string op)(DModInt r) {return mixin ("this=this"~op~"r");}
77     /// xの逆元を求める
78     static DModInt inv(DModInt x) {
79         return DModInt(extGcd!int(x.v, MD)[0]);
80     }
81     string toString() {return v.to!string;}
82 }
83 
84 ///
85 unittest {
86     alias Mint1 = DModInt!"mod1";
87     alias Mint2 = DModInt!"mod2";
88     Mint1.MD = 7;
89     Mint2.MD = 9;
90     assert((Mint1(5)+Mint1(5)).v == 3); // (5+5) % 7
91     assert((Mint2(5)+Mint2(5)).v == 1); // (5+5) % 9    
92 }
93 
94 unittest {
95     alias Mint = DModInt!"default";
96     Mint.MD = 10^^9 + 7;
97     //negative check
98     assert(Mint(-1).v == 10^^9 + 6);
99     assert(Mint(-1L).v == 10^^9 + 6);
100     const Mint a = Mint(48);
101     const Mint b = Mint.inv(a);
102     assert((a*b).v == 1);
103     assert(b.v == 520833337);
104     Mint c = Mint(15);
105     Mint d = Mint(3);
106     assert((c/d).v == 5);
107     c += d;
108     assert(c.v == 18);
109 }
110 
111 template isModInt(T) {
112     const isModInt =
113         is(T : ModInt!MD, uint MD) || is(T : DModInt!S, string S);
114 }
115 
116 /// return [0!, 1!, 2!, ..., (length-1)!]
117 T[] factTable(T)(size_t length) if (isModInt!T) {
118     import std.range : take, recurrence;
119     import std.array : array;
120     return T(1).recurrence!((a, n) => a[n-1]*T(n)).take(length).array;
121 }
122 
123 /// return [1/0!, 1/1!, 1/2!, ..., 1/(length-1)!]
124 T[] invFactTable(T)(size_t length) if (isModInt!T) {
125     import std.algorithm : map, reduce;
126     import std.range : take, recurrence, iota;
127     import std.array : array;
128     auto res = new T[length];
129     res[$-1] = T(1) / iota(1, length).map!T.reduce!"a*b";
130     foreach_reverse (i, v; res[0..$-1]) {
131         res[i] = res[i+1] * T(i+1);
132     }
133     return res;
134 }
135 
136 /// return [0, 1/1, 1/2, 1/3, ...]
137 T[] invTable(T)(size_t length) if (isModInt!T) {
138     auto f = factTable!T(length);
139     auto invf = invFactTable!T(length);
140     auto res = new T[length];
141     foreach (i; 1..length) {
142         res[i] = invf[i] * f[i-1];
143     }
144     return res;
145 }
146 
147 unittest {
148     import std.stdio;
149     alias Mint = ModInt!(10^^9 + 7);
150     auto r = factTable!Mint(20);
151     Mint a = 1;
152     assert(r[0] == Mint(1));
153     foreach (i; 1..20) {
154         a *= Mint(i);
155         assert(r[i] == a);
156     }
157     auto p = invFactTable!Mint(20);
158     foreach (i; 1..20) {
159         assert((r[i]*p[i]).v == 1);
160     }
161 }