1 /**
2 64bit op 64bit -> 128bit の乗算/除算ライブラリ
3  */
4 
5 module dkh.int128;
6 
7 version(LDC) {
8     import dkh.ldc.inline;
9 }
10 
11 version(LDC) version(X86_64) {
12     version = LDC_IR;
13 }
14 
15 /// a * b = (return[1]<<64) + return[0]
16 ulong[2] mul128(ulong a, ulong b) {
17     ulong[2] res;
18     version(LDC_IR) {
19         ulong upper, lower;
20         inlineIR!(`
21             %r0 = zext i64 %0 to i128 
22             %r1 = zext i64 %1 to i128
23             %r2 = mul i128 %r1, %r0
24             %r3 = trunc i128 %r2 to i64
25             %r4 = lshr i128 %r2, 64
26             %r5 = trunc i128 %r4 to i64
27             store i64 %r3, i64* %2
28             store i64 %r5, i64* %3`, void)(a, b, &lower, &upper);
29         return [lower, upper];
30     } else version(D_InlineAsm_X86_64) {
31         ulong upper, lower;
32         asm {
33             mov RAX, a;
34             mul b;
35             mov lower, RAX;
36             mov upper, RDX;
37         }
38         return [lower, upper];
39     } else {
40         ulong B = 2UL^^32;
41         ulong[2] a2 = [a % B, a / B];
42         ulong[2] b2 = [b % B, b / B];
43         ulong[4] c;
44         foreach (i; 0..2) {
45             foreach (j; 0..2) {
46                 c[i+j] += a2[i] * b2[j] % B;
47                 c[i+j+1] += a2[i] * b2[j] / B;
48             }
49         }
50         foreach (i; 0..3) {
51             c[i+1] += c[i] / B;
52             c[i] %= B;
53         }
54         return [c[0] + c[1] * B, c[2] + c[3] * B];
55     }
56 }
57 
58 unittest {
59     import std.random, std.algorithm, std.stdio, std.conv;
60     import dkh.stopwatch;
61     StopWatch sw; sw.start;
62     ulong[2] naive_mul(ulong a, ulong b) {
63         import std.bigint, std.conv;
64         auto a2 = BigInt(a), b2 = BigInt(b);
65         auto c = a2*b2;
66         auto m = BigInt(1)<<64;
67         return [(c % m).to!string.to!ulong, (c / m).to!string.to!ulong];
68     }
69     ulong[] li;
70     foreach (i; 0..100) {
71         li ~= i;
72         li ~= ulong.max - i;
73     }
74     foreach (i; 0..100) {
75         li ~= uniform(0UL, ulong.max);
76     }
77     foreach (l; li) {
78         foreach (r; li) {
79             assert(equal(mul128(l, r)[], naive_mul(l, r)[]));
80         }
81     }
82     writeln("Mul128: ", sw.peek.toMsecs);
83 }
84 
85 /// [a[1], a[0]] / b = return, 答えが64bitに収まらないとヤバイ
86 ulong div128(ulong[2] a, ulong b) {
87     version(LDC_IR) {
88         return inlineIR!(`
89             %r0 = zext i64 %0 to i128
90             %r1 = zext i64 %1 to i128
91             %r2 = shl i128 %r1, 64
92             %r3 = add i128 %r0, %r2
93             %r4 = zext i64 %2 to i128
94             %r5 = udiv i128 %r3, %r4
95             %r6 = trunc i128 %r5 to i64
96             ret i64 %r6`,ulong)(a[0], a[1], b);
97     } else version(D_InlineAsm_X86_64) {
98         ulong upper = a[1], lower = a[0];
99         ulong res;
100         asm {
101             mov RDX, upper;
102             mov RAX, lower;
103             div b;
104             mov res, RAX;
105         }
106         return res;
107     } else {
108         if (b == 1) return a[0];
109         while (!(b & (1UL << 63))) {
110             a[1] <<= 1;
111             if (a[0] & (1UL << 63)) a[1] |= 1;
112             a[0] <<= 1;
113             b <<= 1;
114         }
115         ulong ans = 0;
116         foreach (i; 0..64) {
117             bool up = (a[1] & (1UL << 63)) != 0;
118             a[1] <<= 1;
119             if (a[0] & (1UL << 63)) a[1] |= 1;
120             a[0] <<= 1;
121 
122             ans <<= 1;
123             if (up || b <= a[1]) {
124                 a[1] -= b;
125                 ans++;
126             }
127         }
128         return ans;
129     }
130 }
131 
132 
133 /// [a[1], a[0]] % b = return, 答えが64bitに収まらないとヤバイ
134 ulong mod128(ulong[2] a, ulong b) {
135     version(D_InlineAsm_X86_64) {
136         ulong upper = a[1], lower = a[0];
137         ulong res;
138         asm {
139             mov RDX, upper;
140             mov RAX, lower;
141             div b;
142             mov res, RDX;
143         }
144         return res;
145     } else {
146         return a[0] - div128(a, b) * b;
147     }
148 }
149 
150 unittest {
151     import std.bigint, std.conv, std.stdio;
152     import std.random, std.algorithm;
153     import dkh.stopwatch;
154     StopWatch sw; sw.start;
155     bool overflow_check(ulong[2] a, ulong b) {
156         auto a2 = (BigInt(a[1]) << 64) + BigInt(a[0]);
157         return (a2 / b) > BigInt(ulong.max);
158     }
159     ulong naive_div(ulong[2] a, ulong b) {
160         auto a2 = (BigInt(a[1]) << 64) + BigInt(a[0]);
161         return (a2 / b).to!string.to!ulong;
162     }
163     ulong naive_mod(ulong[2] a, ulong b) {
164         auto a2 = (BigInt(a[1]) << 64) + BigInt(a[0]);
165         return (a2 % b).to!string.to!ulong;
166     }
167     ulong[2][] li;
168     ulong[] ri;
169     foreach (i; 0..50) {
170         li ~= [i, 0UL];
171         li ~= [ulong.max - i, 0UL];
172     }
173     foreach (i; 0..50) {
174         ri ~= i;
175         ri ~= ulong.max - i;
176     }
177     foreach (i; 0..50) {
178         li ~= [uniform(0UL, ulong.max), 0UL];
179     }
180     foreach (i; 0..50) {
181         li ~= [uniform(0UL, ulong.max), uniform(0UL, ulong.max)];
182     }    
183     foreach (i; 0..50) {
184         ri ~= uniform(0UL, ulong.max);
185     }
186     li ~= [0, ulong.max];
187     li ~= [ulong.max, ulong.max-1];
188     foreach (l; li) {
189         foreach (r; ri) {
190             if (r == 0) continue;
191             if (overflow_check(l, r)) continue;
192             assert(div128(l, r) == naive_div(l, r));
193             assert(mod128(l, r) == naive_mod(l, r));
194         }
195     }
196     writeln("Div128: ", sw.peek.toMsecs);
197 }