-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
module @reactant_kernel_... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f32> {enzymexla.memory_effects = []}, %arg2: tensor<64x64xf32> {enzymexla.memory_effects = []}, %arg3: tensor<32x64xf32> {enzymexla.memory_effects = []}) -> tensor<32x64xf32> attributes {enzymexla.memory_effects = []} {
%c = stablehlo.constant dense<0> : tensor<i32>
%c_0 = stablehlo.constant dense<"0xtensor<64x2xi32>
%c_1 = stablehlo.constant dense<64> : tensor<i64>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<32x64xf32>
%c_2 = stablehlo.constant dense<1> : tensor<i32>
%c_3 = stablehlo.constant dense<0> : tensor<i64>
%c_4 = stablehlo.constant dense<1> : tensor<i64>
%0 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%1 = stablehlo.transpose %arg3, dims = [1, 0] : (tensor<32x64xf32>) -> tensor<64x32xf32>
%2 = stablehlo.broadcast_in_dim %arg2, dims = [2, 0] : (tensor<64x64xf32>) -> tensor<64x32x64x1xf32>
%3 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [0] x [1] : (tensor<64x64xf32>, tensor<32x64xf32>) -> tensor<64x32xf32>
%4 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<64x32xf32>
%5 = stablehlo.multiply %4, %1 : tensor<64x32xf32>
%6 = "stablehlo.gather"(%0, %c_0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<64x64xf32>, tensor<64x2xi32>) -> tensor<64xf32>
%7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<64xf32>) -> tensor<64x32xf32>
%8 = stablehlo.multiply %5, %7 : tensor<64x32xf32>
%9 = stablehlo.broadcast_in_dim %5, dims = [0, 1] : (tensor<64x32xf32>) -> tensor<64x32x64x1xf32>
%10 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<64x32x1x1xf32>
%11 = stablehlo.reshape %3 : (tensor<64x32xf32>) -> tensor<64x32x1x1xf32>
%12 = stablehlo.multiply %10, %11 : tensor<64x32x1x1xf32>
%13 = stablehlo.multiply %9, %2 : tensor<64x32x64x1xf32>
%14 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<32x1x1xf32>
%15:2 = stablehlo.while(%iterArg = %c_3, %iterArg_5 = %cst) : tensor<i64>, tensor<32x64xf32> attributes {enzyme.disable_mincut, enzymexla.symmetric_matrix = [#enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed NOTGUARANTEED>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>]}
cond {
%16 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %16 : tensor<i1>
} do {
%16 = stablehlo.add %c_4, %iterArg {enzymexla.bounds = [[1, 64]]} : tensor<i64>
%17 = stablehlo.convert %16 {enzymexla.bounds = [[1, 64]]} : (tensor<i64>) -> tensor<i32>
%18 = stablehlo.subtract %17, %c_2 {enzymexla.bounds = [[0, 63]]} : tensor<i32>
%19 = stablehlo.reshape %iterArg_5 : (tensor<32x64xf32>) -> tensor<32x64x1xf32>
%20 = stablehlo.dynamic_slice %13, %iterArg, %c_3, %c_3, %c_3, sizes = [1, 32, 64, 1] : (tensor<64x32x64x1xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x32x64x1xf32>
%21 = stablehlo.reshape %20 : (tensor<1x32x64x1xf32>) -> tensor<32x64x1xf32>
%22 = stablehlo.add %19, %21 : tensor<32x64x1xf32>
%23 = stablehlo.convert %18 {enzymexla.bounds = [[0, 63]]} : (tensor<i32>) -> tensor<i64>
%24 = stablehlo.dynamic_slice %22, %c_3, %23, %c_3, sizes = [32, 1, 1] : (tensor<32x64x1xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<32x1x1xf32>
%25 = stablehlo.multiply %14, %24 : tensor<32x1x1xf32>
%26 = stablehlo.dynamic_slice %8, %iterArg, %c_3, sizes = [1, 32] : (tensor<64x32xf32>, tensor<i64>, tensor<i64>) -> tensor<1x32xf32>
%27 = stablehlo.reshape %26 : (tensor<1x32xf32>) -> tensor<32x1x1xf32>
%28 = stablehlo.add %25, %27 : tensor<32x1x1xf32>
%29 = stablehlo.dynamic_slice %12, %iterArg, %c_3, %c_3, %c_3, sizes = [1, 32, 1, 1] : (tensor<64x32x1x1xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x32x1x1xf32>
%30 = stablehlo.reshape %29 : (tensor<1x32x1x1xf32>) -> tensor<32x1x1xf32>
%31 = stablehlo.add %28, %30 : tensor<32x1x1xf32>
%32 = stablehlo.reshape %22 : (tensor<32x64x1xf32>) -> tensor<32x1x64xf32>
%33 = stablehlo.dynamic_update_slice %32, %31, %c, %c, %18 : (tensor<32x1x64xf32>, tensor<32x1x1xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<32x1x64xf32>
%34 = stablehlo.reshape %33 : (tensor<32x1x64xf32>) -> tensor<32x64xf32>
stablehlo.return %16, %34 : tensor<i64>, tensor<32x64xf32>
}
return %15#1 : tensor<32x64xf32>
}
}we are able to raise the 2 inner loops. the outer loop remains
Metadata
Metadata
Assignees
Labels
No labels