Skip to content

Commit 6ce9178

Browse files
authored
Merge pull request #275 from JuliaDiff/ox/st_con
Fix frule for static array constructor that converts eltype
2 parents 1e57ef7 + 2ddb972 commit 6ce9178

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

src/extra_rules.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,14 @@ end
179179
Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds)
180180
Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind]
181181

182-
function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
183-
SArray{S, T, N, L}(x), SArray{S}(∂x)
182+
function ChainRules.frule(
183+
(_, ∂x)::Tuple{Any, Tangent{TUP}},
184+
::Type{SArray{S, T, N, L}},
185+
x::TUP,
186+
) where {L, TUP<:NTuple{L, Number}, S, T<:Number, N}
187+
y = SArray{S, T, N, L}(x)
188+
∂y = SArray{S, T, N, L}(ChainRulesCore.backing(∂x))
189+
return y, ∂y
184190
end
185191

186192
@ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T)

test/extra_rules.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using Diffractor
2+
using StaticArrays
3+
using ChainRulesCore
4+
using Test
5+
6+
@testset "StaticArrays constructor" begin
7+
#frule(::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.Tangent{Tuple{Int64, Vararg{Float64, 9}}, Tuple{Int64, Vararg{Float64, 9}}}}, ::Type{StaticArraysCore.SVector{10, Float64}}, x::Tuple{Int64, Vararg{Float64, 9}})
8+
# @ Diffractor ~/.julia/packages/Diffractor/yCsbI/src/extra_rules.jl:183
9+
10+
@testset "homogenious type" begin
11+
x = (10.0, 20.0, 30.0)
12+
= zero_tangent(x)
13+
y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x)
14+
@test y == @SVector [10.0, 20.0, 30.0]
15+
@test== @SVector [0.0, 0.0, 0.0]
16+
end
17+
18+
@testset "convertable type" begin
19+
x::Tuple{Int, Float64, Float64} = (10, 20.0, 30.0)
20+
= zero_tangent(x)
21+
y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x)
22+
# all are float
23+
@test y == @SVector [10.0, 20.0, 30.0]
24+
@test== @SVector [0.0, 0.0, 0.0]
25+
end
26+
27+
@testset "convertable type with ZeroTangent()" begin
28+
x = (10, 20.0, 30.0)
29+
= Tangent{typeof(x)}(ZeroTangent(), 1.0, 2.0)
30+
y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x)
31+
# all are float
32+
@test y == @SVector [10.0, 20.0, 30.0]
33+
@test== @SVector [0.0, 1.0, 2.0]
34+
end
35+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ const bwd = Diffractor.PrimeDerivativeBack
1414
@testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run
1515

1616
@testset "$file" for file in (
17+
"extra_rules.jl"
1718
"stage2_fwd.jl",
1819
"tangent.jl",
1920
"forward_diff_no_inf.jl",

0 commit comments

Comments
 (0)