-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][linalg] Convert linalg.named to linalg.elementwise op. #148424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Convert linalg.named ops which are elementwise (e.g. add/exp) to `linalg.elementwise`. Currently, named ops have to drop to linalg.generic (--generalize-named-ops), where one figures out which generic are elementwise. Also, folding of broadcast or transpose can occur then only at generic level. Instead, with this rewrite, these can happen now at linalg.elementwise.
@llvm/pr-subscribers-mlir-linalg Author: Javed Absar (javedabsar1) ChangesConvert linalg.named ops which are elementwise (e.g. add/exp) to Full diff: https://github.com/llvm/llvm-project/pull/148424.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..f2c1b99b138bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
+ let summary = "Convert linalg named ops to elementwise where possible";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let summary = "Fold transform, broadcast and other ops into elementwise";
let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..086073c11c80a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1810,6 +1810,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold operations like
/// `linalg.transform` into elementwise op map.
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fdabf9a58..7cb83377fa0d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
+ NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000000000..1303b7cb5f6f9
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,118 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+ .Case([](SelectOp) { return ElementwiseKind::select; })
+ .Case([](AddOp) { return ElementwiseKind::add; })
+ .Case([](SubOp) { return ElementwiseKind::sub; })
+ .Case([](MulOp) { return ElementwiseKind::mul; })
+ .Case([](DivOp) { return ElementwiseKind::div; })
+ .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+ .Case([](PowFOp) { return ElementwiseKind::powf; })
+ .Case([](ExpOp) { return ElementwiseKind::exp; })
+ .Case([](LogOp) { return ElementwiseKind::log; })
+ .Case([](AbsOp) { return ElementwiseKind::abs; })
+ .Case([](CeilOp) { return ElementwiseKind::ceil; })
+ .Case([](FloorOp) { return ElementwiseKind::floor; })
+ .Case([](NegFOp) { return ElementwiseKind::negf; })
+ .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+ .Case([](RoundOp) { return ElementwiseKind::round; })
+ .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+ .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+ .Case([](SquareOp) { return ElementwiseKind::square; })
+ .Case([](TanhOp) { return ElementwiseKind::tanh; })
+ .Case([](ErfOp) { return ElementwiseKind::erf; })
+ .Default([&](Operation *op) {
+ assert(false && "unexpected op");
+ return ElementwiseKind::sub;
+ });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+ using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(NamedOpTy op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<NamedAttribute> attrs;
+ auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+ attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+ attrs.push_back(
+ rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+ op.getDpsInits(), attrs);
+ return success();
+ }
+};
+
+struct LinalgNamedToElementwisePass
+ : public impl::LinalgNamedToElementwisePassBase<
+ LinalgNamedToElementwisePass> {
+ using impl::LinalgNamedToElementwisePassBase<
+ LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ populateLinalgNamedToElementwisePatterns(patterns);
+
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
new file mode 100644
index 0000000000000..3dc8275117336
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %add : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) {
+// CHECK: linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<16x8xf32>, memref<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : memref<16x8xf32>)
+//
+func.func @sub(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C : memref<16x8xf32>) {
+ linalg.sub ins(%A, %B : memref<16x8xf32>, memref<16x8xf32>) outs(%C : memref<16x8xf32>)
+ return
+}
|
@llvm/pr-subscribers-mlir Author: Javed Absar (javedabsar1) ChangesConvert linalg.named ops which are elementwise (e.g. add/exp) to Full diff: https://github.com/llvm/llvm-project/pull/148424.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c9b03de..f2c1b99b138bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
+ let summary = "Convert linalg named ops to elementwise where possible";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let summary = "Fold transform, broadcast and other ops into elementwise";
let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..086073c11c80a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1810,6 +1810,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold operations like
/// `linalg.transform` into elementwise op map.
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fdabf9a58..7cb83377fa0d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
+ NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000000000..1303b7cb5f6f9
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,118 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+ .Case([](SelectOp) { return ElementwiseKind::select; })
+ .Case([](AddOp) { return ElementwiseKind::add; })
+ .Case([](SubOp) { return ElementwiseKind::sub; })
+ .Case([](MulOp) { return ElementwiseKind::mul; })
+ .Case([](DivOp) { return ElementwiseKind::div; })
+ .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+ .Case([](PowFOp) { return ElementwiseKind::powf; })
+ .Case([](ExpOp) { return ElementwiseKind::exp; })
+ .Case([](LogOp) { return ElementwiseKind::log; })
+ .Case([](AbsOp) { return ElementwiseKind::abs; })
+ .Case([](CeilOp) { return ElementwiseKind::ceil; })
+ .Case([](FloorOp) { return ElementwiseKind::floor; })
+ .Case([](NegFOp) { return ElementwiseKind::negf; })
+ .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+ .Case([](RoundOp) { return ElementwiseKind::round; })
+ .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+ .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+ .Case([](SquareOp) { return ElementwiseKind::square; })
+ .Case([](TanhOp) { return ElementwiseKind::tanh; })
+ .Case([](ErfOp) { return ElementwiseKind::erf; })
+ .Default([&](Operation *op) {
+ assert(false && "unexpected op");
+ return ElementwiseKind::sub;
+ });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+ using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(NamedOpTy op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<NamedAttribute> attrs;
+ auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+ attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+ attrs.push_back(
+ rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+ op.getDpsInits(), attrs);
+ return success();
+ }
+};
+
+struct LinalgNamedToElementwisePass
+ : public impl::LinalgNamedToElementwisePassBase<
+ LinalgNamedToElementwisePass> {
+ using impl::LinalgNamedToElementwisePassBase<
+ LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ populateLinalgNamedToElementwisePatterns(patterns);
+
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
new file mode 100644
index 0000000000000..3dc8275117336
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %add : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) {
+// CHECK: linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<16x8xf32>, memref<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : memref<16x8xf32>)
+//
+func.func @sub(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C : memref<16x8xf32>) {
+ linalg.sub ins(%A, %B : memref<16x8xf32>, memref<16x8xf32>) outs(%C : memref<16x8xf32>)
+ return
+}
|
I'd like to have a longer term plan for this. Given the operation tree discussion, we probably should start thinking about "partial (de-)generalization" instead of lose passes.
And also look at fitting the non-perfectly-nested ops like
So that users can pick-and-choose separately which stage they stop for each step of the way. |
Thanks @rengolin -- good thinking on extending this concept/PR. We need conversion from existing linalg.named ops to the new ops in the tree ( i.e. elementwise/contaction, ... https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586) otherwise we don't have that many use case for new ops .
We could discuss further in next Lighthouse meeting? Meanwhile, some common agreement -- e.g. item (2) above? |
Yes, but it goes straight to generic, can't stop at the middle ops (
Yes, but it's yet-another pass that users can easily miss. There's no "structure" in these passes.
I don't particularly need it, but since there's named-to-TOSA conversions and not from the more general forms, this may be useful. But we don't generally implement what may be useful, so I'd wait for a real user to do this. Again, better if it's part of a structured conversion.
Yes, but it goes straight to named, can't stop at the middle ops (
I think this is the most valuable of the bunch. In our cases, at least, it's a lot easier to match against |
If we had a name for each abstraction level, there could be a single pass that converts all matching ops to that representation.
This seems reasonable with a |
Thanks @rengolin / @adam-smnk -- we seem to be converging to something good! An umbrella name for all this could be (borrowing a term from mathematics -- morphism). Named (linalg.add)
I may not be having the best names or option names so please suggest better ones. |
Awesome, this is similar to what I had in mind. And to be clear, we're just talking strategy, here. This PR is a much simpler thing and good for what it is, we just need to know where in the general strategy it will live, and how to get the naming / conversion right. My main concerns are on how we represent the various steps:
Alternatively, we can have both. So, if you ask for a "generalization morphism", you go step-wise. If you ask for a "morphism into a particular stage", you get a normalization of all ops into that form (generalize and specialize). This may end up being a lot of different patterns, especially for generalization strategies (match all named ops), so we also need care in how we do this. I'd also make any issues in conversion soft failures, and then both step-wise and into form variants would behave similar, and the latter can be just an iteration of the former until all conversions return soft failures. |
|
The options are -
Part of what you said correctly earlier is a strategy. And to me that means having building-blocks (step-wise) option, having a pipeline (not in literal sense) that semantically is always
soft fail --- thats what specialize and named do currently.
Good point. To me |
Depends a bit how granular you want to go. If the choice is per operation then one should roll their own pass/transformation that can apply custom logic. |
call-back i.e. controlFn sounds better than command-line option. |
I would avoid having to repeatedly call "abstract 1 level up". It seems annoying to use when you start with mixed abstraction IR. |
assert(false && "unexpected op"); | ||
return ElementwiseKind::sub; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default case should only contain llvm_unreachable
.
Otherwise, this will have different behavior in release mode.
Convert linalg.named ops which are elementwise (e.g. add/exp) to
linalg.elementwise
. Currently, named ops have to drop to linalg.generic (--generalize-named-ops), where one figures out which generic are elementwise. Also, folding of broadcast or transpose can occur then only at generic level. Instead, with this rewrite, these can happen now at linalg.elementwise.