Skip to content

Commit ce44bbc

Browse files
Cherry pick guidance fix into 0.11.1 release (#1872)
### Description This PR cherry-picks the guidance fix PR into rel-0.11.1. ### Motivation and Context This cherry-pick needs to be included for the 0.11.1 patch release.
1 parent 495bfac commit ce44bbc

File tree

7 files changed

+13
-6
lines changed

7 files changed

+13
-6
lines changed

src/constrained_logits_processor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ void GuidanceLogitsProcessor::ResetWithoutCompute() {
229229
}
230230
llg_constraints_[i] = std::unique_ptr<LlgConstraint, LlgConstraintDeleter>(constraint_ptr);
231231
}
232+
for (int i = 0; i < ff_tokens_batch_.size(); i++) {
233+
ff_tokens_batch_[i].clear();
234+
}
232235
}
233236

234237
// Reset the masks and llguidance constraints and then recompute the mask

src/generators.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ void Generator::SetRuntimeOption(const char* key, const char* value) {
444444
state_->SetRunOption(key, value);
445445
}
446446

447-
bool Generator::IsDone() const {
447+
bool Generator::IsDone() {
448448
ThrowErrorIfSessionTerminated(state_->session_terminated_);
449449
if (computed_logits_) {
450450
return false;
@@ -453,6 +453,10 @@ bool Generator::IsDone() const {
453453
bool is_done = search_->IsDone();
454454
if (is_done) {
455455
state_->Finalize(search_->GetSequenceLength());
456+
if (guidance_logits_processor_) {
457+
guidance_logits_processor_->Reset();
458+
last_action_ = Action::standard;
459+
}
456460
}
457461

458462
return is_done;

src/generators.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChec
9494
struct Generator : LeakChecked<Generator> {
9595
Generator(const Model& model, const GeneratorParams& params);
9696

97-
bool IsDone() const;
97+
bool IsDone();
9898
void AppendTokens(cpu_span<const int32_t> input_ids);
9999
void GenerateNextToken();
100100
void RewindToLength(size_t new_length); // Rewind state to new_length

src/ort_genai.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ struct OgaGenerator : OgaAbstract {
444444
return std::unique_ptr<OgaGenerator>(p);
445445
}
446446

447-
bool IsDone() const {
447+
bool IsDone() {
448448
return OgaGenerator_IsDone(this);
449449
}
450450

src/ort_genai_c.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* p
409409
OGA_CATCH
410410
}
411411

412-
bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator) {
412+
bool OGA_API_CALL OgaGenerator_IsDone(OgaGenerator* generator) {
413413
return generator->IsDone();
414414
}
415415

src/ort_genai_c.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator);
445445
* \param[in] generator The generator to check if it is done with generating all sequences.
446446
* \return True if the generator has finished generating all the sequences, false otherwise.
447447
*/
448-
OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);
448+
OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(OgaGenerator* generator);
449449
OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator);
450450

451451
/**

src/python/python.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ struct PyGenerator {
256256
generator_->RewindTo(new_length);
257257
}
258258

259-
bool IsDone() const {
259+
bool IsDone() {
260260
return generator_->IsDone();
261261
}
262262

0 commit comments

Comments
 (0)