-
-
Notifications
You must be signed in to change notification settings - Fork 217
Open
Description
The following example shows inaccuracy of .==
when executed on the GPU within gradient computation:
using Metal, CUDA
using Flux
device = gpu_device()
# device = cpu_device()
f = x -> begin
y = [0, 1, 2] |> device
mask = y .== 1
return sum(x[mask])
end
x = Float32[1, 2, 3] |> device
grad = Flux.gradient(f, x) # should be [0.0, 1.0, 0.0], got [0.0, 0.0, 0.0]
Per this discussion, this is due to the specific way broadcasted functions are differentiated through on GPU using ForwardDiff
.
The problem can be avoided by replacing
mask = y .== 1
with
mask = Flux.@ignore_derivatives y .== 1
The problem seems to be automatically circumvented on CPU
Zygote.jl/src/lib/broadcast.jl
Lines 206 to 211 in 1b914d9
@adjoint broadcasted(::AbstractArrayStyle, f::F, args...) where {F} = _broadcast_generic(__context__, f, args...) | |
@inline function _broadcast_generic(__context__, f::F, args...) where {F} | |
T = Broadcast.combine_eltypes(f, args) | |
# Avoid generic broadcasting in two easy cases: | |
if T == Bool | |
return (f.(args...), _ -> nothing) |
but not on GPU
Zygote.jl/src/lib/broadcast.jl
Lines 359 to 363 in 1b914d9
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, | |
# so perhaps this can be deleted? Possible edge case here: | |
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415 | |
@adjoint broadcasted(::AbstractGPUArrayStyle, f, args...) = | |
broadcast_forward(f, args...) |
Due to the potential difficulty of spotting this unexpected behavior, this may worth being considered a bug that warrants fixing.
Metadata
Metadata
Assignees
Labels
No labels