|
6 | 6 | #include "kv_cache.h" |
7 | 7 | #include "windowed_kv_cache.h" |
8 | 8 | #include "../openvino/interface.h" |
| 9 | +#include <algorithm> |
9 | 10 |
|
10 | 11 | namespace Generators { |
11 | 12 |
|
@@ -175,21 +176,49 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state) |
175 | 176 | } |
176 | 177 |
|
177 | 178 | // Set the size after empty_past_ has been created with 0 for this field |
178 | | - if (state.model_.p_device_->GetType() == DeviceType::NvTensorRtRtx && |
179 | | - model_.config_->model.decoder.sliding_window.has_value() && |
| 179 | + if (model_.config_->model.decoder.sliding_window.has_value() && |
180 | 180 | model_.config_->model.decoder.sliding_window->window_size > 0) { |
181 | | - shape_[2] = std::min(state_.params_->search.max_length, |
182 | | - model_.config_->model.decoder.sliding_window->window_size); |
| 181 | + const int sliding_window_size = model_.config_->model.decoder.sliding_window->window_size; |
| 182 | + const int max_length = state_.params_->search.max_length; |
| 183 | + |
| 184 | + // Check if we need per-layer allocation for models with alternating attention patterns |
| 185 | + if (!model_.config_->model.decoder.sliding_window->layers.empty()) { |
| 186 | + // Use per-layer allocation based on sliding window layer indices |
| 187 | + layer_shapes_.resize(layer_count_); |
| 188 | + |
| 189 | + // Initialize all layers with base shape and max_length |
| 190 | + for (int layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { |
| 191 | + layer_shapes_[layer_idx] = shape_; |
| 192 | + layer_shapes_[layer_idx][2] = max_length; |
| 193 | + } |
| 194 | + |
| 195 | + // Update sliding window layers with constrained cache size |
| 196 | + for (int layer_idx : model_.config_->model.decoder.sliding_window->layers) { |
| 197 | + layer_shapes_[layer_idx][2] = std::min(max_length, sliding_window_size); |
| 198 | + } |
| 199 | + // Set shape_[2] to max of all layer shapes for RewindTo bounds checking |
| 200 | + shape_[2] = max_length; |
| 201 | + } else { |
| 202 | + // Uniform sliding window allocation (backward compatibility) |
| 203 | + shape_[2] = std::min(max_length, sliding_window_size); |
| 204 | + } |
183 | 205 | } else if (past_present_share_buffer_) { |
184 | 206 | shape_[2] = state_.params_->search.max_length; |
185 | 207 | } |
186 | 208 |
|
187 | 209 | try { |
| 210 | + // Allocate KV cache tensors - 2 per layer (key and value) |
| 211 | + // For per-layer shapes: alternates between key and value for each layer |
| 212 | + // For uniform shape: all tensors use the same shape |
188 | 213 | for (int i = 0; i < layer_count_ * 2; ++i) { |
189 | | - presents_.push_back(OrtValue::CreateTensor(Allocator(), shape_, type_)); |
| 214 | + std::array<int64_t, 4> tensor_shape = shape_; |
| 215 | + if (!layer_shapes_.empty()) { |
| 216 | + // Per-layer allocation: use layer-specific shape |
| 217 | + // i/2 gives us the layer index since we have 2 tensors per layer |
| 218 | + tensor_shape = layer_shapes_[i / 2]; |
| 219 | + } |
190 | 220 |
|
191 | | - // Zero the memory so we don't leak any data from the previous run |
192 | | - // WebGPU device has no Zero() implementation yet. Since this zeroing is optional we disable it for WebGPU for now |
| 221 | + presents_.push_back(OrtValue::CreateTensor(Allocator(), tensor_shape, type_)); |
193 | 222 | if (Device().GetType() != DeviceType::WEBGPU) { |
194 | 223 | ByteWrapTensor(Device(), *presents_.back()).Zero(); |
195 | 224 | } |
@@ -240,10 +269,30 @@ void DefaultKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int total_le |
240 | 269 | } |
241 | 270 | } |
242 | 271 |
|
243 | | - shape_[2] = total_length; |
244 | | - for (int i = 0; i < layer_count_ * 2; i++) { |
245 | | - presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_); |
246 | | - state_.outputs_[output_index_ + i] = presents_[i].get(); |
| 272 | + if (!layer_shapes_.empty()) { |
| 273 | + // Update per-layer shapes based on total_length, but respect max allocations |
| 274 | + for (int layer_idx = 0; layer_idx < layer_count_; ++layer_idx) { |
| 275 | + const int max_cache_length = static_cast<int>(layer_shapes_[layer_idx][2]); |
| 276 | + const int actual_length = std::min(total_length, max_cache_length); |
| 277 | + |
| 278 | + std::array<int64_t, 4> current_shape = layer_shapes_[layer_idx]; |
| 279 | + current_shape[2] = actual_length; |
| 280 | + |
| 281 | + // Key tensor |
| 282 | + presents_[layer_idx * 2] = OrtValue::CreateTensor(Allocator(), current_shape, type_); |
| 283 | + state_.outputs_[output_index_ + layer_idx * 2] = presents_[layer_idx * 2].get(); |
| 284 | + |
| 285 | + // Value tensor |
| 286 | + presents_[layer_idx * 2 + 1] = OrtValue::CreateTensor(Allocator(), current_shape, type_); |
| 287 | + state_.outputs_[output_index_ + layer_idx * 2 + 1] = presents_[layer_idx * 2 + 1].get(); |
| 288 | + } |
| 289 | + } else { |
| 290 | + // Uniform shape update (existing behavior) |
| 291 | + shape_[2] = total_length; |
| 292 | + for (int i = 0; i < layer_count_ * 2; i++) { |
| 293 | + presents_[i] = OrtValue::CreateTensor(Allocator(), shape_, type_); |
| 294 | + state_.outputs_[output_index_ + i] = presents_[i].get(); |
| 295 | + } |
247 | 296 | } |
248 | 297 |
|
249 | 298 | is_first_update_ = false; |
@@ -271,39 +320,94 @@ void DefaultKeyValueCache::RewindTo(size_t index) { |
271 | 320 |
|
272 | 321 | template <typename T> |
273 | 322 | void DefaultKeyValueCache::RewindPastTensorsTo(size_t index) { |
274 | | - assert(index > 0 && shape_[2] >= static_cast<int64_t>(index) && !past_present_share_buffer_); |
275 | | - std::array<int64_t, 4> new_shape = shape_; |
276 | | - new_shape[2] = static_cast<int>(index); |
277 | | - auto batch_x_num_heads = new_shape[0] * new_shape[1]; |
278 | | - auto new_length_x_head_size = new_shape[2] * new_shape[3]; |
279 | | - auto old_length_x_head_size = shape_[2] * new_shape[3]; |
280 | | - shape_[2] = new_shape[2]; |
281 | | - |
282 | | - for (int i = 0; i < layer_count_ * 2; i++) { |
283 | | - OrtValue& present = *presents_[i]; |
284 | | - std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), shape_, type_); |
| 323 | + assert(index > 0 && !past_present_share_buffer_); |
| 324 | + |
| 325 | + if (!layer_shapes_.empty()) { |
| 326 | + // Handle per-layer shapes |
| 327 | + // First validate that index doesn't exceed the global max_length |
| 328 | + int max_length = static_cast<int>(shape_[2]); // Set to max_length in constructor |
| 329 | + if (static_cast<int>(index) > max_length) { |
| 330 | + throw std::runtime_error("Requested rewind length exceeds max_length."); |
| 331 | + } |
285 | 332 |
|
286 | | - auto past_span = WrapTensor<T>(Device(), *past); |
287 | | - auto present_span = WrapTensor<T>(Device(), present); |
| 333 | + for (int i = 0; i < layer_count_ * 2; i++) { |
| 334 | + const int layer_idx = i / 2; |
| 335 | + const std::array<int64_t, 4> layer_shape = layer_shapes_[layer_idx]; |
| 336 | + const int layer_max_cache = static_cast<int>(layer_shape[2]); |
| 337 | + |
| 338 | + // For each layer, rewind to min(index, layer's max capacity) |
| 339 | + // - Full attention layers: min(index, max_length) |
| 340 | + // - Sliding window layers: min(index, sliding_window_size) |
| 341 | + const int actual_rewind_length = std::min(static_cast<int>(index), layer_max_cache); |
| 342 | + |
| 343 | + std::array<int64_t, 4> new_shape = layer_shape; |
| 344 | + new_shape[2] = actual_rewind_length; |
| 345 | + const auto batch_x_num_heads = new_shape[0] * new_shape[1]; |
| 346 | + const auto new_length_x_head_size = new_shape[2] * new_shape[3]; |
| 347 | + |
| 348 | + OrtValue& present = *presents_[i]; |
| 349 | + const auto present_shape = present.GetTensorTypeAndShapeInfo()->GetShape(); |
| 350 | + const auto old_length_x_head_size = present_shape[2] * new_shape[3]; |
| 351 | + |
| 352 | + std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), new_shape, type_); |
| 353 | + auto past_span = WrapTensor<T>(Device(), *past); |
| 354 | + auto present_span = WrapTensor<T>(Device(), present); |
| 355 | + |
| 356 | + for (int j = 0; j < batch_x_num_heads; j++) { |
| 357 | + auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size); |
| 358 | + auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size); |
| 359 | + past_data.CopyFrom(present_data); |
| 360 | + } |
| 361 | + pasts_[i] = std::move(past); |
| 362 | + state_.inputs_[input_index_ + i] = pasts_[i].get(); |
| 363 | + } |
| 364 | + } else { |
| 365 | + // Uniform shape handling (existing behavior) |
| 366 | + assert(shape_[2] >= static_cast<int64_t>(index)); |
| 367 | + std::array<int64_t, 4> new_shape = shape_; |
| 368 | + new_shape[2] = static_cast<int>(index); |
| 369 | + auto batch_x_num_heads = new_shape[0] * new_shape[1]; |
| 370 | + auto new_length_x_head_size = new_shape[2] * new_shape[3]; |
| 371 | + auto old_length_x_head_size = shape_[2] * new_shape[3]; |
| 372 | + shape_[2] = new_shape[2]; |
288 | 373 |
|
289 | | - for (int j = 0; j < batch_x_num_heads; j++) { |
290 | | - auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size); |
291 | | - auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size); |
292 | | - past_data.CopyFrom(present_data); |
| 374 | + for (int i = 0; i < layer_count_ * 2; i++) { |
| 375 | + OrtValue& present = *presents_[i]; |
| 376 | + std::unique_ptr<OrtValue> past = OrtValue::CreateTensor(Allocator(), shape_, type_); |
| 377 | + |
| 378 | + auto past_span = WrapTensor<T>(Device(), *past); |
| 379 | + auto present_span = WrapTensor<T>(Device(), present); |
| 380 | + |
| 381 | + for (int j = 0; j < batch_x_num_heads; j++) { |
| 382 | + auto present_data = present_span.subspan(j * old_length_x_head_size, new_length_x_head_size); |
| 383 | + auto past_data = past_span.subspan(j * new_length_x_head_size, new_length_x_head_size); |
| 384 | + past_data.CopyFrom(present_data); |
| 385 | + } |
| 386 | + pasts_[i] = std::move(past); |
| 387 | + state_.inputs_[input_index_ + i] = pasts_[i].get(); |
293 | 388 | } |
294 | | - pasts_[i] = std::move(past); |
295 | | - state_.inputs_[input_index_ + i] = pasts_[i].get(); |
296 | 389 | } |
297 | 390 | } |
298 | 391 |
|
299 | 392 | // Copy present state to past state reordered by the beam_indices |
300 | 393 | template <typename ScoreType> |
301 | 394 | void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_device, int index) { |
302 | 395 | std::span<int32_t> beam_indices = beam_indices_device.CopyDeviceToCpu(); |
303 | | - auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3]; |
| 396 | + |
| 397 | + std::array<int64_t, 4> tensor_shape; |
| 398 | + if (!layer_shapes_.empty()) { |
| 399 | + // Get shape from the actual tensor for per-layer allocation |
| 400 | + OrtValue& present_value = *presents_[index]; |
| 401 | + const auto present_shape = present_value.GetTensorTypeAndShapeInfo()->GetShape(); |
| 402 | + std::copy(present_shape.begin(), present_shape.end(), tensor_shape.begin()); |
| 403 | + } else { |
| 404 | + tensor_shape = shape_; |
| 405 | + } |
| 406 | + |
| 407 | + auto block_size_per_beam = tensor_shape[1] * tensor_shape[2] * tensor_shape[3]; |
304 | 408 |
|
305 | 409 | OrtValue& present_value = *presents_[index]; |
306 | | - std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator(), shape_); |
| 410 | + std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator(), tensor_shape); |
307 | 411 |
|
308 | 412 | auto past_span = WrapTensor<ScoreType>(Device(), *past_value); |
309 | 413 | auto present_span = WrapTensor<ScoreType>(Device(), present_value); |
|
0 commit comments