1 module dkh.functional;
2 
3 /**
4 メモ化ライブラリ
5 
6 std.functional.memoizeとは違い, 引数が連続している必要がある.
7 ハッシュテーブルではなく配列で値を保存するため高速である.
8  */
9 struct memoCont(alias pred) {
10     import std.traits : ReturnType, ParameterTypeTuple, isIntegral;
11     import std.meta : allSatisfy;
12     alias R = ReturnType!pred;
13     alias Args = ParameterTypeTuple!pred;
14     static assert (allSatisfy!(isIntegral, Args));
15     static immutable N = Args.length;
16     
17     private int[2][N] rng;
18     int[N] len;
19     R[] dp;
20     bool[] used;
21     void init(in int[2][N] rng) {
22         import std.algorithm : reduce, map;
23         import std.range : array;
24         this.rng = rng;
25         len = rng[].map!(a => a[1]-a[0]+1).array;
26         auto sz = reduce!"a*b"(1, len);
27         dp = new R[sz];
28         used = new bool[sz];
29     }
30     R opCall(Args args) {
31         import core.exception : RangeError;
32         size_t idx, base = 1;
33         foreach (i, v; args) {
34             version(assert) {
35                 if (v < rng[i][0] || rng[i][1] < v) {
36                     throw new RangeError;
37                 }
38             }
39             assert(rng[i][0] <= v && v <= rng[i][1]);
40             idx += base*(v - rng[i][0]);
41             base *= len[i];
42         }
43         if (used[idx]) return dp[idx];
44         used[idx] = true;
45         auto r = pred(args);
46         dp[idx] = r;
47         return r;
48     }
49 }
50 
51 ///
52 unittest {
53     import dkh.numeric.primitive;
54     import dkh.modint;
55     alias Mint = ModInt!(10^^9+7);
56 
57     struct A {
58         static auto fact = factTable!Mint(100);
59         static auto iFac = invFactTable!Mint(100);
60         static Mint C1(int n, int k) {
61             return fact[n] * iFac[k] * iFac[n-k];
62         }
63 
64         // メモ化再帰でnCkの計算をする
65         static memoCont!C2base C2;
66         static Mint C2base(int n, int k) {
67             if (k == 0) return Mint(1);
68             if (n == 0) return Mint(0);
69             return C2(n-1, k-1) + C2(n-1, k);
70         }
71     }
72     
73     // 0 <= n <= 99, 0 <= k <= 99, 閉区間
74     A.C2.init([[0, 99], [0, 99]]);
75     foreach (i; 0..100) {
76         foreach (j; 0..i+1) {
77             assert(A.C1(i, j) == A.C2(i, j));
78         }
79     }
80 }