Skip to content

Gradient of reshape(::Array{Bool}, ...) does not handle thunks #1567

@mcabbott

Description

@mcabbott

Originally:

julia> using Flux

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)  # similar error with Array or OneHotArray
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})

Edit, see Zygote-only MWE below

Closest candidates are:
  reshape(::ChainRulesCore.AbstractThunk, ::Any...)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
  reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
   @ Base reshapedarray.jl:40
  reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
   @ Base bitarray.jl:479
  ...

Stacktrace:
  [1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
  [2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [3] Embedding
    @ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:776 [inlined]
  [4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [5] FluxML/Flux.jl#197
    @ ./REPL[408]:4 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [7] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
  [8] gradient(f::Function, args::Embedding{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
  [9] #gradient#1
    @ ~/.julia/packages/Flux/3711C/src/gradient.jl:44 [inlined]
 [10] gradient(f::Function, args::Embedding{Matrix{Float32}})
    @ Flux ~/.julia/packages/Flux/3711C/src/gradient.jl:31
 [11] top-level scope
    @ REPL[408]:4
Some type information was truncated. Use `show(err)` to see complete types.

(@v1.11) pkg> st Flux Zygote
Status `~/.julia/environments/v1.11/Project.toml`
  [587475ba] Flux v0.16.3
  [e88e6eb3] Zygote v0.7.5

I presume the problem is Zygote 0.7 and thunks, as it works fine on earlier versions:

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
((weight = Float32[6.834647 3.3733022; 5.7237077 0.9229657],),)

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
((weight = Float32[1.961737 -1.5491782; -0.6510874 11.824801],),)

(jl_ZbRV0D) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌃ [587475ba] Flux v0.14.25
⌅ [e88e6eb3] Zygote v0.6.75

Edit, with Dense the problem is only with OneHotArray, and not with Array:

julia> let d = Dense(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), d)
       end
((weight = Float32[0.6652966 -3.0755887; 1.8529012 2.833063], bias = Float32[-2.4102921, 4.685964], σ = nothing),)

julia> let d = Dense(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), d)
       end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})
The function `reshape` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  reshape(::ChainRulesCore.AbstractThunk, ::Any...)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
  reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
   @ Base reshapedarray.jl:40
  reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
   @ Base bitarray.jl:479
  ...

Stacktrace:
  [1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
  [2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [3] Dense
    @ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:204 [inlined]
  [4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [5] FluxML/Flux.jl#215
    @ ./REPL[419]:4 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions