@@ -358,7 +358,7 @@ def __init__(
358
358
self ._feature_names : List [str ] = []
359
359
self ._feature_splits : List [int ] = []
360
360
self ._length_per_key : List [int ] = []
361
- self ._features_order : List [int ] = []
361
+ self ._features_order : Optional [ List [int ]] = None
362
362
# Registering in a List instead of ModuleList because we want don't want them to be auto-registered.
363
363
# Their states will be modified via self.embedding_bags
364
364
self ._emb_modules : List [nn .Module ] = []
@@ -502,20 +502,22 @@ def forward(
502
502
kjt_keys = _get_kjt_keys (features )
503
503
# Cache the features order since the features will always have the same order of keys in inference.
504
504
if getattr (self , MODULE_ATTR_CACHE_FEATURES_ORDER , False ):
505
- if self ._features_order == [] :
506
- for k in self ._feature_names :
507
- self ._features_order . append ( kjt_keys . index ( k ))
508
- self .register_buffer (
509
- "_features_order_tensor" ,
510
- torch .tensor (
511
- data = self ._features_order ,
512
- device = features .device (),
513
- dtype = torch .int32 ,
514
- ),
515
- persistent = False ,
516
- )
505
+ if self ._features_order is None :
506
+ self . _features_order = [ kjt_keys . index ( k ) for k in self ._feature_names ]
507
+ if self ._features_order :
508
+ self .register_buffer (
509
+ "_features_order_tensor" ,
510
+ torch .tensor (
511
+ data = self ._features_order ,
512
+ device = features .device (),
513
+ dtype = torch .int32 ,
514
+ ),
515
+ persistent = False ,
516
+ )
517
517
kjt_permute = _permute_kjt (
518
- features , self ._features_order , self ._features_order_tensor
518
+ features ,
519
+ self ._features_order ,
520
+ getattr (self , "_features_order_tensor" , None ),
519
521
)
520
522
else :
521
523
kjt_permute_order = [kjt_keys .index (k ) for k in self ._feature_names ]
@@ -752,7 +754,7 @@ def __init__( # noqa C901
752
754
self .row_alignment = row_alignment
753
755
self ._key_to_tables : Dict [DataType , List [EmbeddingConfig ]] = defaultdict (list )
754
756
self ._feature_names : List [str ] = []
755
- self ._features_order : List [int ] = []
757
+ self ._features_order : Optional [ List [int ]] = None
756
758
757
759
self ._table_name_to_quantized_weights : Optional [
758
760
Dict [str , Tuple [Tensor , Tensor ]]
@@ -881,20 +883,22 @@ def forward(
881
883
kjt_keys = _get_kjt_keys (features )
882
884
# Cache the features order since the features will always have the same order of keys in inference.
883
885
if getattr (self , MODULE_ATTR_CACHE_FEATURES_ORDER , False ):
884
- if self ._features_order == [] :
885
- for k in self ._feature_names :
886
- self ._features_order . append ( kjt_keys . index ( k ))
887
- self .register_buffer (
888
- "_features_order_tensor" ,
889
- torch .tensor (
890
- data = self ._features_order ,
891
- device = features .device (),
892
- dtype = torch .int32 ,
893
- ),
894
- persistent = False ,
895
- )
886
+ if self ._features_order is None :
887
+ self . _features_order = [ kjt_keys . index ( k ) for k in self ._feature_names ]
888
+ if self ._features_order :
889
+ self .register_buffer (
890
+ "_features_order_tensor" ,
891
+ torch .tensor (
892
+ data = self ._features_order ,
893
+ device = features .device (),
894
+ dtype = torch .int32 ,
895
+ ),
896
+ persistent = False ,
897
+ )
896
898
kjt_permute = _permute_kjt (
897
- features , self ._features_order , self ._features_order_tensor
899
+ features ,
900
+ self ._features_order ,
901
+ getattr (self , "_features_order_tensor" , None ),
898
902
)
899
903
else :
900
904
kjt_permute_order = [kjt_keys .index (k ) for k in self ._feature_names ]
0 commit comments