|
1 | 1 | function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol,Val})
|
2 |
| - return getproperty(x, s), Δ -> getproperty_adjoint(Δ, x, s) |
| 2 | + return getproperty(x, s), Δ -> getproperty_adjoint(ChainRulesCore.unthunk(Δ), x, s) |
3 | 3 | end
|
4 | 4 |
|
5 | 5 | function getproperty_adjoint(Δ, x, s)
|
@@ -28,9 +28,9 @@ function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig{>:ChainRulesCore.Ha
|
28 | 28 | return y_, pb_f
|
29 | 29 | end
|
30 | 30 |
|
31 |
| -ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x))) |
| 31 | +ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ -> (ChainRulesCore.NoTangent(), ComponentArray(ChainRulesCore.unthunk(Δ), getaxes(x))) |
32 | 32 |
|
33 |
| -ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(Δ), ChainRulesCore.NoTangent()) |
| 33 | +ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ -> (ChainRulesCore.NoTangent(), getdata(ChainRulesCore.unthunk(Δ)), ChainRulesCore.NoTangent()) |
34 | 34 |
|
35 | 35 | function ChainRulesCore.ProjectTo(ca::ComponentArray)
|
36 | 36 | return ChainRulesCore.ProjectTo{ComponentArray}(; project=ChainRulesCore.ProjectTo(getdata(ca)), axes=getaxes(ca))
|
|
49 | 49 | function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray}
|
50 | 50 | y = CA(nt)
|
51 | 51 |
|
| 52 | + ∇NamedTupleToComponentArray(Δ) = ∇NamedTupleToComponentArray(ChainRulesCore.unthunk(Δ)) |
| 53 | + |
52 | 54 | function ∇NamedTupleToComponentArray(Δ::AbstractArray)
|
53 | 55 | if length(Δ) == length(y)
|
54 | 56 | return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y)))
|
|
0 commit comments