Skip to content

[mlir][linalg] Convert linalg.named to linalg.elementwise op. #148424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

javedabsar1
Copy link
Contributor

Convert linalg.named ops which are elementwise (e.g. add/exp) to linalg.elementwise. Currently, named ops have to drop to linalg.generic (--generalize-named-ops), where one figures out which generic are elementwise. Also, folding of broadcast or transpose can occur then only at generic level. Instead, with this rewrite, these can happen now at linalg.elementwise.

Convert linalg.named ops which are elementwise (e.g. add/exp)
to `linalg.elementwise`. Currently, named ops have to drop to
linalg.generic (--generalize-named-ops), where one figures
out which generic are elementwise. Also, folding of broadcast
or transpose can occur then only at generic level. Instead,
with this rewrite, these can happen now at linalg.elementwise.
@llvmbot
Copy link
Member

llvmbot commented Jul 13, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Javed Absar (javedabsar1)

Changes

Convert linalg.named ops which are elementwise (e.g. add/exp) to linalg.elementwise. Currently, named ops have to drop to linalg.generic (--generalize-named-ops), where one figures out which generic are elementwise. Also, folding of broadcast or transpose can occur then only at generic level. Instead, with this rewrite, these can happen now at linalg.elementwise.


Full diff: https://github.com/llvm/llvm-project/pull/148424.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+5)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp (+118)
  • (added) mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir (+38)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..f2c1b99b138bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
+  let summary = "Convert linalg named ops to elementwise where possible";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
   let summary = "Fold transform, broadcast and other ops into elementwise";
   let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..086073c11c80a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1810,6 +1810,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 void populateLinalgGenericOpsSpecializationPatterns(
     RewritePatternSet &patterns);
 
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like
 /// `linalg.transform` into elementwise op map.
 void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fdabf9a58..7cb83377fa0d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
+  NamedToElementwise.cpp
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000000000..1303b7cb5f6f9
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,118 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+  return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+      .Case([](SelectOp) { return ElementwiseKind::select; })
+      .Case([](AddOp) { return ElementwiseKind::add; })
+      .Case([](SubOp) { return ElementwiseKind::sub; })
+      .Case([](MulOp) { return ElementwiseKind::mul; })
+      .Case([](DivOp) { return ElementwiseKind::div; })
+      .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+      .Case([](PowFOp) { return ElementwiseKind::powf; })
+      .Case([](ExpOp) { return ElementwiseKind::exp; })
+      .Case([](LogOp) { return ElementwiseKind::log; })
+      .Case([](AbsOp) { return ElementwiseKind::abs; })
+      .Case([](CeilOp) { return ElementwiseKind::ceil; })
+      .Case([](FloorOp) { return ElementwiseKind::floor; })
+      .Case([](NegFOp) { return ElementwiseKind::negf; })
+      .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+      .Case([](RoundOp) { return ElementwiseKind::round; })
+      .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+      .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+      .Case([](SquareOp) { return ElementwiseKind::square; })
+      .Case([](TanhOp) { return ElementwiseKind::tanh; })
+      .Case([](ErfOp) { return ElementwiseKind::erf; })
+      .Default([&](Operation *op) {
+        assert(false && "unexpected op");
+        return ElementwiseKind::sub;
+      });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+  using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(NamedOpTy op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<NamedAttribute> attrs;
+    auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+    attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+    attrs.push_back(
+        rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+                                               op.getDpsInits(), attrs);
+    return success();
+  }
+};
+
+struct LinalgNamedToElementwisePass
+    : public impl::LinalgNamedToElementwisePassBase<
+          LinalgNamedToElementwisePass> {
+  using impl::LinalgNamedToElementwisePassBase<
+      LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateLinalgNamedToElementwisePatterns(patterns);
+
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
new file mode 100644
index 0000000000000..3dc8275117336
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %exp :  tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %add :  tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) {
+// CHECK:      linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : memref<16x8xf32>, memref<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : memref<16x8xf32>)
+//
+func.func @sub(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C : memref<16x8xf32>) {
+  linalg.sub ins(%A, %B : memref<16x8xf32>, memref<16x8xf32>) outs(%C :  memref<16x8xf32>)
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 13, 2025

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

Changes

Convert linalg.named ops which are elementwise (e.g. add/exp) to linalg.elementwise. Currently, named ops have to drop to linalg.generic (--generalize-named-ops), where one figures out which generic are elementwise. Also, folding of broadcast or transpose can occur then only at generic level. Instead, with this rewrite, these can happen now at linalg.elementwise.


Full diff: https://github.com/llvm/llvm-project/pull/148424.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+5)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp (+118)
  • (added) mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir (+38)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..f2c1b99b138bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
+  let summary = "Convert linalg named ops to elementwise where possible";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
   let summary = "Fold transform, broadcast and other ops into elementwise";
   let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..086073c11c80a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1810,6 +1810,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 void populateLinalgGenericOpsSpecializationPatterns(
     RewritePatternSet &patterns);
 
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like
 /// `linalg.transform` into elementwise op map.
 void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fdabf9a58..7cb83377fa0d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
+  NamedToElementwise.cpp
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000000000..1303b7cb5f6f9
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,118 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+  return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+      .Case([](SelectOp) { return ElementwiseKind::select; })
+      .Case([](AddOp) { return ElementwiseKind::add; })
+      .Case([](SubOp) { return ElementwiseKind::sub; })
+      .Case([](MulOp) { return ElementwiseKind::mul; })
+      .Case([](DivOp) { return ElementwiseKind::div; })
+      .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+      .Case([](PowFOp) { return ElementwiseKind::powf; })
+      .Case([](ExpOp) { return ElementwiseKind::exp; })
+      .Case([](LogOp) { return ElementwiseKind::log; })
+      .Case([](AbsOp) { return ElementwiseKind::abs; })
+      .Case([](CeilOp) { return ElementwiseKind::ceil; })
+      .Case([](FloorOp) { return ElementwiseKind::floor; })
+      .Case([](NegFOp) { return ElementwiseKind::negf; })
+      .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+      .Case([](RoundOp) { return ElementwiseKind::round; })
+      .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+      .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+      .Case([](SquareOp) { return ElementwiseKind::square; })
+      .Case([](TanhOp) { return ElementwiseKind::tanh; })
+      .Case([](ErfOp) { return ElementwiseKind::erf; })
+      .Default([&](Operation *op) {
+        assert(false && "unexpected op");
+        return ElementwiseKind::sub;
+      });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+  using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(NamedOpTy op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<NamedAttribute> attrs;
+    auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+    attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+    attrs.push_back(
+        rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+                                               op.getDpsInits(), attrs);
+    return success();
+  }
+};
+
+struct LinalgNamedToElementwisePass
+    : public impl::LinalgNamedToElementwisePassBase<
+          LinalgNamedToElementwisePass> {
+  using impl::LinalgNamedToElementwisePassBase<
+      LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateLinalgNamedToElementwisePatterns(patterns);
+
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+  patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
new file mode 100644
index 0000000000000..3dc8275117336
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %exp :  tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C :  tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %add :  tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) {
+// CHECK:      linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : memref<16x8xf32>, memref<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : memref<16x8xf32>)
+//
+func.func @sub(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C : memref<16x8xf32>) {
+  linalg.sub ins(%A, %B : memref<16x8xf32>, memref<16x8xf32>) outs(%C :  memref<16x8xf32>)
+  return
+}

@rengolin
Copy link
Member

I'd like to have a longer term plan for this. Given the operation tree discussion, we probably should start thinking about "partial (de-)generalization" instead of lose passes.

generic <--> elementwise <--> add/mul/etc
          ^--> contract <--> matmul/etc

And also look at fitting the non-perfectly-nested ops like contract and composite ops like softmax to a DAG of operations:

DAG(generic ops) <--> DAG(named ops) <--> softmax

So that users can pick-and-choose separately which stage they stop for each step of the way.

@javedabsar1
Copy link
Contributor Author

javedabsar1 commented Jul 14, 2025

users can pick-and-choose separately which stage they stop for each step of the way.

Thanks @rengolin -- good thinking on extending this concept/PR.

We need conversion from existing linalg.named ops to the new ops in the tree ( i.e. elementwise/contaction, ... https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586) otherwise we don't have that many use case for new ops .

  1. named -->generic : already exists ( mlir-opt --linalg-generalize-named-ops ).
  2. named --> linalg.element_wise , linalg.contraction : This PR mlir-opt -linalg-named-to-elementwise. We can rename it to cover contraction-op etc and extend the implementation etc. (-linalg-named-to-category-ops?)
  3. new ops elementwise/contraction --> named ops: Do we want this? I don't see a use case at least to me.
  4. generic --> named : exists -linalg-specialize-generic-ops I did some work on this previously as you know but in general this path is not always possible (fused, handwritten generic).
  5. generic --> elementwise/contract - I think this falls under specialize.

We could discuss further in next Lighthouse meeting? Meanwhile, some common agreement -- e.g. item (2) above?

@rengolin
Copy link
Member

rengolin commented Jul 14, 2025

  • named -->generic : already exists ( mlir-opt --linalg-generalize-named-ops ).

Yes, but it goes straight to generic, can't stop at the middle ops (contract, elementwise).

  • named --> linalg.element_wise , linalg.contraction : This PR mlir-opt -linalg-named-to-elementwise. We can rename it to cover contraction-op etc and extend the implementation etc. (-linalg-named-to-category-ops?)

Yes, but it's yet-another pass that users can easily miss. There's no "structure" in these passes.

  • new ops elementwise/contraction --> named ops: Do we want this? I don't see a use case at least to me.

I don't particularly need it, but since there's named-to-TOSA conversions and not from the more general forms, this may be useful. But we don't generally implement what may be useful, so I'd wait for a real user to do this. Again, better if it's part of a structured conversion.

  • generic --> named : exists -linalg-specialize-generic-ops I did some work on this previously as you know but in general this path is not always possible (fused, handwritten generic).

Yes, but it goes straight to named, can't stop at the middle ops (contract, elementwise).

  • generic --> elementwise/contract - I think this falls under specialize.

I think this is the most valuable of the bunch. In our cases, at least, it's a lot easier to match against elementwise/contract than generic, and many named patterns will need a DAG to represent the same things and we don't (yet) have a good structured matcher. It's also a lot easier to transform a contract into another (different) contract by just operating on affine maps and iterator types than knowing the particularities of batch_matmul or batch_reduce_matmul and their particular usages.

@adam-smnk
Copy link
Contributor

If we had a name for each abstraction level, there could be a single pass that converts all matching ops to that representation.
For everything to generic, there's generalize. For going to named ops, it's specialize.

named --> linalg.element_wise , linalg.contraction : This PR mlir-opt -linalg-named-to-elementwise. We can rename it to cover contraction-op etc and extend the implementation etc. (-linalg-named-to-category-ops?)

This seems reasonable with a linalg-categorize-ops for conversion: generic -> eltwise/contract etc. <- named.

@javedabsar1
Copy link
Contributor Author

javedabsar1 commented Jul 14, 2025

If we had a name for each abstraction level, there could be a single pass that converts all matching ops to that representation. For everything to generic, there's generalize. For going to named ops, it's specialize.

named --> linalg.element_wise , linalg.contraction : This PR mlir-opt -linalg-named-to-elementwise. We can rename it to cover contraction-op etc and extend the implementation etc. (-linalg-named-to-category-ops?)

This seems reasonable with a linalg-categorize-ops for conversion: generic -> eltwise/contract etc. <- named.

Thanks @rengolin / @adam-smnk -- we seem to be converging to something good!

An umbrella name for all this could be (borrowing a term from mathematics -- morphism).

Named (linalg.add)
\/ categorize-(all) named-ops
Category (linalg.contraction)
\/ generalize-(all) named-ops /\ categorize (all) generic-ops
Generic (linalg.generic

  1. --linalg-morph-ops
  2. --linalg-morph-ops=generalize-named-ops converts all named ops to generic (calls --linalg-generalize-named ops)
  3. --linalg-morph-ops=generalize-all -ops` converts all named/contraction/elementwise ops to generic
  4. --linalg-morph-ops=categorize-all-ops` converts named and generic ops to convert to elementwise/contraction....

I may not be having the best names or option names so please suggest better ones.

@rengolin
Copy link
Member

Named (linalg.add)
\/ categorize-(all) named-ops
Category (linalg.contraction)
\/ generalize-(all) named-ops /\ categorize (all) generic-ops
Generic (linalg.generic)

Awesome, this is similar to what I had in mind. And to be clear, we're just talking strategy, here. This PR is a much simpler thing and good for what it is, we just need to know where in the general strategy it will live, and how to get the naming / conversion right.

My main concerns are on how we represent the various steps:

  1. Do we make every generalization and categorization step-wise and then users have to "iterate until it gets on the level they want"?
    1. How do we manage the categorization of a DAG where only some ops can go to named ops while others must remain as a category or generic?
    2. What if the DAG is in separate representations? Do we step-wise generalize/categorize individual ops to the next step, or do we find a common stage and convert all to that stage? And is that max or min of the final stages?
  2. Or do we have names to each morphism that relates to the end state (named/category/generic) and convert all to that stage?
    1. What if some ops can't be moved to that form? Is this a soft fail or a hard fail?
    2. What if we want some ops to be at a particular form while others are different? The current softmax lowering produces both generic and named ops. Is that a good thing, even?

Alternatively, we can have both. So, if you ask for a "generalization morphism", you go step-wise. If you ask for a "morphism into a particular stage", you get a normalization of all ops into that form (generalize and specialize). This may end up being a lot of different patterns, especially for generalization strategies (match all named ops), so we also need care in how we do this.

I'd also make any issues in conversion soft failures, and then both step-wise and into form variants would behave similar, and the latter can be just an iteration of the former until all conversions return soft failures.

@javedabsar1
Copy link
Contributor Author

Named (linalg.add)
\/ categorize-(all) named-ops
Category (linalg.contraction)
\/ generalize-(all) named-ops /\ categorize (all) generic-ops
Generic (linalg.generic)

Awesome, this is similar to what I had in mind. And to be clear, we're just talking strategy, here. This PR is a much simpler thing and good for what it is, we just need to know where in the general strategy it will live, and how to get the naming / conversion right.

My main concerns are on how we represent the various steps:

  1. Do we make every generalization and categorization step-wise and then users have to "iterate until it gets on the level they want"?

    1. How do we manage the categorization of a DAG where only some ops can go to named ops while others must remain as a category or generic?
    2. What if the DAG is in separate representations? Do we step-wise generalize/categorize individual ops to the next step, or do we find a common stage and convert all to that stage? And is that max or min of the final stages?
  2. Or do we have names to each morphism that relates to the end state (named/category/generic) and convert all to that stage?

    1. What if some ops can't be moved to that form? Is this a soft fail or a hard fail?
    2. What if we want some ops to be at a particular form while others are different? The current softmax lowering produces both generic and named ops. Is that a good thing, even?

Alternatively, we can have both. So, if you ask for a "generalization morphism", you go step-wise. If you ask for a "morphism into a particular stage", you get a normalization of all ops into that form (generalize and specialize). This may end up being a lot of different patterns, especially for generalization strategies (match all named ops), so we also need care in how we do this.

I'd also make any issues in conversion soft failures, and then both step-wise and into form variants would behave similar, and the latter can be just an iteration of the former until all conversions return soft failures.

@javedabsar1 javedabsar1 reopened this Jul 14, 2025
@javedabsar1
Copy link
Contributor Author

javedabsar1 commented Jul 14, 2025

  1. Do we make every generalization and categorization step-wise
    -linalg-morph-ops=options

The options are -

  • -linalg-morphism=generalize-named-ops (exactly as legacy --linalg-generalize-named-ops, which converts named to generic
  • generalize-all-ops named + category ops
  • categorize-named-ops rewrite patterns from this PR for elementwise and new ones for contraction etc.
    We can iterate over exact name. The options pack populate*Pattern so its not a headache.

How do we manage the categorization of a DAG where only some ops can go to named ops while others must remain as a category or generic?

-linalg-morph-ops=categorize{...}

  1. What if the DAG is in separate representations? Do we step-wise generalize/categorize individual ops to the next step, or do we find a common stage and convert all to that stage? And is that max or min of the final stages?

Part of what you said correctly earlier is a strategy. And to me that means having building-blocks (step-wise) option, having a pipeline (not in literal sense) that semantically is always named->category->generic but options can pack different mixes, which if i understand you correctly is where input IR is a mix of named+category+generic.

  1. What if some ops can't be moved to that form? Is this a soft fail or a hard fail?

soft fail --- thats what specialize and named do currently.

Alternatively, we can have both. So, if you ask for a "generalization morphism", you go step-wise. If you ask for a "morphism into a particular stage", you get a normalization of all ops into that form (generalize and specialize). This may end up being a lot of different patterns, especially for generalization strategies (match all named ops), so we also need care in how we do this.

Good point. To me named->categorize->generalize is a no-fail. generic->specialize-(category or named) is best-effort.

@adam-smnk
Copy link
Contributor

How do we manage the categorization of a DAG where only some ops can go to named ops while others must remain as a category or generic?

Depends a bit how granular you want to go. If the choice is per operation then one should roll their own pass/transformation that can apply custom logic.
It's a design decision whether the morphism logic should take a control callback or expose rewrites as API (like linalg::vectorize).

@javedabsar1
Copy link
Contributor Author

How do we manage the categorization of a DAG where only some ops can go to named ops while others must remain as a category or generic?

Depends a bit how granular you want to go. If the choice is per operation then one should roll their own pass/transformation that can apply custom logic. It's a design decision whether the morphism logic should take a control callback or expose rewrites as API (like linalg::vectorize).

call-back i.e. controlFn sounds better than command-line option.

@adam-smnk
Copy link
Contributor

Do we make every generalization and categorization step-wise and then users have to "iterate until it gets on the level they want"?

I would avoid having to repeatedly call "abstract 1 level up". It seems annoying to use when you start with mixed abstraction IR.
IMO, a general utility should do best-effort to rewrite all (selected) ops to chosen level.
For anything more specific, users should have to supply their custom controls.

Comment on lines +57 to +58
assert(false && "unexpected op");
return ElementwiseKind::sub;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default case should only contain llvm_unreachable.
Otherwise, this will have different behavior in release mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants