1 module dkh.bigint;
2 
3 import core.checkedint, core.bitop;
4 import dkh.int128, dkh.foundation;
5 
6 /**
7 多倍長整数.
8 静的に長さを指定しする, $(D N*64bit)となる.
9 例えば$(D uintN!2)とすれば、$(D uint128)として使える.
10  */
11 struct uintN(int N) if (N >= 1) {
12     ulong[N] d;
13     this(ulong x) { d[0] = x; }
14     this(string s) {
15         foreach (c; s) {
16             this *= 10;
17             this += uintN(c-'0');
18         }
19     }
20     string toString() {
21         import std.conv : to;
22         import std.algorithm : reverse;
23         import dkh.container.stackpayload;
24         StackPayload!char s;
25         auto x = this;
26         if (!x) return "0";
27         while (x) {
28             static immutable B = 10UL^^18;
29             ulong z = (x % B);
30             x /= B;
31             bool last = (!x);
32             foreach (i; 0..18) {
33                 if (last && !z) break;
34                 s ~= cast(char)('0' + z % 10);
35                 z /= 10;
36             }
37         }
38         reverse(s.data);
39         return s.data.idup;
40     }
41 
42     ref inout(ulong) opIndex(int idx) inout { return d[idx]; }
43     T opCast(T: bool)() {
44         import std.algorithm, std.range;
45         return d[].find!"a!=0".empty == false;
46     }
47     //bit op
48     uintN opUnary(string op)() const if (op == "~") {
49         uintN res;
50         foreach (i; 0..N) {
51             res[i] = ~d[i];
52         }
53         return res;
54     }
55     uintN opBinary(string op)(in uintN r) const
56     if (op == "&" || op == "|" || op == "^") {
57         uintN res;
58         foreach (i; 0..N) {
59             res[i] = mixin("d[i]" ~ op ~ "r.d[i]");
60         }
61         return res;
62     }
63     uintN opBinary(string op : "<<")(int n) const {
64         if (N * 64 <= n) return uintN(0);
65         uintN res;
66         int ws = n / 64;
67         int bs = n % 64;
68         if (bs == 0) {
69             res.d[ws..N][] = d[0..N-ws][];
70             return res;
71         }
72         foreach_reverse (i; 1..N-ws) {
73             res[i+ws] = (d[i] << bs) | (d[i-1] >> (64-bs));
74         }
75         res[ws] = (d[0] << bs);
76         return res;
77     }
78     uintN opBinary(string op : ">>")(int n) const {
79         if (N * 64 <= n) return uintN(0);
80         uintN res;
81         int ws = n / 64;
82         int bs = n % 64;
83         if (bs == 0) {
84             res.d[0..N-ws][] = d[ws..N][];
85             return res;
86         }
87         foreach_reverse (i; 0..N-ws-1) {
88             res[i] = (d[i+ws+1] >> (64-bs)) | (d[i+ws] << bs);
89         }
90         res[N-ws-1] = (d[N-1] << bs);
91         return res;
92     }
93     //cmp
94     int opCmp(in uintN r) const {
95         return cmpMultiWord(d, r.d);
96     }
97 
98     //arit
99     uintN opUnary(string op)() if (op == "++") {
100         foreach (i; 0..N) {
101             d[i]++;
102             if (d[i]) break;
103         }
104         return this;
105     }
106     uintN opUnary(string op)() if (op == "--") {
107         foreach (i; 0..N) {
108             d[i]--;
109             if (d[i] != ulong.max) break;
110         }
111         return this;
112     }
113     uintN opUnary(string op)() const if (op=="+" || op=="-") {
114         if (op == "+") return this;
115         if (op == "-") {
116             return ++(~this);
117         }
118     }
119     
120     uintN opBinary(string op : "+")(in uintN r) const {
121         uintN res;
122         addMultiWord(d, r.d, res.d);
123         return res;
124     }
125     uintN opBinary(string op : "-")(in uintN r) const {
126         uintN res;
127         subMultiWord(d, r.d, res.d);
128         return res;
129     }
130 
131     uintN opBinary(string op : "*")(in uintN r) const {
132         uintN res;
133         static if (N == 2) {
134             auto u = mul128(d[0], r[0]);
135             res[0] = u[0];
136             res[1] = u[1] + d[0]*r[1] + d[1]*r[0];
137             return res;
138         } else {
139             foreach (i; 0..N) {
140                 ulong carry = 0;
141                 foreach (j; 0..N-1-i) {
142                     int s = i+j;
143                     bool of;
144                     auto u = mul128(d[i], r[j]);
145                     res[s] = addu(res[s], carry, of);
146                     carry = u[1];
147                     if (of) carry++;
148                     of = false;
149                     res[s] = addu(res[s], u[0], of);
150                     if (of) carry++;
151                 }
152                 res[N-1] += d[i] * r[N-1-i] + carry;
153             }
154             return res;
155         }
156     }
157     uintN opBinary(string op : "*")(in ulong r) const {
158         uintN res;
159         mulMultiWord(d, r, res.d);
160         return res;
161     }
162     uintN opBinary(string op : "/")(in ulong rr) const {
163         uintN res;
164         ulong back = 0;
165         foreach_reverse (i; 0..N) {
166             assert(back < rr);
167             ulong pred = div128([d[i], back], rr);
168             res[i] = pred;
169             back = d[i]-(rr*pred);
170         }
171         return res;
172     }
173     uintN opBinary(string op : "/")(in uintN rr) const {
174         int up = -1, shift;
175         foreach_reverse (i; 0..N) {
176             if (rr[i]) {
177                 up = i;
178                 shift = 63 - bsr(rr[i]);
179                 break;
180             }
181         }
182         assert(up != -1);
183         if (up == 0) {
184             return this / ulong(rr[0]);
185         }
186         ulong[N+1] l;
187         l[0..N] = d[0..N];
188         shiftLeftMultiWord(l, shift, l);
189         auto r = (rr << shift);
190         uintN res;
191         foreach_reverse (i; 0..N-up) {
192             //compare l[i, i+up+1] -> res[i]
193             ulong pred = (r[up] == ulong.max) ? l[i+up+1] : div128([l[i+up], l[i+up+1]], r[up]+1);
194             res[i] = pred;
195             ulong[N+1] buf;
196             mulMultiWord(r.d[], pred, buf); // r * pred
197             subMultiWord(l[i..i+up+2], buf[], l[i..i+up+2]);
198             while (cmpMultiWord(l[i..i+up+2], r.d[]) != -1) {
199                 res[i]++;
200                 subMultiWord(l[i..i+up+2], r.d[], l[i..i+up+2]);
201             }
202         }
203         return res;
204     }
205     ulong opBinary(string op : "%")(in ulong r) const {
206         static if (N == 2) {
207             return mod128([d[0], d[1] % r], r);
208         } else {
209             return (this % uintN(r)).d[0];
210         }
211     }
212     uintN opBinary(string op : "%")(in uintN r) const {
213         static if (N == 2) {
214             if (r[1] == 0) return uintN(this % ulong(r[0]));
215         }
216         return this - this/r*r;
217     }
218     auto opOpAssign(string op, T)(in T r) {
219         return mixin("this=this" ~ op ~ "r");
220     }
221 }
222 
223 
224 void addMultiWord(in ulong[] l, in ulong[] r, ulong[] res) {
225     auto N = res.length;
226     bool of = false;
227     foreach (i; 0..N) {
228         bool nof;
229         res[i] = addu(
230             (i < l.length) ? l[i] : 0UL,
231             (i < r.length) ? r[i] : 0UL, nof);
232         if (of) {
233             res[i]++;
234             nof |= (res[i] == 0);
235         }
236         of = nof;
237     }
238 }
239 
240 unittest {
241     import std.algorithm;
242     auto l = [ulong.max, ulong.max, 0UL];
243     auto r = [1UL];
244     ulong[] res = new ulong[4];
245     addMultiWord(l, r, res[]);
246     assert(equal(res, [0UL, 0UL, 1UL, 0UL]));
247 }
248 
249 // res = l-r
250 void subMultiWord(in ulong[] l, in ulong[] r, ulong[] res) {
251     auto N = res.length;
252     bool of = false;
253     foreach (i; 0..N) {
254         bool nof;
255         res[i] = subu(
256             (i < l.length) ? l[i] : 0UL,
257             (i < r.length) ? r[i] : 0UL, nof);
258         if (of) {
259             res[i]--;
260             nof |= (res[i] == ulong.max);
261         }
262         of = nof;
263     }
264 }
265 
266 unittest {
267     import std.algorithm;
268     auto l = [0UL, 0UL, 1UL];
269     auto r = [1UL];
270     ulong[] res = new ulong[4];
271     subMultiWord(l, r, res[]);
272     assert(equal(res, [ulong.max, ulong.max, 0UL, 0UL]));
273 }
274 
275 void mulMultiWord(in ulong[] l, in ulong r, ulong[] res) {
276     auto N = res.length;
277     ulong ca;
278     foreach (i; 0..N) {
279         auto u = mul128((i < l.length) ? l[i] : 0UL, r);
280         bool of;
281         res[i] = addu(u[0], ca, of);
282         if (of) u[1]++;
283         ca = u[1];
284     }
285 }
286 
287 void shiftLeftMultiWord(in ulong[] l, int n, ulong[] res) {
288     size_t N = res.length;
289     int ws = n / 64;
290     int bs = n % 64;
291     import std.stdio;
292     foreach_reverse (ptrdiff_t i; 0..N) {
293         ulong b = (0 <= i-ws && i-ws < l.length) ? l[i-ws] : 0UL;
294         if (bs == 0) res[i] = b;
295         else {
296             ulong a = (0 <= i-ws-1 && i-ws-1 < l.length) ? l[i-ws-1] : 0UL;
297             res[i] = (b << bs) | (a >> (64-bs));
298         }
299     }
300 }
301 
302 // std.algorithm.cmp, reverse ver
303 int cmpMultiWord(in ulong[] l, in ulong[] r) {
304     import std.algorithm : max;
305     auto N = max(l.length, r.length);
306     foreach_reverse (i; 0..N) {
307         auto ld = (i < l.length) ? l[i] : 0UL;
308         auto rd = (i < r.length) ? r[i] : 0UL;
309         if (ld < rd) return -1;
310         if (ld > rd) return 1;
311     }
312     return 0;
313 }
314 
315 ///
316 unittest {
317     import std.conv;
318     alias Uint = uintN!20;
319     auto x = Uint("31415926535897969393238462");
320     auto y = Uint("1145141919810893");
321     assert((x*y).to!string == "35975694425956177975650270094479894166566");
322     assert((x/y).to!string == "27434090039");
323 }
324 
325 unittest {
326     import std.conv;
327     alias Uint = uintN!4;
328     auto x = Uint("115792089237316195417293883273301227089434195242432897623355228563449095127040");
329     auto y = Uint("340282366920938463500268095579187314687");
330     assert((x%y).to!string == "340282366920938463186673446326124937222");
331 }
332 
333 unittest {
334     import std.conv;
335     uintN!10 x = uintN!10("114514");
336     assert(x.toString == "114514");
337     assert(x.toString == "114514");
338     assert(x.toString == "114514");
339 }
340 
341 unittest {
342     void check(int N)() {
343         alias Uint = uintN!N;
344         Uint[] v;
345         Uint buf;
346         void dfs(int p) {
347             if (p == N) {
348                 v ~= buf;
349                 return;
350             }
351             buf.d[p] = 0;
352             dfs(p+1);
353             buf.d[p] = 1;
354             dfs(p+1);
355             buf.d[p] = ulong.max;
356             dfs(p+1);
357             if (N <= 3) {
358                 buf.d[p] = ulong.max - 1;
359                 dfs(p+1);
360             }
361         }
362         dfs(0);
363         import std.bigint;
364         BigInt mask = BigInt(1) << (64*N);
365         void f(string op, R)(Uint x, R y) {
366             import std.conv;
367             auto x2 = BigInt(x.to!string);
368             auto y2 = BigInt(y.to!string);
369             auto z = mixin("x" ~ op ~ "y");
370             auto z2 = mixin("x2" ~ op ~ "y2");
371             z2 = (z2 % mask + mask) % mask;
372             string s1 = z.to!string;
373             string s2 = z2.to!string;
374             assert(s1 == s2);
375         }
376         void g(string op)(Uint x) {
377             import std.conv;
378             auto x2 = BigInt(x.to!string);
379             auto z = mixin(op ~ "x");
380             auto z2 = mixin(op ~ "x2");
381             x2 = (x2 % mask + mask) % mask;
382             z2 = (z2 % mask + mask) % mask;
383             assert(x.to!string == x2.to!string);
384             string s1 = z.to!string;
385             string s2 = z2.to!string;
386             assert(s1 == s2);
387         }
388         foreach (d; v) {
389             g!"++"(d);
390             g!"--"(d);
391             g!"~"(d);       
392             f!"/"(d, ulong(1));
393             f!"/"(d, ulong(2));
394             f!"/"(d, ulong(ulong.max));
395             f!"/"(d, ulong(ulong.max-1));
396             f!"%"(d, ulong(1));
397             f!"%"(d, ulong(2));
398             f!"%"(d, ulong(ulong.max));
399             f!"%"(d, ulong(ulong.max-1));
400             foreach (e; v) {
401                 f!"+"(d, e);
402                 f!"-"(d, e);
403                 f!"*"(d, e);
404                 if (e != Uint(0)) {
405                     f!"/"(d, e);
406                     f!"%"(d, e);
407                 }
408                 f!"&"(d, e);
409                 f!"|"(d, e);
410                 f!"^"(d, e);
411             }
412         }
413     }
414 
415     import std.stdio, std.algorithm;
416     import dkh.stopwatch;
417     auto ti = benchmark!(check!1, check!2, check!3, check!4)(1);
418     writeln("BigInt: ", ti[].map!(a => a.toMsecs()));
419 }