Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Dialect/ONNX/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ add_onnx_mlir_library(OMONNXRewrite
LegalizeQuarkQuantizedOps.cpp
QuantTypes.cpp
ConvertToChannelLast.cpp
ResultNamesUpdater.cpp

DEPENDS
OMONNXDecomposeIncGen
Expand Down
5 changes: 4 additions & 1 deletion src/Dialect/ONNX/Transforms/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp"
#include "src/Dialect/ONNX/Transforms/ConstProp.hpp"
#include "src/Dialect/ONNX/Transforms/ResultNamesUpdater.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Support/TypeUtilities.hpp"

Expand Down Expand Up @@ -1270,7 +1271,9 @@ void ConstPropONNXToONNXPass::runOnOperation() {

RewritePatternSet patterns(context);
getConstPropONNXToONNXPatterns(patterns);
if (failed(applyPatternsGreedily(function, std::move(patterns))))
onnx_mlir::ResultNamesUpdater rnUpdater;
if (failed(applyPatternsGreedily(function, std::move(patterns),
GreedyRewriteConfig{.listener = &rnUpdater})))
signalPassFailure();
}

Expand Down
5 changes: 4 additions & 1 deletion src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/Transforms/Decompose.hpp"
#include "src/Dialect/ONNX/Transforms/DecomposeEinsum.hpp"
#include "src/Dialect/ONNX/Transforms/ResultNamesUpdater.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Support/TypeUtilities.hpp"
#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -4097,7 +4098,9 @@ void DecomposeONNXToONNXPass::runOnOperation() {
}
#endif

if (failed(applyPatternsGreedily(function, std::move(patterns))))
onnx_mlir::ResultNamesUpdater rnUpdater;
if (failed(applyPatternsGreedily(function, std::move(patterns),
GreedyRewriteConfig{.listener = &rnUpdater})))
signalPassFailure();
}

Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "src/Dialect/ONNX/Transforms/Decompose.hpp"
#include "src/Dialect/ONNX/Transforms/LegalizeQuarkQuantizedOps.hpp"
#include "src/Dialect/ONNX/Transforms/Recompose.hpp"
#include "src/Dialect/ONNX/Transforms/ResultNamesUpdater.hpp"
#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp"
#include "src/Interface/ShapeInferenceOpInterface.hpp"
#include "src/Pass/Passes.hpp"
Expand Down Expand Up @@ -212,6 +213,8 @@ struct ONNXHybridTransformPass
Region &body = f.getBody();

GreedyRewriteConfig config;
ResultNamesUpdater rnUpdater;
config.listener = &rnUpdater;
config.useTopDownTraversal = true;
if (maxNumRewritesOffset == -1) {
config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
Expand Down
5 changes: 4 additions & 1 deletion src/Dialect/ONNX/Transforms/QDQCanonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Dialect/ONNX/Transforms/ResultNamesUpdater.hpp"

using namespace mlir;

Expand Down Expand Up @@ -69,7 +70,9 @@ class QDQCanonicalizePass
}

void runOnOperation() override {
if (failed(applyPatternsGreedily(getOperation(), frozenPatterns)))
onnx_mlir::ResultNamesUpdater rnUpdater;
if (failed(applyPatternsGreedily(getOperation(), frozenPatterns,
GreedyRewriteConfig{.listener = &rnUpdater})))
signalPassFailure();
}

Expand Down
5 changes: 4 additions & 1 deletion src/Dialect/ONNX/Transforms/Recompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/Transforms/Recompose.hpp"
#include "src/Dialect/ONNX/Transforms/ResultNamesUpdater.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Support/TypeUtilities.hpp"

Expand Down Expand Up @@ -1522,7 +1523,9 @@ void RecomposeONNXToONNXPass::runOnOperation() {
onnx_mlir::getRecomposeONNXToONNXPatterns(
patterns, recomposeLayernormByTranspose);

if (failed(applyPatternsGreedily(function, std::move(patterns))))
onnx_mlir::ResultNamesUpdater rnUpdater;
if (failed(applyPatternsGreedily(function, std::move(patterns),
GreedyRewriteConfig{.listener = &rnUpdater})))
signalPassFailure();
}

Expand Down
50 changes: 50 additions & 0 deletions src/Dialect/ONNX/Transforms/ResultNamesUpdater.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.

#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Value.h>
#include <mlir/Support/LLVM.h>

#include "src/Dialect/ONNX/Transforms/ResultNamesUpdater.hpp"

namespace onnx_mlir {

void ResultNamesUpdater::notifyOperationReplaced(
mlir::Operation *op, mlir::ValueRange replacement) {
if (!op->hasAttrOfType<mlir::ArrayAttr>("ResultNames"))
return;

auto resultNamesArray = op->getAttrOfType<mlir::ArrayAttr>("ResultNames");

// If the op is replaced by a single op, simply copy the attribute
mlir::Operation *replSingleOp = replacement.front().getDefiningOp();
if (replSingleOp &&
llvm::all_of(replacement, [replSingleOp](mlir::Value val) -> bool {
return val.getDefiningOp() == replSingleOp;
})) {
replSingleOp->setAttr("ResultNames", resultNamesArray);
return;
}

mlir::MLIRContext *ctx = op->getContext();
for (auto [name, value] : llvm::zip_equal(resultNamesArray, replacement)) {
if (mlir::OpResult replResult = mlir::dyn_cast<mlir::OpResult>(value)) {
mlir::Operation *replOp = replResult.getOwner();

// Get new or existing ResultNames
mlir::SmallVector<mlir::Attribute> replResultNames(
replOp->getNumResults(), mlir::StringAttr::get(ctx));
if (auto existing = replOp->getAttrOfType<mlir::ArrayAttr>("ResultNames"))
replResultNames =
mlir::SmallVector<mlir::Attribute>(existing.getValue());

// Replace the ResultName of current result
replResultNames[replResult.getResultNumber()] = name;
replOp->setAttr(
"ResultNames", mlir::ArrayAttr::get(ctx, replResultNames));
}
}
}

} // namespace onnx_mlir
15 changes: 15 additions & 0 deletions src/Dialect/ONNX/Transforms/ResultNamesUpdater.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <mlir/IR/PatternMatch.h>

namespace onnx_mlir {

class ResultNamesUpdater : public mlir::RewriterBase::Listener {
public:
void notifyOperationReplaced(
mlir::Operation *op, mlir::ValueRange replacement) override;
};

} // namespace onnx_mlir
71 changes: 71 additions & 0 deletions test/mlir/onnx/onnx_resultnames_prop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights reserved.

// RUN: onnx-mlir-opt %s --constprop-onnx --decompose-onnx=enable-split-to-slice --onnx-hybrid-transform --qdq-canonicalize=remove-qdq-around-ops | FileCheck %s

func.func @constprop() -> tensor<f32> {
%0 = onnx.Constant {ResultNames = ["const0"]} dense<1.000000e+00> : tensor<f32>
%1 = onnx.Constant {ResultNames = ["const0"]} dense<2.000000e+00> : tensor<f32>
%2 = "onnx.Add"(%0, %1) {ResultNames = ["add0"]} : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}

// CHECK-LABEL: @constprop()
// CHECK: onnx.Constant
// CHECK-SAME: ResultNames = ["add0"]
// CHECK-SAME: dense<3.000000e+00>

func.func @decompose(%arg0: tensor<8x4xf32>) -> (tensor<4x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) {
%0 = onnx.Constant dense<[4, 2, 2]> : tensor<3xi64>
%1:3 = "onnx.Split"(%arg0, %0) {axis = 0 : si64, ResultNames = ["split_out0", "split_out1", "split_out2"]} : (tensor<8x4xf32>, tensor<3xi64>) -> (tensor<4x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>)
return %1#0, %1#1, %1#2 : tensor<4x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>
}

// CHECK-LABEL: @decompose
// CHECK: onnx.Slice
// CHECK-SAME: ResultNames = ["split_out0"]
// CHECK-NEXT: onnx.Slice
// CHECK-SAME: ResultNames = ["split_out1"]
// CHECK-NEXT: onnx.Slice
// CHECK-SAME: ResultNames = ["split_out2"]

func.func @canonicalize(%arg0: tensor<f32>) -> tensor<f32> {
%0 = onnx.Constant {ResultNames = ["const0"]} dense<2.000000e+00> : tensor<f32>
%1 = "onnx.Add"(%0, %arg0) {ResultNames = ["add0"]} : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}

// CHECK-LABEL: @canonicalize
// CHECK: "onnx.Add"(%arg0, %0)
// CHECK-SAME: ResultNames = ["add0"]

func.func @qdq_canonicalize(%arg0: tensor<1x128xf32>) -> tensor<1x1x128xf32> {
%0 = onnx.Constant {ResultNames = ["scale"]} dense<1.000000e+00> : tensor<f32>
%1 = onnx.Constant {ResultNames = ["zp"]} dense<128> : tensor<ui8>
%2 = onnx.Constant {ResultNames = ["shape"]} dense<[1, 1, 128]> : tensor<3xi64>
%3 = "onnx.QuantizeLinear"(%arg0, %0, %1) {ResultNames = ["q0"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x128xf32>, tensor<f32>, tensor<ui8>) -> tensor<1x128xui8>
%4 = "onnx.DequantizeLinear"(%3, %0, %1) {ResultNames = ["dq0"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x128xui8>, tensor<f32>, tensor<ui8>) -> tensor<1x128xf32>
%5 = "onnx.Reshape"(%4, %2) {ResultNames = ["reshape"], allowzero = 0 : si64} : (tensor<1x128xf32>, tensor<3xi64>) -> tensor<1x1x128xf32>
%6 = "onnx.QuantizeLinear"(%5, %0, %1) {ResultNames = ["q1"], axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x1x128xf32>, tensor<f32>, tensor<ui8>) -> tensor<1x1x128xui8>
%7 = "onnx.DequantizeLinear"(%6, %0, %1) {ResultNames = ["dq1"], axis = 1 : si64, block_size = 0 : si64} : (tensor<1x1x128xui8>, tensor<f32>, tensor<ui8>) -> tensor<1x1x128xf32>
return %7 : tensor<1x1x128xf32>
}

// CHECK-LABEL: @qdq_canonicalize
// CHECK: onnx.QuantizeLinear
// CHECK-SAME: ResultNames = ["q0"]
// CHECK-NOT: onnx.DequantizeLinear
// CHECK: onnx.Reshape
// CHECK-SAME: ResultNames = ["q1"]
// CHECK-NOT: onnx.QuantizeLinear
// CHECK: onnx.DequantizeLinear

func.func @complex_names(%arg0: tensor<f32>) -> tensor<f32> {
%0 = onnx.Constant {ResultNames = ["const0"]} dense<2.000000e+00> : tensor<f32>
%1 = "onnx.Add"(%0, %arg0) {ResultNames = [["add0", "with", "array", [1, 2, 3, 4]]]} : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}

// CHECK-LABEL: @complex_names
// CHECK: "onnx.Add"(%arg0, %0)
// CHECK-SAME: ResultNames = [
// CHECK-SAME: ["add0", "with", "array", [1, 2, 3, 4]]]