Skip to content

Commit 8e4adfb

Browse files
[imex-opt] port serializespirvpass from main branch (#307)
* [imex-opt] port serializespirvpass from main to refactor * add more comment and documentation
1 parent 77c3a90 commit 8e4adfb

File tree

11 files changed

+311
-2
lines changed

11 files changed

+311
-2
lines changed

docs/Transforms/SerializeSPIRV.md

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# SerializeSPIRV Pass (serialize-spirv)
2+
3+
The SerializeSPIRV pass utilizes the upstream spirv::serialize() function to serialize MLIR SPIR-V module to SPIR-V binary and attaches the binary as a string attribute to the gpuModule(attributes {gpu.binary}).
4+
This pass works like the upstream SerializeToCubin and SerializeToHsaco, only that the other two passes translate llvm dialect to llvm IR and then translate to ISA binary.
5+
6+
## Example
7+
8+
```
9+
// -----// IR Dump Before SerializeSPIRV //----- //
10+
module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, #spv.resource_limits<>>} {
11+
spv.module @__spv__addt_kernel Physical64 OpenCL requires #spv.vce<v1.0, [Int64, Addresses, Kernel], []> {
12+
spv.GlobalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr<vector<3xi64>, Input>
13+
spv.func @addt_kernel(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1: !spv.ptr<f32, CrossWorkgroup>, %arg2: !spv.ptr<f32, CrossWorkgroup>) "None" attributes {spv.entry_point_abi = #spv.entry_point_abi<>, workgroup_attributions = 0 : i64} {
14+
%cst5_i64 = spv.Constant 5 : i64
15+
%__builtin_var_WorkgroupId___addr = spv.mlir.addressof @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi64>, Input>
16+
%0 = spv.Load "Input" %__builtin_var_WorkgroupId___addr : vector<3xi64>
17+
%1 = spv.CompositeExtract %0[0 : i32] : vector<3xi64>
18+
%__builtin_var_WorkgroupId___addr_0 = spv.mlir.addressof @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi64>, Input>
19+
.
20+
.
21+
.
22+
%13 = spv.IMul %1, %cst5_i64 : i64
23+
%14 = spv.IAdd %13, %3 : i64
24+
%15 = spv.InBoundsPtrAccessChain %arg2[%14] : !spv.ptr<f32, CrossWorkgroup>, i64
25+
spv.Store "CrossWorkgroup" %15, %12 ["Aligned", 4] : f32
26+
spv.Return
27+
}
28+
spv.EntryPoint "Kernel" @addt_kernel, @__builtin_var_WorkgroupId__
29+
}
30+
gpu.module @addt_kernel {
31+
gpu.func @addt_kernel(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi<>} {
32+
%c5 = arith.constant 5 : index
33+
%0 = gpu.block_id x
34+
%1 = gpu.block_id y
35+
.
36+
.
37+
.
38+
%9 = arith.muli %0, %c5 : index
39+
%10 = arith.addi %9, %1 : index
40+
memref.store %8, %arg2[%10] : memref<?xf32>
41+
gpu.return
42+
}
43+
}
44+
}
45+
46+
```
47+
48+
The Pass will change the IR to:
49+
50+
```
51+
// -----// IR Dump After SerializeSPIRV //----- //
52+
module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Ve ctor16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, #spv.re source_limits<>>} {
53+
gpu.module @addt_kernel attributes {gpu.binary = "\03\02#\07\00\00\01\00 ... \00"} {
54+
gpu.func @addt_kernel(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi<> } {
55+
%c5 = arith.constant 5 : index
56+
%0 = gpu.block_id x
57+
%1 = gpu.block_id y
58+
.
59+
.
60+
.
61+
%9 = arith.muli %0, %c5 : index
62+
%10 = arith.addi %9, %1 : index
63+
memref.store %8, %arg2[%10] : memref<?xf32>
64+
gpu.return
65+
}
66+
}
67+
}
68+
69+
```
70+
71+
As shown in the example above, the spv module op was serialized to spv binary and attached to the gpu module op as a "gpu.binary" attribute
72+
73+
74+
## Reason for this Custom Pass:
75+
76+
Upstream does not have a standalone pass which wraps this function. We can upstream this if it proves this is a common flow needed.

include/imex/InitIMEXPasses.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
// #include <imex/Transforms/IMEXPasses.h>
2020
#include <imex/Dialect/Dist/Transforms/Passes.h>
2121
// #include <imex/Dialect/*/Transforms/Passes.h>
22+
#include "imex/Transforms/Passes.h"
2223

2324
#include <cstdlib>
2425

@@ -33,7 +34,7 @@ namespace imex {
3334
/// The global registry is interesting to interact with the command-line tools.
3435
inline void registerAllPasses() {
3536
// General passes
36-
// registerTransformsPasses();
37+
registerTransformsPasses();
3738

3839
// Conversion passes
3940
registerConversionPasses();

include/imex/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header --prefix Transforms)
44
mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl --prefix Transforms)
55
add_public_tablegen_target(IMEXTransformsPassIncGen)
66

7-
add_mlir_doc(Passes GeneralPasses ./ -gen-pass-doc)
7+
add_mlir_doc(Passes IMEXGeneralPasses ./ -gen-pass-doc)

include/imex/Transforms/Passes.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
2+
//
3+
// Copyright 2022 Intel Corporation
4+
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This header file defines prototypes for IMEX transformation passes
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef IMEX_TRANSFORMS_PASSES_H_
15+
#define IMEX_TRANSFORMS_PASSES_H_
16+
17+
#include "mlir/Pass/Pass.h"
18+
19+
namespace imex {
20+
//===----------------------------------------------------------------------===//
21+
// Passes
22+
//===----------------------------------------------------------------------===//
23+
std::unique_ptr<mlir::Pass> createSerializeSPIRVPass();
24+
25+
//===----------------------------------------------------------------------===//
26+
// Registration
27+
//===----------------------------------------------------------------------===//
28+
29+
/// Generate the code for registering passes.
30+
#define GEN_PASS_REGISTRATION
31+
#include "imex/Transforms/Passes.h.inc"
32+
} // namespace imex
33+
34+
#endif // IMEX_TRANSFORMS_PASSES_H_

include/imex/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@
1717

1818
include "mlir/Pass/PassBase.td"
1919

20+
def SerializeSPIRVPass : Pass<"serialize-spirv", "::mlir::ModuleOp"> {
21+
let summary = "serialize MLIR SPIR-V module to SPIR-V binary";
22+
let description = [{
23+
This pass iterates all the SPIR-V modules in the top module and serializes
24+
each SPIR-V module to SPIR-V binary and then attachs the binary blob as a
25+
string attribute to the corresponding gpu module.
26+
}];
27+
let constructor = "imex::createSerializeSPIRVPass()";
28+
let dependentDialects = [
29+
"mlir::gpu::GPUDialect",
30+
"mlir::spirv::SPIRVDialect"
31+
];
32+
}
33+
2034
#endif // _IMEX_TRANSFORMS_PASSES_TD_INCLUDED_

lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(Dialect)
22
add_subdirectory(Conversion)
3+
add_subdirectory(Transforms)

lib/Transforms/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_library(IMEXTransforms
2+
SerializeSPIRV.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${PROJECT_SOURCE_DIR}/imex/Transforms
6+
7+
LINK_LIBS PUBLIC
8+
MLIRGPUOps
9+
MLIRSPIRVDialect
10+
MLIRFuncDialect
11+
MLIRPass
12+
MLIRSupport
13+
MLIRTransformUtils
14+
15+
DEPENDS
16+
IMEXTransformsPassIncGen
17+
)

lib/Transforms/PassDetail.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- PassDetail.h - Transforms Pass class details -------------*- C++ -*-===//
2+
//
3+
// Copyright 2022 Intel Corporation
4+
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef TRANSFORMS_PASSDETAIL_H_
11+
#define TRANSFORMS_PASSDETAIL_H_
12+
13+
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/IR/DialectRegistry.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/Passes.h"
17+
18+
namespace mlir {
19+
// Forward declaration from Dialect.h
20+
template <typename ConcreteDialect>
21+
void registerDialect(DialectRegistry &registry);
22+
23+
namespace gpu {
24+
class GPUDialect;
25+
}
26+
namespace spirv {
27+
class SPIRVDialect;
28+
}
29+
30+
} // end namespace mlir
31+
32+
namespace imex {
33+
#define GEN_PASS_CLASSES
34+
#include "imex/Transforms/Passes.h.inc"
35+
36+
} // namespace imex
37+
38+
#endif // TRANSFORMS_PASSDETAIL_H

lib/Transforms/SerializeSPIRV.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//===- SerializeSPIRV.cpp - SPIR-V serialize pass --------------*- C++ -*-===//
2+
//
3+
// Copyright 2022 Intel Corporation
4+
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
/// \file
11+
/// This pass iterates all the SPIR-V modules in the top module and serializes
12+
/// each SPIR-V module to SPIR-V binary and then attachs the binary blob as a
13+
/// string attribute to the corresponding gpu module.
14+
///
15+
//===----------------------------------------------------------------------===//
16+
17+
#include "PassDetail.h"
18+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
20+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
21+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
22+
#include "mlir/Target/SPIRV/Serialization.h"
23+
24+
using namespace mlir;
25+
using namespace imex;
26+
27+
namespace {
28+
struct SerializeSPIRVPass : public SerializeSPIRVPassBase<SerializeSPIRVPass> {
29+
public:
30+
void runOnOperation() override {
31+
auto mod = getOperation();
32+
llvm::SmallVector<uint32_t, 0> spvBinary;
33+
for (auto gpuMod : mod.getOps<gpu::GPUModuleOp>()) {
34+
auto name = gpuMod.getName();
35+
// check that the spv module has the same name with gpu module except the
36+
// prefix "__spv__"
37+
auto isSameMod = [&](spirv::ModuleOp spvMod) -> bool {
38+
auto spvModName = spvMod.getName();
39+
return spvModName->consume_front("__spv__") && spvModName == name;
40+
};
41+
auto spvMods = mod.getOps<spirv::ModuleOp>();
42+
auto it = llvm::find_if(spvMods, isSameMod);
43+
if (it == spvMods.end()) {
44+
gpuMod.emitError() << "Unable to find corresponding SPIR-V module";
45+
signalPassFailure();
46+
return;
47+
}
48+
auto spvMod = *it;
49+
50+
spvBinary.clear();
51+
// serialize the spv module to spv binary
52+
if (mlir::failed(spirv::serialize(spvMod, spvBinary))) {
53+
spvMod.emitError() << "Failed to serialize SPIR-V module";
54+
signalPassFailure();
55+
return;
56+
}
57+
58+
// attach the spv binary to the gpu module
59+
auto spvData =
60+
llvm::StringRef(reinterpret_cast<const char *>(spvBinary.data()),
61+
spvBinary.size() * sizeof(uint32_t));
62+
auto spvAttr = mlir::StringAttr::get(&getContext(), spvData);
63+
gpuMod->setAttr(gpu::getDefaultGpuBinaryAnnotation(), spvAttr);
64+
spvMod->erase();
65+
}
66+
}
67+
};
68+
} // namespace
69+
70+
namespace imex {
71+
std::unique_ptr<mlir::Pass> createSerializeSPIRVPass() {
72+
return std::make_unique<SerializeSPIRVPass>();
73+
}
74+
} // namespace imex

test/Transforms/serialize-spirv.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: imex-opt -serialize-spirv %s | FileCheck %s
2+
module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, #spv.resource_limits<>>} {
3+
// CHECK: gpu.module @addt_kernel attributes {gpu.binary =
4+
spv.module @__spv__addt_kernel Physical64 OpenCL requires #spv.vce<v1.0, [Int64, Addresses, Kernel], []> {
5+
spv.GlobalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr<vector<3xi64>, Input>
6+
spv.func @addt_kernel(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1: !spv.ptr<f32, CrossWorkgroup>, %arg2: !spv.ptr<f32, CrossWorkgroup>) "None" attributes {spv.entry_point_abi = #spv.entry_point_abi<>, workgroup_attributions = 0 : i64} {
7+
%cst5_i64 = spv.Constant 5 : i64
8+
%__builtin_var_WorkgroupId___addr = spv.mlir.addressof @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi64>, Input>
9+
%0 = spv.Load "Input" %__builtin_var_WorkgroupId___addr : vector<3xi64>
10+
%1 = spv.CompositeExtract %0[0 : i32] : vector<3xi64>
11+
%__builtin_var_WorkgroupId___addr_0 = spv.mlir.addressof @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi64>, Input>
12+
%2 = spv.Load "Input" %__builtin_var_WorkgroupId___addr_0 : vector<3xi64>
13+
%3 = spv.CompositeExtract %2[1 : i32] : vector<3xi64>
14+
spv.Branch ^bb1
15+
^bb1: // pred: ^bb0
16+
%4 = spv.IMul %1, %cst5_i64 : i64
17+
%5 = spv.IAdd %4, %3 : i64
18+
%6 = spv.InBoundsPtrAccessChain %arg0[%5] : !spv.ptr<f32, CrossWorkgroup>, i64
19+
%7 = spv.Load "CrossWorkgroup" %6 ["Aligned", 4] : f32
20+
%8 = spv.IMul %1, %cst5_i64 : i64
21+
%9 = spv.IAdd %8, %3 : i64
22+
%10 = spv.InBoundsPtrAccessChain %arg1[%9] : !spv.ptr<f32, CrossWorkgroup>, i64
23+
%11 = spv.Load "CrossWorkgroup" %10 ["Aligned", 4] : f32
24+
%12 = spv.FAdd %7, %11 : f32
25+
%13 = spv.IMul %1, %cst5_i64 : i64
26+
%14 = spv.IAdd %13, %3 : i64
27+
%15 = spv.InBoundsPtrAccessChain %arg2[%14] : !spv.ptr<f32, CrossWorkgroup>, i64
28+
spv.Store "CrossWorkgroup" %15, %12 ["Aligned", 4] : f32
29+
spv.Return
30+
}
31+
spv.EntryPoint "Kernel" @addt_kernel, @__builtin_var_WorkgroupId__
32+
}
33+
gpu.module @addt_kernel {
34+
gpu.func @addt_kernel(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi<>} {
35+
%c5 = arith.constant 5 : index
36+
%0 = gpu.block_id x
37+
%1 = gpu.block_id y
38+
cf.br ^bb1
39+
^bb1: // pred: ^bb0
40+
%2 = arith.muli %0, %c5 : index
41+
%3 = arith.addi %2, %1 : index
42+
%4 = memref.load %arg0[%3] : memref<?xf32>
43+
%5 = arith.muli %0, %c5 : index
44+
%6 = arith.addi %5, %1 : index
45+
%7 = memref.load %arg1[%6] : memref<?xf32>
46+
%8 = arith.addf %4, %7 : f32
47+
%9 = arith.muli %0, %c5 : index
48+
%10 = arith.addi %9, %1 : index
49+
memref.store %8, %arg2[%10] : memref<?xf32>
50+
gpu.return
51+
}
52+
}
53+
}

0 commit comments

Comments
 (0)