1 /** 2 64bit op 64bit -> 128bit, library(mul, div, divmod) 3 */ 4 5 module dkh.int128; 6 7 version(LDC) { 8 import dkh.ldc.inline; 9 } 10 11 version(LDC) version(X86_64) { 12 version = LDC_IR; 13 } 14 15 /// a * b = (return[1]<<64) + return[0] 16 ulong[2] mul128(ulong a, ulong b) { 17 ulong[2] res; 18 version(LDC_IR) { 19 ulong upper, lower; 20 inlineIR!(` 21 %r0 = zext i64 %0 to i128 22 %r1 = zext i64 %1 to i128 23 %r2 = mul i128 %r1, %r0 24 %r3 = trunc i128 %r2 to i64 25 %r4 = lshr i128 %r2, 64 26 %r5 = trunc i128 %r4 to i64 27 store i64 %r3, i64* %2 28 store i64 %r5, i64* %3`, void)(a, b, &lower, &upper); 29 return [lower, upper]; 30 } else version(D_InlineAsm_X86_64) { 31 ulong upper, lower; 32 asm { 33 mov RAX, a; 34 mul b; 35 mov lower, RAX; 36 mov upper, RDX; 37 } 38 return [lower, upper]; 39 } else { 40 ulong B = 2UL^^32; 41 ulong[2] a2 = [a % B, a / B]; 42 ulong[2] b2 = [b % B, b / B]; 43 ulong[4] c; 44 foreach (i; 0..2) { 45 foreach (j; 0..2) { 46 c[i+j] += a2[i] * b2[j] % B; 47 c[i+j+1] += a2[i] * b2[j] / B; 48 } 49 } 50 foreach (i; 0..3) { 51 c[i+1] += c[i] / B; 52 c[i] %= B; 53 } 54 return [c[0] + c[1] * B, c[2] + c[3] * B]; 55 } 56 } 57 58 unittest { 59 import std.random, std.algorithm, std.stdio, std.conv; 60 import dkh.stopwatch; 61 StopWatch sw; sw.start; 62 ulong[2] naive_mul(ulong a, ulong b) { 63 import std.bigint, std.conv; 64 auto a2 = BigInt(a), b2 = BigInt(b); 65 auto c = a2*b2; 66 auto m = BigInt(1)<<64; 67 return [(c % m).to!string.to!ulong, (c / m).to!string.to!ulong]; 68 } 69 ulong[] li; 70 foreach (i; 0..100) { 71 li ~= i; 72 li ~= ulong.max - i; 73 } 74 foreach (i; 0..100) { 75 li ~= uniform(0UL, ulong.max); 76 } 77 foreach (l; li) { 78 foreach (r; li) { 79 assert(equal(mul128(l, r)[], naive_mul(l, r)[])); 80 } 81 } 82 writeln("Mul128: ", sw.peek.toMsecs); 83 } 84 85 /// [a[1], a[0]] / b = return, 答えが64bitに収まらないとヤバイ 86 ulong div128(ulong[2] a, ulong b) { 87 version(LDC_IR) { 88 return inlineIR!(` 89 %r0 = zext i64 %0 to i128 90 %r1 = zext i64 %1 to i128 91 %r2 = shl i128 %r1, 64 92 %r3 = add i128 %r0, %r2 93 %r4 = zext i64 %2 to i128 94 %r5 = udiv i128 %r3, %r4 95 %r6 = trunc i128 %r5 to i64 96 ret i64 %r6`,ulong)(a[0], a[1], b); 97 } else version(D_InlineAsm_X86_64) { 98 ulong upper = a[1], lower = a[0]; 99 ulong res; 100 asm { 101 mov RDX, upper; 102 mov RAX, lower; 103 div b; 104 mov res, RAX; 105 } 106 return res; 107 } else { 108 if (b == 1) return a[0]; 109 while (!(b & (1UL << 63))) { 110 a[1] <<= 1; 111 if (a[0] & (1UL << 63)) a[1] |= 1; 112 a[0] <<= 1; 113 b <<= 1; 114 } 115 ulong ans = 0; 116 foreach (i; 0..64) { 117 bool up = (a[1] & (1UL << 63)) != 0; 118 a[1] <<= 1; 119 if (a[0] & (1UL << 63)) a[1] |= 1; 120 a[0] <<= 1; 121 122 ans <<= 1; 123 if (up || b <= a[1]) { 124 a[1] -= b; 125 ans++; 126 } 127 } 128 return ans; 129 } 130 } 131 132 133 /// [a[1], a[0]] % b = return, 答えが64bitに収まらないとヤバイ 134 ulong mod128(ulong[2] a, ulong b) { 135 version(D_InlineAsm_X86_64) { 136 ulong upper = a[1], lower = a[0]; 137 ulong res; 138 asm { 139 mov RDX, upper; 140 mov RAX, lower; 141 div b; 142 mov res, RDX; 143 } 144 return res; 145 } else { 146 return a[0] - div128(a, b) * b; 147 } 148 } 149 150 unittest { 151 import std.bigint, std.conv, std.stdio; 152 import std.random, std.algorithm; 153 import dkh.stopwatch; 154 StopWatch sw; sw.start; 155 bool overflow_check(ulong[2] a, ulong b) { 156 auto a2 = (BigInt(a[1]) << 64) + BigInt(a[0]); 157 return (a2 / b) > BigInt(ulong.max); 158 } 159 ulong naive_div(ulong[2] a, ulong b) { 160 auto a2 = (BigInt(a[1]) << 64) + BigInt(a[0]); 161 return (a2 / b).to!string.to!ulong; 162 } 163 ulong naive_mod(ulong[2] a, ulong b) { 164 auto a2 = (BigInt(a[1]) << 64) + BigInt(a[0]); 165 return (a2 % b).to!string.to!ulong; 166 } 167 ulong[2][] li; 168 ulong[] ri; 169 foreach (i; 0..50) { 170 li ~= [i, 0UL]; 171 li ~= [ulong.max - i, 0UL]; 172 } 173 foreach (i; 0..50) { 174 ri ~= i; 175 ri ~= ulong.max - i; 176 } 177 foreach (i; 0..50) { 178 li ~= [uniform(0UL, ulong.max), 0UL]; 179 } 180 foreach (i; 0..50) { 181 li ~= [uniform(0UL, ulong.max), uniform(0UL, ulong.max)]; 182 } 183 foreach (i; 0..50) { 184 ri ~= uniform(0UL, ulong.max); 185 } 186 li ~= [0, ulong.max]; 187 li ~= [ulong.max, ulong.max-1]; 188 foreach (l; li) { 189 foreach (r; ri) { 190 if (r == 0) continue; 191 if (overflow_check(l, r)) continue; 192 assert(div128(l, r) == naive_div(l, r)); 193 assert(mod128(l, r) == naive_mod(l, r)); 194 } 195 } 196 writeln("Div128: ", sw.peek.toMsecs); 197 }