Skip to content

[mlir][linalg] Morphism across linalg named, category and generic ops. #148424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
let summary = "Convert linalg named ops to elementwise where possible";
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let summary = "Fold transform, broadcast and other ops into elementwise";
let dependentDialects = ["linalg::LinalgDialect"];
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);

/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
/// to equivalent `linalg.elementwise`.
void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that fold operations like
/// `linalg.transform` into elementwise op map.
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
Expand Down
118 changes: 118 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewriting those linalg named ops that are essentially
// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
// optimization on `linalg.elementwise` such as folding transpose, broadcast.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::linalg;

#define DEBUG_TYPE "linalg-named-to-elementwise"

namespace {
ElementwiseKind getKind(Operation *op) {
return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
.Case([](SelectOp) { return ElementwiseKind::select; })
.Case([](AddOp) { return ElementwiseKind::add; })
.Case([](SubOp) { return ElementwiseKind::sub; })
.Case([](MulOp) { return ElementwiseKind::mul; })
.Case([](DivOp) { return ElementwiseKind::div; })
.Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
.Case([](PowFOp) { return ElementwiseKind::powf; })
.Case([](ExpOp) { return ElementwiseKind::exp; })
.Case([](LogOp) { return ElementwiseKind::log; })
.Case([](AbsOp) { return ElementwiseKind::abs; })
.Case([](CeilOp) { return ElementwiseKind::ceil; })
.Case([](FloorOp) { return ElementwiseKind::floor; })
.Case([](NegFOp) { return ElementwiseKind::negf; })
.Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
.Case([](RoundOp) { return ElementwiseKind::round; })
.Case([](SqrtOp) { return ElementwiseKind::sqrt; })
.Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
.Case([](SquareOp) { return ElementwiseKind::square; })
.Case([](TanhOp) { return ElementwiseKind::tanh; })
.Case([](ErfOp) { return ElementwiseKind::erf; })
.Default([&](Operation *op) {
assert(false && "unexpected op");
return ElementwiseKind::sub;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default case should only contain llvm_unreachable.
Otherwise, this will have different behavior in release mode.

});
}

template <typename NamedOpTy>
struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
using OpRewritePattern<NamedOpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(NamedOpTy op,
PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute> attrs;
auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
attrs.push_back(
rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));

rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
op.getDpsInits(), attrs);
return success();
}
};

struct LinalgNamedToElementwisePass
: public impl::LinalgNamedToElementwisePassBase<
LinalgNamedToElementwisePass> {
using impl::LinalgNamedToElementwisePassBase<
LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;

void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateLinalgNamedToElementwisePatterns(patterns);

if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

void mlir::linalg::populateLinalgNamedToElementwisePatterns(
RewritePatternSet &patterns) {
patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
}
38 changes: 38 additions & 0 deletions mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s

// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
return %exp : tensor<16x8xf32>
}

// ----

// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<add>
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
%add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
return %add : tensor<16x8xf32>
}

// ----

// CHECK: @sub(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) {
// CHECK: linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
// CHECK-SAME: ins(%[[A]], %[[B]] : memref<16x8xf32>, memref<16x8xf32>)
// CHECK-SAME: outs(%[[C]] : memref<16x8xf32>)
//
func.func @sub(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C : memref<16x8xf32>) {
linalg.sub ins(%A, %B : memref<16x8xf32>, memref<16x8xf32>) outs(%C : memref<16x8xf32>)
return
}