1 module dkh.bigint; 2 3 import core.checkedint, core.bitop; 4 import dkh.int128, dkh.foundation; 5 6 /** 7 Bigint with fixed length. The length is $(D (N*64)bit). 8 Ex. if you use $(D uintN!2), it mean uint128. 9 */ 10 struct uintN(int N) if (N >= 1) { 11 ulong[N] d; 12 this(ulong x) { d[0] = x; } 13 this(string s) { 14 foreach (c; s) { 15 this *= 10; 16 this += uintN(c-'0'); 17 } 18 } 19 string toString() { 20 import std.conv : to; 21 import std.algorithm : reverse; 22 import dkh.container.stackpayload; 23 StackPayload!char s; 24 auto x = this; 25 if (!x) return "0"; 26 while (x) { 27 static immutable B = 10UL^^18; 28 ulong z = (x % B); 29 x /= B; 30 bool last = (!x); 31 foreach (i; 0..18) { 32 if (last && !z) break; 33 s ~= cast(char)('0' + z % 10); 34 z /= 10; 35 } 36 } 37 reverse(s.data); 38 return s.data.idup; 39 } 40 41 ref inout(ulong) opIndex(int idx) inout { return d[idx]; } 42 T opCast(T: bool)() { 43 import std.algorithm, std.range; 44 return d[].find!"a!=0".empty == false; 45 } 46 //bit op 47 uintN opUnary(string op)() const if (op == "~") { 48 uintN res; 49 foreach (i; 0..N) { 50 res[i] = ~d[i]; 51 } 52 return res; 53 } 54 uintN opBinary(string op)(in uintN r) const 55 if (op == "&" || op == "|" || op == "^") { 56 uintN res; 57 foreach (i; 0..N) { 58 res[i] = mixin("d[i]" ~ op ~ "r.d[i]"); 59 } 60 return res; 61 } 62 uintN opBinary(string op : "<<")(int n) const { 63 if (N * 64 <= n) return uintN(0); 64 uintN res; 65 int ws = n / 64; 66 int bs = n % 64; 67 if (bs == 0) { 68 res.d[ws..N][] = d[0..N-ws][]; 69 return res; 70 } 71 foreach_reverse (i; 1..N-ws) { 72 res[i+ws] = (d[i] << bs) | (d[i-1] >> (64-bs)); 73 } 74 res[ws] = (d[0] << bs); 75 return res; 76 } 77 uintN opBinary(string op : ">>")(int n) const { 78 if (N * 64 <= n) return uintN(0); 79 uintN res; 80 int ws = n / 64; 81 int bs = n % 64; 82 if (bs == 0) { 83 res.d[0..N-ws][] = d[ws..N][]; 84 return res; 85 } 86 foreach_reverse (i; 0..N-ws-1) { 87 res[i] = (d[i+ws+1] >> (64-bs)) | (d[i+ws] << bs); 88 } 89 res[N-ws-1] = (d[N-1] << bs); 90 return res; 91 } 92 //cmp 93 int opCmp(in uintN r) const { 94 return cmpMultiWord(d, r.d); 95 } 96 97 //arit 98 uintN opUnary(string op)() if (op == "++") { 99 foreach (i; 0..N) { 100 d[i]++; 101 if (d[i]) break; 102 } 103 return this; 104 } 105 uintN opUnary(string op)() if (op == "--") { 106 foreach (i; 0..N) { 107 d[i]--; 108 if (d[i] != ulong.max) break; 109 } 110 return this; 111 } 112 uintN opUnary(string op)() const if (op=="+" || op=="-") { 113 if (op == "+") return this; 114 if (op == "-") { 115 return ++(~this); 116 } 117 } 118 119 uintN opBinary(string op : "+")(in uintN r) const { 120 uintN res; 121 addMultiWord(d, r.d, res.d); 122 return res; 123 } 124 uintN opBinary(string op : "-")(in uintN r) const { 125 uintN res; 126 subMultiWord(d, r.d, res.d); 127 return res; 128 } 129 130 uintN opBinary(string op : "*")(in uintN r) const { 131 uintN res; 132 static if (N == 2) { 133 auto u = mul128(d[0], r[0]); 134 res[0] = u[0]; 135 res[1] = u[1] + d[0]*r[1] + d[1]*r[0]; 136 return res; 137 } else { 138 foreach (i; 0..N) { 139 ulong carry = 0; 140 foreach (j; 0..N-1-i) { 141 int s = i+j; 142 bool of; 143 auto u = mul128(d[i], r[j]); 144 res[s] = addu(res[s], carry, of); 145 carry = u[1]; 146 if (of) carry++; 147 of = false; 148 res[s] = addu(res[s], u[0], of); 149 if (of) carry++; 150 } 151 res[N-1] += d[i] * r[N-1-i] + carry; 152 } 153 return res; 154 } 155 } 156 uintN opBinary(string op : "*")(in ulong r) const { 157 uintN res; 158 mulMultiWord(d, r, res.d); 159 return res; 160 } 161 uintN opBinary(string op : "/")(in ulong rr) const { 162 uintN res; 163 ulong back = 0; 164 foreach_reverse (i; 0..N) { 165 assert(back < rr); 166 ulong pred = div128([d[i], back], rr); 167 res[i] = pred; 168 back = d[i]-(rr*pred); 169 } 170 return res; 171 } 172 uintN opBinary(string op : "/")(in uintN rr) const { 173 int up = -1, shift; 174 foreach_reverse (i; 0..N) { 175 if (rr[i]) { 176 up = i; 177 shift = 63 - bsr(rr[i]); 178 break; 179 } 180 } 181 assert(up != -1); 182 if (up == 0) { 183 return this / ulong(rr[0]); 184 } 185 ulong[N+1] l; 186 l[0..N] = d[0..N]; 187 shiftLeftMultiWord(l, shift, l); 188 auto r = (rr << shift); 189 uintN res; 190 foreach_reverse (i; 0..N-up) { 191 //compare l[i, i+up+1] -> res[i] 192 ulong pred = (r[up] == ulong.max) ? l[i+up+1] : div128([l[i+up], l[i+up+1]], r[up]+1); 193 res[i] = pred; 194 ulong[N+1] buf; 195 mulMultiWord(r.d[], pred, buf); // r * pred 196 subMultiWord(l[i..i+up+2], buf[], l[i..i+up+2]); 197 while (cmpMultiWord(l[i..i+up+2], r.d[]) != -1) { 198 res[i]++; 199 subMultiWord(l[i..i+up+2], r.d[], l[i..i+up+2]); 200 } 201 } 202 return res; 203 } 204 ulong opBinary(string op : "%")(in ulong r) const { 205 static if (N == 2) { 206 return mod128([d[0], d[1] % r], r); 207 } else { 208 return (this % uintN(r)).d[0]; 209 } 210 } 211 uintN opBinary(string op : "%")(in uintN r) const { 212 static if (N == 2) { 213 if (r[1] == 0) return uintN(this % ulong(r[0])); 214 } 215 return this - this/r*r; 216 } 217 auto opOpAssign(string op, T)(in T r) { 218 return mixin("this=this" ~ op ~ "r"); 219 } 220 } 221 222 223 void addMultiWord(in ulong[] l, in ulong[] r, ulong[] res) { 224 auto N = res.length; 225 bool of = false; 226 foreach (i; 0..N) { 227 bool nof; 228 res[i] = addu( 229 (i < l.length) ? l[i] : 0UL, 230 (i < r.length) ? r[i] : 0UL, nof); 231 if (of) { 232 res[i]++; 233 nof |= (res[i] == 0); 234 } 235 of = nof; 236 } 237 } 238 239 unittest { 240 import std.algorithm; 241 auto l = [ulong.max, ulong.max, 0UL]; 242 auto r = [1UL]; 243 ulong[] res = new ulong[4]; 244 addMultiWord(l, r, res[]); 245 assert(equal(res, [0UL, 0UL, 1UL, 0UL])); 246 } 247 248 // res = l-r 249 void subMultiWord(in ulong[] l, in ulong[] r, ulong[] res) { 250 auto N = res.length; 251 bool of = false; 252 foreach (i; 0..N) { 253 bool nof; 254 res[i] = subu( 255 (i < l.length) ? l[i] : 0UL, 256 (i < r.length) ? r[i] : 0UL, nof); 257 if (of) { 258 res[i]--; 259 nof |= (res[i] == ulong.max); 260 } 261 of = nof; 262 } 263 } 264 265 unittest { 266 import std.algorithm; 267 auto l = [0UL, 0UL, 1UL]; 268 auto r = [1UL]; 269 ulong[] res = new ulong[4]; 270 subMultiWord(l, r, res[]); 271 assert(equal(res, [ulong.max, ulong.max, 0UL, 0UL])); 272 } 273 274 void mulMultiWord(in ulong[] l, in ulong r, ulong[] res) { 275 auto N = res.length; 276 ulong ca; 277 foreach (i; 0..N) { 278 auto u = mul128((i < l.length) ? l[i] : 0UL, r); 279 bool of; 280 res[i] = addu(u[0], ca, of); 281 if (of) u[1]++; 282 ca = u[1]; 283 } 284 } 285 286 void shiftLeftMultiWord(in ulong[] l, int n, ulong[] res) { 287 size_t N = res.length; 288 int ws = n / 64; 289 int bs = n % 64; 290 import std.stdio; 291 foreach_reverse (ptrdiff_t i; 0..N) { 292 ulong b = (0 <= i-ws && i-ws < l.length) ? l[i-ws] : 0UL; 293 if (bs == 0) res[i] = b; 294 else { 295 ulong a = (0 <= i-ws-1 && i-ws-1 < l.length) ? l[i-ws-1] : 0UL; 296 res[i] = (b << bs) | (a >> (64-bs)); 297 } 298 } 299 } 300 301 // std.algorithm.cmp, reverse ver 302 int cmpMultiWord(in ulong[] l, in ulong[] r) { 303 import std.algorithm : max; 304 auto N = max(l.length, r.length); 305 foreach_reverse (i; 0..N) { 306 auto ld = (i < l.length) ? l[i] : 0UL; 307 auto rd = (i < r.length) ? r[i] : 0UL; 308 if (ld < rd) return -1; 309 if (ld > rd) return 1; 310 } 311 return 0; 312 } 313 314 /// 315 unittest { 316 import std.conv; 317 alias Uint = uintN!20; 318 auto x = Uint("31415926535897969393238462"); 319 auto y = Uint("1145141919810893"); 320 assert((x*y).to!string == "35975694425956177975650270094479894166566"); 321 assert((x/y).to!string == "27434090039"); 322 } 323 324 unittest { 325 import std.conv; 326 alias Uint = uintN!4; 327 auto x = Uint("115792089237316195417293883273301227089434195242432897623355228563449095127040"); 328 auto y = Uint("340282366920938463500268095579187314687"); 329 assert((x%y).to!string == "340282366920938463186673446326124937222"); 330 } 331 332 unittest { 333 import std.conv; 334 uintN!10 x = uintN!10("114514"); 335 assert(x.toString == "114514"); 336 assert(x.toString == "114514"); 337 assert(x.toString == "114514"); 338 } 339 340 unittest { 341 void check(int N)() { 342 alias Uint = uintN!N; 343 Uint[] v; 344 Uint buf; 345 void dfs(int p) { 346 if (p == N) { 347 v ~= buf; 348 return; 349 } 350 buf.d[p] = 0; 351 dfs(p+1); 352 buf.d[p] = 1; 353 dfs(p+1); 354 buf.d[p] = ulong.max; 355 dfs(p+1); 356 if (N <= 3) { 357 buf.d[p] = ulong.max - 1; 358 dfs(p+1); 359 } 360 } 361 dfs(0); 362 import std.bigint; 363 BigInt mask = BigInt(1) << (64*N); 364 void f(string op, R)(Uint x, R y) { 365 import std.conv; 366 auto x2 = BigInt(x.to!string); 367 auto y2 = BigInt(y.to!string); 368 auto z = mixin("x" ~ op ~ "y"); 369 auto z2 = mixin("x2" ~ op ~ "y2"); 370 z2 = (z2 % mask + mask) % mask; 371 string s1 = z.to!string; 372 string s2 = z2.to!string; 373 assert(s1 == s2); 374 } 375 void g(string op)(Uint x) { 376 import std.conv; 377 auto x2 = BigInt(x.to!string); 378 auto z = mixin(op ~ "x"); 379 auto z2 = mixin(op ~ "x2"); 380 x2 = (x2 % mask + mask) % mask; 381 z2 = (z2 % mask + mask) % mask; 382 assert(x.to!string == x2.to!string); 383 string s1 = z.to!string; 384 string s2 = z2.to!string; 385 assert(s1 == s2); 386 } 387 foreach (d; v) { 388 g!"++"(d); 389 g!"--"(d); 390 g!"~"(d); 391 f!"/"(d, ulong(1)); 392 f!"/"(d, ulong(2)); 393 f!"/"(d, ulong(ulong.max)); 394 f!"/"(d, ulong(ulong.max-1)); 395 f!"%"(d, ulong(1)); 396 f!"%"(d, ulong(2)); 397 f!"%"(d, ulong(ulong.max)); 398 f!"%"(d, ulong(ulong.max-1)); 399 foreach (e; v) { 400 f!"+"(d, e); 401 f!"-"(d, e); 402 f!"*"(d, e); 403 if (e != Uint(0)) { 404 f!"/"(d, e); 405 f!"%"(d, e); 406 } 407 f!"&"(d, e); 408 f!"|"(d, e); 409 f!"^"(d, e); 410 } 411 } 412 } 413 414 import std.stdio, std.algorithm; 415 import dkh.stopwatch; 416 auto ti = benchmark!(check!1, check!2, check!3, check!4)(1); 417 writeln("BigInt: ", ti[].map!(a => a.toMsecs())); 418 }