-
Notifications
You must be signed in to change notification settings - Fork 3.5k
ORT 1.23.2 cherrypick 1 #26347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release-1.23.2
Are you sure you want to change the base?
ORT 1.23.2 cherrypick 1 #26347
Conversation
…tion opt (#26103) ### Description This is an internal branch dupe of #25255 + some minor cosmetic changes to account for Copilot feedback ### Motivation and Context Improve performance of NCHW Conv - Both grouped convolutions and batched inputs should benefit from this change. For a detailed understanding of perf improvement, please refer to the numbers in #25255. Credit to @zoeczy and team for this improvement and code change --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Edward Chen <[email protected]>
### Description Fix a bug in the TRT Execution Provider where the DDS output tensor was not bound after an engine update. ### Motivation and Context The `dds_output_allocator_map` is not cleared on engine update, so that it will mis-recognized as a known DDS and will not bind the output allocation. Script to reproduce the issue: ```:python # create an onnx model with: # inputs: data -> NonZeros(data) -> GatherND -> output # then run the model with onnxruntime def create_model(): import onnx from onnx import helper, TensorProto input = helper.make_tensor_value_info("data", TensorProto.FLOAT, ["d1", "d2"]) output = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["nzr"]) nonzeros_node = helper.make_node("NonZero", ["data"], ["nonzeros"], "nonzeros_node") transpose_node = helper.make_node( "Transpose", ["nonzeros"], ["nonzeros_t"], "transpose_node" ) gathernd_node = helper.make_node( "GatherND", ["data", "nonzeros_t"], ["output"], "gathernd_node" ) value_info = [ helper.make_tensor_value_info("nonzeros", TensorProto.INT64, [2, "nzr"]), helper.make_tensor_value_info("nonzeros_t", TensorProto.INT64, ["nzr", 2]), ] graph = helper.make_graph( [nonzeros_node, transpose_node, gathernd_node], "test_graph", [input], [output], value_info=value_info, ) model = helper.make_model(graph) onnx.save(model, "model_dds.onnx") def run_model(): import onnxruntime as ort import numpy as np sess = ort.InferenceSession("model_dds.onnx", providers=["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]) print("Running with data shape (3,4)") data = np.random.randn(3, 4).astype(np.float32) sess.run(None, {"data": data}) print("Running with data shape (5,6)") data = np.random.randn(5, 6).astype(np.float32) sess.run(None, {"data": data}) create_model() run_model() ``` Before the change: > IExecutionContext::enqueueV3: Error Code 3: API Usage Error (Parameter check failed, condition: mContext.profileObliviousBindings.at(profileObliviousIndex) || getPtrOrNull(mOutputAllocators, profileObliviousIndex). Neither address or allocator is set for output tensor scores. Call setOutputTensorAddress, setTensorAddress or setOutputAllocator before enqueue/execute.) ... Status Message: TensorRT EP execution context enqueue failed.
## Description Fixes #26261 This PR resolves a regression introduced in v1.23.0 where models with Constant nodes containing tensors larger than 127 bytes fail to load with a shape inference error. ### Root Cause Commit 3b97d79 (PR #25320) introduced an optimization to convert large Constant node tensors (> 127 bytes) into OrtValues with in-memory external data references for better memory management. However, ONNX shape inference cannot distinguish between in-memory and file-based external data, and rejects any TensorProto with `data_location = EXTERNAL`. ### The Fix Modified `InferenceContextImpl::getInputData()` to: 1. Detect tensors with in-memory external data using `utils::HasExternalDataInMemory()` 2. Retrieve the corresponding OrtValue 3. Create a temporary TensorProto with embedded data (not external reference) 4. Provide this temporary proto to ONNX shape inference This allows ONNX shape inference to access the actual tensor data without rejecting it as external. ### Memory Impact This fix introduces a minor and temporary increase in memory usage during the model loading phase. - **When:** The additional memory is allocated only when the shape inference engine needs to access the data of a constant tensor that is larger than 127 bytes. This is a one-time event during the initial analysis of the model. - **What:** The fix creates a temporary in-memory copy of the tensor data. - **Duration:** This temporary copy is released as soon as shape inference is complete. The impact on the overall peak memory usage of the application is expected to be negligible. The memory usage during inference is not affected. While it is theoretically possible for the temporary tensor to be large if a multi-gigabyte constant tensor is used for shape inference, this is a highly unlikely scenario in practice for well-designed models. ### Testing - Tested with the problematic model from issue #26261 - All optimization levels now work correctly (DISABLE_ALL, BASIC, EXTENDED, ALL) - Unit tests to be added ### Changes - **onnxruntime/core/graph/graph.cc**: - Modified `getInputData()` method in `InferenceContextImpl` class - Added `temp_tensor_protos_` member to store temporary TensorProtos during shape inference ## TODO - [ ] Add unit tests - [ ] Run full test suite --------- Co-authored-by: Dmitri Smirnov <[email protected]>
Users with RTX 5090 GPUs are experiencing runtime errors when using onnxruntime-gpu: ``` [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Slice node. Name:'Slice_34' Status Message: CUDA error cudaErrorNoKernelImageForDevice: no kernel image is available for execution on the device ``` This occurs because RTX 5090 uses CUDA compute architecture 12.0 (SM 12.0). The incompatibility of `onnxruntime-gpu` 1.23 was built with `90a-virtual`. The `90a` architecture is a specialized, non-forward-compatible version of the Hopper architecture, making it incompatible with future GPU generations like Blackwell. This change will revert `90a-virtual` back to `90-virtual` as used in 1.22. This shall bring back the compatibility in Blackwell GPU. The FPA_INTB_GEMM is disabled by default. It need some extra work to make it compatible with 90-virtual and no 90a-real use case. Related: #26002 #26226 #26181
### Description Fix logic flow bug where rpc polling interval is set to 9999 when perf performance is NOT burst. The interval should be set to 9999 when the perf performance is burst ### Motivation and Context Co-authored-by: quic_calvnguy <quic_calvnguy@quic_inc.com>
Update operator spec to support block quantization in qMoE. Implementation will come later.
### Description Add new API to VitisAI to save graph as a string ### Motivation and Context to support in-memory flow --------- Co-authored-by: yifei <[email protected]>
@@ -0,0 +1,26 @@ | |||
import onnx |
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note test
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 3 days ago
To fix the problem, we should remove from onnx import TensorProto, helper
, and reference these items via the already-imported onnx
module as onnx.TensorProto
and onnx.helper
. This fixes the double import. Specifically, wherever TensorProto
or helper
is used, we need to replace them with onnx.TensorProto
and onnx.helper
, respectively. No functionality will change, but the code will be consistent and clearer. The only changes are in this file: remove line 2, and update all uses of TensorProto
and helper
to their fully qualified names.
-
Copy modified lines R4-R5 -
Copy modified lines R7-R9 -
Copy modified lines R12-R13 -
Copy modified line R16 -
Copy modified line R17
@@ -1,20 +1,19 @@ | ||
import onnx | ||
from onnx import TensorProto, helper | ||
|
||
# Create a simple ONNX model with DDS output | ||
input = helper.make_tensor_value_info("data", TensorProto.FLOAT, ["d1", "d2"]) | ||
output = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["nzr"]) | ||
input = onnx.helper.make_tensor_value_info("data", onnx.TensorProto.FLOAT, ["d1", "d2"]) | ||
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, ["nzr"]) | ||
|
||
nonzeros_node = helper.make_node("NonZero", ["data"], ["nonzeros"], "nonzeros_node") | ||
transpose_node = helper.make_node("Transpose", ["nonzeros"], ["nonzeros_t"], "transpose_node") | ||
gathernd_node = helper.make_node("GatherND", ["data", "nonzeros_t"], ["output"], "gathernd_node") | ||
nonzeros_node = onnx.helper.make_node("NonZero", ["data"], ["nonzeros"], "nonzeros_node") | ||
transpose_node = onnx.helper.make_node("Transpose", ["nonzeros"], ["nonzeros_t"], "transpose_node") | ||
gathernd_node = onnx.helper.make_node("GatherND", ["data", "nonzeros_t"], ["output"], "gathernd_node") | ||
|
||
value_info = [ | ||
helper.make_tensor_value_info("nonzeros", TensorProto.INT64, [2, "nzr"]), | ||
helper.make_tensor_value_info("nonzeros_t", TensorProto.INT64, ["nzr", 2]), | ||
onnx.helper.make_tensor_value_info("nonzeros", onnx.TensorProto.INT64, [2, "nzr"]), | ||
onnx.helper.make_tensor_value_info("nonzeros_t", onnx.TensorProto.INT64, ["nzr", 2]), | ||
] | ||
|
||
graph = helper.make_graph( | ||
graph = onnx.helper.make_graph( | ||
[nonzeros_node, transpose_node, gathernd_node], | ||
"test_graph", | ||
[input], | ||
@@ -22,5 +14,5 @@ | ||
value_info=value_info, | ||
) | ||
|
||
model = helper.make_model(graph) | ||
model = onnx.helper.make_model(graph) | ||
onnx.save(model, "ort_github_issue_26272_dds.onnx") |
Adds the following commits to the release-1.23.2 branch for ORT 1.23.2: