diff --git a/src/axis.jl b/src/axis.jl index 375e9ff..1119a3f 100644 --- a/src/axis.jl +++ b/src/axis.jl @@ -124,6 +124,12 @@ ViewAxis{Inds,IdxMap,Ax}() where {Inds,IdxMap,Ax} = ViewAxis(Inds, Ax()) ViewAxis(Inds, IdxMap) = ViewAxis(Inds, Axis(IdxMap)) ViewAxis(Inds) = Inds +Base.length(ax::ViewAxis{Inds}) where Inds = length(Inds) +# Fix https://github.com/Deltares/Ribasim/issues/2028 +Base.getindex(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, idx::Integer) where {Inds,IdxMap} = Inds[idx] +Base.iterate(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}) where {Inds,IdxMap} = iterate(Inds) +Base.iterate(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, idx) where {Inds,IdxMap} = iterate(Inds, idx) + const View = ViewAxis const NullOrFlatView{Inds,IdxMap} = ViewAxis{Inds,IdxMap,<:NullorFlatAxis} diff --git a/src/componentindex.jl b/src/componentindex.jl index b63317c..a821468 100644 --- a/src/componentindex.jl +++ b/src/componentindex.jl @@ -13,6 +13,8 @@ const NullComponentIndex{Idx} = ComponentIndex{Idx, NullAxis} Base.:(==)(ci1::ComponentIndex, ci2::ComponentIndex) = ci1.idx == ci2.idx && ci1.ax == ci2.ax +Base.length(ci::ComponentIndex) = length(ci.idx) + """ KeepIndex(idx) diff --git a/test/runtests.jl b/test/runtests.jl index 257a769..f281f61 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -180,6 +180,8 @@ end x = ComponentArray(b=1, a=2) @test merge(NamedTuple(), x) == NamedTuple(x) @test kw_fun(; x...) == 2 + + @test length(ViewAxis(2:7, ShapedAxis((2,3)))) == 6 end @testset "Get" begin @@ -385,6 +387,12 @@ end @test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3))) ax2 = getaxes(ca2)[1] @test ax2[(:a, :c)] == ax2[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3))))) + + @test length(ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis())) == 1 + @test length(ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4)))) == 2 + @test length(ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4))) == 4 + @test length(ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))) == 3 + @test length(ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3)))))) == 7 end @testset "KeepIndex" begin @@ -843,6 +851,49 @@ end @test all(Xstack4_dcolon[:a, :, :] .== Xstack4_noca_dcolon[1, :, :]) @test all(Xstack4_dcolon[:b, :, :] .== Xstack4_noca_dcolon[2:3, :, :]) end + + # Test fix https://github.com/Deltares/Ribasim/issues/2028 + a = range(0.0, 1.0, length=0) |> collect + b = range(0.0, 1.0; length=2) |> collect + c = range(0.0, 1.0, length=3) |> collect + d = range(0.0, 1.0; length=0) |> collect + u = ComponentVector(a=a, b=b, c=c, d=d) + + function get_state_index( + idx::Int, + ::ComponentVector{A, B, <:Tuple{<:Axis{NT}}}, + component_name::Symbol + ) where {A, B, NT} + for (comp, range) in pairs(NT) + if comp == component_name + return range[idx] + end + end + return nothing + end + + @test_throws BoundsError get_state_index(1, u, :a) + @test_throws BoundsError get_state_index(2, u, :a) + @test get_state_index(1, u, :b) == 1 + @test get_state_index(2, u, :b) == 2 + @test get_state_index(1, u, :c) == 3 + @test get_state_index(2, u, :c) == 4 + @test get_state_index(3, u, :c) == 5 + @test_throws BoundsError get_state_index(1, u, :d) + @test_throws BoundsError get_state_index(2, u, :d) + + # Must be a better way to make sure we can `Base.iterate` the `ViewAxis{UnitRange, Shaped1DAxis}`. + nt = ComponentArrays.indexmap(getaxes(u)[1]) + for (i, idx) in enumerate(nt.a) + end + for (i, idx) in enumerate(nt.b) + @test idx == i + end + for (i, idx) in enumerate(nt.c) + @test idx == i + 2 + end + for (i, idx) in enumerate(nt.d) + end end @testset "axpy! / axpby!" begin