Skip to content

Commit 6d5fd84

Browse files
committed
Impliment interleaved KV cache managment for local and global kv
1 parent 420e000 commit 6d5fd84

File tree

7 files changed

+244
-40
lines changed

7 files changed

+244
-40
lines changed

benchmark/python/benchmark_e2e.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
247247
args.chat_template = '<s>{input}'
248248
elif model_type.startswith("qwen2"):
249249
args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n'
250+
elif model_type.startswith("gemma"):
251+
# Gemma and Gemma2 models use this format
252+
args.chat_template = '<start_of_turn>user\n{input}<end_of_turn>\n<start_of_turn>model\n'
250253
else:
251254
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")
252255

examples/python/model-qa.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,26 @@ def main(args):
146146

147147
generator = og.Generator(model, params)
148148
if args.verbose: print("Generator created")
149-
if guidance_type == "json_schema" or guidance_type == "lark_grammar":
150-
messages = f"""[{{"role": "system", "content": "{system_prompt}", "tools": "{prompt_tool_input}"}}, {{"role": "user", "content": "{text}"}}]"""
149+
150+
# Create messages with proper JSON encoding
151+
# Gemma2 models don't support system role, so we prepend system prompt to user message
152+
if model.type == "gemma2":
153+
combined_message = f"{system_prompt}\n\n{text}" if system_prompt else text
154+
messages_list = [{"role": "user", "content": combined_message}]
155+
elif guidance_type == "json_schema" or guidance_type == "lark_grammar":
156+
messages_list = [
157+
{"role": "system", "content": system_prompt, "tools": prompt_tool_input},
158+
{"role": "user", "content": text}
159+
]
151160
else:
152-
messages = f"""[{{"role": "system", "content": "{system_prompt}"}}, {{"role": "user", "content": "{text}"}}]"""
161+
messages_list = [
162+
{"role": "system", "content": system_prompt},
163+
{"role": "user", "content": text}
164+
]
165+
166+
# Convert to JSON string for tokenizer
167+
messages = json.dumps(messages_list)
168+
153169
# Apply Chat Template
154170
if model.type == "marian-ssru":
155171
prompt = text

src/config.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,16 @@ struct SlidingWindow_Element : JSON::Element {
468468
}
469469
}
470470

471+
Element& OnArray(std::string_view name) override {
472+
if (name == "layer_types") {
473+
return layer_types_;
474+
}
475+
throw JSON::unknown_value_error{};
476+
}
477+
471478
private:
472479
std::optional<Config::Model::Decoder::SlidingWindow>& v_;
480+
StringArray_Element layer_types_{v_->layer_types};
473481
};
474482

475483
struct Encoder_Element : JSON::Element {

src/config.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ struct Config {
200200
std::string alignment{"right"}; // The alignment of the window, either "left" or "right"
201201
bool slide_key_value_cache{true}; // Whether to slide the key-value cache along with the input prompt
202202
bool slide_inputs{true}; // Whether to slide the input prompt along with the key-value cache
203+
std::vector<std::string> layer_types; // Layer-specific attention types: "full_attention" or "sliding_attention"
203204
};
204205
std::optional<SlidingWindow> sliding_window;
205206

src/models/kv_cache.cpp

Lines changed: 145 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -175,23 +175,64 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state)
175175
}
176176

177177
// Set the size after empty_past_ has been created with 0 for this field
178+
// Check if we need to use per-layer allocation for models with alternating attention patterns
178179
if (state.model_.p_device_->GetType() == DeviceType::NvTensorRtRtx &&
179180
model_.config_->model.decoder.sliding_window.has_value() &&
180-
model_.config_->model.decoder.sliding_window->window_size > 0) {
181+
model_.config_->model.decoder.sliding_window->window_size > 0 &&
182+
!model_.config_->model.decoder.sliding_window->layer_types.empty()) {
183+
// Use per-layer allocation based on layer_types
184+
use_layer_types_ = true;
185+
layer_shapes_.resize(layer_count_);
186+
187+
int sliding_window_size = model_.config_->model.decoder.sliding_window->window_size;
188+
int max_length = state_.params_->search.max_length;
189+
190+
for (int layer_idx = 0; layer_idx < layer_count_; ++layer_idx) {
191+
layer_shapes_[layer_idx] = shape_; // Copy base shape
192+
193+
const std::string& layer_type = model_.config_->model.decoder.sliding_window->layer_types[layer_idx];
194+
if (layer_type == "sliding_attention") {
195+
layer_shapes_[layer_idx][2] = std::min(max_length, sliding_window_size);
196+
} else { // "full_attention"
197+
layer_shapes_[layer_idx][2] = max_length;
198+
}
199+
}
200+
} else if (state.model_.p_device_->GetType() == DeviceType::NvTensorRtRtx &&
201+
model_.config_->model.decoder.sliding_window.has_value() &&
202+
model_.config_->model.decoder.sliding_window->window_size > 0) {
203+
// Uniform sliding window allocation (backward compatibility)
181204
shape_[2] = std::min(state_.params_->search.max_length,
182205
model_.config_->model.decoder.sliding_window->window_size);
183206
} else if (past_present_share_buffer_) {
184207
shape_[2] = state_.params_->search.max_length;
185208
}
186209

187210
try {
188-
for (int i = 0; i < layer_count_ * 2; ++i) {
189-
presents_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_));
190-
191-
// Zero the memory so we don't leak any data from the previous run
192-
// WebGPU device has no Zero() implementation yet. Since this zeroing is optional we disable it for WebGPU for now
193-
if (Device().GetType() != DeviceType::WEBGPU) {
194-
ByteWrapTensor(Device(), *presents_.back()).Zero();
211+
if (use_layer_types_) {
212+
// Allocate per-layer with different shapes
213+
for (int layer_idx = 0; layer_idx < layer_count_; ++layer_idx) {
214+
// Key tensor
215+
presents_.push_back(OrtValue::CreateTensor(Allocator(), layer_shapes_[layer_idx], type_));
216+
if (Device().GetType() != DeviceType::WEBGPU) {
217+
ByteWrapTensor(Device(), *presents_.back()).Zero();
218+
}
219+
220+
// Value tensor
221+
presents_.push_back(OrtValue::CreateTensor(Allocator(), layer_shapes_[layer_idx], type_));
222+
if (Device().GetType() != DeviceType::WEBGPU) {
223+
ByteWrapTensor(Device(), *presents_.back()).Zero();
224+
}
225+
}
226+
} else {
227+
// Uniform allocation (existing behavior)
228+
for (int i = 0; i < layer_count_ * 2; ++i) {
229+
presents_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_));
230+
231+
// Zero the memory so we don't leak any data from the previous run
232+
// WebGPU device has no Zero() implementation yet. Since this zeroing is optional we disable it for WebGPU for now
233+
if (Device().GetType() != DeviceType::WEBGPU) {
234+
ByteWrapTensor(Device(), *presents_.back()).Zero();
235+
}
195236
}
196237
}
197238
} catch (const Ort::Exception&) {
@@ -240,10 +281,30 @@ void DefaultKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int total_le
240281
}
241282
}
242283

243-
shape_[2] = total_length;
244-
for (int i = 0; i < layer_count_ * 2; i++) {
245-
presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_);
246-
state_.outputs_[output_index_ + i] = presents_[i].get();
284+
if (use_layer_types_) {
285+
// Update per-layer shapes based on total_length, but respect max allocations
286+
for (int layer_idx = 0; layer_idx < layer_count_; ++layer_idx) {
287+
int max_cache_length = static_cast<int>(layer_shapes_[layer_idx][2]);
288+
int actual_length = std::min(total_length, max_cache_length);
289+
290+
std::array<int64_t, 4> current_shape = layer_shapes_[layer_idx];
291+
current_shape[2] = actual_length;
292+
293+
// Key tensor
294+
presents_[layer_idx * 2] = OrtValue::CreateTensor(Allocator(), current_shape, type_);
295+
state_.outputs_[output_index_ + layer_idx * 2] = presents_[layer_idx * 2].get();
296+
297+
// Value tensor
298+
presents_[layer_idx * 2 + 1] = OrtValue::CreateTensor(Allocator(), current_shape, type_);
299+
state_.outputs_[output_index_ + layer_idx * 2 + 1] = presents_[layer_idx * 2 + 1].get();
300+
}
301+
} else {
302+
// Uniform shape update (existing behavior)
303+
shape_[2] = total_length;
304+
for (int i = 0; i < layer_count_ * 2; i++) {
305+
presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_);
306+
state_.outputs_[output_index_ + i] = presents_[i].get();
307+
}
247308
}
248309

249310
is_first_update_ = false;
@@ -271,39 +332,90 @@ void DefaultKeyValueCache::RewindTo(size_t index) {
271332

272333
template <typename T>
273334
void DefaultKeyValueCache::RewindPastTensorsTo(size_t index) {
274-
assert(index > 0 && shape_[2] >= static_cast<int64_t>(index) && !past_present_share_buffer_);
275-
std::array<int64_t, 4> new_shape = shape_;
276-
new_shape[2] = static_cast<int>(index);
277-
auto batch_x_num_heads = new_shape[0] * new_shape[1];
278-
auto new_length_x_head_size = new_shape[2] * new_shape[3];
279-
auto old_length_x_head_size = shape_[2] * new_shape[3];
280-
shape_[2] = new_shape[2];
281-
282-
for (int i = 0; i < layer_count_ * 2; i++) {
283-
OrtValue& present = *presents_[i];
284-
std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), shape_, type_);
335+
assert(index > 0 && !past_present_share_buffer_);
336+
337+
if (use_layer_types_) {
338+
// Handle per-layer shapes
339+
for (int i = 0; i < layer_count_ * 2; i++) {
340+
int layer_idx = i / 2;
341+
std::array<int64_t, 4> layer_shape = layer_shapes_[layer_idx];
342+
int max_cache_length = static_cast<int>(layer_shape[2]);
343+
344+
// Ensure we don't rewind beyond what's available
345+
if (static_cast<int>(index) > max_cache_length) {
346+
throw std::runtime_error("Requested rewind length is greater than the layer's cache length.");
347+
}
348+
349+
std::array<int64_t, 4> new_shape = layer_shape;
350+
new_shape[2] = static_cast<int>(index);
351+
auto batch_x_num_heads = new_shape[0] * new_shape[1];
352+
auto new_length_x_head_size = new_shape[2] * new_shape[3];
353+
354+
OrtValue& present = *presents_[i];
355+
auto present_shape = present.GetTensorTypeAndShapeInfo()->GetShape();
356+
auto old_length_x_head_size = present_shape[2] * new_shape[3];
357+
358+
std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), new_shape, type_);
359+
auto past_span = WrapTensor<T>(Device(), *past);
360+
auto present_span = WrapTensor<T>(Device(), present);
361+
362+
for (int j = 0; j < batch_x_num_heads; j++) {
363+
auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
364+
auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size);
365+
past_data.CopyFrom(present_data);
366+
}
367+
pasts_[i] = std::move(past);
368+
state_.inputs_[input_index_ + i] = pasts_[i].get();
369+
}
370+
} else {
371+
// Uniform shape handling (existing behavior)
372+
assert(shape_[2] >= static_cast<int64_t>(index));
373+
std::array<int64_t, 4> new_shape = shape_;
374+
new_shape[2] = static_cast<int>(index);
375+
auto batch_x_num_heads = new_shape[0] * new_shape[1];
376+
auto new_length_x_head_size = new_shape[2] * new_shape[3];
377+
auto old_length_x_head_size = shape_[2] * new_shape[3];
378+
shape_[2] = new_shape[2];
285379

286-
auto past_span = WrapTensor<T>(Device(), *past);
287-
auto present_span = WrapTensor<T>(Device(), present);
380+
for (int i = 0; i < layer_count_ * 2; i++) {
381+
OrtValue& present = *presents_[i];
382+
std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), shape_, type_);
288383

289-
for (int j = 0; j < batch_x_num_heads; j++) {
290-
auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
291-
auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size);
292-
past_data.CopyFrom(present_data);
384+
auto past_span = WrapTensor<T>(Device(), *past);
385+
auto present_span = WrapTensor<T>(Device(), present);
386+
387+
for (int j = 0; j < batch_x_num_heads; j++) {
388+
auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size);
389+
auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size);
390+
past_data.CopyFrom(present_data);
391+
}
392+
pasts_[i] = std::move(past);
393+
state_.inputs_[input_index_ + i] = pasts_[i].get();
293394
}
294-
pasts_[i] = std::move(past);
295-
state_.inputs_[input_index_ + i] = pasts_[i].get();
296395
}
297396
}
298397

299398
// Copy present state to past state reordered by the beam_indices
300399
template <typename ScoreType>
301400
void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_device, int index) {
302401
std::span<int32_t> beam_indices = beam_indices_device.CopyDeviceToCpu();
303-
auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3];
402+
403+
std::array<int64_t, 4> tensor_shape;
404+
if (use_layer_types_) {
405+
// Get shape from the actual tensor for per-layer allocation
406+
OrtValue& present_value = *presents_[index];
407+
auto present_shape = present_value.GetTensorTypeAndShapeInfo()->GetShape();
408+
for (size_t i = 0; i < 4; i++) {
409+
tensor_shape[i] = present_shape[i];
410+
}
411+
} else {
412+
tensor_shape = shape_;
413+
}
414+
415+
auto block_size_per_beam = tensor_shape[1] * tensor_shape[2] * tensor_shape[3];
304416

305417
OrtValue& present_value = *presents_[index];
306-
std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator(), shape_);
418+
std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator(), tensor_shape);
307419

308420
auto past_span = WrapTensor<ScoreType>(Device(), *past_value);
309421
auto present_span = WrapTensor<ScoreType>(Device(), present_value);

src/models/kv_cache.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ struct DefaultKeyValueCache : KeyValueCache {
9797
std::array<int64_t, 4> shape_;
9898
ONNXTensorElementDataType type_;
9999

100+
// Support for per-layer KV cache shapes (for models with alternating attention patterns)
101+
bool use_layer_types_{false};
102+
std::vector<std::array<int64_t, 4>> layer_shapes_;
103+
100104
std::unique_ptr<OrtValue> empty_past_;
101105
std::vector<std::unique_ptr<OrtValue>> pasts_, presents_;
102106
std::vector<std::string> input_name_strings_, output_name_strings_;

src/python/py/models/builder.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,11 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
464464

465465
if self.ep == "trt-rtx" and self.window_size is not None and self.window_size > 0:
466466
genai_config["model"]["decoder"]["sliding_window"] = {"window_size": self.window_size, "slide_key_value_cache": False, "slide_inputs": False}
467+
468+
# Add layer-specific attention types if model has alternating attention patterns
469+
layer_types = self.get_layer_types()
470+
if layer_types is not None:
471+
genai_config["model"]["decoder"]["sliding_window"]["layer_types"] = layer_types
467472

468473
if self.ep != "cpu":
469474
ep_name = self.ep.replace("trt-rtx", "NvTensorRtRtx")
@@ -474,6 +479,25 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
474479
with open(os.path.join(out_dir,"genai_config.json"), "w") as f:
475480
json.dump(genai_config, f, indent=4)
476481

482+
def get_layer_types(self):
483+
"""
484+
Returns a list of attention types for each layer.
485+
Override in subclasses to provide layer-specific attention patterns.
486+
Returns None for models with uniform attention across all layers.
487+
"""
488+
return None
489+
490+
def use_alternating_kv_dimensions(self):
491+
"""
492+
Returns True if this model needs alternating KV cache dimension names.
493+
This is needed for models with alternating attention patterns when using TensorRT.
494+
"""
495+
# Enable for models with layer_types when using TensorRT EP
496+
if self.ep == "trt-rtx" and hasattr(self, 'get_layer_types'):
497+
layer_types = self.get_layer_types()
498+
return layer_types is not None
499+
return False
500+
477501
def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
478502
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=self.hf_remote, **extra_kwargs)
479503
print(f"Saving processing files in {out_dir} for GenAI")
@@ -650,17 +674,39 @@ def make_inputs_and_outputs(self):
650674

651675
# Add KV cache to inputs and outputs
652676
for i in range(self.num_layers):
677+
# Use alternating dimension names if needed (for TensorRT with alternating attention)
678+
if self.use_alternating_kv_dimensions():
679+
layer_types = self.get_layer_types()
680+
layer_type = layer_types[i] if layer_types and i < len(layer_types) else "full_attention"
681+
682+
# Use dimension name based on attention type
683+
if layer_type == "sliding_attention":
684+
dim_suffix = "_sliding"
685+
else: # "full_attention"
686+
dim_suffix = "_full"
687+
688+
past_key_shape = ["batch_size", self.num_kv_heads, f"past_sequence_length{dim_suffix}", self.head_size]
689+
past_value_shape = ["batch_size", self.num_kv_heads, f"past_sequence_length{dim_suffix}", self.head_size]
690+
present_key_shape = ["batch_size", self.num_kv_heads, f"total_sequence_length{dim_suffix}", self.head_size]
691+
present_value_shape = ["batch_size", self.num_kv_heads, f"total_sequence_length{dim_suffix}", self.head_size]
692+
else:
693+
# Use standard dimension names (current behavior)
694+
past_key_shape = self.input_shapes["past_key_values.key"]
695+
past_value_shape = self.input_shapes["past_key_values.value"]
696+
present_key_shape = self.output_shapes["present.key"]
697+
present_value_shape = self.output_shapes["present.value"]
698+
653699
# Add KV cache to inputs
654700
key_name = f"past_key_values.{i}.key"
655-
inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"]))
701+
inputs.append(self.make_value(key_name, dtype=self.input_types["past_key_values.key"], shape=past_key_shape))
656702
value_name = f"past_key_values.{i}.value"
657-
inputs.append(self.make_value(value_name, dtype=self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"]))
703+
inputs.append(self.make_value(value_name, dtype=self.input_types["past_key_values.value"], shape=past_value_shape))
658704

659705
# Add KV cache to outputs
660706
key_name = f"present.{i}.key"
661-
outputs.append(self.make_value(key_name, dtype=self.output_types["present.key"], shape=self.output_shapes["present.key"]))
707+
outputs.append(self.make_value(key_name, dtype=self.output_types["present.key"], shape=present_key_shape))
662708
value_name = f"present.{i}.value"
663-
outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=self.output_shapes["present.value"]))
709+
outputs.append(self.make_value(value_name, dtype=self.output_types["present.value"], shape=present_value_shape))
664710

665711
def make_constant(self, name):
666712
# Make constant ops for 0, 1, 2, 3, etc.
@@ -3455,6 +3501,20 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
34553501
super().make_attention(layer_id, attention, root_input, **kwargs)
34563502
self.window_size = original_window_size
34573503

3504+
def get_layer_types(self):
3505+
"""
3506+
Gemma2 uses alternating attention patterns:
3507+
- Even layers (0, 2, 4, ...): full_attention
3508+
- Odd layers (1, 3, 5, ...): sliding_attention
3509+
"""
3510+
layer_types = []
3511+
for layer_id in range(self.num_layers):
3512+
if self.is_local(layer_id):
3513+
layer_types.append("sliding_attention")
3514+
else:
3515+
layer_types.append("full_attention")
3516+
return layer_types
3517+
34583518

34593519
class Phi3MiniModel(MistralModel):
34603520
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):

0 commit comments

Comments
 (0)