Skip to content

Chainrule for CUDA reduction #666

Open
@renatobellotti

Description

@renatobellotti

Hi,

I'd like to suggest including a rule for GPU reductions.

using Zygote

function my_loss(v)
    # This works:
    # l = sum(v)
    # This does not work:
    l = reduce(+, v)
    return l
end

v = cu([1., 2.])
Zygote.gradient(my_loss, v)

See also: FluxML/Zygote.jl#730 (comment)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions