Skip to content

Broadcasted equality (.==) inaccurate on GPU due to ForwardDiff #1570

@ivanightingale

Description

@ivanightingale

The following example shows inaccuracy of .== when executed on the GPU within gradient computation:

using Metal, CUDA
using Flux

device = gpu_device()
# device = cpu_device()

f = x -> begin
    y = [0, 1, 2] |> device
    mask = y .== 1
    return sum(x[mask])
end

x = Float32[1, 2, 3] |> device
grad = Flux.gradient(f, x) # should be [0.0, 1.0, 0.0], got [0.0, 0.0, 0.0]

Per this discussion, this is due to the specific way broadcasted functions are differentiated through on GPU using ForwardDiff.

The problem can be avoided by replacing

mask = y .== 1

with

mask = Flux.@ignore_derivatives y .== 1

The problem seems to be automatically circumvented on CPU

@adjoint broadcasted(::AbstractArrayStyle, f::F, args...) where {F} = _broadcast_generic(__context__, f, args...)
@inline function _broadcast_generic(__context__, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)

but not on GPU
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
# so perhaps this can be deleted? Possible edge case here:
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
@adjoint broadcasted(::AbstractGPUArrayStyle, f, args...) =
broadcast_forward(f, args...)

Due to the potential difficulty of spotting this unexpected behavior, this may worth being considered a bug that warrants fixing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions