@@ -1096,13 +1096,15 @@ def _maybe_compute_stride_kjt(
1096
1096
stride : Optional [int ],
1097
1097
lengths : Optional [torch .Tensor ],
1098
1098
offsets : Optional [torch .Tensor ],
1099
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1099
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1100
1100
) -> int :
1101
1101
if stride is None :
1102
1102
if len (keys ) == 0 :
1103
1103
stride = 0
1104
- elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
1105
- stride = max ([sum (s ) for s in stride_per_key_per_rank ])
1104
+ elif (
1105
+ stride_per_key_per_rank is not None and stride_per_key_per_rank .numel () > 0
1106
+ ):
1107
+ stride = int (stride_per_key_per_rank .sum (dim = 1 ).max ().item ())
1106
1108
elif offsets is not None and offsets .numel () > 0 :
1107
1109
stride = (offsets .numel () - 1 ) // len (keys )
1108
1110
elif lengths is not None :
@@ -1481,8 +1483,8 @@ def _strides_from_kjt(
1481
1483
def _kjt_empty_like (kjt : "KeyedJaggedTensor" ) -> "KeyedJaggedTensor" :
1482
1484
# empty like function fx wrapped, also avoids device hardcoding
1483
1485
stride , stride_per_key_per_rank = (
1484
- (None , kjt .stride_per_key_per_rank () )
1485
- if kjt .variable_stride_per_key ()
1486
+ (None , kjt ._stride_per_key_per_rank )
1487
+ if kjt ._stride_per_key_per_rank is not None and kjt . variable_stride_per_key ()
1486
1488
else (kjt .stride (), None )
1487
1489
)
1488
1490
@@ -1668,14 +1670,20 @@ def _maybe_compute_lengths_offset_per_key(
1668
1670
1669
1671
def _maybe_compute_stride_per_key (
1670
1672
stride_per_key : Optional [List [int ]],
1671
- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1673
+ stride_per_key_per_rank : Optional [torch . IntTensor ],
1672
1674
stride : Optional [int ],
1673
1675
keys : List [str ],
1674
1676
) -> Optional [List [int ]]:
1675
1677
if stride_per_key is not None :
1676
1678
return stride_per_key
1677
1679
elif stride_per_key_per_rank is not None :
1678
- return [sum (s ) for s in stride_per_key_per_rank ]
1680
+ if stride_per_key_per_rank .dim () != 2 :
1681
+ # after permute the kjt could be empty
1682
+ return []
1683
+ rt : List [int ] = stride_per_key_per_rank .sum (dim = 1 ).tolist ()
1684
+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1685
+ pt2_checks_all_is_size (rt )
1686
+ return rt
1679
1687
elif stride is not None :
1680
1688
return [stride ] * len (keys )
1681
1689
else :
@@ -1766,7 +1774,9 @@ def __init__(
1766
1774
lengths : Optional [torch .Tensor ] = None ,
1767
1775
offsets : Optional [torch .Tensor ] = None ,
1768
1776
stride : Optional [int ] = None ,
1769
- stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
1777
+ stride_per_key_per_rank : Optional [
1778
+ Union [torch .IntTensor , List [List [int ]]]
1779
+ ] = None ,
1770
1780
# Below exposed to ensure torch.script-able
1771
1781
stride_per_key : Optional [List [int ]] = None ,
1772
1782
length_per_key : Optional [List [int ]] = None ,
@@ -1788,8 +1798,14 @@ def __init__(
1788
1798
self ._lengths : Optional [torch .Tensor ] = lengths
1789
1799
self ._offsets : Optional [torch .Tensor ] = offsets
1790
1800
self ._stride : Optional [int ] = stride
1791
- self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1792
- stride_per_key_per_rank
1801
+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1802
+ # in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1803
+ # does not take List[List[int]]
1804
+ assert not isinstance (stride_per_key_per_rank , list )
1805
+ self ._stride_per_key_per_rank : Optional [torch .IntTensor ] = (
1806
+ torch .IntTensor (stride_per_key_per_rank , device = "cpu" )
1807
+ if isinstance (stride_per_key_per_rank , list )
1808
+ else stride_per_key_per_rank
1793
1809
)
1794
1810
self ._stride_per_key : Optional [List [int ]] = stride_per_key
1795
1811
self ._length_per_key : Optional [List [int ]] = length_per_key
@@ -1815,10 +1831,8 @@ def _init_pt2_checks(self) -> None:
1815
1831
return
1816
1832
if self ._stride_per_key is not None :
1817
1833
pt2_checks_all_is_size (self ._stride_per_key )
1818
- if self ._stride_per_key_per_rank is not None :
1819
- # pyre-ignore [16]
1820
- for s in self ._stride_per_key_per_rank :
1821
- pt2_checks_all_is_size (s )
1834
+ # this is only needed for torch.compile case
1835
+ self ._pt2_stride_per_key_per_rank : Optional [List [List [int ]]] = None
1822
1836
1823
1837
@staticmethod
1824
1838
def from_offsets_sync (
@@ -2028,7 +2042,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
2028
2042
kjt_stride , kjt_stride_per_key_per_rank = (
2029
2043
(stride_per_key [0 ], None )
2030
2044
if all (s == stride_per_key [0 ] for s in stride_per_key )
2031
- else (None , [[ stride ] for stride in stride_per_key ] )
2045
+ else (None , torch . IntTensor ( stride_per_key , device = "cpu" ). reshape ( - 1 , 1 ) )
2032
2046
)
2033
2047
kjt = KeyedJaggedTensor (
2034
2048
keys = kjt_keys ,
@@ -2193,12 +2207,32 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
2193
2207
Returns:
2194
2208
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
2195
2209
"""
2196
- stride_per_key_per_rank = self ._stride_per_key_per_rank
2197
- return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2210
+ # making a local reference to the class variable to make jit.script behave
2211
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2212
+ if (
2213
+ not torch .jit .is_scripting ()
2214
+ and is_torchdynamo_compiling ()
2215
+ and _stride_per_key_per_rank is not None
2216
+ ):
2217
+ if self ._pt2_stride_per_key_per_rank is not None :
2218
+ return self ._pt2_stride_per_key_per_rank
2219
+ stride_per_key_per_rank = _stride_per_key_per_rank .tolist ()
2220
+ for stride_per_rank in stride_per_key_per_rank :
2221
+ pt2_checks_all_is_size (stride_per_rank )
2222
+ self ._pt2_stride_per_key_per_rank = stride_per_key_per_rank
2223
+ return stride_per_key_per_rank
2224
+ return (
2225
+ []
2226
+ if _stride_per_key_per_rank is None
2227
+ else _stride_per_key_per_rank .tolist ()
2228
+ )
2198
2229
2199
2230
def variable_stride_per_key (self ) -> bool :
2200
2231
"""
2201
2232
Returns whether the KeyedJaggedTensor has variable stride per key.
2233
+ NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
2234
+ is not `None`. It might be assigned to False externally/intentionally, usually the
2235
+ `self._stride_per_key_per_rank` is trivial.
2202
2236
2203
2237
Returns:
2204
2238
bool: whether the KeyedJaggedTensor has variable stride per key.
@@ -2343,13 +2377,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2343
2377
start_offset = 0
2344
2378
_length_per_key = self .length_per_key ()
2345
2379
_offset_per_key = self .offset_per_key ()
2380
+ # use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2381
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2346
2382
for segment in segments :
2347
2383
end = start + segment
2348
2384
end_offset = _offset_per_key [end ]
2349
2385
keys : List [str ] = self ._keys [start :end ]
2350
2386
stride_per_key_per_rank = (
2351
- self . stride_per_key_per_rank () [start :end ]
2387
+ _stride_per_key_per_rank [start :end , : ]
2352
2388
if self .variable_stride_per_key ()
2389
+ and _stride_per_key_per_rank is not None
2353
2390
else None
2354
2391
)
2355
2392
if segment == len (self ._keys ):
@@ -2496,18 +2533,21 @@ def permute(
2496
2533
)
2497
2534
2498
2535
length_per_key = self .length_per_key ()
2536
+ permuted_stride_per_key = None if self ._stride_per_key is None else []
2499
2537
permuted_keys : List [str ] = []
2500
- permuted_stride_per_key_per_rank : List [List [int ]] = []
2501
2538
permuted_length_per_key : List [int ] = []
2502
2539
permuted_length_per_key_sum = 0
2503
2540
for index in indices :
2504
2541
key = self .keys ()[index ]
2505
2542
permuted_keys .append (key )
2506
2543
permuted_length_per_key .append (length_per_key [index ])
2507
- if self .variable_stride_per_key ():
2508
- permuted_stride_per_key_per_rank .append (
2509
- self .stride_per_key_per_rank ()[index ]
2510
- )
2544
+ if permuted_stride_per_key is not None :
2545
+ permuted_stride_per_key .append (self ._stride_per_key [index ])
2546
+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2547
+ if self .variable_stride_per_key () and _stride_per_key_per_rank is not None :
2548
+ permuted_stride_per_key_per_rank = _stride_per_key_per_rank [indices , :]
2549
+ else :
2550
+ permuted_stride_per_key_per_rank = None
2511
2551
2512
2552
permuted_length_per_key_sum = sum (permuted_length_per_key )
2513
2553
if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2559,18 +2599,16 @@ def permute(
2559
2599
self .weights_or_none (),
2560
2600
permuted_length_per_key_sum ,
2561
2601
)
2562
- stride_per_key_per_rank = (
2563
- permuted_stride_per_key_per_rank if self .variable_stride_per_key () else None
2564
- )
2602
+
2565
2603
kjt = KeyedJaggedTensor (
2566
2604
keys = permuted_keys ,
2567
2605
values = permuted_values ,
2568
2606
weights = permuted_weights ,
2569
2607
lengths = permuted_lengths .view (- 1 ),
2570
2608
offsets = None ,
2571
2609
stride = self ._stride ,
2572
- stride_per_key_per_rank = stride_per_key_per_rank ,
2573
- stride_per_key = None ,
2610
+ stride_per_key_per_rank = permuted_stride_per_key_per_rank ,
2611
+ stride_per_key = permuted_stride_per_key ,
2574
2612
length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
2575
2613
lengths_offset_per_key = None ,
2576
2614
offset_per_key = None ,
@@ -2887,7 +2925,7 @@ def dist_init(
2887
2925
2888
2926
if variable_stride_per_key :
2889
2927
assert stride_per_rank_per_key is not None
2890
- stride_per_key_per_rank_tensor : torch .Tensor = stride_per_rank_per_key .view (
2928
+ stride_per_key_per_rank : torch .Tensor = stride_per_rank_per_key .view (
2891
2929
num_workers , len (keys )
2892
2930
).T .cpu ()
2893
2931
@@ -2924,23 +2962,18 @@ def dist_init(
2924
2962
weights ,
2925
2963
)
2926
2964
2927
- stride_per_key_per_rank = torch .jit .annotate (
2928
- List [List [int ]], stride_per_key_per_rank_tensor .tolist ()
2929
- )
2965
+ if stride_per_key_per_rank .numel () == 0 :
2966
+ stride_per_key_per_rank = torch .zeros (
2967
+ (len (keys ), 1 ), device = "cpu" , dtype = torch .int64
2968
+ )
2930
2969
2931
- if not stride_per_key_per_rank :
2932
- stride_per_key_per_rank = [[0 ]] * len (keys )
2933
2970
if stagger > 1 :
2934
- stride_per_key_per_rank_stagger : List [List [int ]] = []
2935
2971
local_world_size = num_workers // stagger
2936
- for i in range (len (keys )):
2937
- stride_per_rank_stagger : List [int ] = []
2938
- for j in range (local_world_size ):
2939
- stride_per_rank_stagger .extend (
2940
- stride_per_key_per_rank [i ][j ::local_world_size ]
2941
- )
2942
- stride_per_key_per_rank_stagger .append (stride_per_rank_stagger )
2943
- stride_per_key_per_rank = stride_per_key_per_rank_stagger
2972
+ indices = [
2973
+ list (range (i , num_workers , local_world_size ))
2974
+ for i in range (local_world_size )
2975
+ ]
2976
+ stride_per_key_per_rank = stride_per_key_per_rank [:, indices ]
2944
2977
2945
2978
kjt = KeyedJaggedTensor (
2946
2979
keys = keys ,
0 commit comments