Skip to content

Commit b9ee0cb

Browse files
committed
Simplify layer-wise KV logic - compute layer indices inline, remove get_layer_types()
1 parent 6fbb95f commit b9ee0cb

File tree

1 file changed

+13
-34
lines changed

1 file changed

+13
-34
lines changed

src/python/py/models/builder.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,15 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
463463
}
464464

465465
if self.ep == "trt-rtx" and self.window_size is not None and self.window_size > 0:
466-
genai_config["model"]["decoder"]["sliding_window"] = {"window_size": self.window_size, "slide_key_value_cache": False, "slide_inputs": False}
466+
# Compute layer indices that use sliding window attention
467+
layer_idxs = [layer_id for layer_id in range(self.num_layers) if hasattr(self, "is_local") and self.is_local(layer_id)]
467468

468-
# Add layer indices for sliding window layers if model has alternating attention patterns
469-
layer_types = self.get_layer_types()
470-
if layer_types is not None:
471-
# Export list of layer indices that use sliding window attention
472-
sliding_layers = [i for i, lt in enumerate(layer_types) if lt == "sliding_attention"]
473-
genai_config["model"]["decoder"]["sliding_window"]["layers"] = sliding_layers
469+
genai_config["model"]["decoder"]["sliding_window"] = {
470+
"window_size": self.window_size,
471+
"slide_key_value_cache": False,
472+
"slide_inputs": False,
473+
"layers": layer_idxs
474+
}
474475

475476
if self.ep != "cpu":
476477
ep_name = self.ep.replace("trt-rtx", "NvTensorRtRtx")
@@ -481,15 +482,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
481482
with open(os.path.join(out_dir,"genai_config.json"), "w") as f:
482483
json.dump(genai_config, f, indent=4)
483484

484-
def get_layer_types(self):
485-
"""
486-
Returns a list of attention types for each layer.
487-
Override in subclasses to provide layer-specific attention patterns.
488-
Returns None for models with uniform attention across all layers.
489-
"""
490-
return None
491-
492-
def make_kv_value_cache_shape(self, layer_id, shape):
485+
def make_key_value_cache_shape(self, layer_id, shape):
493486
"""
494487
Modifies KV cache shape dimension names for models with alternating attention patterns.
495488
For TensorRT EP with sliding window layers, replaces 'sequence' with 'sliding' in dimension name.
@@ -676,20 +669,20 @@ def make_inputs_and_outputs(self):
676669
for i in range(self.num_layers):
677670
# Add KV cache to inputs
678671
key_name = f"past_key_values.{i}.key"
679-
key_shape = self.make_kv_value_cache_shape(i, self.input_shapes["past_key_values.key"])
672+
key_shape = self.make_key_value_cache_shape(i, self.input_shapes["past_key_values.key"])
680673
inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=key_shape))
681674

682675
value_name = f"past_key_values.{i}.value"
683-
value_shape = self.make_kv_value_cache_shape(i, self.input_shapes["past_key_values.value"])
676+
value_shape = self.make_key_value_cache_shape(i, self.input_shapes["past_key_values.value"])
684677
inputs.append(self.make_value(value_name, dtype=self.input_types["past_key_values.value"], shape=value_shape))
685678

686679
# Add KV cache to outputs
687680
key_name = f"present.{i}.key"
688-
key_shape = self.make_kv_value_cache_shape(i, self.output_shapes["present.key"])
681+
key_shape = self.make_key_value_cache_shape(i, self.output_shapes["present.key"])
689682
outputs.append(self.make_value(key_name, dtype=self.output_types["present.key"], shape=key_shape))
690683

691684
value_name = f"present.{i}.value"
692-
value_shape = self.make_kv_value_cache_shape(i, self.output_shapes["present.value"])
685+
value_shape = self.make_key_value_cache_shape(i, self.output_shapes["present.value"])
693686
outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=value_shape))
694687

695688
def make_constant(self, name):
@@ -3485,20 +3478,6 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
34853478
super().make_attention(layer_id, attention, root_input, **kwargs)
34863479
self.window_size = original_window_size
34873480

3488-
def get_layer_types(self):
3489-
"""
3490-
Gemma2 uses alternating attention patterns:
3491-
- Even layers (0, 2, 4, ...): full_attention
3492-
- Odd layers (1, 3, 5, ...): sliding_attention
3493-
"""
3494-
layer_types = []
3495-
for layer_id in range(self.num_layers):
3496-
if self.is_local(layer_id):
3497-
layer_types.append("sliding_attention")
3498-
else:
3499-
layer_types.append("full_attention")
3500-
return layer_types
3501-
35023481

35033482
class Phi3MiniModel(MistralModel):
35043483
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):

0 commit comments

Comments
 (0)