1 module dkh.numeric.primitive;
2 
3 import std.traits;
4 import std.bigint;
5 
6 /// 高速累乗
7 Unqual!T pow(T, U)(T x, U n)
8 if (!isFloatingPoint!T && (isIntegral!U || is(U == BigInt))) {
9     return pow(x, n, T(1));
10 }
11 
12 /// ditto
13 Unqual!T pow(T, U, V)(T x, U n, V e)
14 if ((isIntegral!U || is(U == BigInt)) && is(Unqual!T == Unqual!V)) {
15     Unqual!T b = x, v = e;
16     Unqual!U m = n;
17     while (m) {
18         if (m & 1) v *= b;
19         b *= b;
20         m /= 2;
21     }
22     return v;
23 }
24 
25 unittest {
26     assert(pow(3, 5) == 243);
27     assert(pow(3, 5, 2) == 486);
28 }
29 
30 ///
31 T powMod(T, U, V)(T x, U n, V md)
32 if (isIntegral!U || is(U == BigInt)) {
33     T r = T(1);
34     while (n) {
35         if (n & 1) r = (r*x)%md;
36         x = (x*x)%md;
37         n >>= 1;
38     }
39     return r % md;
40 }
41 
42 import dkh.int128;
43 
44 ///
45 ulong ulongPowMod(U)(ulong x, U n, ulong md)
46 if (isIntegral!U || is(U == BigInt)) {
47     x %= md;
48     ulong r = 1;
49     while (n) {
50         if (n & 1) {
51             r = mul128(r, x).mod128(md);
52         }
53         x = mul128(x, x).mod128(md);
54         n >>= 1;
55     }
56     return r % md;
57 }
58 
59 /// lcm
60 T lcm(T)(in T a, in T b) {
61     import std.numeric : gcd;
62     return a / gcd(a,b) * b;
63 }
64 
65 ///
66 unittest {
67     assert(lcm(2, 4) == 4);
68     assert(lcm(3, 5) == 15);
69     assert(lcm(1, 1) == 1);
70     assert(lcm(0, 100) == 0);
71 }
72 
73 //todo: consider binary extgcd
74 /// a*T[0]+b*T[1]=T[2], T[2]=gcd
75 T[3] extGcd(T)(in T a, in T b) 
76 if (!isIntegral!T || isSigned!T) //unsignedはNG
77 {
78     if (b==0) {
79         return [T(1), T(0), a];
80     } else {
81         auto e = extGcd(b, a%b);
82         return [e[1], e[0]-a/b*e[1], e[2]];
83     }
84 }
85 
86 ///
87 unittest {
88     import std.numeric : gcd;
89     foreach (i; 0..100) {
90         foreach (j; 0..100) {
91             auto e = extGcd(i, j);
92             assert(e[2] == gcd(i, j));
93             assert(e[0] * i + e[1] * j == e[2]);
94         }
95     }
96 }