@@ -17,30 +17,45 @@ class Solution {
17
17
count[i] -= pair_count;
18
18
count[size (count) - i] -= pair_count;
19
19
}
20
- int max_mask = accumulate (cbegin (count), cend (count), 1 ,
21
- [](int total, int c) {
22
- return total * (c + 1 );
20
+ unordered_map<int , int > new_count;
21
+ for (int i = 0 ; i < size (count); ++i) {
22
+ if (count[i]) {
23
+ new_count[i] = count[i];
24
+ }
25
+ }
26
+ int max_mask = accumulate (cbegin (new_count), cend (new_count), 1 ,
27
+ [](int total, pair<int , int > kvp) {
28
+ return total * (kvp.second + 1 );
23
29
});
24
30
vector<int > lookup (max_mask);
25
- return result + memoization (count , max_mask - 1 , 0 , &lookup);
31
+ return result + memoization (batchSize, new_count , max_mask - 1 , 0 , &lookup);
26
32
}
27
33
28
34
private:
29
- int memoization (const vector< int >& count, int mask, int remain, vector<int > *lookup) {
35
+ int memoization (int batchSize, const unordered_map< int , int >& count, int mask, int remain, vector<int > *lookup) {
30
36
if (!(*lookup)[mask]) {
31
- int curr = mask, basis = 1 , i = 0 ;
32
- for (; i < remain; basis *= (count[i] + 1 ), curr /= (count[i] + 1 ), ++i);
33
- // mask: a0 + a1 * (c0 + 1) + a2 * (c0 + 1) * (c1 + 1) + ... + a(b-1) * (c0 + 1) * (c1 + 1) * ... * (c(b-2) + 1)
34
- int a_remain = curr % (count[remain] + 1 );
37
+ int a_remain = 0 ;
38
+ int curr = mask, basis = 1 ;
39
+ if (count.count (remain)) {
40
+ for (const auto & [i, c] : count) {
41
+ if (i == remain) {
42
+ break ;
43
+ }
44
+ basis *= (c + 1 ), curr /= (c + 1 );
45
+ }
46
+ // mask: a0 + a1 * (c0 + 1) + a2 * (c0 + 1) * (c1 + 1) + ... + a(b-1) * (c0 + 1) * (c1 + 1) * ... * (c(b-2) + 1)
47
+ a_remain = curr % (count.at (remain) + 1 );
48
+ }
35
49
int result = 0 ;
36
50
if (a_remain) { // greedily use remain
37
- result = max (result, (remain == 0 ) + memoization (count, mask - basis, 0 , lookup));
51
+ result = max (result, (remain == 0 ) + memoization (batchSize, count, mask - basis, 0 , lookup));
38
52
} else {
39
- for (int curr = mask, basis = 1 , i = 0 ; i < size (count); basis *= (count[i] + 1 ), curr /= (count[i] + 1 ), ++i) {
40
- if (curr % (count[i] + 1 ) == 0 ) {
41
- continue ;
53
+ int curr = mask, basis = 1 ;
54
+ for (const auto & [i, c] : count) {
55
+ if (curr % (c + 1 )) {
56
+ result = max (result, (remain == 0 ) + memoization (batchSize, count, mask - basis, ((remain - i) + batchSize) % batchSize, lookup));
42
57
}
43
- result = max (result, (remain == 0 ) + memoization (count, mask - basis, ((remain - i) + size (count)) % size (count), lookup) );
58
+ basis *= (c + 1 ), curr /= (c + 1 );
44
59
}
45
60
}
46
61
(*lookup)[mask] = result;
@@ -67,21 +82,28 @@ class Solution2 {
67
82
count[i] -= pair_count;
68
83
count[size (count) - i] -= pair_count;
69
84
}
70
- int max_mask = accumulate (cbegin (count), cend (count), 1 ,
71
- [](int total, int c) {
72
- return total * (c + 1 );
85
+ unordered_map<int , int > new_count;
86
+ for (int i = 0 ; i < size (count); ++i) {
87
+ if (count[i]) {
88
+ new_count[i] = count[i];
89
+ }
90
+ }
91
+ int max_mask = accumulate (cbegin (new_count), cend (new_count), 1 ,
92
+ [](int total, pair<int , int > kvp) {
93
+ return total * (kvp.second + 1 );
73
94
});
74
95
vector<int > dp (max_mask);
75
96
for (int mask = 0 ; mask < size (dp); ++mask) {
76
97
int remain = 0 ;
77
- for ( int curr = mask, basis = 1 , i = 0 ; i < size (count) ;
78
- basis *= (count[i] + 1 ), curr /= (count[i] + 1 ), ++i) { // decode mask
98
+ int curr = mask, basis = 1 ;
99
+ for ( const auto & [i, c] : new_count) {
79
100
// mask: a0 + a1 * (c0 + 1) + a2 * (c0 + 1) * (c1 + 1) + ... + a(b-1) * (c0 + 1) * (c1 + 1) * ... * (c(b-2) + 1)
80
101
int ai = curr % (count[i] + 1 );
81
102
if (ai > 0 ) {
82
103
dp[mask] = max (dp[mask], dp[mask - basis]);
83
104
}
84
- remain = (remain + ai * i) % size (count);
105
+ remain = (remain + ai * i) % batchSize;
106
+ basis *= (c + 1 ), curr /= (c + 1 );
85
107
}
86
108
if (mask != size (dp) - 1 && remain == 0 ) {
87
109
++dp[mask];
0 commit comments