@@ -463,14 +463,15 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
463463 }
464464
465465 if self .ep == "trt-rtx" and self .window_size is not None and self .window_size > 0 :
466- genai_config ["model" ]["decoder" ]["sliding_window" ] = {"window_size" : self .window_size , "slide_key_value_cache" : False , "slide_inputs" : False }
466+ # Compute layer indices that use sliding window attention
467+ layer_idxs = [layer_id for layer_id in range (self .num_layers ) if hasattr (self , "is_local" ) and self .is_local (layer_id )]
467468
468- # Add layer indices for sliding window layers if model has alternating attention patterns
469- layer_types = self .get_layer_types ()
470- if layer_types is not None :
471- # Export list of layer indices that use sliding window attention
472- sliding_layers = [ i for i , lt in enumerate ( layer_types ) if lt == "sliding_attention" ]
473- genai_config [ "model" ][ "decoder" ][ "sliding_window" ][ "layers" ] = sliding_layers
469+ genai_config [ " model" ][ "decoder" ][ "sliding_window" ] = {
470+ "window_size" : self .window_size ,
471+ "slide_key_value_cache" : False ,
472+ "slide_inputs" : False ,
473+ "layers" : layer_idxs
474+ }
474475
475476 if self .ep != "cpu" :
476477 ep_name = self .ep .replace ("trt-rtx" , "NvTensorRtRtx" )
@@ -481,15 +482,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
481482 with open (os .path .join (out_dir ,"genai_config.json" ), "w" ) as f :
482483 json .dump (genai_config , f , indent = 4 )
483484
484- def get_layer_types (self ):
485- """
486- Returns a list of attention types for each layer.
487- Override in subclasses to provide layer-specific attention patterns.
488- Returns None for models with uniform attention across all layers.
489- """
490- return None
491-
492- def make_kv_value_cache_shape (self , layer_id , shape ):
485+ def make_key_value_cache_shape (self , layer_id , shape ):
493486 """
494487 Modifies KV cache shape dimension names for models with alternating attention patterns.
495488 For TensorRT EP with sliding window layers, replaces 'sequence' with 'sliding' in dimension name.
@@ -676,20 +669,20 @@ def make_inputs_and_outputs(self):
676669 for i in range (self .num_layers ):
677670 # Add KV cache to inputs
678671 key_name = f"past_key_values.{ i } .key"
679- key_shape = self .make_kv_value_cache_shape (i , self .input_shapes ["past_key_values.key" ])
672+ key_shape = self .make_key_value_cache_shape (i , self .input_shapes ["past_key_values.key" ])
680673 inputs .append (self .make_value (key_name , dtype = self .input_types ["past_key_values.key" ], shape = key_shape ))
681674
682675 value_name = f"past_key_values.{ i } .value"
683- value_shape = self .make_kv_value_cache_shape (i , self .input_shapes ["past_key_values.value" ])
676+ value_shape = self .make_key_value_cache_shape (i , self .input_shapes ["past_key_values.value" ])
684677 inputs .append (self .make_value (value_name , dtype = self .input_types ["past_key_values.value" ], shape = value_shape ))
685678
686679 # Add KV cache to outputs
687680 key_name = f"present.{ i } .key"
688- key_shape = self .make_kv_value_cache_shape (i , self .output_shapes ["present.key" ])
681+ key_shape = self .make_key_value_cache_shape (i , self .output_shapes ["present.key" ])
689682 outputs .append (self .make_value (key_name , dtype = self .output_types ["present.key" ], shape = key_shape ))
690683
691684 value_name = f"present.{ i } .value"
692- value_shape = self .make_kv_value_cache_shape (i , self .output_shapes ["present.value" ])
685+ value_shape = self .make_key_value_cache_shape (i , self .output_shapes ["present.value" ])
693686 outputs .append (self .make_value (value_name , dtype = self .output_types ["present.value" ], shape = value_shape ))
694687
695688 def make_constant (self , name ):
@@ -3485,20 +3478,6 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
34853478 super ().make_attention (layer_id , attention , root_input , ** kwargs )
34863479 self .window_size = original_window_size
34873480
3488- def get_layer_types (self ):
3489- """
3490- Gemma2 uses alternating attention patterns:
3491- - Even layers (0, 2, 4, ...): full_attention
3492- - Odd layers (1, 3, 5, ...): sliding_attention
3493- """
3494- layer_types = []
3495- for layer_id in range (self .num_layers ):
3496- if self .is_local (layer_id ):
3497- layer_types .append ("sliding_attention" )
3498- else :
3499- layer_types .append ("full_attention" )
3500- return layer_types
3501-
35023481
35033482class Phi3MiniModel (MistralModel ):
35043483 def __init__ (self , config , io_dtype , onnx_dtype , ep , cache_dir , extra_options ):
0 commit comments