File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -8,7 +8,7 @@ namespace Generators {
88DefaultInputIDs::DefaultInputIDs (State& state)
99 : state_{state} {
1010 name_ = model_.config_ ->model .decoder .inputs .input_ids .c_str ();
11- shape_ = {state_.params_ ->BatchBeamSize () , 0 };
11+ shape_ = {state_.params_ ->search . batch_size , 0 };
1212 type_ = model_.session_info_ ->GetInputDataType (name_);
1313
1414 if (model_.session_info_ ->HasInput (model_.config_ ->model .decoder .inputs .current_sequence_length ) &&
@@ -47,7 +47,7 @@ void DefaultInputIDs::Add() {
4747
4848void DefaultInputIDs::Update (DeviceSpan<int32_t >& new_tokens) {
4949 if (!value_) {
50- shape_[1 ] = static_cast <int64_t >(new_tokens.size ());
50+ shape_[1 ] = static_cast <int64_t >(new_tokens.size ()) / shape_[ 0 ] ;
5151
5252 // If 64-bit, convert from 32-bit to 64-bit
5353 auto input_ids = new_tokens.CopyDeviceToCpu ();
You can’t perform that action at this time.
0 commit comments