-
Notifications
You must be signed in to change notification settings - Fork 23
Open
Description
I encountered an issue while attempting to compile one of the test cases for NNlib.scatter for differentiation with Enzyme.
This minimal example reproduces the error (adapted from Lux.jl training loop):
using Enzyme, Lux, NNlib, Reactant
device = reactant_device()
dst = Float32[3 3 4 4 5
5 5 6 6 7] |> device
src = ones(Float32, 2, 5) |> device
idx = [4, 2, 1, 5, 3] |> device
function test_scatter(dsts, srcs, idxs)
return sum(NNlib.scatter!(+, dsts, srcs, idxs))
end
function test_gradient(objective_function, dsts, srcs, idxs)
derivs, val = Enzyme.gradient(
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(objective_function),
dsts,
srcs,
idxs,
)
return derivs, val
end
test_gradient_compiled = @compile test_gradient(test_scatter, dst, src, idx)
The following error is thrown:
loc("scatter"("/home/julian/.julia/packages/Reactant/kCRbM/src/Ops.jl":1701:0)): error: AutoDiffScatterRev only supports Setindex operations
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_wHnyOF/module_000_vxwR_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/kCRbM/src/mlir/IR/Pass.jl:119
ERROR: "failed to run pass manager on module"
Stacktrace:
[1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/kCRbM/src/mlir/IR/Pass.jl:163
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:1223
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:1218 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:1647
[5] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:1455 [inlined]
[6] compile_xla(f::Function, args::Tuple{typeof(test_scatter), ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}}; client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:3369
[7] compile_xla
@ ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:3349 [inlined]
[8] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:3434
[9] top-level scope
@ ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:2529
This is the corresponding file of the MLIR module:
#loc = loc(unknown)
#loc1 = loc("arg1 (path=(Symbol(\22##autodiffarg#234\22), 1))")
#loc2 = loc("arg2 (path=(Symbol(\22##autodiffarg#234\22), 2))")
#loc4 = loc("arg1 (path=(:args, 1))")
#loc5 = loc("arg2 (path=(:args, 2))")
#loc7 = loc("arg3 (path=(:args, 3))")
module {
func.func private @"Const{typeof(test_scatter)}(Main.test_scatter)_autodiff"(%arg0: tensor<5x2xf32> loc("arg1 (path=(Symbol(\22##autodiffarg#234\22), 1))"), %arg1: tensor<5x2xf32> loc("arg2 (path=(Symbol(\22##autodiffarg#234\22), 2))")) -> (tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%c = stablehlo.constant dense<[[3, 1, 0, 4, 2]]> : tensor<1x5xi64> loc(#loc)
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
%2 = "stablehlo.scatter"(%0, %c, %1) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true}> ({
^bb0(%arg2: tensor<f32> loc("arg1 (path=(:args, 1))"), %arg3: tensor<f32> loc("arg2 (path=(:args, 2))")):
%5 = stablehlo.add %arg2, %arg3 : tensor<f32> loc(#loc9)
stablehlo.return %5 : tensor<f32> loc(#loc)
}) : (tensor<2x5xf32>, tensor<1x5xi64>, tensor<2x5xf32>) -> tensor<2x5xf32> loc(#loc8)
%3 = stablehlo.reduce(%2 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<2x5xf32>, tensor<f32>) -> tensor<f32> loc(#loc)
%4 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32> loc(#loc)
return %3, %4, %arg1 : tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32> loc(#loc)
} loc(#loc)
func.func @test_gradient(%arg0: tensor<5x2xf32> loc("arg2 (path=(:args, 2))"), %arg1: tensor<5x2xf32> loc("arg3 (path=(:args, 3))")) -> (tensor<5x2xf32>, tensor<5x2xf32>, tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32> loc(#loc)
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc)
%0:5 = enzyme.autodiff @"Const{typeof(test_scatter)}(Main.test_scatter)_autodiff"(%arg0, %arg1, %cst, %cst_0, %cst_0) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>]} : (tensor<5x2xf32>, tensor<5x2xf32>, tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>) -> (tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>) loc(#loc)
return %0#3, %0#4, %0#0, %0#1, %0#2 : tensor<5x2xf32>, tensor<5x2xf32>, tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32> loc(#loc)
} loc(#loc)
func.func private @"diffeConst{typeof(test_scatter)}(Main.test_scatter)_autodiff"(%arg0: tensor<5x2xf32> loc("arg1 (path=(Symbol(\22##autodiffarg#234\22), 1))"), %arg1: tensor<5x2xf32> loc("arg2 (path=(Symbol(\22##autodiffarg#234\22), 2))"), %arg2: tensor<f32> loc(unknown), %arg3: tensor<5x2xf32> loc(unknown), %arg4: tensor<5x2xf32> loc(unknown)) -> (tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%0 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<5x2xf32>> loc(#loc1)
%cst = arith.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc1)
"enzyme.set"(%0, %cst) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc1)
%1 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<2x5xf32>> loc(#loc)
%cst_0 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
"enzyme.set"(%1, %cst_0) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc)
%2 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<2x5xf32>> loc(#loc)
%cst_1 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
"enzyme.set"(%2, %cst_1) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc)
%3 = "enzyme.init"() : () -> !enzyme.Cache<tensor<1x5xi64>> loc(#loc)
%4 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<2x5xf32>> loc(#loc8)
%cst_2 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc8)
"enzyme.set"(%4, %cst_2) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc8)
%5 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<5x2xf32>> loc(#loc2)
%cst_3 = arith.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc2)
"enzyme.set"(%5, %cst_3) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc2)
%6 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<5x2xf32>> loc(#loc)
%cst_4 = arith.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc)
"enzyme.set"(%6, %cst_4) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc)
%7 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<f32>> loc(#loc)
%cst_5 = arith.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
"enzyme.set"(%7, %cst_5) : (!enzyme.Gradient<tensor<f32>>, tensor<f32>) -> () loc(#loc)
%c = stablehlo.constant dense<[[3, 1, 0, 4, 2]]> : tensor<1x5xi64> loc(#loc)
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
%8 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
%9 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
"enzyme.push"(%3, %c) : (!enzyme.Cache<tensor<1x5xi64>>, tensor<1x5xi64>) -> () loc(#loc)
%10 = "stablehlo.scatter"(%8, %c, %9) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true}> ({
^bb0(%arg5: tensor<f32> loc("arg1 (path=(:args, 1))"), %arg6: tensor<f32> loc("arg2 (path=(:args, 2))")):
%37 = stablehlo.add %arg5, %arg6 : tensor<f32> loc(#loc9)
stablehlo.return %37 : tensor<f32> loc(#loc)
}) : (tensor<2x5xf32>, tensor<1x5xi64>, tensor<2x5xf32>) -> tensor<2x5xf32> loc(#loc8)
%11 = stablehlo.reduce(%10 init: %cst_6) applies stablehlo.add across dimensions = [0, 1] : (tensor<2x5xf32>, tensor<f32>) -> tensor<f32> loc(#loc)
%12 = stablehlo.transpose %10, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32> loc(#loc)
cf.br ^bb1 loc(#loc)
^bb1: // pred: ^bb0
%13 = "enzyme.get"(%7) : (!enzyme.Gradient<tensor<f32>>) -> tensor<f32> loc(#loc)
%14 = arith.addf %13, %arg2 : tensor<f32> loc(#loc)
"enzyme.set"(%7, %14) : (!enzyme.Gradient<tensor<f32>>, tensor<f32>) -> () loc(#loc)
%15 = "enzyme.get"(%6) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc)
%16 = arith.addf %15, %arg3 : tensor<5x2xf32> loc(#loc)
"enzyme.set"(%6, %16) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc)
%17 = "enzyme.get"(%5) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc2)
%18 = arith.addf %17, %arg4 : tensor<5x2xf32> loc(#loc2)
"enzyme.set"(%5, %18) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc2)
%19 = "enzyme.get"(%6) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc)
%cst_7 = arith.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc)
"enzyme.set"(%6, %cst_7) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc)
%20 = stablehlo.transpose %19, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
%21 = "enzyme.get"(%4) : (!enzyme.Gradient<tensor<2x5xf32>>) -> tensor<2x5xf32> loc(#loc8)
%22 = arith.addf %21, %20 : tensor<2x5xf32> loc(#loc8)
"enzyme.set"(%4, %22) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc8)
%cst_8 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
%23 = "enzyme.get"(%7) : (!enzyme.Gradient<tensor<f32>>) -> tensor<f32> loc(#loc)
%cst_9 = arith.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
"enzyme.set"(%7, %cst_9) : (!enzyme.Gradient<tensor<f32>>, tensor<f32>) -> () loc(#loc)
%24 = stablehlo.broadcast_in_dim %23, dims = [] : (tensor<f32>) -> tensor<2x5xf32> loc(#loc)
%25 = "enzyme.get"(%4) : (!enzyme.Gradient<tensor<2x5xf32>>) -> tensor<2x5xf32> loc(#loc8)
%26 = arith.addf %25, %24 : tensor<2x5xf32> loc(#loc8)
"enzyme.set"(%4, %26) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc8)
%27 = "enzyme.get"(%2) : (!enzyme.Gradient<tensor<2x5xf32>>) -> tensor<2x5xf32> loc(#loc)
%cst_10 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
"enzyme.set"(%2, %cst_10) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc)
%28 = stablehlo.transpose %27, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32> loc(#loc)
%29 = "enzyme.get"(%5) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc2)
%30 = arith.addf %29, %28 : tensor<5x2xf32> loc(#loc2)
"enzyme.set"(%5, %30) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc2)
%31 = "enzyme.get"(%1) : (!enzyme.Gradient<tensor<2x5xf32>>) -> tensor<2x5xf32> loc(#loc)
%cst_11 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
"enzyme.set"(%1, %cst_11) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc)
%32 = stablehlo.transpose %31, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32> loc(#loc)
%33 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc1)
%34 = arith.addf %33, %32 : tensor<5x2xf32> loc(#loc1)
"enzyme.set"(%0, %34) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc1)
%35 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc1)
%36 = "enzyme.get"(%5) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc2)
return %11, %12, %arg1, %35, %36 : tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc3 = loc("/home/julian/.julia/packages/Reactant/kCRbM/src/Ops.jl":1701:0)
#loc6 = loc("/home/julian/.julia/packages/Reactant/kCRbM/src/Ops.jl":376:0)
#loc8 = loc("scatter"(#loc3))
#loc9 = loc("add"(#loc6))
Metadata
Metadata
Assignees
Labels
No labels