Skip to content

Commit 317452c

Browse files
authored
reduce shape for merged emb size (#1620)
1 parent 4682ce3 commit 317452c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/cpu/test_merged_embeddingbag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,13 @@ class TestMergedEmbedding(TestCase):
160160

161161
table0 = nn.EmbeddingBag(100, 16, mode='mean', sparse=False).double()
162162
table1 = nn.EmbeddingBag(50, 32, mode='sum', sparse=False)
163-
table2 = nn.EmbeddingBag(18000000, 128, mode='sum', include_last_offset=True, _weight=torch.empty(18000000, 128, dtype=torch.bfloat16), sparse=False)
163+
table2 = nn.EmbeddingBag(18000000, 8, mode='sum', include_last_offset=True, _weight=torch.empty(18000000, 8, dtype=torch.bfloat16), sparse=False)
164164
table3 = nn.EmbeddingBag(100, 16, mode='mean', sparse=True).double()
165165
merged = MergedEmbeddingBag.from_embeddingbag_list([table0, table1, table2])
166166
merged2 = MergedEmbeddingBag([
167167
(100, 16, 'mean', table0.weight.dtype, table0.weight.detach(), False),
168168
(50, 32, 'sum', table1.weight.dtype, table1.weight.detach(), False),
169-
(18000000, 128, 'sum', table2.weight.dtype, table2.weight.detach(), False),
169+
(18000000, 8, 'sum', table2.weight.dtype, table2.weight.detach(), False),
170170
])
171171
input = [
172172
[torch.LongTensor([10, 10, 15, 10, 20, 25]), torch.LongTensor([[0, 30], [21, 15], [30, 11]]), torch.LongTensor([10, 15, 17999999])],

0 commit comments

Comments
 (0)