1 module dkh.bigint;
2 
3 import core.checkedint, core.bitop;
4 import dkh.int128, dkh.foundation;
5 
6 /**
7 Bigint with fixed length. The length is $(D (N*64)bit).
8 Ex. if you use $(D uintN!2), it mean uint128.
9  */
10 struct uintN(int N) if (N >= 1) {
11     ulong[N] d;
12     this(ulong x) { d[0] = x; }
13     this(string s) {
14         foreach (c; s) {
15             this *= 10;
16             this += uintN(c-'0');
17         }
18     }
19     string toString() {
20         import std.conv : to;
21         import std.algorithm : reverse;
22         import dkh.container.stackpayload;
23         StackPayload!char s;
24         auto x = this;
25         if (!x) return "0";
26         while (x) {
27             static immutable B = 10UL^^18;
28             ulong z = (x % B);
29             x /= B;
30             bool last = (!x);
31             foreach (i; 0..18) {
32                 if (last && !z) break;
33                 s ~= cast(char)('0' + z % 10);
34                 z /= 10;
35             }
36         }
37         reverse(s.data);
38         return s.data.idup;
39     }
40 
41     ref inout(ulong) opIndex(int idx) inout { return d[idx]; }
42     T opCast(T: bool)() {
43         import std.algorithm, std.range;
44         return d[].find!"a!=0".empty == false;
45     }
46     //bit op
47     uintN opUnary(string op)() const if (op == "~") {
48         uintN res;
49         foreach (i; 0..N) {
50             res[i] = ~d[i];
51         }
52         return res;
53     }
54     uintN opBinary(string op)(in uintN r) const
55     if (op == "&" || op == "|" || op == "^") {
56         uintN res;
57         foreach (i; 0..N) {
58             res[i] = mixin("d[i]" ~ op ~ "r.d[i]");
59         }
60         return res;
61     }
62     uintN opBinary(string op : "<<")(int n) const {
63         if (N * 64 <= n) return uintN(0);
64         uintN res;
65         int ws = n / 64;
66         int bs = n % 64;
67         if (bs == 0) {
68             res.d[ws..N][] = d[0..N-ws][];
69             return res;
70         }
71         foreach_reverse (i; 1..N-ws) {
72             res[i+ws] = (d[i] << bs) | (d[i-1] >> (64-bs));
73         }
74         res[ws] = (d[0] << bs);
75         return res;
76     }
77     uintN opBinary(string op : ">>")(int n) const {
78         if (N * 64 <= n) return uintN(0);
79         uintN res;
80         int ws = n / 64;
81         int bs = n % 64;
82         if (bs == 0) {
83             res.d[0..N-ws][] = d[ws..N][];
84             return res;
85         }
86         foreach_reverse (i; 0..N-ws-1) {
87             res[i] = (d[i+ws+1] >> (64-bs)) | (d[i+ws] << bs);
88         }
89         res[N-ws-1] = (d[N-1] << bs);
90         return res;
91     }
92     //cmp
93     int opCmp(in uintN r) const {
94         return cmpMultiWord(d, r.d);
95     }
96 
97     //arit
98     uintN opUnary(string op)() if (op == "++") {
99         foreach (i; 0..N) {
100             d[i]++;
101             if (d[i]) break;
102         }
103         return this;
104     }
105     uintN opUnary(string op)() if (op == "--") {
106         foreach (i; 0..N) {
107             d[i]--;
108             if (d[i] != ulong.max) break;
109         }
110         return this;
111     }
112     uintN opUnary(string op)() const if (op=="+" || op=="-") {
113         if (op == "+") return this;
114         if (op == "-") {
115             return ++(~this);
116         }
117     }
118     
119     uintN opBinary(string op : "+")(in uintN r) const {
120         uintN res;
121         addMultiWord(d, r.d, res.d);
122         return res;
123     }
124     uintN opBinary(string op : "-")(in uintN r) const {
125         uintN res;
126         subMultiWord(d, r.d, res.d);
127         return res;
128     }
129 
130     uintN opBinary(string op : "*")(in uintN r) const {
131         uintN res;
132         static if (N == 2) {
133             auto u = mul128(d[0], r[0]);
134             res[0] = u[0];
135             res[1] = u[1] + d[0]*r[1] + d[1]*r[0];
136             return res;
137         } else {
138             foreach (i; 0..N) {
139                 ulong carry = 0;
140                 foreach (j; 0..N-1-i) {
141                     int s = i+j;
142                     bool of;
143                     auto u = mul128(d[i], r[j]);
144                     res[s] = addu(res[s], carry, of);
145                     carry = u[1];
146                     if (of) carry++;
147                     of = false;
148                     res[s] = addu(res[s], u[0], of);
149                     if (of) carry++;
150                 }
151                 res[N-1] += d[i] * r[N-1-i] + carry;
152             }
153             return res;
154         }
155     }
156     uintN opBinary(string op : "*")(in ulong r) const {
157         uintN res;
158         mulMultiWord(d, r, res.d);
159         return res;
160     }
161     uintN opBinary(string op : "/")(in ulong rr) const {
162         uintN res;
163         ulong back = 0;
164         foreach_reverse (i; 0..N) {
165             assert(back < rr);
166             ulong pred = div128([d[i], back], rr);
167             res[i] = pred;
168             back = d[i]-(rr*pred);
169         }
170         return res;
171     }
172     uintN opBinary(string op : "/")(in uintN rr) const {
173         int up = -1, shift;
174         foreach_reverse (i; 0..N) {
175             if (rr[i]) {
176                 up = i;
177                 shift = 63 - bsr(rr[i]);
178                 break;
179             }
180         }
181         assert(up != -1);
182         if (up == 0) {
183             return this / ulong(rr[0]);
184         }
185         ulong[N+1] l;
186         l[0..N] = d[0..N];
187         shiftLeftMultiWord(l, shift, l);
188         auto r = (rr << shift);
189         uintN res;
190         foreach_reverse (i; 0..N-up) {
191             //compare l[i, i+up+1] -> res[i]
192             ulong pred = (r[up] == ulong.max) ? l[i+up+1] : div128([l[i+up], l[i+up+1]], r[up]+1);
193             res[i] = pred;
194             ulong[N+1] buf;
195             mulMultiWord(r.d[], pred, buf); // r * pred
196             subMultiWord(l[i..i+up+2], buf[], l[i..i+up+2]);
197             while (cmpMultiWord(l[i..i+up+2], r.d[]) != -1) {
198                 res[i]++;
199                 subMultiWord(l[i..i+up+2], r.d[], l[i..i+up+2]);
200             }
201         }
202         return res;
203     }
204     ulong opBinary(string op : "%")(in ulong r) const {
205         static if (N == 2) {
206             return mod128([d[0], d[1] % r], r);
207         } else {
208             return (this % uintN(r)).d[0];
209         }
210     }
211     uintN opBinary(string op : "%")(in uintN r) const {
212         static if (N == 2) {
213             if (r[1] == 0) return uintN(this % ulong(r[0]));
214         }
215         return this - this/r*r;
216     }
217     auto opOpAssign(string op, T)(in T r) {
218         return mixin("this=this" ~ op ~ "r");
219     }
220 }
221 
222 
223 void addMultiWord(in ulong[] l, in ulong[] r, ulong[] res) {
224     auto N = res.length;
225     bool of = false;
226     foreach (i; 0..N) {
227         bool nof;
228         res[i] = addu(
229             (i < l.length) ? l[i] : 0UL,
230             (i < r.length) ? r[i] : 0UL, nof);
231         if (of) {
232             res[i]++;
233             nof |= (res[i] == 0);
234         }
235         of = nof;
236     }
237 }
238 
239 unittest {
240     import std.algorithm;
241     auto l = [ulong.max, ulong.max, 0UL];
242     auto r = [1UL];
243     ulong[] res = new ulong[4];
244     addMultiWord(l, r, res[]);
245     assert(equal(res, [0UL, 0UL, 1UL, 0UL]));
246 }
247 
248 // res = l-r
249 void subMultiWord(in ulong[] l, in ulong[] r, ulong[] res) {
250     auto N = res.length;
251     bool of = false;
252     foreach (i; 0..N) {
253         bool nof;
254         res[i] = subu(
255             (i < l.length) ? l[i] : 0UL,
256             (i < r.length) ? r[i] : 0UL, nof);
257         if (of) {
258             res[i]--;
259             nof |= (res[i] == ulong.max);
260         }
261         of = nof;
262     }
263 }
264 
265 unittest {
266     import std.algorithm;
267     auto l = [0UL, 0UL, 1UL];
268     auto r = [1UL];
269     ulong[] res = new ulong[4];
270     subMultiWord(l, r, res[]);
271     assert(equal(res, [ulong.max, ulong.max, 0UL, 0UL]));
272 }
273 
274 void mulMultiWord(in ulong[] l, in ulong r, ulong[] res) {
275     auto N = res.length;
276     ulong ca;
277     foreach (i; 0..N) {
278         auto u = mul128((i < l.length) ? l[i] : 0UL, r);
279         bool of;
280         res[i] = addu(u[0], ca, of);
281         if (of) u[1]++;
282         ca = u[1];
283     }
284 }
285 
286 void shiftLeftMultiWord(in ulong[] l, int n, ulong[] res) {
287     size_t N = res.length;
288     int ws = n / 64;
289     int bs = n % 64;
290     import std.stdio;
291     foreach_reverse (ptrdiff_t i; 0..N) {
292         ulong b = (0 <= i-ws && i-ws < l.length) ? l[i-ws] : 0UL;
293         if (bs == 0) res[i] = b;
294         else {
295             ulong a = (0 <= i-ws-1 && i-ws-1 < l.length) ? l[i-ws-1] : 0UL;
296             res[i] = (b << bs) | (a >> (64-bs));
297         }
298     }
299 }
300 
301 // std.algorithm.cmp, reverse ver
302 int cmpMultiWord(in ulong[] l, in ulong[] r) {
303     import std.algorithm : max;
304     auto N = max(l.length, r.length);
305     foreach_reverse (i; 0..N) {
306         auto ld = (i < l.length) ? l[i] : 0UL;
307         auto rd = (i < r.length) ? r[i] : 0UL;
308         if (ld < rd) return -1;
309         if (ld > rd) return 1;
310     }
311     return 0;
312 }
313 
314 ///
315 unittest {
316     import std.conv;
317     alias Uint = uintN!20;
318     auto x = Uint("31415926535897969393238462");
319     auto y = Uint("1145141919810893");
320     assert((x*y).to!string == "35975694425956177975650270094479894166566");
321     assert((x/y).to!string == "27434090039");
322 }
323 
324 unittest {
325     import std.conv;
326     alias Uint = uintN!4;
327     auto x = Uint("115792089237316195417293883273301227089434195242432897623355228563449095127040");
328     auto y = Uint("340282366920938463500268095579187314687");
329     assert((x%y).to!string == "340282366920938463186673446326124937222");
330 }
331 
332 unittest {
333     import std.conv;
334     uintN!10 x = uintN!10("114514");
335     assert(x.toString == "114514");
336     assert(x.toString == "114514");
337     assert(x.toString == "114514");
338 }
339 
340 unittest {
341     void check(int N)() {
342         alias Uint = uintN!N;
343         Uint[] v;
344         Uint buf;
345         void dfs(int p) {
346             if (p == N) {
347                 v ~= buf;
348                 return;
349             }
350             buf.d[p] = 0;
351             dfs(p+1);
352             buf.d[p] = 1;
353             dfs(p+1);
354             buf.d[p] = ulong.max;
355             dfs(p+1);
356             if (N <= 3) {
357                 buf.d[p] = ulong.max - 1;
358                 dfs(p+1);
359             }
360         }
361         dfs(0);
362         import std.bigint;
363         BigInt mask = BigInt(1) << (64*N);
364         void f(string op, R)(Uint x, R y) {
365             import std.conv;
366             auto x2 = BigInt(x.to!string);
367             auto y2 = BigInt(y.to!string);
368             auto z = mixin("x" ~ op ~ "y");
369             auto z2 = mixin("x2" ~ op ~ "y2");
370             z2 = (z2 % mask + mask) % mask;
371             string s1 = z.to!string;
372             string s2 = z2.to!string;
373             assert(s1 == s2);
374         }
375         void g(string op)(Uint x) {
376             import std.conv;
377             auto x2 = BigInt(x.to!string);
378             auto z = mixin(op ~ "x");
379             auto z2 = mixin(op ~ "x2");
380             x2 = (x2 % mask + mask) % mask;
381             z2 = (z2 % mask + mask) % mask;
382             assert(x.to!string == x2.to!string);
383             string s1 = z.to!string;
384             string s2 = z2.to!string;
385             assert(s1 == s2);
386         }
387         foreach (d; v) {
388             g!"++"(d);
389             g!"--"(d);
390             g!"~"(d);       
391             f!"/"(d, ulong(1));
392             f!"/"(d, ulong(2));
393             f!"/"(d, ulong(ulong.max));
394             f!"/"(d, ulong(ulong.max-1));
395             f!"%"(d, ulong(1));
396             f!"%"(d, ulong(2));
397             f!"%"(d, ulong(ulong.max));
398             f!"%"(d, ulong(ulong.max-1));
399             foreach (e; v) {
400                 f!"+"(d, e);
401                 f!"-"(d, e);
402                 f!"*"(d, e);
403                 if (e != Uint(0)) {
404                     f!"/"(d, e);
405                     f!"%"(d, e);
406                 }
407                 f!"&"(d, e);
408                 f!"|"(d, e);
409                 f!"^"(d, e);
410             }
411         }
412     }
413 
414     import std.stdio, std.algorithm;
415     import dkh.stopwatch;
416     auto ti = benchmark!(check!1, check!2, check!3, check!4)(1);
417     writeln("BigInt: ", ti[].map!(a => a.toMsecs()));
418 }