Skip to content

@compile gradient of NNlib.scatter! with Enzyme fails #1423

@JulianTrommer

Description

@JulianTrommer

I encountered an issue while attempting to compile one of the test cases for NNlib.scatter for differentiation with Enzyme.

This minimal example reproduces the error (adapted from Lux.jl training loop):

using Enzyme, Lux, NNlib, Reactant

device = reactant_device()

dst = Float32[3 3 4 4 5
    5 5 6 6 7] |> device

src = ones(Float32, 2, 5) |> device

idx = [4, 2, 1, 5, 3] |> device

function test_scatter(dsts, srcs, idxs)
    return sum(NNlib.scatter!(+, dsts, srcs, idxs))
end

function test_gradient(objective_function, dsts, srcs, idxs)
    derivs, val = Enzyme.gradient(
        Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
        Const(objective_function),
        dsts,
        srcs,
        idxs,
    )
    return derivs, val
end

test_gradient_compiled = @compile test_gradient(test_scatter, dst, src, idx)

The following error is thrown:

loc("scatter"("/home/julian/.julia/packages/Reactant/kCRbM/src/Ops.jl":1701:0)): error: AutoDiffScatterRev only supports Setindex operations
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_wHnyOF/module_000_vxwR_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/kCRbM/src/mlir/IR/Pass.jl:119
ERROR: "failed to run pass manager on module"
Stacktrace:
 [1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
   @ Reactant.MLIR.IR ~/.julia/packages/Reactant/kCRbM/src/mlir/IR/Pass.jl:163
 [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
   @ Reactant.Compiler ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:1223
 [3] run_pass_pipeline!
   @ ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:1218 [inlined]
 [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, kwargs::@Kwargs{})
   @ Reactant.Compiler ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:1647
 [5] compile_mlir! (repeats 2 times)
   @ ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:1455 [inlined]
 [6] compile_xla(f::Function, args::Tuple{typeof(test_scatter), ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}}; client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
   @ Reactant.Compiler ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:3369
 [7] compile_xla
   @ ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:3349 [inlined]
 [8] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{})
   @ Reactant.Compiler ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:3434
 [9] top-level scope
   @ ~/.julia/packages/Reactant/kCRbM/src/Compiler.jl:2529

This is the corresponding file of the MLIR module:

#loc = loc(unknown)
#loc1 = loc("arg1 (path=(Symbol(\22##autodiffarg#234\22), 1))")
#loc2 = loc("arg2 (path=(Symbol(\22##autodiffarg#234\22), 2))")
#loc4 = loc("arg1 (path=(:args, 1))")
#loc5 = loc("arg2 (path=(:args, 2))")
#loc7 = loc("arg3 (path=(:args, 3))")
module {
  func.func private @"Const{typeof(test_scatter)}(Main.test_scatter)_autodiff"(%arg0: tensor<5x2xf32> loc("arg1 (path=(Symbol(\22##autodiffarg#234\22), 1))"), %arg1: tensor<5x2xf32> loc("arg2 (path=(Symbol(\22##autodiffarg#234\22), 2))")) -> (tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %c = stablehlo.constant dense<[[3, 1, 0, 4, 2]]> : tensor<1x5xi64> loc(#loc)
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
    %2 = "stablehlo.scatter"(%0, %c, %1) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true}> ({
    ^bb0(%arg2: tensor<f32> loc("arg1 (path=(:args, 1))"), %arg3: tensor<f32> loc("arg2 (path=(:args, 2))")):
      %5 = stablehlo.add %arg2, %arg3 : tensor<f32> loc(#loc9)
      stablehlo.return %5 : tensor<f32> loc(#loc)
    }) : (tensor<2x5xf32>, tensor<1x5xi64>, tensor<2x5xf32>) -> tensor<2x5xf32> loc(#loc8)
    %3 = stablehlo.reduce(%2 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<2x5xf32>, tensor<f32>) -> tensor<f32> loc(#loc)
    %4 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32> loc(#loc)
    return %3, %4, %arg1 : tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32> loc(#loc)
  } loc(#loc)
  func.func @test_gradient(%arg0: tensor<5x2xf32> loc("arg2 (path=(:args, 2))"), %arg1: tensor<5x2xf32> loc("arg3 (path=(:args, 3))")) -> (tensor<5x2xf32>, tensor<5x2xf32>, tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32> loc(#loc)
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc)
    %0:5 = enzyme.autodiff @"Const{typeof(test_scatter)}(Main.test_scatter)_autodiff"(%arg0, %arg1, %cst, %cst_0, %cst_0) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>]} : (tensor<5x2xf32>, tensor<5x2xf32>, tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>) -> (tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>) loc(#loc)
    return %0#3, %0#4, %0#0, %0#1, %0#2 : tensor<5x2xf32>, tensor<5x2xf32>, tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32> loc(#loc)
  } loc(#loc)
  func.func private @"diffeConst{typeof(test_scatter)}(Main.test_scatter)_autodiff"(%arg0: tensor<5x2xf32> loc("arg1 (path=(Symbol(\22##autodiffarg#234\22), 1))"), %arg1: tensor<5x2xf32> loc("arg2 (path=(Symbol(\22##autodiffarg#234\22), 2))"), %arg2: tensor<f32> loc(unknown), %arg3: tensor<5x2xf32> loc(unknown), %arg4: tensor<5x2xf32> loc(unknown)) -> (tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %0 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<5x2xf32>> loc(#loc1)
    %cst = arith.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc1)
    "enzyme.set"(%0, %cst) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc1)
    %1 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<2x5xf32>> loc(#loc)
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
    "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc)
    %2 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<2x5xf32>> loc(#loc)
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
    "enzyme.set"(%2, %cst_1) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc)
    %3 = "enzyme.init"() : () -> !enzyme.Cache<tensor<1x5xi64>> loc(#loc)
    %4 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<2x5xf32>> loc(#loc8)
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc8)
    "enzyme.set"(%4, %cst_2) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc8)
    %5 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<5x2xf32>> loc(#loc2)
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc2)
    "enzyme.set"(%5, %cst_3) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc2)
    %6 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<5x2xf32>> loc(#loc)
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc)
    "enzyme.set"(%6, %cst_4) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc)
    %7 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<f32>> loc(#loc)
    %cst_5 = arith.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
    "enzyme.set"(%7, %cst_5) : (!enzyme.Gradient<tensor<f32>>, tensor<f32>) -> () loc(#loc)
    %c = stablehlo.constant dense<[[3, 1, 0, 4, 2]]> : tensor<1x5xi64> loc(#loc)
    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
    %8 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
    %9 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
    "enzyme.push"(%3, %c) : (!enzyme.Cache<tensor<1x5xi64>>, tensor<1x5xi64>) -> () loc(#loc)
    %10 = "stablehlo.scatter"(%8, %c, %9) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true}> ({
    ^bb0(%arg5: tensor<f32> loc("arg1 (path=(:args, 1))"), %arg6: tensor<f32> loc("arg2 (path=(:args, 2))")):
      %37 = stablehlo.add %arg5, %arg6 : tensor<f32> loc(#loc9)
      stablehlo.return %37 : tensor<f32> loc(#loc)
    }) : (tensor<2x5xf32>, tensor<1x5xi64>, tensor<2x5xf32>) -> tensor<2x5xf32> loc(#loc8)
    %11 = stablehlo.reduce(%10 init: %cst_6) applies stablehlo.add across dimensions = [0, 1] : (tensor<2x5xf32>, tensor<f32>) -> tensor<f32> loc(#loc)
    %12 = stablehlo.transpose %10, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32> loc(#loc)
    cf.br ^bb1 loc(#loc)
  ^bb1:  // pred: ^bb0
    %13 = "enzyme.get"(%7) : (!enzyme.Gradient<tensor<f32>>) -> tensor<f32> loc(#loc)
    %14 = arith.addf %13, %arg2 : tensor<f32> loc(#loc)
    "enzyme.set"(%7, %14) : (!enzyme.Gradient<tensor<f32>>, tensor<f32>) -> () loc(#loc)
    %15 = "enzyme.get"(%6) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc)
    %16 = arith.addf %15, %arg3 : tensor<5x2xf32> loc(#loc)
    "enzyme.set"(%6, %16) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc)
    %17 = "enzyme.get"(%5) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc2)
    %18 = arith.addf %17, %arg4 : tensor<5x2xf32> loc(#loc2)
    "enzyme.set"(%5, %18) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc2)
    %19 = "enzyme.get"(%6) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc)
    %cst_7 = arith.constant dense<0.000000e+00> : tensor<5x2xf32> loc(#loc)
    "enzyme.set"(%6, %cst_7) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc)
    %20 = stablehlo.transpose %19, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32> loc(#loc)
    %21 = "enzyme.get"(%4) : (!enzyme.Gradient<tensor<2x5xf32>>) -> tensor<2x5xf32> loc(#loc8)
    %22 = arith.addf %21, %20 : tensor<2x5xf32> loc(#loc8)
    "enzyme.set"(%4, %22) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc8)
    %cst_8 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
    %23 = "enzyme.get"(%7) : (!enzyme.Gradient<tensor<f32>>) -> tensor<f32> loc(#loc)
    %cst_9 = arith.constant dense<0.000000e+00> : tensor<f32> loc(#loc)
    "enzyme.set"(%7, %cst_9) : (!enzyme.Gradient<tensor<f32>>, tensor<f32>) -> () loc(#loc)
    %24 = stablehlo.broadcast_in_dim %23, dims = [] : (tensor<f32>) -> tensor<2x5xf32> loc(#loc)
    %25 = "enzyme.get"(%4) : (!enzyme.Gradient<tensor<2x5xf32>>) -> tensor<2x5xf32> loc(#loc8)
    %26 = arith.addf %25, %24 : tensor<2x5xf32> loc(#loc8)
    "enzyme.set"(%4, %26) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc8)
    %27 = "enzyme.get"(%2) : (!enzyme.Gradient<tensor<2x5xf32>>) -> tensor<2x5xf32> loc(#loc)
    %cst_10 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
    "enzyme.set"(%2, %cst_10) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc)
    %28 = stablehlo.transpose %27, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32> loc(#loc)
    %29 = "enzyme.get"(%5) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc2)
    %30 = arith.addf %29, %28 : tensor<5x2xf32> loc(#loc2)
    "enzyme.set"(%5, %30) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc2)
    %31 = "enzyme.get"(%1) : (!enzyme.Gradient<tensor<2x5xf32>>) -> tensor<2x5xf32> loc(#loc)
    %cst_11 = arith.constant dense<0.000000e+00> : tensor<2x5xf32> loc(#loc)
    "enzyme.set"(%1, %cst_11) : (!enzyme.Gradient<tensor<2x5xf32>>, tensor<2x5xf32>) -> () loc(#loc)
    %32 = stablehlo.transpose %31, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32> loc(#loc)
    %33 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc1)
    %34 = arith.addf %33, %32 : tensor<5x2xf32> loc(#loc1)
    "enzyme.set"(%0, %34) : (!enzyme.Gradient<tensor<5x2xf32>>, tensor<5x2xf32>) -> () loc(#loc1)
    %35 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc1)
    %36 = "enzyme.get"(%5) : (!enzyme.Gradient<tensor<5x2xf32>>) -> tensor<5x2xf32> loc(#loc2)
    return %11, %12, %arg1, %35, %36 : tensor<f32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32>, tensor<5x2xf32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc3 = loc("/home/julian/.julia/packages/Reactant/kCRbM/src/Ops.jl":1701:0)
#loc6 = loc("/home/julian/.julia/packages/Reactant/kCRbM/src/Ops.jl":376:0)
#loc8 = loc("scatter"(#loc3))
#loc9 = loc("add"(#loc6))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions