Skip to content

Commit 987a495

Browse files
committed
fix shape inference error for ep context nodes
1 parent 9de58ac commit 987a495

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

onnxruntime/core/session/ep_plugin_provider_interfaces.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,12 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
244244
/// Note that the EP plugin uses the model editor API to create the OrtNode instances.
245245
/// </summary>
246246
/// <param name="ep_name">Name of the plugin EP.</param>
247+
/// <param name="fused_nodes">fused nodes provided by ORT.</param>
247248
/// <param name="plugin_ep_context_nodes">EPContext nodes provided by the plugin EP.</param>
248249
/// <param name="result_nodes">Output parameter set to the resulting array of EPContext nodes.</param>
249250
/// <param name="result_node_args">Output parameter that stores the NodeArgs used by the EPContext nodes.</param>
250251
/// <returns>A status indicating success or an error.</returns>
251-
static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector<OrtNode*> plugin_ep_context_nodes,
252+
static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes, const std::vector<OrtNode*> plugin_ep_context_nodes,
252253
/*out*/ std::vector<std::unique_ptr<Node>>& result_nodes,
253254
/*out*/ std::vector<std::unique_ptr<NodeArg>>& result_node_args) {
254255
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
@@ -260,8 +261,10 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto
260261
std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_holder;
261262

262263
ep_context_nodes_holder.reserve(plugin_ep_context_nodes.size());
263-
264+
int index = -1;
264265
for (const OrtNode* ort_node : plugin_ep_context_nodes) {
266+
index = index + 1;
267+
auto& fused_node_filtered_graph = fused_nodes[index].filtered_graph;
265268
ORT_RETURN_IF_NOT(ort_node != nullptr, ep_name, ": OrtEp::Compile() returned a NULL EPContext node.");
266269

267270
const ModelEditorNode* editor_node = ModelEditorNode::ToInternal(ort_node);
@@ -276,13 +279,17 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto
276279
output_node_args.reserve(editor_node->output_names.size());
277280

278281
for (const std::string& input_name : editor_node->input_names) {
279-
auto node_arg = std::make_unique<NodeArg>(input_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
282+
auto node_arg_on_fused_graph = fused_node_filtered_graph.get().GetNodeArg(input_name);
283+
const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto() : nullptr;
284+
auto node_arg = std::make_unique<NodeArg>(input_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape infererence function avaliable.
280285
input_node_args.push_back(node_arg.get());
281286
ep_context_node_args_holder.push_back(std::move(node_arg));
282287
}
283288

284289
for (const std::string& output_name : editor_node->output_names) {
285-
auto node_arg = std::make_unique<NodeArg>(output_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
290+
auto node_arg_on_fused_graph = fused_node_filtered_graph.get().GetNodeArg(output_name);
291+
const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto() : nullptr;
292+
auto node_arg = std::make_unique<NodeArg>(output_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape infererence function avaliable.
286293
output_node_args.push_back(node_arg.get());
287294
ep_context_node_args_holder.push_back(std::move(node_arg));
288295
}
@@ -422,7 +429,7 @@ Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
422429
// We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph
423430
// partitioner via a call to IExecutionProvider::GetEpContextNodes().
424431
if (generate_ep_ctx_model_) {
425-
ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), plugin_ep_context_nodes,
432+
ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), fused_nodes_and_graphs, plugin_ep_context_nodes,
426433
/*out*/ ep_context_nodes_, /*out*/ ep_context_node_args_));
427434
}
428435

0 commit comments

Comments
 (0)