Skip to content

Commit b0624a2

Browse files
committed
Fix batch size issues
1 parent 58fedf7 commit b0624a2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/models/input_ids.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Generators {
88
DefaultInputIDs::DefaultInputIDs(State& state)
99
: state_{state} {
1010
name_ = model_.config_->model.decoder.inputs.input_ids.c_str();
11-
shape_ = {state_.params_->BatchBeamSize(), 0};
11+
shape_ = {state_.params_->search.batch_size, 0};
1212
type_ = model_.session_info_->GetInputDataType(name_);
1313

1414
if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.current_sequence_length) &&
@@ -47,7 +47,7 @@ void DefaultInputIDs::Add() {
4747

4848
void DefaultInputIDs::Update(DeviceSpan<int32_t>& new_tokens) {
4949
if (!value_) {
50-
shape_[1] = static_cast<int64_t>(new_tokens.size());
50+
shape_[1] = static_cast<int64_t>(new_tokens.size()) / shape_[0];
5151

5252
// If 64-bit, convert from 32-bit to 64-bit
5353
auto input_ids = new_tokens.CopyDeviceToCpu();

0 commit comments

Comments
 (0)