Skip to content

Missing Rule for Roots.jl #2035

Open
Open
@mhauru

Description

@mhauru

MWE:

module MWE
import Bijectors, Enzyme, StableRNGs
b = Bijectors.PlanarLayer(3)
binv = Bijectors.inverse(b)
x = randn(StableRNGs.StableRNG(23), (3, 3))
f = x -> sum(b(binv(x)))
Enzyme.gradient(Enzyme.Forward, f, x)
end

Output:

ERROR: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia_init_state_16477([4 x double]* noalias nocapture noundef nonnull sret([4 x double]) align 8 dereferenceable(32) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %0, { [3 x double] } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(24) "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}" "enzymejl_parmtype"="4457150416" "enzymejl_parmtype_ref"="1" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4770454480" "enzymejl_parmtype_ref"="0" %2, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4770454480" "enzymejl_parmtype_ref"="0" %3, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4770454480" "enzymejl_parmtype_ref"="0" %4, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4770454480" "enzymejl_parmtype_ref"="0" %5) unnamed_addr #59 !dbg !5084 {
top:
  %6 = alloca [4 x double], align 8
  %7 = call {}*** @julia.get_pgcstack() #60
  %ptls_field14 = getelementptr inbounds {}**, {}*** %7, i64 2
  %8 = bitcast {}*** %ptls_field14 to i64***
  %ptls_load1516 = load i64**, i64*** %8, align 8, !tbaa !61
  %9 = getelementptr inbounds i64*, i64** %ptls_load1516, i64 2
  %safepoint = load i64*, i64** %9, align 8, !tbaa !65
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #60, !dbg !5085
  fence syncscope("singlethread") seq_cst
  %10 = call double @llvm.fabs.f64(double %2) #60, !dbg !5086
  %11 = bitcast double %10 to i64, !dbg !5087
  %.not = icmp eq i64 %11, 9218868437227405312, !dbg !5087
  %12 = call double @julia_nextfloat_16460(double %2) #61, !dbg !5088
  %value_phi = select i1 %.not, double %12, double %2, !dbg !5088
  %13 = call double @llvm.fabs.f64(double %3) #60, !dbg !5089
  %14 = bitcast double %13 to i64, !dbg !5090
  %.not17 = icmp eq i64 %14, 9218868437227405312, !dbg !5090
  %15 = call double @julia_prevfloat_16488(double %3) #61, !dbg !5091
  %value_phi2 = select i1 %.not17, double %15, double %3, !dbg !5091
  %16 = fcmp uge double %value_phi, 0.000000e+00, !dbg !5092
  %17 = fcmp ule double %value_phi, 0.000000e+00, !dbg !5095
  %18 = select i1 %17, double %value_phi, double 1.000000e+00, !dbg !5097
  %19 = select i1 %16, double %18, double -1.000000e+00, !dbg !5097
  %20 = fcmp uge double %value_phi2, 0.000000e+00, !dbg !5092
  %21 = fcmp ule double %value_phi2, 0.000000e+00, !dbg !5095
  %22 = select i1 %21, double %value_phi2, double 1.000000e+00, !dbg !5097
  %23 = select i1 %20, double %22, double -1.000000e+00, !dbg !5097
  %24 = fmul double %19, %23, !dbg !5098
  %25 = fcmp uge double %24, 0.000000e+00, !dbg !5099
  br i1 %25, label %L31, label %L47, !dbg !5094

L31:                                              ; preds = %top
  %26 = call double @llvm.fabs.f64(double %value_phi) #60, !dbg !5101
  %bitcast_coercion = bitcast double %26 to i64, !dbg !5105
  %27 = call double @llvm.fabs.f64(double %value_phi2) #60, !dbg !5106
  %bitcast_coercion9 = bitcast double %27 to i64, !dbg !5108
  %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !5109
  %29 = lshr i64 %28, 1, !dbg !5111
  %30 = fadd double %value_phi, %value_phi2, !dbg !5113
  %31 = fcmp uge double %30, 0.000000e+00, !dbg !5115
  %32 = fcmp ule double %30, 0.000000e+00, !dbg !5117
  %33 = select i1 %32, double %30, double 1.000000e+00, !dbg !5119
  %34 = select i1 %31, double %33, double -1.000000e+00, !dbg !5119
  %bitcast_coercion12 = bitcast i64 %29 to double, !dbg !5120
  %35 = fmul double %34, %bitcast_coercion12, !dbg !5121
  br label %L47, !dbg !5121

L47:                                              ; preds = %L31, %top
  %value_phi6 = phi double [ %35, %L31 ], [ 0.000000e+00, %top ]
  %36 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 1, !dbg !5122
  %37 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 2, !dbg !5122
  %unbox = load double, double addrspace(11)* %37, align 8, !dbg !5124, !tbaa !65, !alias.scope !211, !noalias !214
  %38 = fadd double %value_phi6, %unbox, !dbg !5124
  %39 = call double @julia_tanh_16435(double %38) #62, !dbg !5122
  %unbox7 = load double, double addrspace(11)* %36, align 8, !dbg !5125, !tbaa !65, !alias.scope !211, !noalias !214
  %40 = fmul double %39, %unbox7, !dbg !5125
  %41 = fadd double %value_phi6, %40, !dbg !5124
  %42 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 0, !dbg !5122
  %unbox8 = load double, double addrspace(11)* %42, align 8, !dbg !5126, !tbaa !65, !alias.scope !211, !noalias !214
  %43 = fsub double %41, %unbox8, !dbg !5126
  call fastcc void @julia__init_state_52_16483([4 x double]* noalias nocapture noundef nonnull sret([4 x double]) align 8 dereferenceable(32) %6, double %value_phi6, double %43, double %2, double %3, double %4, double %5) #60, !dbg !5085
  %44 = bitcast [4 x double]* %0 to i8*, !dbg !5085
  %45 = bitcast [4 x double]* %6 to i8*, !dbg !5085
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 8 dereferenceable(32) %44, i8* noundef nonnull align 8 dereferenceable(32) %45, i64 32, i1 false) #60, !dbg !5085, !noalias !5127
  ret void, !dbg !5085
}

 Type analysis state:
<analysis>
[4 x double]* %0: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
{ [3 x double] } addrspace(11)* %1: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
double %2: {[-1]:Float@double}, intvals: {}
double %3: {[-1]:Float@double}, intvals: {}
double %4: {[-1]:Float@double}, intvals: {}
double %5: {[-1]:Float@double}, intvals: {}
double -1.000000e+00: {[-1]:Float@double}, intvals: {}
double 1.000000e+00: {[-1]:Float@double}, intvals: {}
double 0.000000e+00: {[-1]:Anything}, intvals: {}
  %value_phi2 = select i1 %.not17, double %15, double %3, !dbg !77: {[-1]:Float@double}, intvals: {}
  %23 = select i1 %20, double %22, double -1.000000e+00, !dbg !88: {[-1]:Float@double}, intvals: {}
  %22 = select i1 %21, double %value_phi2, double 1.000000e+00, !dbg !88: {[-1]:Float@double}, intvals: {}
  %19 = select i1 %16, double %18, double -1.000000e+00, !dbg !88: {[-1]:Float@double}, intvals: {}
  %18 = select i1 %17, double %value_phi, double 1.000000e+00, !dbg !88: {[-1]:Float@double}, intvals: {}
  %value_phi = select i1 %.not, double %12, double %2, !dbg !73: {[-1]:Float@double}, intvals: {}
  %14 = bitcast double %13 to i64, !dbg !76: {[-1]:Float@double}, intvals: {}
  %bitcast_coercion = bitcast double %26 to i64, !dbg !100: {[-1]:Float@double}, intvals: {}
  %ptls_load1516 = load i64**, i64*** %8, align 8, !tbaa !61: {}, intvals: {}
  %safepoint = load i64*, i64** %9, align 8, !tbaa !65: {}, intvals: {}
  %bitcast_coercion9 = bitcast double %27 to i64, !dbg !104: {[-1]:Float@double}, intvals: {}
  %11 = bitcast double %10 to i64, !dbg !71: {[-1]:Float@double}, intvals: {}
  %8 = bitcast {}*** %ptls_field14 to i64***: {[-1]:Pointer}, intvals: {}
  %6 = alloca [4 x double], align 8: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
  %7 = call {}*** @julia.get_pgcstack() #60: {}, intvals: {}
  %ptls_field14 = getelementptr inbounds {}**, {}*** %7, i64 2: {}, intvals: {}
  %9 = getelementptr inbounds i64*, i64** %ptls_load1516, i64 2: {[-1]:Pointer}, intvals: {}
  %10 = call double @llvm.fabs.f64(double %2) #60, !dbg !68: {[-1]:Float@double}, intvals: {}
  %12 = call double @julia_nextfloat_16460(double %2) #61, !dbg !73: {[-1]:Float@double}, intvals: {}
  %13 = call double @llvm.fabs.f64(double %3) #60, !dbg !75: {[-1]:Float@double}, intvals: {}
  %15 = call double @julia_prevfloat_16488(double %3) #61, !dbg !77: {[-1]:Float@double}, intvals: {}
  %26 = call double @llvm.fabs.f64(double %value_phi) #60, !dbg !95: {[-1]:Float@double}, intvals: {}
  %27 = call double @llvm.fabs.f64(double %value_phi2) #60, !dbg !102: {[-1]:Float@double}, intvals: {}
  %39 = call double @julia_tanh_16435(double %38) #62, !dbg !122: {[-1]:Float@double}, intvals: {}
i64 9218868437227405312: {[-1]:Anything}, intvals: {9218868437227405312,}
  %.not = icmp eq i64 %11, 9218868437227405312, !dbg !71: {[-1]:Integer}, intvals: {}
  %.not17 = icmp eq i64 %14, 9218868437227405312, !dbg !76: {[-1]:Integer}, intvals: {}
  %16 = fcmp uge double %value_phi, 0.000000e+00, !dbg !78: {[-1]:Integer}, intvals: {}
  %17 = fcmp ule double %value_phi, 0.000000e+00, !dbg !84: {[-1]:Integer}, intvals: {}
  %20 = fcmp uge double %value_phi2, 0.000000e+00, !dbg !78: {[-1]:Integer}, intvals: {}
  %21 = fcmp ule double %value_phi2, 0.000000e+00, !dbg !84: {[-1]:Integer}, intvals: {}
  %24 = fmul double %19, %23, !dbg !91: {[-1]:Float@double}, intvals: {}
  %25 = fcmp uge double %24, 0.000000e+00, !dbg !93: {[-1]:Integer}, intvals: {}
  %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !105: {}, intvals: {}
  %38 = fadd double %value_phi6, %unbox, !dbg !128: {[-1]:Float@double}, intvals: {}
  %43 = fsub double %41, %unbox8, !dbg !138: {[-1]:Float@double}, intvals: {}
  %value_phi6 = phi double [ %35, %L31 ], [ 0.000000e+00, %top ]: {[-1]:Float@double}, intvals: {}
</analysis>

Illegal updateBinop Analysis   %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !105
Illegal binopIn(down): 13 lhs: {[]:Float@double} rhs: {[]:Float@double}

MethodInstance for Roots.init_state(::Roots.Bisection, ::Roots.Callable_Function{Val{1}, Val{false}, Bijectors.var"#60#61"{Float64, Float64, Float64}, Nothing}, ::Float64, ::Float64, ::Float64, ::Float64)


Caused by:
Stacktrace:
 [1] +
   @ ./int.jl:87
 [2] __middle
   @ ~/.julia/packages/Roots/KNVCY/src/Bracketing/bisection.jl:135
 [3] __middle
   @ ~/.julia/packages/Roots/KNVCY/src/Bracketing/bisection.jl:124
 [4] _middle
   @ ~/.julia/packages/Roots/KNVCY/src/Bracketing/bisection.jl:117
 [5] init_state
   @ ~/.julia/packages/Roots/KNVCY/src/Bracketing/bisection.jl:34

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:1508
  [2] EnzymeCreateForwardDiff(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…})
    @ Enzyme.API ~/projects/Enzyme.jl/src/api.jl:319
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:4057
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:7125
  [5] codegen
    @ ~/projects/Enzyme.jl/src/compiler.jl:5950 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:8233
  [7] _thunk
    @ ~/projects/Enzyme.jl/src/compiler.jl:8233 [inlined]
  [8] cached_compilation
    @ ~/projects/Enzyme.jl/src/compiler.jl:8274 [inlined]
  [9] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:8406
 [10] #s2079#19071
    @ ~/projects/Enzyme.jl/src/compiler.jl:8543 [inlined]
 [11]
    @ Enzyme.Compiler ./none:0
 [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [13] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::Bijectors.Inverse{…}, df::Nothing, df_2::Nothing, df_3::Nothing, df_4::Nothing, df_5::Nothing, df_6::Nothing, df_7::Nothing, df_8::Nothing, df_9::Nothing, primal_1::Matrix{…}, shadow_1_1::Matrix{…}, shadow_1_2::Matrix{…}, shadow_1_3::Matrix{…}, shadow_1_4::Matrix{…}, shadow_1_5::Matrix{…}, shadow_1_6::Matrix{…}, shadow_1_7::Matrix{…}, shadow_1_8::Matrix{…}, shadow_1_9::Matrix{…})
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/rules/jitrules.jl:290
 [14] #1
    @ ./REPL[3]:6 [inlined]
 [15] fwddiffe9julia__1_15870wrap
    @ ./REPL[3]:0
 [16] macro expansion
    @ ~/projects/Enzyme.jl/src/compiler.jl:8163 [inlined]
 [17] enzyme_call
    @ ~/projects/Enzyme.jl/src/compiler.jl:7729 [inlined]
 [18] ForwardModeThunk
    @ ~/projects/Enzyme.jl/src/compiler.jl:7518 [inlined]
 [19] autodiff
    @ ~/projects/Enzyme.jl/src/Enzyme.jl:647 [inlined]
 [20] autodiff
    @ ~/projects/Enzyme.jl/src/Enzyme.jl:512 [inlined]
 [21] macro expansion
    @ ~/projects/Enzyme.jl/src/Enzyme.jl:2069 [inlined]
 [22] gradient(::EnzymeCore.ForwardMode{…}, ::Main.MWE.var"#1#2", ::Matrix{…}; chunk::Nothing, shadows::Tuple{…})
    @ Enzyme ~/projects/Enzyme.jl/src/Enzyme.jl:1971
 [23] gradient(::EnzymeCore.ForwardMode{false, EnzymeCore.FFIABI, false, false}, ::Main.MWE.var"#1#2", ::Matrix{Float64})
    @ Enzyme ~/projects/Enzyme.jl/src/Enzyme.jl:1971
 [24] top-level scope
    @ REPL[3]:7
Some type information was truncated. Use `show(err)` to see complete types.

Affects both forward and reverse mode. Enzyme current main branch.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions