Skip to content

Commit 43fa6ab

Browse files
authored
Allow some known extra inputs in the model (#1167)
1 parent 9143cfd commit 43fa6ab

File tree

5 files changed

+138
-2
lines changed

5 files changed

+138
-2
lines changed

src/models/extra_inputs.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,46 @@
55

66
namespace Generators {
77

8+
PresetExtraInputs::PresetExtraInputs(State& state)
9+
: state_(state),
10+
registry_{
11+
{"num_logits_to_keep", [&state = state_]() -> std::unique_ptr<OrtValue> {
12+
std::vector<int64_t> shape{1};
13+
auto num_logits_to_keep = OrtValue::CreateTensor<int64_t>(state.model_.allocator_cpu_, shape);
14+
*num_logits_to_keep->GetTensorMutableData<int64_t>() = 0;
15+
return num_logits_to_keep;
16+
}}} {}
17+
18+
void PresetExtraInputs::Add() {
19+
const auto input_names_vector = state_.model_.session_info_->GetInputNames();
20+
const std::unordered_set<std::string> input_names(state_.input_names_.begin(), state_.input_names_.end());
21+
std::vector<std::string> unclaimed_input_names;
22+
// Add any model input for which we don't have a corresponding input in the state to the unclaimed_input_names
23+
for (const auto& input_name : input_names_vector) {
24+
if (input_names.find(input_name) == input_names.end()) {
25+
unclaimed_input_names.push_back(input_name);
26+
}
27+
}
28+
29+
// Try to claim the unclaimed inputs from the registry
30+
for (const auto& input_name : unclaimed_input_names) {
31+
auto it = registry_.find(input_name);
32+
if (it != registry_.end()) {
33+
extra_input_names_.push_back(input_name);
34+
extra_inputs_.push_back(it->second());
35+
state_.input_names_.push_back(extra_input_names_.back().c_str());
36+
state_.inputs_.push_back(extra_inputs_.back().get());
37+
} else if (input_name.rfind("onnx::Neg_", 0) == 0) {
38+
// The unclaimed input has a prefix of onnx::Neg_, which is a special case
39+
// We treat this as an alias to num_logits_to_keep
40+
extra_input_names_.push_back(input_name);
41+
extra_inputs_.push_back(registry_.at("num_logits_to_keep")());
42+
state_.input_names_.push_back(extra_input_names_.back().c_str());
43+
state_.inputs_.push_back(extra_inputs_.back().get());
44+
}
45+
}
46+
}
47+
848
ExtraInputs::ExtraInputs(State& state)
949
: state_{state} {
1050
extra_inputs_.reserve(state_.params_->extra_inputs.size());
@@ -78,6 +118,8 @@ void ExtraInputs::Add() {
78118
throw std::runtime_error("Unsupported device for graph capture");
79119
}
80120
}
121+
122+
registrar_.Add();
81123
}
82124

83125
#pragma warning(pop)

src/models/extra_inputs.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55
namespace Generators {
66

7+
struct PresetExtraInputs {
8+
PresetExtraInputs(State& state);
9+
void Add();
10+
11+
private:
12+
using FuncType = std::function<std::unique_ptr<OrtValue>()>;
13+
State& state_;
14+
std::unordered_map<std::string, FuncType> registry_;
15+
std::vector<std::unique_ptr<OrtValue>> extra_inputs_;
16+
std::vector<std::string> extra_input_names_;
17+
};
18+
719
struct ExtraInputs {
820
ExtraInputs(State& state);
921
void Add();
@@ -14,6 +26,7 @@ struct ExtraInputs {
1426
std::vector<OrtValue*> extra_inputs_;
1527
std::vector<std::unique_ptr<OrtValue>> owned_extra_inputs_;
1628
std::unordered_map<std::string, StaticBuffer*> sb_extra_inputs_;
29+
PresetExtraInputs registrar_{state_};
1730
};
1831

1932
} // namespace Generators

src/models/model.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,14 @@ ONNXTensorElementDataType SessionInfo::GetOutputDataType(const std::string& name
284284
return result->second;
285285
}
286286

287+
std::vector<std::string> SessionInfo::GetInputNames() const {
288+
std::vector<std::string> names;
289+
names.reserve(inputs_.size());
290+
for (const auto& input : inputs_)
291+
names.push_back(input.first);
292+
return names;
293+
}
294+
287295
Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
288296
CreateSessionOptions();
289297
}

src/models/model.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ struct SessionInfo {
122122
ONNXTensorElementDataType GetInputDataType(const std::string& name) const;
123123
ONNXTensorElementDataType GetOutputDataType(const std::string& name) const;
124124

125+
std::vector<std::string> GetInputNames() const;
126+
125127
private:
126128
std::unordered_map<std::string, ONNXTensorElementDataType> inputs_, outputs_;
127129
};

test/python/test_onnxruntime_genai_api.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,12 +628,83 @@ def _export_adapter(adapter, adapter_file_name):
628628
params = og.GeneratorParams(model)
629629
params.set_search_options(max_length=20, batch_size=len(prompts))
630630

631-
print(len(adapter_paths))
632-
633631
generator = og.Generator(model, params)
634632
for i in range(len(adapter_paths)):
635633
generator.set_active_adapter(adapters, f"adapter_{i}")
636634

637635
generator.append_tokens(tokenizer.encode_batch(prompts))
638636
while not generator.is_done():
639637
generator.generate_next_token()
638+
639+
640+
@pytest.mark.parametrize("device", devices)
641+
@pytest.mark.skipif(
642+
sysconfig.get_platform().endswith("arm64"),
643+
reason="ONNX is not available on ARM64",
644+
)
645+
@pytest.mark.parametrize("extra_inputs", [("num_logits_to_keep", True), ("onnx::Neg_67", True), ("abcde", False)])
646+
def test_preset_extra_inputs(test_data_path, device, phi2_for, extra_inputs):
647+
def _prepare_model(test_data_path):
648+
phi2_model_path = phi2_for(device)
649+
relative_model_path = "preset_extra_inputs"
650+
extra_inputs_model_path = os.fspath(Path(test_data_path) / relative_model_path)
651+
652+
shutil.copytree(phi2_model_path, extra_inputs_model_path, dirs_exist_ok=True)
653+
654+
# Create the model with the extra inputs
655+
model = onnx.load(Path(extra_inputs_model_path) / "model.onnx")
656+
657+
for node in model.graph.node:
658+
if node.name == "/lm_head/Add":
659+
node.output[0] = "logits_0"
660+
break
661+
662+
extra_input_name, valid = extra_inputs
663+
extra_input = onnx.helper.make_tensor_value_info(
664+
extra_input_name,
665+
onnx.TensorProto.INT64,
666+
[],
667+
)
668+
669+
model.graph.input.append(extra_input)
670+
671+
cast_node = onnx.helper.make_node(
672+
"Cast", [extra_input_name], [f"{extra_input_name}_cast"], to=onnx.TensorProto.FLOAT if device == "cpu" else onnx.TensorProto.FLOAT16
673+
)
674+
add_node = onnx.helper.make_node(
675+
"Add", [f"{extra_input_name}_cast", "logits_0"], ["logits"], name="add_to_logits"
676+
)
677+
model.graph.node.extend([cast_node, add_node])
678+
679+
onnx.save(
680+
model,
681+
Path(extra_inputs_model_path) / "model.onnx",
682+
save_as_external_data=True,
683+
location="model.data",
684+
)
685+
686+
return extra_inputs_model_path, valid
687+
688+
model_path, valid_model = _prepare_model(test_data_path)
689+
model = og.Model(model_path)
690+
tokenizer = og.Tokenizer(model)
691+
prompts = [
692+
"This is a test.",
693+
"Rats are awesome pets!",
694+
"The quick brown fox jumps over the lazy dog.",
695+
]
696+
697+
params = og.GeneratorParams(model)
698+
params.set_search_options(max_length=20, batch_size=len(prompts))
699+
700+
generator = og.Generator(model, params)
701+
if not valid_model:
702+
with pytest.raises(og.OrtException) as exc_info:
703+
generator.append_tokens(tokenizer.encode_batch(prompts))
704+
705+
assert f"Missing Input: {extra_inputs[0]}" in str(exc_info.value)
706+
else:
707+
generator.append_tokens(tokenizer.encode_batch(prompts))
708+
709+
while not generator.is_done():
710+
generator.generate_next_token()

0 commit comments

Comments
 (0)