Skip to content

Commit 74ec1be

Browse files
authored
Refactor past_present_share_buffer logic into reusable function (#1839)
- Add IsPastPresentShareBufferEnabled() method to GeneratorParams - Consolidate logic for determining if past_present_share_buffer should be enabled - Replace inline conditions in kv_cache.cpp and position_inputs.cpp with function call - Improves maintainability by providing single source of truth - Logic: enabled only when config is true AND (num_beams == 1 OR model is Whisper)
1 parent 0cf5c73 commit 74ec1be

File tree

4 files changed

+14
-2
lines changed

4 files changed

+14
-2
lines changed

src/generators.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ void GeneratorParams::SetGuidance(std::string_view type, std::string_view data)
268268
guidance_data = data;
269269
}
270270

271+
bool GeneratorParams::IsPastPresentShareBufferEnabled(const std::string& model_type) const {
272+
// past_present_share_buffer is only actually enabled when:
273+
// 1. The config option is set to true, AND
274+
// 2. Either num_beams == 1 OR the model is Whisper
275+
return search.past_present_share_buffer &&
276+
(search.num_beams == 1 || model_type == "whisper");
277+
}
278+
271279
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params) {
272280
return std::make_unique<Generator>(model, params);
273281
}

src/generators.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChec
8484
std::string guidance_type; // e.g. json_schema or regex
8585
std::string guidance_data; // e.g. rules data in json_schema or regex
8686
void SetGuidance(std::string_view type, std::string_view data);
87+
88+
// Determines if past_present_share_buffer is actually enabled based on config and runtime conditions
89+
// Returns true only if config option is true AND (num_beams == 1 OR model is Whisper)
90+
bool IsPastPresentShareBufferEnabled(const std::string& model_type) const;
8791
};
8892

8993
struct Generator : LeakChecked<Generator> {

src/models/kv_cache.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ void CombinedKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices, int
149149
DefaultKeyValueCache::DefaultKeyValueCache(State& state)
150150
: state_{state},
151151
layer_count_{model_.config_->model.decoder.num_hidden_layers},
152-
past_present_share_buffer_{state_.params_->search.past_present_share_buffer && (state_.params_->search.num_beams == 1 || model_.config_->model.type == "whisper")},
152+
past_present_share_buffer_{state_.params_->IsPastPresentShareBufferEnabled(model_.config_->model.type)},
153153
shape_{state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, model_.config_->model.decoder.head_size} {
154154
if (g_log.enabled && g_log.warning && past_present_share_buffer_ != state_.params_->search.past_present_share_buffer)
155155
Log("warning", "past_present_share_buffer search option set to true, but has been disabled due to the current configuration. See https://aka.ms/generate_config for details");

src/models/position_inputs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ void DefaultPositionInputs::RewindMask(size_t index) {
327327

328328
bool DefaultPositionInputs::ShouldUseStaticMaskHandling() const {
329329
return state_.params_->use_graph_capture ||
330-
(state_.params_->search.past_present_share_buffer &&
330+
(state_.params_->IsPastPresentShareBufferEnabled(model_.config_->model.type) &&
331331
model_.p_device_->GetType() == DeviceType::NvTensorRtRtx);
332332
}
333333

0 commit comments

Comments
 (0)