@@ -160,13 +160,13 @@ class TestMergedEmbedding(TestCase):
160160
161161 table0 = nn .EmbeddingBag (100 , 16 , mode = 'mean' , sparse = False ).double ()
162162 table1 = nn .EmbeddingBag (50 , 32 , mode = 'sum' , sparse = False )
163- table2 = nn .EmbeddingBag (18000000 , 128 , mode = 'sum' , include_last_offset = True , _weight = torch .empty (18000000 , 128 , dtype = torch .bfloat16 ), sparse = False )
163+ table2 = nn .EmbeddingBag (18000000 , 8 , mode = 'sum' , include_last_offset = True , _weight = torch .empty (18000000 , 8 , dtype = torch .bfloat16 ), sparse = False )
164164 table3 = nn .EmbeddingBag (100 , 16 , mode = 'mean' , sparse = True ).double ()
165165 merged = MergedEmbeddingBag .from_embeddingbag_list ([table0 , table1 , table2 ])
166166 merged2 = MergedEmbeddingBag ([
167167 (100 , 16 , 'mean' , table0 .weight .dtype , table0 .weight .detach (), False ),
168168 (50 , 32 , 'sum' , table1 .weight .dtype , table1 .weight .detach (), False ),
169- (18000000 , 128 , 'sum' , table2 .weight .dtype , table2 .weight .detach (), False ),
169+ (18000000 , 8 , 'sum' , table2 .weight .dtype , table2 .weight .detach (), False ),
170170 ])
171171 input = [
172172 [torch .LongTensor ([10 , 10 , 15 , 10 , 20 , 25 ]), torch .LongTensor ([[0 , 30 ], [21 , 15 ], [30 , 11 ]]), torch .LongTensor ([10 , 15 , 17999999 ])],
0 commit comments