@@ -9,21 +9,25 @@ def maxHappyGroups(self, batchSize, groups):
9
9
:type groups: List[int]
10
10
:rtype: int
11
11
"""
12
- def memoization (count , mask , remain , lookup ):
12
+ def memoization (batchSize , count , mask , remain , lookup ):
13
13
if lookup [mask ] == 0 :
14
- curr , basis = mask , 1
15
- for i in xrange (remain ):
16
- basis *= count [i ]+ 1
17
- curr //= count [i ]+ 1
18
- a_remain = curr % (count [remain ]+ 1 )
14
+ a_remain = 0
15
+ if remain in count :
16
+ curr , basis = mask , 1
17
+ for i , c in count .iteritems ():
18
+ if i == remain :
19
+ break
20
+ basis *= c + 1
21
+ curr //= c + 1
22
+ a_remain = curr % (count [remain ]+ 1 )
19
23
result = 0
20
24
if a_remain :
21
- result = max (result , int (remain == 0 ) + memoization (count , mask - basis , 0 , lookup ))
25
+ result = max (result , int (remain == 0 ) + memoization (batchSize , count , mask - basis , 0 , lookup ))
22
26
else :
23
27
curr , basis = mask , 1
24
- for i , c in enumerate ( count ):
28
+ for i , c in count . iteritems ( ):
25
29
if curr % (c + 1 ):
26
- result = max (result , int (remain == 0 ) + memoization (count , mask - basis , (remain - i )% len ( count ) , lookup ))
30
+ result = max (result , int (remain == 0 ) + memoization (batchSize , count , mask - basis , (remain - i )% batchSize , lookup ))
27
31
basis *= c + 1
28
32
curr //= c + 1
29
33
lookup [mask ] = result
@@ -39,9 +43,10 @@ def memoization(count, mask, remain, lookup):
39
43
result += pair_count
40
44
count [i ] -= pair_count
41
45
count [len (count )- i ] -= pair_count
42
- max_mask = reduce (lambda total , c : total * (c + 1 ), count , 1 )
46
+ count = {i :c for i , c in enumerate (count ) if c }
47
+ max_mask = reduce (lambda total , c : total * (c + 1 ), count .itervalues (), 1 )
43
48
lookup = [0 ]* max_mask
44
- return result + memoization (count , max_mask - 1 , 0 , lookup )
49
+ return result + memoization (batchSize , count , max_mask - 1 , 0 , lookup )
45
50
46
51
47
52
# Time: O((b/2) * (n/(b/2)+1)^(b/2))
@@ -64,16 +69,17 @@ def maxHappyGroups(self, batchSize, groups):
64
69
result += pair_count
65
70
count [i ] -= pair_count
66
71
count [len (count )- i ] -= pair_count
67
- max_mask = reduce (lambda total , c : total * (c + 1 ), count , 1 )
72
+ count = {i :c for i , c in enumerate (count ) if c }
73
+ max_mask = reduce (lambda total , c : total * (c + 1 ), count .itervalues (), 1 )
68
74
dp = [0 ]* max_mask
69
75
for mask in xrange (len (dp )):
70
76
remain = 0
71
77
curr , basis = mask , 1
72
- for i , c in enumerate ( count ):
78
+ for i , c in count . iteritems ( ):
73
79
ai = curr % (c + 1 )
74
80
if ai :
75
81
dp [mask ] = max (dp [mask ], dp [mask - basis ])
76
- remain = (remain + ai * i )% len ( count )
82
+ remain = (remain + ai * i )% batchSize
77
83
basis *= c + 1
78
84
curr //= c + 1
79
85
if mask != len (dp )- 1 and remain == 0 :
0 commit comments