Skip to content

Commit 6833733

Browse files
committed
recompute KV cache for Phi3 when switching from short to long factor
1 parent 7735e10 commit 6833733

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

src/generators.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,19 @@ void Generator::GenerateNextToken() {
352352
ThrowErrorIfSessionTerminated(state_->session_terminated_);
353353
if (search_->GetSequenceLength() == 0 && !computed_logits_)
354354
throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or params.SetInputs before calling GenerateNextToken.");
355+
356+
// TODO: Extend the solution to make it work for batch size > 1 and num beams > 1
357+
// Phi3 model switches from short factor to long factor at 4097 (original_max_position_embeddings+1) token, needs Recomputation of Position IDs and KV Cache
358+
// at this stage which is achieved by rewinding to zero and appending the current sequence
359+
if (model_->config_->model.type == "phi3" && search_->params_->search.batch_size == 1 && params.search.num_beams == 1) {
360+
if (search_->GetSequenceLength() == 4097 && first_switch) {
361+
first_switch = false;
362+
auto current_seq = cpu_span<int32_t>(GetSequence(0).CpuSpan());
363+
RewindToLength(0);
364+
AppendTokens(current_seq);
365+
}
366+
}
367+
355368
if (!computed_logits_) {
356369
auto next_tokens = search_->GetNextTokens();
357370
if (last_action_ == Action::rewound)

src/generators.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ struct Generator : LeakChecked<Generator> {
125125
std::unique_ptr<State> state_;
126126
std::unique_ptr<Search> search_;
127127
bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio
128+
bool first_switch{true};
128129

129130
private:
130131
DeviceSpan<int32_t> AllocateInputIdsOnDevice(const cpu_span<int32_t> input_ids);

0 commit comments

Comments
 (0)