Skip to content

Commit d4eabac

Browse files
authored
Layer-wise KV Cache Allocation for Models with Alternating Attention Patterns (#1832)
## 📋 Problem Statement ### Background Gemma2 models use an alternating attention pattern: - **Even layers (0, 2, 4, ...)**: Full attention (requires full context, e.g., 8K tokens) - **Odd layers (1, 3, 5, ...)**: Sliding window attention (only needs 4K tokens) PR #1523 applied a **uniform** sliding window size (4K) to all layers for NvTensorRtRtx EP. This was incorrect ## ✅ Solution ### Approach 1. **Export layer-specific attention types** during model building 2. **Use different dimension names** in ONNX to satisfy TensorRT constraints: - `past_sequence_length_full` for full attention layers - `past_sequence_length_sliding` for sliding window layers - As the shapes for full attention and sliding attention will be diffrent, trt-rtx expects the diff names 3. **Implement per-layer KV cache allocation** in runtime based on attention type 4. **Allocate optimal memory** for each layer type: Below eg for Gemma - Full attention: 8192 tokens - Sliding window: 4096 tokens **Memory Savings Examples**: ### Gemma2 - **Pattern**: Alternating (every other layer: full, sliding, full, sliding, ...) - **Memory**: 8K (full) vs 4K (sliding) - **Savings**: ~25% reduction (13 full + 13 sliding vs 26 full) ## Benefits for Multiple Models This approach enables significant memory optimization for various model architectures with mixed attention patterns: ### Gemma2 - **Pattern**: Alternating (every other layer: full, sliding, full, sliding, ...) - **Memory**: 8K (full) vs 4K (sliding) - **Savings**: ~25% reduction (13 full + 13 sliding vs 26 full) ### Gemma3-4B - **Pattern**: Every 6th layer uses global attention (5 sliding + 1 full) - **Layers**: 34 layers (29 sliding + 5 full) - **Memory**: 128K (full) vs 1K (sliding) - **Savings**: **~85% reduction** - huge improvement in KV memory requirement ### GPT-OSS - **Pattern**: Alternating global and sliding windows - **Memory**: 131K (full) vs 128 (sliding) - **Savings**: **~99.9% reduction** for sliding layers - huge improvement in KV memory requirement - Enables support for extremely long context windows with manageable memory footprint
1 parent 26f5fd2 commit d4eabac

File tree

7 files changed

+213
-40
lines changed

7 files changed

+213
-40
lines changed

benchmark/python/benchmark_e2e.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
247247
args.chat_template = '<s>{input}'
248248
elif model_type.startswith("qwen2"):
249249
args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n'
250+
elif model_type.startswith("gemma"):
251+
# Gemma and Gemma2 models use this format
252+
args.chat_template = '<start_of_turn>user\n{input}<end_of_turn>\n<start_of_turn>model\n'
250253
else:
251254
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")
252255

examples/python/model-qa.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,26 @@ def main(args):
146146

147147
generator = og.Generator(model, params)
148148
if args.verbose: print("Generator created")
149-
if guidance_type == "json_schema" or guidance_type == "lark_grammar":
150-
messages = f"""[{{"role": "system", "content": "{system_prompt}", "tools": "{prompt_tool_input}"}}, {{"role": "user", "content": "{text}"}}]"""
149+
150+
# Create messages with proper JSON encoding
151+
# Gemma2 models don't support system role, so we prepend system prompt to user message
152+
if model.type == "gemma2":
153+
combined_message = f"{system_prompt}\n\n{text}" if system_prompt else text
154+
messages_list = [{"role": "user", "content": combined_message}]
155+
elif guidance_type == "json_schema" or guidance_type == "lark_grammar":
156+
messages_list = [
157+
{"role": "system", "content": system_prompt, "tools": prompt_tool_input},
158+
{"role": "user", "content": text}
159+
]
151160
else:
152-
messages = f"""[{{"role": "system", "content": "{system_prompt}"}}, {{"role": "user", "content": "{text}"}}]"""
161+
messages_list = [
162+
{"role": "system", "content": system_prompt},
163+
{"role": "user", "content": text}
164+
]
165+
166+
# Convert to JSON string for tokenizer
167+
messages = json.dumps(messages_list)
168+
153169
# Apply Chat Template
154170
if model.type == "marian-ssru":
155171
prompt = text

src/config.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,17 @@ struct StringArray_Element : JSON::Element {
359359
std::vector<std::string>& v_;
360360
};
361361

362+
struct IntArray_Element : JSON::Element {
363+
explicit IntArray_Element(std::vector<int>& v) : v_{v} {}
364+
365+
void OnValue(std::string_view name, JSON::Value value) override {
366+
v_.push_back(static_cast<int>(JSON::Get<double>(value)));
367+
}
368+
369+
private:
370+
std::vector<int>& v_;
371+
};
372+
362373
struct StringStringMap_Element : JSON::Element {
363374
explicit StringStringMap_Element(std::unordered_map<std::string, std::string>& v) : v_{v} {}
364375

@@ -470,8 +481,20 @@ struct SlidingWindow_Element : JSON::Element {
470481
}
471482
}
472483

484+
Element& OnArray(std::string_view name) override {
485+
if (name == "layers") {
486+
// Lazy initialize layers_ when first accessed
487+
if (!layers_) {
488+
layers_ = std::make_unique<IntArray_Element>(v_->layers);
489+
}
490+
return *layers_;
491+
}
492+
throw JSON::unknown_value_error{};
493+
}
494+
473495
private:
474496
std::optional<Config::Model::Decoder::SlidingWindow>& v_;
497+
std::unique_ptr<IntArray_Element> layers_;
475498
};
476499

477500
struct Encoder_Element : JSON::Element {

src/config.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ struct Config {
207207
std::string alignment{"right"}; // The alignment of the window, either "left" or "right"
208208
bool slide_key_value_cache{true}; // Whether to slide the key-value cache along with the input prompt
209209
bool slide_inputs{true}; // Whether to slide the input prompt along with the key-value cache
210+
std::vector<int> layers; // Layer indices that use sliding window attention (for models with alternating patterns)
210211
};
211212
std::optional<SlidingWindow> sliding_window;
212213

src/models/kv_cache.cpp

Lines changed: 136 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "kv_cache.h"
77
#include "windowed_kv_cache.h"
88
#include "../openvino/interface.h"
9+
#include <algorithm>
910

1011
namespace Generators {
1112

@@ -175,21 +176,49 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state)
175176
}
176177

177178
// Set the size after empty_past_ has been created with 0 for this field
178-
if (state.model_.p_device_->GetType() == DeviceType::NvTensorRtRtx &&
179-
model_.config_->model.decoder.sliding_window.has_value() &&
179+
if (model_.config_->model.decoder.sliding_window.has_value() &&
180180
model_.config_->model.decoder.sliding_window->window_size > 0) {
181-
shape_[2] = std::min(state_.params_->search.max_length,
182-
model_.config_->model.decoder.sliding_window->window_size);
181+
const int sliding_window_size = model_.config_->model.decoder.sliding_window->window_size;
182+
const int max_length = state_.params_->search.max_length;
183+
184+
// Check if we need per-layer allocation for models with alternating attention patterns
185+
if (!model_.config_->model.decoder.sliding_window->layers.empty()) {
186+
// Use per-layer allocation based on sliding window layer indices
187+
layer_shapes_.resize(layer_count_);
188+
189+
// Initialize all layers with base shape and max_length
190+
for (int layer_idx = 0; layer_idx < layer_count_; ++layer_idx) {
191+
layer_shapes_[layer_idx] = shape_;
192+
layer_shapes_[layer_idx][2] = max_length;
193+
}
194+
195+
// Update sliding window layers with constrained cache size
196+
for (int layer_idx : model_.config_->model.decoder.sliding_window->layers) {
197+
layer_shapes_[layer_idx][2] = std::min(max_length, sliding_window_size);
198+
}
199+
// Set shape_[2] to max of all layer shapes for RewindTo bounds checking
200+
shape_[2] = max_length;
201+
} else {
202+
// Uniform sliding window allocation (backward compatibility)
203+
shape_[2] = std::min(max_length, sliding_window_size);
204+
}
183205
} else if (past_present_share_buffer_) {
184206
shape_[2] = state_.params_->search.max_length;
185207
}
186208

187209
try {
210+
// Allocate KV cache tensors - 2 per layer (key and value)
211+
// For per-layer shapes: alternates between key and value for each layer
212+
// For uniform shape: all tensors use the same shape
188213
for (int i = 0; i < layer_count_ * 2; ++i) {
189-
presents_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_));
214+
std::array<int64_t, 4> tensor_shape = shape_;
215+
if (!layer_shapes_.empty()) {
216+
// Per-layer allocation: use layer-specific shape
217+
// i/2 gives us the layer index since we have 2 tensors per layer
218+
tensor_shape = layer_shapes_[i / 2];
219+
}
190220

191-
// Zero the memory so we don't leak any data from the previous run
192-
// WebGPU device has no Zero() implementation yet. Since this zeroing is optional we disable it for WebGPU for now
221+
presents_.push_back(OrtValue::CreateTensor(Allocator(), tensor_shape, type_));
193222
if (Device().GetType() != DeviceType::WEBGPU) {
194223
ByteWrapTensor(Device(), *presents_.back()).Zero();
195224
}
@@ -240,10 +269,30 @@ void DefaultKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int total_le
240269
}
241270
}
242271

243-
shape_[2] = total_length;
244-
for (int i = 0; i < layer_count_ * 2; i++) {
245-
presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_);
246-
state_.outputs_[output_index_ + i] = presents_[i].get();
272+
if (!layer_shapes_.empty()) {
273+
// Update per-layer shapes based on total_length, but respect max allocations
274+
for (int layer_idx = 0; layer_idx < layer_count_; ++layer_idx) {
275+
const int max_cache_length = static_cast<int>(layer_shapes_[layer_idx][2]);
276+
const int actual_length = std::min(total_length, max_cache_length);
277+
278+
std::array<int64_t, 4> current_shape = layer_shapes_[layer_idx];
279+
current_shape[2] = actual_length;
280+
281+
// Key tensor
282+
presents_[layer_idx * 2] = OrtValue::CreateTensor(Allocator(), current_shape, type_);
283+
state_.outputs_[output_index_ + layer_idx * 2] = presents_[layer_idx * 2].get();
284+
285+
// Value tensor
286+
presents_[layer_idx * 2 + 1] = OrtValue::CreateTensor(Allocator(), current_shape, type_);
287+
state_.outputs_[output_index_ + layer_idx * 2 + 1] = presents_[layer_idx * 2 + 1].get();
288+
}
289+
} else {
290+
// Uniform shape update (existing behavior)
291+
shape_[2] = total_length;
292+
for (int i = 0; i < layer_count_ * 2; i++) {
293+
presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_);
294+
state_.outputs_[output_index_ + i] = presents_[i].get();
295+
}
247296
}
248297

249298
is_first_update_ = false;
@@ -271,39 +320,94 @@ void DefaultKeyValueCache::RewindTo(size_t index) {
271320

272321
template <typename T>
273322
void DefaultKeyValueCache::RewindPastTensorsTo(size_t index) {
274-
assert(index > 0 && shape_[2] >= static_cast<int64_t>(index) && !past_present_share_buffer_);
275-
std::array<int64_t, 4> new_shape = shape_;
276-
new_shape[2] = static_cast<int>(index);
277-
auto batch_x_num_heads = new_shape[0] * new_shape[1];
278-
auto new_length_x_head_size = new_shape[2] * new_shape[3];
279-
auto old_length_x_head_size = shape_[2] * new_shape[3];
280-
shape_[2] = new_shape[2];
281-
282-
for (int i = 0; i < layer_count_ * 2; i++) {
283-
OrtValue& present = *presents_[i];
284-
std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), shape_, type_);
323+
assert(index > 0 && !past_present_share_buffer_);
324+
325+
if (!layer_shapes_.empty()) {
326+
// Handle per-layer shapes
327+
// First validate that index doesn't exceed the global max_length
328+
int max_length = static_cast<int>(shape_[2]); // Set to max_length in constructor
329+
if (static_cast<int>(index) > max_length) {
330+
throw std::runtime_error("Requested rewind length exceeds max_length.");
331+
}
285332

286-
auto past_span = WrapTensor<T>(Device(), *past);
287-
auto present_span = WrapTensor<T>(Device(), present);
333+
for (int i = 0; i < layer_count_ * 2; i++) {
334+
const int layer_idx = i / 2;
335+
const std::array<int64_t, 4> layer_shape = layer_shapes_[layer_idx];
336+
const int layer_max_cache = static_cast<int>(layer_shape[2]);
337+
338+
// For each layer, rewind to min(index, layer's max capacity)
339+
// - Full attention layers: min(index, max_length)
340+
// - Sliding window layers: min(index, sliding_window_size)
341+
const int actual_rewind_length = std::min(static_cast<int>(index), layer_max_cache);
342+
343+
std::array<int64_t, 4> new_shape = layer_shape;
344+
new_shape[2] = actual_rewind_length;
345+
const auto batch_x_num_heads = new_shape[0] * new_shape[1];
346+
const auto new_length_x_head_size = new_shape[2] * new_shape[3];
347+
348+
OrtValue& present = *presents_[i];
349+
const auto present_shape = present.GetTensorTypeAndShapeInfo()->GetShape();
350+
const auto old_length_x_head_size = present_shape[2] * new_shape[3];
351+
352+
std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), new_shape, type_);
353+
auto past_span = WrapTensor<T>(Device(), *past);
354+
auto present_span = WrapTensor<T>(Device(), present);
355+
356+
for (int j = 0; j < batch_x_num_heads; j++) {
357+
auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
358+
auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size);
359+
past_data.CopyFrom(present_data);
360+
}
361+
pasts_[i] = std::move(past);
362+
state_.inputs_[input_index_ + i] = pasts_[i].get();
363+
}
364+
} else {
365+
// Uniform shape handling (existing behavior)
366+
assert(shape_[2] >= static_cast<int64_t>(index));
367+
std::array<int64_t, 4> new_shape = shape_;
368+
new_shape[2] = static_cast<int>(index);
369+
auto batch_x_num_heads = new_shape[0] * new_shape[1];
370+
auto new_length_x_head_size = new_shape[2] * new_shape[3];
371+
auto old_length_x_head_size = shape_[2] * new_shape[3];
372+
shape_[2] = new_shape[2];
288373

289-
for (int j = 0; j < batch_x_num_heads; j++) {
290-
auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
291-
auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size);
292-
past_data.CopyFrom(present_data);
374+
for (int i = 0; i < layer_count_ * 2; i++) {
375+
OrtValue& present = *presents_[i];
376+
std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), shape_, type_);
377+
378+
auto past_span = WrapTensor<T>(Device(), *past);
379+
auto present_span = WrapTensor<T>(Device(), present);
380+
381+
for (int j = 0; j < batch_x_num_heads; j++) {
382+
auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
383+
auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size);
384+
past_data.CopyFrom(present_data);
385+
}
386+
pasts_[i] = std::move(past);
387+
state_.inputs_[input_index_ + i] = pasts_[i].get();
293388
}
294-
pasts_[i] = std::move(past);
295-
state_.inputs_[input_index_ + i] = pasts_[i].get();
296389
}
297390
}
298391

299392
// Copy present state to past state reordered by the beam_indices
300393
template <typename ScoreType>
301394
void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_device, int index) {
302395
std::span<int32_t> beam_indices = beam_indices_device.CopyDeviceToCpu();
303-
auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3];
396+
397+
std::array<int64_t, 4> tensor_shape;
398+
if (!layer_shapes_.empty()) {
399+
// Get shape from the actual tensor for per-layer allocation
400+
OrtValue& present_value = *presents_[index];
401+
const auto present_shape = present_value.GetTensorTypeAndShapeInfo()->GetShape();
402+
std::copy(present_shape.begin(), present_shape.end(), tensor_shape.begin());
403+
} else {
404+
tensor_shape = shape_;
405+
}
406+
407+
auto block_size_per_beam = tensor_shape[1] * tensor_shape[2] * tensor_shape[3];
304408

305409
OrtValue& present_value = *presents_[index];
306-
std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator(), shape_);
410+
std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator(), tensor_shape);
307411

308412
auto past_span = WrapTensor<ScoreType>(Device(), *past_value);
309413
auto present_span = WrapTensor<ScoreType>(Device(), present_value);

src/models/kv_cache.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ struct DefaultKeyValueCache : KeyValueCache {
9797
std::array<int64_t, 4> shape_;
9898
ONNXTensorElementDataType type_;
9999

100+
// Support for per-layer KV cache shapes (for models with alternating attention patterns)
101+
std::vector<std::array<int64_t, 4>> layer_shapes_;
102+
100103
std::unique_ptr<OrtValue> empty_past_;
101104
std::vector<std::unique_ptr<OrtValue>> pasts_, presents_;
102105
std::vector<std::string> input_name_strings_, output_name_strings_;

src/python/py/models/builder.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,15 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
464464
}
465465

466466
if self.ep == "trt-rtx" and self.window_size is not None and self.window_size > 0:
467-
genai_config["model"]["decoder"]["sliding_window"] = {"window_size": self.window_size, "slide_key_value_cache": False, "slide_inputs": False}
467+
# Compute layer indices that use sliding window attention
468+
layer_idxs = [layer_id for layer_id in range(self.num_layers) if hasattr(self, "is_local") and self.is_local(layer_id)]
469+
470+
genai_config["model"]["decoder"]["sliding_window"] = {
471+
"window_size": self.window_size,
472+
"slide_key_value_cache": False,
473+
"slide_inputs": False,
474+
"layers": layer_idxs
475+
}
468476

469477
if self.ep != "cpu":
470478
ep_name = self.ep.replace("trt-rtx", "NvTensorRtRtx")
@@ -475,6 +483,15 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
475483
with open(os.path.join(out_dir,"genai_config.json"), "w") as f:
476484
json.dump(genai_config, f, indent=4)
477485

486+
def make_key_value_cache_shape(self, layer_id, shape):
487+
"""
488+
Modifies KV cache shape dimension names for models with alternating attention patterns.
489+
For TensorRT EP with sliding window layers, replaces 'sequence' with 'sliding' in dimension name.
490+
"""
491+
if self.ep == "trt-rtx" and hasattr(self, "is_local") and self.is_local(layer_id):
492+
return [shape[0], shape[1], shape[2].replace("sequence", "sliding"), shape[3]]
493+
return shape
494+
478495
def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
479496
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=self.hf_remote, **extra_kwargs)
480497
print(f"Saving processing files in {out_dir} for GenAI")
@@ -653,15 +670,21 @@ def make_inputs_and_outputs(self):
653670
for i in range(self.num_layers):
654671
# Add KV cache to inputs
655672
key_name = f"past_key_values.{i}.key"
656-
inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"]))
673+
key_shape = self.make_key_value_cache_shape(i, self.input_shapes["past_key_values.key"])
674+
inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=key_shape))
675+
657676
value_name = f"past_key_values.{i}.value"
658-
inputs.append(self.make_value(value_name, dtype=self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"]))
677+
value_shape = self.make_key_value_cache_shape(i, self.input_shapes["past_key_values.value"])
678+
inputs.append(self.make_value(value_name, dtype=self.input_types["past_key_values.value"], shape=value_shape))
659679

660680
# Add KV cache to outputs
661681
key_name = f"present.{i}.key"
662-
outputs.append(self.make_value(key_name, dtype=self.output_types["present.key"], shape=self.output_shapes["present.key"]))
682+
key_shape = self.make_key_value_cache_shape(i, self.output_shapes["present.key"])
683+
outputs.append(self.make_value(key_name, dtype=self.output_types["present.key"], shape=key_shape))
684+
663685
value_name = f"present.{i}.value"
664-
outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=self.output_shapes["present.value"]))
686+
value_shape = self.make_key_value_cache_shape(i, self.output_shapes["present.value"])
687+
outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=value_shape))
665688

666689
def make_constant(self, name):
667690
# Make constant ops for 0, 1, 2, 3, etc.

0 commit comments

Comments
 (0)