@@ -175,23 +175,64 @@ DefaultKeyValueCache::DefaultKeyValueCache(State& state)
175175 }
176176
177177 // Set the size after empty_past_ has been created with 0 for this field
178+ // Check if we need to use per-layer allocation for models with alternating attention patterns
178179 if (state.model_ .p_device_ ->GetType () == DeviceType::NvTensorRtRtx &&
179180 model_.config_ ->model .decoder .sliding_window .has_value () &&
180- model_.config_ ->model .decoder .sliding_window ->window_size > 0 ) {
181+ model_.config_ ->model .decoder .sliding_window ->window_size > 0 &&
182+ !model_.config_ ->model .decoder .sliding_window ->layer_types .empty ()) {
183+ // Use per-layer allocation based on layer_types
184+ use_layer_types_ = true ;
185+ layer_shapes_.resize (layer_count_);
186+
187+ int sliding_window_size = model_.config_ ->model .decoder .sliding_window ->window_size ;
188+ int max_length = state_.params_ ->search .max_length ;
189+
190+ for (int layer_idx = 0 ; layer_idx < layer_count_; ++layer_idx) {
191+ layer_shapes_[layer_idx] = shape_; // Copy base shape
192+
193+ const std::string& layer_type = model_.config_ ->model .decoder .sliding_window ->layer_types [layer_idx];
194+ if (layer_type == " sliding_attention" ) {
195+ layer_shapes_[layer_idx][2 ] = std::min (max_length, sliding_window_size);
196+ } else { // "full_attention"
197+ layer_shapes_[layer_idx][2 ] = max_length;
198+ }
199+ }
200+ } else if (state.model_ .p_device_ ->GetType () == DeviceType::NvTensorRtRtx &&
201+ model_.config_ ->model .decoder .sliding_window .has_value () &&
202+ model_.config_ ->model .decoder .sliding_window ->window_size > 0 ) {
203+ // Uniform sliding window allocation (backward compatibility)
181204 shape_[2 ] = std::min (state_.params_ ->search .max_length ,
182205 model_.config_ ->model .decoder .sliding_window ->window_size );
183206 } else if (past_present_share_buffer_) {
184207 shape_[2 ] = state_.params_ ->search .max_length ;
185208 }
186209
187210 try {
188- for (int i = 0 ; i < layer_count_ * 2 ; ++i) {
189- presents_.push_back (OrtValue::CreateTensor (Allocator (), shape_, type_));
190-
191- // Zero the memory so we don't leak any data from the previous run
192- // WebGPU device has no Zero() implementation yet. Since this zeroing is optional we disable it for WebGPU for now
193- if (Device ().GetType () != DeviceType::WEBGPU) {
194- ByteWrapTensor (Device (), *presents_.back ()).Zero ();
211+ if (use_layer_types_) {
212+ // Allocate per-layer with different shapes
213+ for (int layer_idx = 0 ; layer_idx < layer_count_; ++layer_idx) {
214+ // Key tensor
215+ presents_.push_back (OrtValue::CreateTensor (Allocator (), layer_shapes_[layer_idx], type_));
216+ if (Device ().GetType () != DeviceType::WEBGPU) {
217+ ByteWrapTensor (Device (), *presents_.back ()).Zero ();
218+ }
219+
220+ // Value tensor
221+ presents_.push_back (OrtValue::CreateTensor (Allocator (), layer_shapes_[layer_idx], type_));
222+ if (Device ().GetType () != DeviceType::WEBGPU) {
223+ ByteWrapTensor (Device (), *presents_.back ()).Zero ();
224+ }
225+ }
226+ } else {
227+ // Uniform allocation (existing behavior)
228+ for (int i = 0 ; i < layer_count_ * 2 ; ++i) {
229+ presents_.push_back (OrtValue::CreateTensor (Allocator (), shape_, type_));
230+
231+ // Zero the memory so we don't leak any data from the previous run
232+ // WebGPU device has no Zero() implementation yet. Since this zeroing is optional we disable it for WebGPU for now
233+ if (Device ().GetType () != DeviceType::WEBGPU) {
234+ ByteWrapTensor (Device (), *presents_.back ()).Zero ();
235+ }
195236 }
196237 }
197238 } catch (const Ort::Exception&) {
@@ -240,10 +281,30 @@ void DefaultKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int total_le
240281 }
241282 }
242283
243- shape_[2 ] = total_length;
244- for (int i = 0 ; i < layer_count_ * 2 ; i++) {
245- presents_[i] = OrtValue::CreateTensor (Allocator (), shape_, type_);
246- state_.outputs_ [output_index_ + i] = presents_[i].get ();
284+ if (use_layer_types_) {
285+ // Update per-layer shapes based on total_length, but respect max allocations
286+ for (int layer_idx = 0 ; layer_idx < layer_count_; ++layer_idx) {
287+ int max_cache_length = static_cast <int >(layer_shapes_[layer_idx][2 ]);
288+ int actual_length = std::min (total_length, max_cache_length);
289+
290+ std::array<int64_t , 4 > current_shape = layer_shapes_[layer_idx];
291+ current_shape[2 ] = actual_length;
292+
293+ // Key tensor
294+ presents_[layer_idx * 2 ] = OrtValue::CreateTensor (Allocator (), current_shape, type_);
295+ state_.outputs_ [output_index_ + layer_idx * 2 ] = presents_[layer_idx * 2 ].get ();
296+
297+ // Value tensor
298+ presents_[layer_idx * 2 + 1 ] = OrtValue::CreateTensor (Allocator (), current_shape, type_);
299+ state_.outputs_ [output_index_ + layer_idx * 2 + 1 ] = presents_[layer_idx * 2 + 1 ].get ();
300+ }
301+ } else {
302+ // Uniform shape update (existing behavior)
303+ shape_[2 ] = total_length;
304+ for (int i = 0 ; i < layer_count_ * 2 ; i++) {
305+ presents_[i] = OrtValue::CreateTensor (Allocator (), shape_, type_);
306+ state_.outputs_ [output_index_ + i] = presents_[i].get ();
307+ }
247308 }
248309
249310 is_first_update_ = false ;
@@ -271,39 +332,90 @@ void DefaultKeyValueCache::RewindTo(size_t index) {
271332
272333template <typename T>
273334void DefaultKeyValueCache::RewindPastTensorsTo (size_t index) {
274- assert (index > 0 && shape_[2 ] >= static_cast <int64_t >(index) && !past_present_share_buffer_);
275- std::array<int64_t , 4 > new_shape = shape_;
276- new_shape[2 ] = static_cast <int >(index);
277- auto batch_x_num_heads = new_shape[0 ] * new_shape[1 ];
278- auto new_length_x_head_size = new_shape[2 ] * new_shape[3 ];
279- auto old_length_x_head_size = shape_[2 ] * new_shape[3 ];
280- shape_[2 ] = new_shape[2 ];
281-
282- for (int i = 0 ; i < layer_count_ * 2 ; i++) {
283- OrtValue& present = *presents_[i];
284- std::unique_ptr<OrtValue> past = OrtValue::CreateTensor (Allocator (), shape_, type_);
335+ assert (index > 0 && !past_present_share_buffer_);
336+
337+ if (use_layer_types_) {
338+ // Handle per-layer shapes
339+ for (int i = 0 ; i < layer_count_ * 2 ; i++) {
340+ int layer_idx = i / 2 ;
341+ std::array<int64_t , 4 > layer_shape = layer_shapes_[layer_idx];
342+ int max_cache_length = static_cast <int >(layer_shape[2 ]);
343+
344+ // Ensure we don't rewind beyond what's available
345+ if (static_cast <int >(index) > max_cache_length) {
346+ throw std::runtime_error (" Requested rewind length is greater than the layer's cache length." );
347+ }
348+
349+ std::array<int64_t , 4 > new_shape = layer_shape;
350+ new_shape[2 ] = static_cast <int >(index);
351+ auto batch_x_num_heads = new_shape[0 ] * new_shape[1 ];
352+ auto new_length_x_head_size = new_shape[2 ] * new_shape[3 ];
353+
354+ OrtValue& present = *presents_[i];
355+ auto present_shape = present.GetTensorTypeAndShapeInfo ()->GetShape ();
356+ auto old_length_x_head_size = present_shape[2 ] * new_shape[3 ];
357+
358+ std::unique_ptr<OrtValue> past = OrtValue::CreateTensor (Allocator (), new_shape, type_);
359+ auto past_span = WrapTensor<T>(Device (), *past);
360+ auto present_span = WrapTensor<T>(Device (), present);
361+
362+ for (int j = 0 ; j < batch_x_num_heads; j++) {
363+ auto present_data = present_span.subspan (j * old_length_x_head_size, new_length_x_head_size);
364+ auto past_data = past_span.subspan (j * new_length_x_head_size, new_length_x_head_size);
365+ past_data.CopyFrom (present_data);
366+ }
367+ pasts_[i] = std::move (past);
368+ state_.inputs_ [input_index_ + i] = pasts_[i].get ();
369+ }
370+ } else {
371+ // Uniform shape handling (existing behavior)
372+ assert (shape_[2 ] >= static_cast <int64_t >(index));
373+ std::array<int64_t , 4 > new_shape = shape_;
374+ new_shape[2 ] = static_cast <int >(index);
375+ auto batch_x_num_heads = new_shape[0 ] * new_shape[1 ];
376+ auto new_length_x_head_size = new_shape[2 ] * new_shape[3 ];
377+ auto old_length_x_head_size = shape_[2 ] * new_shape[3 ];
378+ shape_[2 ] = new_shape[2 ];
285379
286- auto past_span = WrapTensor<T>(Device (), *past);
287- auto present_span = WrapTensor<T>(Device (), present);
380+ for (int i = 0 ; i < layer_count_ * 2 ; i++) {
381+ OrtValue& present = *presents_[i];
382+ std::unique_ptr<OrtValue> past = OrtValue::CreateTensor (Allocator (), shape_, type_);
288383
289- for (int j = 0 ; j < batch_x_num_heads; j++) {
290- auto present_data = present_span.subspan (j * old_length_x_head_size, new_length_x_head_size);
291- auto past_data = past_span.subspan (j * new_length_x_head_size, new_length_x_head_size);
292- past_data.CopyFrom (present_data);
384+ auto past_span = WrapTensor<T>(Device (), *past);
385+ auto present_span = WrapTensor<T>(Device (), present);
386+
387+ for (int j = 0 ; j < batch_x_num_heads; j++) {
388+ auto present_data = present_span.subspan (j * old_length_x_head_size, new_length_x_head_size);
389+ auto past_data = past_span.subspan (j * new_length_x_head_size, new_length_x_head_size);
390+ past_data.CopyFrom (present_data);
391+ }
392+ pasts_[i] = std::move (past);
393+ state_.inputs_ [input_index_ + i] = pasts_[i].get ();
293394 }
294- pasts_[i] = std::move (past);
295- state_.inputs_ [input_index_ + i] = pasts_[i].get ();
296395 }
297396}
298397
299398// Copy present state to past state reordered by the beam_indices
300399template <typename ScoreType>
301400void DefaultKeyValueCache::PickPastState (DeviceSpan<int32_t > beam_indices_device, int index) {
302401 std::span<int32_t > beam_indices = beam_indices_device.CopyDeviceToCpu ();
303- auto block_size_per_beam = shape_[1 ] * shape_[2 ] * shape_[3 ];
402+
403+ std::array<int64_t , 4 > tensor_shape;
404+ if (use_layer_types_) {
405+ // Get shape from the actual tensor for per-layer allocation
406+ OrtValue& present_value = *presents_[index];
407+ auto present_shape = present_value.GetTensorTypeAndShapeInfo ()->GetShape ();
408+ for (size_t i = 0 ; i < 4 ; i++) {
409+ tensor_shape[i] = present_shape[i];
410+ }
411+ } else {
412+ tensor_shape = shape_;
413+ }
414+
415+ auto block_size_per_beam = tensor_shape[1 ] * tensor_shape[2 ] * tensor_shape[3 ];
304416
305417 OrtValue& present_value = *presents_[index];
306- std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator (), shape_ );
418+ std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(Allocator (), tensor_shape );
307419
308420 auto past_span = WrapTensor<ScoreType>(Device (), *past_value);
309421 auto present_span = WrapTensor<ScoreType>(Device (), present_value);
0 commit comments