Skip to content

Commit 6fbb95f

Browse files
committed
refactor: Simplify builder.py per review - use layer indices instead of type array
1 parent 6d5fd84 commit 6fbb95f

File tree

4 files changed

+51
-45
lines changed

4 files changed

+51
-45
lines changed

src/config.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,17 @@ struct StringArray_Element : JSON::Element {
364364
std::vector<std::string>& v_;
365365
};
366366

367+
struct IntArray_Element : JSON::Element {
368+
explicit IntArray_Element(std::vector<int>& v) : v_{v} {}
369+
370+
void OnValue(std::string_view name, JSON::Value value) override {
371+
v_.push_back(static_cast<int>(JSON::Get<double>(value)));
372+
}
373+
374+
private:
375+
std::vector<int>& v_;
376+
};
377+
367378
struct StringStringMap_Element : JSON::Element {
368379
explicit StringStringMap_Element(std::unordered_map<std::string, std::string>& v) : v_{v} {}
369380

@@ -469,15 +480,19 @@ struct SlidingWindow_Element : JSON::Element {
469480
}
470481

471482
Element& OnArray(std::string_view name) override {
472-
if (name == "layer_types") {
473-
return layer_types_;
483+
if (name == "layers") {
484+
// Lazy initialize layers_ when first accessed
485+
if (!layers_) {
486+
layers_ = std::make_unique<IntArray_Element>(v_->layers);
487+
}
488+
return *layers_;
474489
}
475490
throw JSON::unknown_value_error{};
476491
}
477492

478493
private:
479494
std::optional<Config::Model::Decoder::SlidingWindow>& v_;
480-
StringArray_Element layer_types_{v_->layer_types};
495+
std::unique_ptr<IntArray_Element> layers_;
481496
};
482497

483498
struct Encoder_Element : JSON::Element {

src/config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ struct Config {
200200
std::string alignment{"right"}; // The alignment of the window, either "left" or "right"
201201
bool slide_key_value_cache{true}; // Whether to slide the key-value cache along with the input prompt
202202
bool slide_inputs{true}; // Whether to slide the input prompt along with the key-value cache
203-
std::vector<std::string> layer_types; // Layer-specific attention types: "full_attention" or "sliding_attention"
203+
std::vector<int> layers; // Layer indices that use sliding window attention (for models with alternating patterns)
204204
};
205205
std::optional<SlidingWindow> sliding_window;
206206

src/models/kv_cache.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "kv_cache.h"
77
#include "windowed_kv_cache.h"
88
#include "../openvino/interface.h"
9+
#include <unordered_set>
910

1011
namespace Generators {
1112

@@ -179,21 +180,27 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state)
179180
if (state.model_.p_device_->GetType() == DeviceType::NvTensorRtRtx &&
180181
model_.config_->model.decoder.sliding_window.has_value() &&
181182
model_.config_->model.decoder.sliding_window->window_size > 0 &&
182-
!model_.config_->model.decoder.sliding_window->layer_types.empty()) {
183-
// Use per-layer allocation based on layer_types
183+
!model_.config_->model.decoder.sliding_window->layers.empty()) {
184+
// Use per-layer allocation based on sliding window layer indices
184185
use_layer_types_ = true;
185186
layer_shapes_.resize(layer_count_);
186187

187188
int sliding_window_size = model_.config_->model.decoder.sliding_window->window_size;
188189
int max_length = state_.params_->search.max_length;
189190

191+
// Create a set of sliding window layer indices for fast lookup
192+
std::unordered_set<int> sliding_layers(
193+
model_.config_->model.decoder.sliding_window->layers.begin(),
194+
model_.config_->model.decoder.sliding_window->layers.end());
195+
190196
for (int layer_idx = 0; layer_idx < layer_count_; ++layer_idx) {
191197
layer_shapes_[layer_idx] = shape_; // Copy base shape
192198

193-
const std::string& layer_type = model_.config_->model.decoder.sliding_window->layer_types[layer_idx];
194-
if (layer_type == "sliding_attention") {
199+
if (sliding_layers.count(layer_idx) > 0) {
200+
// Sliding window layer
195201
layer_shapes_[layer_idx][2] = std::min(max_length, sliding_window_size);
196-
} else { // "full_attention"
202+
} else {
203+
// Full attention layer
197204
layer_shapes_[layer_idx][2] = max_length;
198205
}
199206
}

src/python/py/models/builder.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,12 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
465465
if self.ep == "trt-rtx" and self.window_size is not None and self.window_size > 0:
466466
genai_config["model"]["decoder"]["sliding_window"] = {"window_size": self.window_size, "slide_key_value_cache": False, "slide_inputs": False}
467467

468-
# Add layer-specific attention types if model has alternating attention patterns
468+
# Add layer indices for sliding window layers if model has alternating attention patterns
469469
layer_types = self.get_layer_types()
470470
if layer_types is not None:
471-
genai_config["model"]["decoder"]["sliding_window"]["layer_types"] = layer_types
471+
# Export list of layer indices that use sliding window attention
472+
sliding_layers = [i for i, lt in enumerate(layer_types) if lt == "sliding_attention"]
473+
genai_config["model"]["decoder"]["sliding_window"]["layers"] = sliding_layers
472474

473475
if self.ep != "cpu":
474476
ep_name = self.ep.replace("trt-rtx", "NvTensorRtRtx")
@@ -487,16 +489,14 @@ def get_layer_types(self):
487489
"""
488490
return None
489491

490-
def use_alternating_kv_dimensions(self):
492+
def make_kv_value_cache_shape(self, layer_id, shape):
491493
"""
492-
Returns True if this model needs alternating KV cache dimension names.
493-
This is needed for models with alternating attention patterns when using TensorRT.
494+
Modifies KV cache shape dimension names for models with alternating attention patterns.
495+
For TensorRT EP with sliding window layers, replaces 'sequence' with 'sliding' in dimension name.
494496
"""
495-
# Enable for models with layer_types when using TensorRT EP
496-
if self.ep == "trt-rtx" and hasattr(self, 'get_layer_types'):
497-
layer_types = self.get_layer_types()
498-
return layer_types is not None
499-
return False
497+
if self.ep == "trt-rtx" and hasattr(self, "is_local") and self.is_local(layer_id):
498+
return [shape[0], shape[1], shape[2].replace("sequence", "sliding"), shape[3]]
499+
return shape
500500

501501
def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
502502
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=self.hf_remote, **extra_kwargs)
@@ -674,39 +674,23 @@ def make_inputs_and_outputs(self):
674674

675675
# Add KV cache to inputs and outputs
676676
for i in range(self.num_layers):
677-
# Use alternating dimension names if needed (for TensorRT with alternating attention)
678-
if self.use_alternating_kv_dimensions():
679-
layer_types = self.get_layer_types()
680-
layer_type = layer_types[i] if layer_types and i < len(layer_types) else "full_attention"
681-
682-
# Use dimension name based on attention type
683-
if layer_type == "sliding_attention":
684-
dim_suffix = "_sliding"
685-
else: # "full_attention"
686-
dim_suffix = "_full"
687-
688-
past_key_shape = ["batch_size", self.num_kv_heads, f"past_sequence_length{dim_suffix}", self.head_size]
689-
past_value_shape = ["batch_size", self.num_kv_heads, f"past_sequence_length{dim_suffix}", self.head_size]
690-
present_key_shape = ["batch_size", self.num_kv_heads, f"total_sequence_length{dim_suffix}", self.head_size]
691-
present_value_shape = ["batch_size", self.num_kv_heads, f"total_sequence_length{dim_suffix}", self.head_size]
692-
else:
693-
# Use standard dimension names (current behavior)
694-
past_key_shape = self.input_shapes["past_key_values.key"]
695-
past_value_shape = self.input_shapes["past_key_values.value"]
696-
present_key_shape = self.output_shapes["present.key"]
697-
present_value_shape = self.output_shapes["present.value"]
698-
699677
# Add KV cache to inputs
700678
key_name = f"past_key_values.{i}.key"
701-
inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=past_key_shape))
679+
key_shape = self.make_kv_value_cache_shape(i, self.input_shapes["past_key_values.key"])
680+
inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=key_shape))
681+
702682
value_name = f"past_key_values.{i}.value"
703-
inputs.append(self.make_value(value_name, dtype=self.input_types["past_key_values.value"], shape=past_value_shape))
683+
value_shape = self.make_kv_value_cache_shape(i, self.input_shapes["past_key_values.value"])
684+
inputs.append(self.make_value(value_name, dtype=self.input_types["past_key_values.value"], shape=value_shape))
704685

705686
# Add KV cache to outputs
706687
key_name = f"present.{i}.key"
707-
outputs.append(self.make_value(key_name, dtype=self.output_types["present.key"], shape=present_key_shape))
688+
key_shape = self.make_kv_value_cache_shape(i, self.output_shapes["present.key"])
689+
outputs.append(self.make_value(key_name, dtype=self.output_types["present.key"], shape=key_shape))
690+
708691
value_name = f"present.{i}.value"
709-
outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=present_value_shape))
692+
value_shape = self.make_kv_value_cache_shape(i, self.output_shapes["present.value"])
693+
outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=value_shape))
710694

711695
def make_constant(self, name):
712696
# Make constant ops for 0, 1, 2, 3, etc.

0 commit comments

Comments
 (0)