Skip to content

Commit a826dfb

Browse files
authored
fix: force unthunk in rrule (#293)
1 parent dd0e4c0 commit a826dfb

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <[email protected]>"]
4-
version = "0.15.21"
4+
version = "0.15.22"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/compat/chainrulescore.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val})
2-
return getproperty(x, s), Δ -> getproperty_adjoint(Δ, x, s)
2+
return getproperty(x, s), Δ -> getproperty_adjoint(ChainRulesCore.unthunk(Δ), x, s)
33
end
44

55
function getproperty_adjoint(Δ, x, s)
@@ -28,9 +28,9 @@ function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig{>:ChainRulesCore.Ha
2828
return y_, pb_f
2929
end
3030

31-
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
31+
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(ChainRulesCore.unthunk(Δ), getaxes(x)))
3232

33-
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent())
33+
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(ChainRulesCore.unthunk(Δ)), ChainRulesCore.NoTangent())
3434

3535
function ChainRulesCore.ProjectTo(ca::ComponentArray)
3636
return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca))
@@ -49,6 +49,8 @@ end
4949
function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray}
5050
y = CA(nt)
5151

52+
∇NamedTupleToComponentArray(Δ) = ∇NamedTupleToComponentArray(ChainRulesCore.unthunk(Δ))
53+
5254
function ∇NamedTupleToComponentArray::AbstractArray)
5355
if length(Δ) == length(y)
5456
return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y)))

0 commit comments

Comments
 (0)