@@ -126,6 +126,10 @@ def __call__(
126
126
Union [List [EmbeddingTableConfig ], List [EmbeddingBagConfig ]]
127
127
] = None ,
128
128
variable_batch_size : bool = False ,
129
+ use_offsets : bool = False ,
130
+ indices_dtype : torch .dtype = torch .int64 ,
131
+ offsets_dtype : torch .dtype = torch .int64 ,
132
+ lengths_dtype : torch .dtype = torch .int64 ,
129
133
long_indices : bool = True ,
130
134
) -> Tuple ["ModelInput" , List ["ModelInput" ]]: ...
131
135
@@ -140,6 +144,10 @@ def __call__(
140
144
weighted_tables : Union [List [EmbeddingTableConfig ], List [EmbeddingBagConfig ]],
141
145
pooling_avg : int = 10 ,
142
146
global_constant_batch : bool = False ,
147
+ use_offsets : bool = False ,
148
+ indices_dtype : torch .dtype = torch .int64 ,
149
+ offsets_dtype : torch .dtype = torch .int64 ,
150
+ lengths_dtype : torch .dtype = torch .int64 ,
143
151
) -> Tuple ["ModelInput" , List ["ModelInput" ]]: ...
144
152
145
153
@@ -161,10 +169,14 @@ def gen_model_and_input(
161
169
variable_batch_size : bool = False ,
162
170
batch_size : int = 4 ,
163
171
feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None ,
164
- long_indices : bool = True ,
172
+ use_offsets : bool = False ,
173
+ indices_dtype : torch .dtype = torch .int64 ,
174
+ offsets_dtype : torch .dtype = torch .int64 ,
175
+ lengths_dtype : torch .dtype = torch .int64 ,
165
176
global_constant_batch : bool = False ,
166
177
num_inputs : int = 1 ,
167
178
input_type : str = "kjt" , # "kjt" or "td"
179
+ long_indices : bool = True ,
168
180
) -> Tuple [nn .Module , List [Tuple [ModelInput , List [ModelInput ]]]]:
169
181
torch .manual_seed (0 )
170
182
if dedup_feature_names :
@@ -205,6 +217,10 @@ def gen_model_and_input(
205
217
tables = tables ,
206
218
weighted_tables = weighted_tables or [],
207
219
global_constant_batch = global_constant_batch ,
220
+ use_offsets = use_offsets ,
221
+ indices_dtype = indices_dtype ,
222
+ offsets_dtype = offsets_dtype ,
223
+ lengths_dtype = lengths_dtype ,
208
224
)
209
225
)
210
226
elif generate == ModelInput .generate :
@@ -218,8 +234,12 @@ def gen_model_and_input(
218
234
num_float_features = num_float_features ,
219
235
variable_batch_size = variable_batch_size ,
220
236
batch_size = batch_size ,
221
- long_indices = long_indices ,
222
237
input_type = input_type ,
238
+ use_offsets = use_offsets ,
239
+ indices_dtype = indices_dtype ,
240
+ offsets_dtype = offsets_dtype ,
241
+ lengths_dtype = lengths_dtype ,
242
+ long_indices = long_indices ,
223
243
)
224
244
)
225
245
else :
@@ -233,6 +253,10 @@ def gen_model_and_input(
233
253
num_float_features = num_float_features ,
234
254
variable_batch_size = variable_batch_size ,
235
255
batch_size = batch_size ,
256
+ use_offsets = use_offsets ,
257
+ indices_dtype = indices_dtype ,
258
+ offsets_dtype = offsets_dtype ,
259
+ lengths_dtype = lengths_dtype ,
236
260
long_indices = long_indices ,
237
261
)
238
262
)
@@ -336,6 +360,10 @@ def sharding_single_rank_test(
336
360
input_type : str = "kjt" , # "kjt" or "td"
337
361
allow_zero_batch_size : bool = False ,
338
362
custom_all_reduce : bool = False , # 2D parallel
363
+ use_offsets : bool = False ,
364
+ indices_dtype : torch .dtype = torch .int64 ,
365
+ offsets_dtype : torch .dtype = torch .int64 ,
366
+ lengths_dtype : torch .dtype = torch .int64 ,
339
367
) -> None :
340
368
with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
341
369
batch_size = (
@@ -363,6 +391,10 @@ def sharding_single_rank_test(
363
391
feature_processor_modules = feature_processor_modules ,
364
392
global_constant_batch = global_constant_batch ,
365
393
input_type = input_type ,
394
+ use_offsets = use_offsets ,
395
+ indices_dtype = indices_dtype ,
396
+ offsets_dtype = offsets_dtype ,
397
+ lengths_dtype = lengths_dtype ,
366
398
)
367
399
global_model = global_model .to (ctx .device )
368
400
global_input = inputs [0 ][0 ].to (ctx .device )
0 commit comments