@@ -71,19 +71,30 @@ def generate(
71
71
]
72
72
] = None ,
73
73
variable_batch_size : bool = False ,
74
- long_indices : bool = True ,
75
74
tables_pooling : Optional [List [int ]] = None ,
76
75
weighted_tables_pooling : Optional [List [int ]] = None ,
77
76
randomize_indices : bool = True ,
78
77
device : Optional [torch .device ] = None ,
79
78
max_feature_lengths : Optional [List [int ]] = None ,
80
79
input_type : str = "kjt" ,
80
+ use_offsets : bool = False ,
81
+ indices_dtype : torch .dtype = torch .int64 ,
82
+ offsets_dtype : torch .dtype = torch .int64 ,
83
+ lengths_dtype : torch .dtype = torch .int64 ,
84
+ long_indices : bool = True , # TODO - remove this once code base is updated to support more than long_indices spec
81
85
) -> Tuple ["ModelInput" , List ["ModelInput" ]]:
82
86
"""
83
87
Returns a global (single-rank training) batch
84
88
and a list of local (multi-rank training) batches of world_size.
85
89
"""
86
-
90
+ if long_indices :
91
+ indices_dtype = torch .int64
92
+ lengths_dtype = torch .int64
93
+ use_offsets = False
94
+ else :
95
+ indices_dtype = torch .int32
96
+ lengths_dtype = torch .int32
97
+ use_offsets = False
87
98
batch_size_by_rank = [batch_size ] * world_size
88
99
if variable_batch_size :
89
100
batch_size_by_rank = [
@@ -119,7 +130,6 @@ def _validate_pooling_factor(
119
130
if tables [idx ].num_embeddings_post_pruning is not None
120
131
else tables [idx ].num_embeddings
121
132
)
122
-
123
133
idlist_features_to_max_length [feature ] = (
124
134
max_feature_lengths [feature_idx ] if max_feature_lengths else None
125
135
)
@@ -144,18 +154,21 @@ def _validate_pooling_factor(
144
154
145
155
idlist_pooling_factor = list (idlist_features_to_pooling_factor .values ())
146
156
idscore_pooling_factor = weighted_tables_pooling
147
-
148
157
idlist_max_lengths = list (idlist_features_to_max_length .values ())
149
158
150
159
# Generate global batch.
151
160
global_idlist_lengths = []
152
161
global_idlist_indices = []
162
+ global_idlist_offsets = []
163
+
153
164
global_idscore_lengths = []
154
165
global_idscore_indices = []
166
+ global_idscore_offsets = []
155
167
global_idscore_weights = []
156
168
157
169
for idx in range (len (idlist_ind_ranges )):
158
170
ind_range = idlist_ind_ranges [idx ]
171
+
159
172
if idlist_pooling_factor :
160
173
lengths_ = torch .max (
161
174
torch .normal (
@@ -165,17 +178,19 @@ def _validate_pooling_factor(
165
178
device = device ,
166
179
),
167
180
torch .tensor (1.0 , device = device ),
168
- ).int ( )
181
+ ).to ( lengths_dtype )
169
182
else :
170
183
lengths_ = torch .abs (
171
184
torch .randn (batch_size * world_size , device = device ) + pooling_avg ,
172
- ).int ( )
185
+ ).to ( lengths_dtype )
173
186
174
187
if idlist_max_lengths [idx ]:
175
188
lengths_ = torch .clamp (lengths_ , max = idlist_max_lengths [idx ])
176
189
177
190
if variable_batch_size :
178
- lengths = torch .zeros (batch_size * world_size , device = device ).int ()
191
+ lengths = torch .zeros (batch_size * world_size , device = device ).to (
192
+ lengths_dtype
193
+ )
179
194
for r in range (world_size ):
180
195
lengths [r * batch_size : r * batch_size + batch_size_by_rank [r ]] = (
181
196
lengths_ [
@@ -186,42 +201,30 @@ def _validate_pooling_factor(
186
201
lengths = lengths_
187
202
188
203
num_indices = cast (int , torch .sum (lengths ).item ())
204
+
189
205
if randomize_indices :
190
206
indices = torch .randint (
191
207
0 ,
192
208
ind_range ,
193
209
(num_indices ,),
194
- dtype = torch . long if long_indices else torch . int32 ,
210
+ dtype = indices_dtype ,
195
211
device = device ,
196
212
)
197
213
else :
198
214
indices = torch .zeros (
199
- (num_indices ),
200
- dtype = torch . long if long_indices else torch . int32 ,
215
+ (num_indices , ),
216
+ dtype = indices_dtype ,
201
217
device = device ,
202
218
)
219
+
220
+ # Calculate offsets from lengths
221
+ offsets = torch .cat (
222
+ [torch .tensor ([0 ], device = device ), lengths .cumsum (0 )]
223
+ ).to (offsets_dtype )
224
+
203
225
global_idlist_lengths .append (lengths )
204
226
global_idlist_indices .append (indices )
205
-
206
- if input_type == "kjt" :
207
- global_idlist_input = KeyedJaggedTensor (
208
- keys = idlist_features ,
209
- values = torch .cat (global_idlist_indices ),
210
- lengths = torch .cat (global_idlist_lengths ),
211
- )
212
- elif input_type == "td" :
213
- dict_of_nt = {
214
- k : torch .nested .nested_tensor_from_jagged (
215
- values = values ,
216
- lengths = lengths ,
217
- )
218
- for k , values , lengths in zip (
219
- idlist_features , global_idlist_indices , global_idlist_lengths
220
- )
221
- }
222
- global_idlist_input = TensorDict (source = dict_of_nt )
223
- else :
224
- raise ValueError (f"For IdList features, unknown input type { input_type } " )
227
+ global_idlist_offsets .append (offsets )
225
228
226
229
for idx , ind_range in enumerate (idscore_ind_ranges ):
227
230
lengths_ = torch .abs (
@@ -231,9 +234,12 @@ def _validate_pooling_factor(
231
234
if idscore_pooling_factor
232
235
else pooling_avg
233
236
)
234
- ).int ()
237
+ ).to (lengths_dtype )
238
+
235
239
if variable_batch_size :
236
- lengths = torch .zeros (batch_size * world_size , device = device ).int ()
240
+ lengths = torch .zeros (batch_size * world_size , device = device ).to (
241
+ lengths_dtype
242
+ )
237
243
for r in range (world_size ):
238
244
lengths [r * batch_size : r * batch_size + batch_size_by_rank [r ]] = (
239
245
lengths_ [
@@ -242,39 +248,68 @@ def _validate_pooling_factor(
242
248
)
243
249
else :
244
250
lengths = lengths_
251
+
245
252
num_indices = cast (int , torch .sum (lengths ).item ())
253
+
246
254
if randomize_indices :
247
255
indices = torch .randint (
248
256
0 ,
249
257
# pyre-ignore [6]
250
258
ind_range ,
251
259
(num_indices ,),
252
- dtype = torch . long if long_indices else torch . int32 ,
260
+ dtype = indices_dtype ,
253
261
device = device ,
254
262
)
255
263
else :
256
264
indices = torch .zeros (
257
- (num_indices ),
258
- dtype = torch . long if long_indices else torch . int32 ,
265
+ (num_indices , ),
266
+ dtype = indices_dtype ,
259
267
device = device ,
260
268
)
261
269
weights = torch .rand ((num_indices ,), device = device )
270
+ # Calculate offsets from lengths
271
+ offsets = torch .cat (
272
+ [torch .tensor ([0 ], device = device ), lengths .cumsum (0 )]
273
+ ).to (offsets_dtype )
274
+
262
275
global_idscore_lengths .append (lengths )
263
276
global_idscore_indices .append (indices )
264
277
global_idscore_weights .append (weights )
278
+ global_idscore_offsets .append (offsets )
265
279
266
280
if input_type == "kjt" :
281
+ global_idlist_input = KeyedJaggedTensor (
282
+ keys = idlist_features ,
283
+ values = torch .cat (global_idlist_indices ),
284
+ offsets = torch .cat (global_idlist_offsets ) if use_offsets else None ,
285
+ lengths = torch .cat (global_idlist_lengths ) if not use_offsets else None ,
286
+ )
287
+
267
288
global_idscore_input = (
268
289
KeyedJaggedTensor (
269
290
keys = idscore_features ,
270
291
values = torch .cat (global_idscore_indices ),
271
- lengths = torch .cat (global_idscore_lengths ),
292
+ offsets = torch .cat (global_idscore_offsets ) if use_offsets else None ,
293
+ lengths = (
294
+ torch .cat (global_idscore_lengths ) if not use_offsets else None
295
+ ),
272
296
weights = torch .cat (global_idscore_weights ),
273
297
)
274
298
if global_idscore_indices
275
299
else None
276
300
)
277
301
elif input_type == "td" :
302
+ dict_of_nt = {
303
+ k : torch .nested .nested_tensor_from_jagged (
304
+ values = values ,
305
+ lengths = lengths ,
306
+ )
307
+ for k , values , lengths in zip (
308
+ idlist_features , global_idlist_indices , global_idlist_lengths
309
+ )
310
+ }
311
+ global_idlist_input = TensorDict (source = dict_of_nt )
312
+
278
313
assert (
279
314
len (idscore_features ) == 0
280
315
), "TensorDict does not support weighted features"
@@ -295,14 +330,20 @@ def _validate_pooling_factor(
295
330
296
331
# Split global batch into local batches.
297
332
local_inputs = []
333
+
298
334
for r in range (world_size ):
299
335
local_idlist_lengths = []
300
336
local_idlist_indices = []
337
+ local_idlist_offsets = []
338
+
301
339
local_idscore_lengths = []
302
340
local_idscore_indices = []
303
341
local_idscore_weights = []
342
+ local_idscore_offsets = []
304
343
305
- for lengths , indices in zip (global_idlist_lengths , global_idlist_indices ):
344
+ for lengths , indices , offsets in zip (
345
+ global_idlist_lengths , global_idlist_indices , global_idlist_offsets
346
+ ):
306
347
local_idlist_lengths .append (
307
348
lengths [r * batch_size : r * batch_size + batch_size_by_rank [r ]]
308
349
)
@@ -312,9 +353,15 @@ def _validate_pooling_factor(
312
353
local_idlist_indices .append (
313
354
indices [lengths_cumsum [r ] : lengths_cumsum [r + 1 ]]
314
355
)
356
+ local_idlist_offsets .append (
357
+ offsets [r * batch_size : r * batch_size + batch_size_by_rank [r ] + 1 ]
358
+ )
315
359
316
- for lengths , indices , weights in zip (
317
- global_idscore_lengths , global_idscore_indices , global_idscore_weights
360
+ for lengths , indices , weights , offsets in zip (
361
+ global_idscore_lengths ,
362
+ global_idscore_indices ,
363
+ global_idscore_weights ,
364
+ global_idscore_offsets ,
318
365
):
319
366
local_idscore_lengths .append (
320
367
lengths [r * batch_size : r * batch_size + batch_size_by_rank [r ]]
@@ -329,18 +376,32 @@ def _validate_pooling_factor(
329
376
weights [lengths_cumsum [r ] : lengths_cumsum [r + 1 ]]
330
377
)
331
378
379
+ local_idscore_offsets .append (
380
+ offsets [r * batch_size : r * batch_size + batch_size_by_rank [r ] + 1 ]
381
+ )
382
+
332
383
if input_type == "kjt" :
333
384
local_idlist_input = KeyedJaggedTensor (
334
385
keys = idlist_features ,
335
386
values = torch .cat (local_idlist_indices ),
336
- lengths = torch .cat (local_idlist_lengths ),
387
+ offsets = torch .cat (local_idlist_offsets ) if use_offsets else None ,
388
+ lengths = (
389
+ torch .cat (local_idlist_lengths ) if not use_offsets else None
390
+ ),
337
391
)
338
392
339
393
local_idscore_input = (
340
394
KeyedJaggedTensor (
341
395
keys = idscore_features ,
342
396
values = torch .cat (local_idscore_indices ),
343
- lengths = torch .cat (local_idscore_lengths ),
397
+ offsets = (
398
+ torch .cat (local_idscore_offsets ) if use_offsets else None
399
+ ),
400
+ lengths = (
401
+ torch .cat (local_idscore_lengths )
402
+ if not use_offsets
403
+ else None
404
+ ),
344
405
weights = torch .cat (local_idscore_weights ),
345
406
)
346
407
if local_idscore_indices
@@ -353,15 +414,16 @@ def _validate_pooling_factor(
353
414
lengths = lengths ,
354
415
)
355
416
for k , values , lengths in zip (
356
- idlist_features , local_idlist_indices , local_idlist_lengths
417
+ idlist_features ,
418
+ local_idlist_indices ,
419
+ local_idlist_lengths ,
357
420
)
358
421
}
359
422
local_idlist_input = TensorDict (source = dict_of_nt )
360
423
assert (
361
424
len (idscore_features ) == 0
362
425
), "TensorDict does not support weighted features"
363
426
local_idscore_input = None
364
-
365
427
else :
366
428
raise ValueError (
367
429
f"For weighted features, unknown input type { input_type } "
0 commit comments