1 module dkh.modint;
2 
3 import dkh.numeric.primitive;
4 
5 /**
6 自動mod取り構造体
7  */
8 struct ModInt(uint MD) if (MD < int.max) {
9     import std.conv : to;
10     uint v;
11     this(int v) {this(long(v));}
12     this(long v) {this.v = (v%MD+MD)%MD;}
13     static auto normS(uint x) {return (x<MD)?x:x-MD;}
14     static auto make(uint x) {ModInt m; m.v = x; return m;}
15     /// 整数型と同じように演算可能 割り算のみ遅い
16     auto opBinary(string op:"+")(ModInt r) const {return make(normS(v+r.v));}
17     /// ditto
18     auto opBinary(string op:"-")(ModInt r) const {return make(normS(v+MD-r.v));}
19     /// ditto
20     auto opBinary(string op:"*")(ModInt r) const {return make((ulong(v)*r.v%MD).to!uint);}
21     /// ditto
22     auto opBinary(string op:"/")(ModInt r) const {return this*inv(r);}
23     auto opOpAssign(string op)(ModInt r) {return mixin ("this=this"~op~"r");}
24     /// xの逆元を求める
25     static ModInt inv(ModInt x) {return ModInt(extGcd!int(x.v, MD)[0]);}
26     string toString() const {return v.to!string;}
27 }
28 
29 ///
30 unittest {
31     alias Mint = ModInt!(107);
32     assert((Mint(100) + Mint(10)).v == 3);
33     assert(( Mint(10) * Mint(12)).v == 13);
34     assert((  Mint(1) /  Mint(2)).v == 108/2);
35 }
36 
37 unittest {
38     static assert( is(ModInt!(uint(1000000000) * 2))); //not overflow
39     static assert(!is(ModInt!(uint(1145141919) * 2))); //overflow!
40     alias Mint = ModInt!(10^^9+7);
41     // negative check
42     assert(Mint(-1).v == 10^^9 + 6);
43     assert(Mint(-1L).v == 10^^9 + 6);
44 
45     Mint a = 48;
46     Mint b = Mint.inv(a);
47     assert(b.v == 520833337);
48 
49     Mint c = Mint(15);
50     Mint d = Mint(3);
51     assert((c/d).v == 5);
52 }
53 
54 /**
55 自動mod取り構造体(実行時mod指定)
56  */
57 struct DModInt(string name) {
58     import std.conv : to;
59     static uint MD;
60     uint v;
61     this(int v) {this(long(v));}
62     this(long v) {this.v = ((v%MD+MD)%MD).to!uint;}
63     static auto normS(uint x) {return (x<MD)?x:x-MD;}
64     static auto make(uint x) {DModInt m; m.MD = MD; m.v = x; return m;}
65     /// 整数型と同じように演算可能 割り算のみ遅い
66     auto opBinary(string op:"+")(DModInt r) const {return make(normS(v+r.v));}
67     /// ditto
68     auto opBinary(string op:"-")(DModInt r) const {return make(normS(v+MD-r.v));}
69     /// ditto
70     auto opBinary(string op:"*")(DModInt r) const {return make((ulong(v)*r.v%MD).to!uint);}
71     /// ditto
72     auto opBinary(string op:"/")(DModInt r) const {return this*inv(r);}
73     auto opOpAssign(string op)(DModInt r) {return mixin ("this=this"~op~"r");}
74     /// xの逆元を求める
75     static DModInt inv(DModInt x) {
76         return DModInt(extGcd!int(x.v, MD)[0]);
77     }
78     string toString() {return v.to!string;}
79 }
80 
81 ///
82 unittest {
83     alias Mint1 = DModInt!"mod1";
84     alias Mint2 = DModInt!"mod2";
85     Mint1.MD = 7;
86     Mint2.MD = 9;
87     assert((Mint1(5)+Mint1(5)).v == 3); // (5+5) % 7
88     assert((Mint2(5)+Mint2(5)).v == 1); // (5+5) % 9    
89 }
90 
91 unittest {
92     alias Mint = DModInt!"default";
93     Mint.MD = 10^^9 + 7;
94     //negative check
95     assert(Mint(-1).v == 10^^9 + 6);
96     assert(Mint(-1L).v == 10^^9 + 6);
97     const Mint a = Mint(48);
98     const Mint b = Mint.inv(a);
99     assert((a*b).v == 1);
100     assert(b.v == 520833337);
101     Mint c = Mint(15);
102     Mint d = Mint(3);
103     assert((c/d).v == 5);
104     c += d;
105     assert(c.v == 18);
106 }
107 
108 template isModInt(T) {
109     const isModInt =
110         is(T : ModInt!MD, uint MD) || is(S : DModInt!S, string s);
111 }
112 
113 
114 T[] factTable(T)(size_t length) if (isModInt!T) {
115     import std.range : take, recurrence;
116     import std.array : array;
117     return T(1).recurrence!((a, n) => a[n-1]*T(n)).take(length).array;
118 }
119 
120 T[] invFactTable(T)(size_t length) if (isModInt!T) {
121     import std.algorithm : map, reduce;
122     import std.range : take, recurrence, iota;
123     import std.array : array;
124     auto res = new T[length];
125     res[$-1] = T(1) / iota(1, length).map!T.reduce!"a*b";
126     foreach_reverse (i, v; res[0..$-1]) {
127         res[i] = res[i+1] * T(i+1);
128     }
129     return res;
130 }
131 
132 T[] invTable(T)(size_t length) if (isModInt!T) {
133     auto f = factTable!T(length);
134     auto invf = invFactTable!T(length);
135     auto res = new T[length];
136     foreach (i; 1..length) {
137         res[i] = invf[i] * f[i-1];
138     }
139     return res;
140 }
141 
142 unittest {
143     import std.stdio;
144     alias Mint = ModInt!(10^^9 + 7);
145     auto r = factTable!Mint(20);
146     Mint a = 1;
147     assert(r[0] == Mint(1));
148     foreach (i; 1..20) {
149         a *= Mint(i);
150         assert(r[i] == a);
151     }
152     auto p = invFactTable!Mint(20);
153     foreach (i; 1..20) {
154         assert((r[i]*p[i]).v == 1);
155     }
156 }