Skip to content

Commit c1126e2

Browse files
authored
Update maximum-number-of-groups-getting-fresh-donuts.cpp
1 parent db352af commit c1126e2

File tree

1 file changed

+42
-20
lines changed

1 file changed

+42
-20
lines changed

C++/maximum-number-of-groups-getting-fresh-donuts.cpp

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,45 @@ class Solution {
1717
count[i] -= pair_count;
1818
count[size(count) - i] -= pair_count;
1919
}
20-
int max_mask = accumulate(cbegin(count), cend(count), 1,
21-
[](int total, int c) {
22-
return total * (c + 1);
20+
unordered_map<int, int> new_count;
21+
for (int i = 0; i < size(count); ++i) {
22+
if (count[i]) {
23+
new_count[i] = count[i];
24+
}
25+
}
26+
int max_mask = accumulate(cbegin(new_count), cend(new_count), 1,
27+
[](int total, pair<int, int> kvp) {
28+
return total * (kvp.second + 1);
2329
});
2430
vector<int> lookup(max_mask);
25-
return result + memoization(count, max_mask - 1, 0, &lookup);
31+
return result + memoization(batchSize, new_count, max_mask - 1, 0, &lookup);
2632
}
2733

2834
private:
29-
int memoization(const vector<int>& count, int mask, int remain, vector<int> *lookup) {
35+
int memoization(int batchSize, const unordered_map<int, int>& count, int mask, int remain, vector<int> *lookup) {
3036
if (!(*lookup)[mask]) {
31-
int curr = mask, basis = 1, i = 0;
32-
for (; i < remain; basis *= (count[i] + 1), curr /= (count[i] + 1), ++i);
33-
// mask: a0 + a1 * (c0 + 1) + a2 * (c0 + 1) * (c1 + 1) + ... + a(b-1) * (c0 + 1) * (c1 + 1) * ... * (c(b-2) + 1)
34-
int a_remain = curr % (count[remain] + 1);
37+
int a_remain = 0;
38+
int curr = mask, basis = 1;
39+
if (count.count(remain)) {
40+
for (const auto& [i, c] : count) {
41+
if (i == remain) {
42+
break;
43+
}
44+
basis *= (c + 1), curr /= (c + 1);
45+
}
46+
// mask: a0 + a1 * (c0 + 1) + a2 * (c0 + 1) * (c1 + 1) + ... + a(b-1) * (c0 + 1) * (c1 + 1) * ... * (c(b-2) + 1)
47+
a_remain = curr % (count.at(remain) + 1);
48+
}
3549
int result = 0;
3650
if (a_remain) { // greedily use remain
37-
result = max(result, (remain == 0) + memoization(count, mask - basis, 0, lookup));
51+
result = max(result, (remain == 0) + memoization(batchSize, count, mask - basis, 0, lookup));
3852
} else {
39-
for (int curr = mask, basis = 1, i = 0; i < size(count); basis *= (count[i] + 1), curr /= (count[i] + 1), ++i) {
40-
if (curr % (count[i] + 1) == 0) {
41-
continue;
53+
int curr = mask, basis = 1;
54+
for (const auto& [i, c] : count) {
55+
if (curr % (c + 1)) {
56+
result = max(result, (remain == 0) + memoization(batchSize, count, mask - basis, ((remain - i) + batchSize) % batchSize, lookup));
4257
}
43-
result = max(result, (remain == 0) + memoization(count, mask - basis, ((remain - i) + size(count)) % size(count), lookup));
58+
basis *= (c + 1), curr /= (c + 1);
4459
}
4560
}
4661
(*lookup)[mask] = result;
@@ -67,21 +82,28 @@ class Solution2 {
6782
count[i] -= pair_count;
6883
count[size(count) - i] -= pair_count;
6984
}
70-
int max_mask = accumulate(cbegin(count), cend(count), 1,
71-
[](int total, int c) {
72-
return total * (c + 1);
85+
unordered_map<int, int> new_count;
86+
for (int i = 0; i < size(count); ++i) {
87+
if (count[i]) {
88+
new_count[i] = count[i];
89+
}
90+
}
91+
int max_mask = accumulate(cbegin(new_count), cend(new_count), 1,
92+
[](int total, pair<int, int> kvp) {
93+
return total * (kvp.second + 1);
7394
});
7495
vector<int> dp(max_mask);
7596
for (int mask = 0; mask < size(dp); ++mask) {
7697
int remain = 0;
77-
for (int curr = mask, basis = 1, i = 0; i < size(count);
78-
basis *= (count[i] + 1), curr /= (count[i] + 1), ++i) { // decode mask
98+
int curr = mask, basis = 1;
99+
for (const auto& [i, c] : new_count) {
79100
// mask: a0 + a1 * (c0 + 1) + a2 * (c0 + 1) * (c1 + 1) + ... + a(b-1) * (c0 + 1) * (c1 + 1) * ... * (c(b-2) + 1)
80101
int ai = curr % (count[i] + 1);
81102
if (ai > 0) {
82103
dp[mask] = max(dp[mask], dp[mask - basis]);
83104
}
84-
remain = (remain + ai * i) % size(count);
105+
remain = (remain + ai * i) % batchSize;
106+
basis *= (c + 1), curr /= (c + 1);
85107
}
86108
if (mask != size(dp) - 1 && remain == 0) {
87109
++dp[mask];

0 commit comments

Comments
 (0)