Skip to content

rrule for map for tuples is outdated with respect to JuliaLang/julia#42216 #798

Open
@nsajko

Description

@nsajko

function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tuple...) where {F}
length_y = minimum(length, xs)
hobbits = ntuple(length_y) do i
args = getindex.(xs, i)
rrule_via_ad(config, f, args...)
end
y = map(first, hobbits)
num_xs = Val(length(xs))
paddings = map(x -> ntuple(Returns(NoTangent()), (length(x) - length_y)), xs)
all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mistmatched lengths!
But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed."""
function map_pullback(dy_raw)
dy = unthunk(dy_raw)
# We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass:
backevals = ntuple(length_y) do i
rev_i = length_y - i + 1
last(hobbits[rev_i])(dy[rev_i])
end |> reverse
# This df doesn't infer, could test Base.issingletontype(F), but it's not the only inference problem.
df = ProjectTo(f)(sum(first, backevals))
# Now unzip that. Because `map` like `zip` should when any `x` stops, some `dx`s may need padding.
# Although in fact, `map(+, (1,2), (3,4,5))` is an error... https://github.com/JuliaLang/julia/issues/42216
dxs = ntuple(num_xs) do k
dx_short = map(bv -> bv[k+1], backevals)
ProjectTo(xs[k])((dx_short..., paddings[k]...)) # ProjectTo makes the Tangent for us
end
return (NoTangent(), df, dxs...)
end
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
return y, map_pullback
end

The referenced Julia issue is resolved: JuliaLang/julia#42216

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinggood first issueGood for newcomers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions