@@ -152,6 +152,10 @@ def _test_sharding(
152
152
use_inter_host_allreduce : bool = False ,
153
153
allow_zero_batch_size : bool = False ,
154
154
custom_all_reduce : bool = False ,
155
+ use_offsets : bool = False ,
156
+ indices_dtype : torch .dtype = torch .int64 ,
157
+ offsets_dtype : torch .dtype = torch .int64 ,
158
+ lengths_dtype : torch .dtype = torch .int64 ,
155
159
) -> None :
156
160
self ._build_tables_and_groups (data_type = data_type )
157
161
self ._run_multi_process_test (
@@ -176,6 +180,10 @@ def _test_sharding(
176
180
use_inter_host_allreduce = use_inter_host_allreduce ,
177
181
allow_zero_batch_size = allow_zero_batch_size ,
178
182
custom_all_reduce = custom_all_reduce ,
183
+ use_offsets = use_offsets ,
184
+ indices_dtype = indices_dtype ,
185
+ offsets_dtype = offsets_dtype ,
186
+ lengths_dtype = lengths_dtype ,
179
187
)
180
188
181
189
@@ -901,3 +909,58 @@ def test_sharding_grid_8gpu(
901
909
apply_optimizer_in_backward_config = apply_optimizer_in_backward_config ,
902
910
pooling = pooling ,
903
911
)
912
+
913
+ @unittest .skipIf (
914
+ torch .cuda .device_count () <= 1 ,
915
+ "Not enough GPUs, this test requires at least two GPUs" ,
916
+ )
917
+ # pyre-fixme[56]
918
+ @given (
919
+ dtype = st .sampled_from ([torch .int32 , torch .int64 ]),
920
+ use_offsets = st .booleans (),
921
+ sharder_type = st .sampled_from (
922
+ [
923
+ SharderType .EMBEDDING_BAG_COLLECTION .value ,
924
+ ]
925
+ ),
926
+ kernel_type = st .sampled_from (
927
+ [
928
+ EmbeddingComputeKernel .FUSED .value ,
929
+ ],
930
+ ),
931
+ )
932
+ @settings (verbosity = Verbosity .verbose , max_examples = 2 , deadline = None )
933
+ def test_sharding_diff_table_index_type (
934
+ self ,
935
+ dtype : torch .dtype ,
936
+ use_offsets : bool ,
937
+ sharder_type : str ,
938
+ kernel_type : str ,
939
+ ) -> None :
940
+ """
941
+ Test that the model correctly handles input indices and offsets
942
+ with both int32 and int64 data types.
943
+ """
944
+ sharders = [
945
+ cast (
946
+ ModuleSharder [nn .Module ],
947
+ create_test_sharder (
948
+ sharder_type = sharder_type ,
949
+ sharding_type = ShardingType .ROW_WISE .value , # or any other relevant sharding type
950
+ kernel_type = kernel_type ,
951
+ device = self .device ,
952
+ ),
953
+ ),
954
+ ]
955
+ # TODO - how to pass dtype so that sampled data uses different type indices/offsets?
956
+ self ._test_sharding (
957
+ sharders = sharders ,
958
+ backend = self .backend ,
959
+ apply_optimizer_in_backward_config = None ,
960
+ variable_batch_size = False ,
961
+ pooling = PoolingType .SUM ,
962
+ use_offsets = use_offsets ,
963
+ indices_dtype = dtype ,
964
+ offsets_dtype = dtype ,
965
+ lengths_dtype = dtype ,
966
+ )
0 commit comments