-
-
Notifications
You must be signed in to change notification settings - Fork 217
Open
FluxML/ZygoteRules.jl
#35Labels
bugSomething isn't workingSomething isn't working
Description
Originally:
julia> using Flux
julia> let e = Embedding(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
# x = Array(x) # similar error with Array or OneHotArray
Flux.gradient(m -> sum(abs2, m(x)), e)
end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})
Edit, see Zygote-only MWE below
Closest candidates are:
reshape(::ChainRulesCore.AbstractThunk, ::Any...)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
@ Base reshapedarray.jl:40
reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
@ Base bitarray.jl:479
...
Stacktrace:
[1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
[2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[3] Embedding
@ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:776 [inlined]
[4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
[5] FluxML/Flux.jl#197
@ ./REPL[408]:4 [inlined]
[6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
[7] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
[8] gradient(f::Function, args::Embedding{Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
[9] #gradient#1
@ ~/.julia/packages/Flux/3711C/src/gradient.jl:44 [inlined]
[10] gradient(f::Function, args::Embedding{Matrix{Float32}})
@ Flux ~/.julia/packages/Flux/3711C/src/gradient.jl:31
[11] top-level scope
@ REPL[408]:4
Some type information was truncated. Use `show(err)` to see complete types.
(@v1.11) pkg> st Flux Zygote
Status `~/.julia/environments/v1.11/Project.toml`
[587475ba] Flux v0.16.3
[e88e6eb3] Zygote v0.7.5
I presume the problem is Zygote 0.7 and thunks, as it works fine on earlier versions:
julia> let e = Embedding(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
# x = Array(x)
Flux.gradient(m -> sum(abs2, m(x)), e)
end
((weight = Float32[6.834647 3.3733022; 5.7237077 0.9229657],),)
julia> let e = Embedding(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
x = Array(x)
Flux.gradient(m -> sum(abs2, m(x)), e)
end
((weight = Float32[1.961737 -1.5491782; -0.6510874 11.824801],),)
(jl_ZbRV0D) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌃ [587475ba] Flux v0.14.25
⌅ [e88e6eb3] Zygote v0.6.75
Edit, with Dense
the problem is only with OneHotArray
, and not with Array
:
julia> let d = Dense(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
x = Array(x)
Flux.gradient(m -> sum(abs2, m(x)), d)
end
((weight = Float32[0.6652966 -3.0755887; 1.8529012 2.833063], bias = Float32[-2.4102921, 4.685964], σ = nothing),)
julia> let d = Dense(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
# x = Array(x)
Flux.gradient(m -> sum(abs2, m(x)), d)
end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})
The function `reshape` exists, but no method is defined for this combination of argument types.
Closest candidates are:
reshape(::ChainRulesCore.AbstractThunk, ::Any...)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
@ Base reshapedarray.jl:40
reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
@ Base bitarray.jl:479
...
Stacktrace:
[1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
[2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[3] Dense
@ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:204 [inlined]
[4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
[5] FluxML/Flux.jl#215
@ ./REPL[419]:4 [inlined]
[6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working