@@ -465,10 +465,12 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
465465 if self .ep == "trt-rtx" and self .window_size is not None and self .window_size > 0 :
466466 genai_config ["model" ]["decoder" ]["sliding_window" ] = {"window_size" : self .window_size , "slide_key_value_cache" : False , "slide_inputs" : False }
467467
468- # Add layer-specific attention types if model has alternating attention patterns
468+ # Add layer indices for sliding window layers if model has alternating attention patterns
469469 layer_types = self .get_layer_types ()
470470 if layer_types is not None :
471- genai_config ["model" ]["decoder" ]["sliding_window" ]["layer_types" ] = layer_types
471+ # Export list of layer indices that use sliding window attention
472+ sliding_layers = [i for i , lt in enumerate (layer_types ) if lt == "sliding_attention" ]
473+ genai_config ["model" ]["decoder" ]["sliding_window" ]["layers" ] = sliding_layers
472474
473475 if self .ep != "cpu" :
474476 ep_name = self .ep .replace ("trt-rtx" , "NvTensorRtRtx" )
@@ -487,16 +489,14 @@ def get_layer_types(self):
487489 """
488490 return None
489491
490- def use_alternating_kv_dimensions (self ):
492+ def make_kv_value_cache_shape (self , layer_id , shape ):
491493 """
492- Returns True if this model needs alternating KV cache dimension names .
493- This is needed for models with alternating attention patterns when using TensorRT .
494+ Modifies KV cache shape dimension names for models with alternating attention patterns .
495+ For TensorRT EP with sliding window layers, replaces 'sequence' with 'sliding' in dimension name .
494496 """
495- # Enable for models with layer_types when using TensorRT EP
496- if self .ep == "trt-rtx" and hasattr (self , 'get_layer_types' ):
497- layer_types = self .get_layer_types ()
498- return layer_types is not None
499- return False
497+ if self .ep == "trt-rtx" and hasattr (self , "is_local" ) and self .is_local (layer_id ):
498+ return [shape [0 ], shape [1 ], shape [2 ].replace ("sequence" , "sliding" ), shape [3 ]]
499+ return shape
500500
501501 def save_processing (self , model_name_or_path , extra_kwargs , out_dir ):
502502 tokenizer = AutoTokenizer .from_pretrained (model_name_or_path , token = self .hf_token , trust_remote_code = self .hf_remote , ** extra_kwargs )
@@ -674,39 +674,23 @@ def make_inputs_and_outputs(self):
674674
675675 # Add KV cache to inputs and outputs
676676 for i in range (self .num_layers ):
677- # Use alternating dimension names if needed (for TensorRT with alternating attention)
678- if self .use_alternating_kv_dimensions ():
679- layer_types = self .get_layer_types ()
680- layer_type = layer_types [i ] if layer_types and i < len (layer_types ) else "full_attention"
681-
682- # Use dimension name based on attention type
683- if layer_type == "sliding_attention" :
684- dim_suffix = "_sliding"
685- else : # "full_attention"
686- dim_suffix = "_full"
687-
688- past_key_shape = ["batch_size" , self .num_kv_heads , f"past_sequence_length{ dim_suffix } " , self .head_size ]
689- past_value_shape = ["batch_size" , self .num_kv_heads , f"past_sequence_length{ dim_suffix } " , self .head_size ]
690- present_key_shape = ["batch_size" , self .num_kv_heads , f"total_sequence_length{ dim_suffix } " , self .head_size ]
691- present_value_shape = ["batch_size" , self .num_kv_heads , f"total_sequence_length{ dim_suffix } " , self .head_size ]
692- else :
693- # Use standard dimension names (current behavior)
694- past_key_shape = self .input_shapes ["past_key_values.key" ]
695- past_value_shape = self .input_shapes ["past_key_values.value" ]
696- present_key_shape = self .output_shapes ["present.key" ]
697- present_value_shape = self .output_shapes ["present.value" ]
698-
699677 # Add KV cache to inputs
700678 key_name = f"past_key_values.{ i } .key"
701- inputs .append (self .make_value (key_name , dtype = self .input_types ["past_key_values.key" ], shape = past_key_shape ))
679+ key_shape = self .make_kv_value_cache_shape (i , self .input_shapes ["past_key_values.key" ])
680+ inputs .append (self .make_value (key_name , dtype = self .input_types ["past_key_values.key" ], shape = key_shape ))
681+
702682 value_name = f"past_key_values.{ i } .value"
703- inputs .append (self .make_value (value_name , dtype = self .input_types ["past_key_values.value" ], shape = past_value_shape ))
683+ value_shape = self .make_kv_value_cache_shape (i , self .input_shapes ["past_key_values.value" ])
684+ inputs .append (self .make_value (value_name , dtype = self .input_types ["past_key_values.value" ], shape = value_shape ))
704685
705686 # Add KV cache to outputs
706687 key_name = f"present.{ i } .key"
707- outputs .append (self .make_value (key_name , dtype = self .output_types ["present.key" ], shape = present_key_shape ))
688+ key_shape = self .make_kv_value_cache_shape (i , self .output_shapes ["present.key" ])
689+ outputs .append (self .make_value (key_name , dtype = self .output_types ["present.key" ], shape = key_shape ))
690+
708691 value_name = f"present.{ i } .value"
709- outputs .append (self .make_value (value_name , dtype = self .output_types ["present.value" ], shape = present_value_shape ))
692+ value_shape = self .make_kv_value_cache_shape (i , self .output_shapes ["present.value" ])
693+ outputs .append (self .make_value (value_name , dtype = self .output_types ["present.value" ], shape = value_shape ))
710694
711695 def make_constant (self , name ):
712696 # Make constant ops for 0, 1, 2, 3, etc.
0 commit comments