Skip to content

Commit 6e183b8

Browse files
fix: fix get_mtkparameters_reconstructor handling of nonnumerics
1 parent 263c870 commit 6e183b8

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

src/systems/problem_utils.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,8 @@ end
659659
$(TYPEDEF)
660660
661661
A callable struct which applies `p_constructor` to possibly nested arrays. It also
662-
ensures that views (including nested ones) are concretized.
662+
ensures that views (including nested ones) are concretized. This is implemented manually
663+
of using `narrow_buffer_type` to preserve type-stability.
663664
"""
664665
struct PConstructorApplicator{F}
665666
p_constructor::F
@@ -669,10 +670,18 @@ function (pca::PConstructorApplicator)(x::AbstractArray)
669670
pca.p_constructor(x)
670671
end
671672

673+
function (pca::PConstructorApplicator)(x::AbstractArray{Bool})
674+
pca.p_constructor(BitArray(x))
675+
end
676+
672677
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray)
673678
collect(x)
674679
end
675680

681+
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{Bool})
682+
BitArray(x)
683+
end
684+
676685
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray})
677686
collect(pca.(x))
678687
end
@@ -695,6 +704,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
695704
"""
696705
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
697706
initials = false, unwrap_initials = false, p_constructor = identity)
707+
_p_constructor = p_constructor
698708
p_constructor = PConstructorApplicator(p_constructor)
699709
# if we call `getu` on this (and it were able to handle empty tuples) we get the
700710
# fields of `MTKParameters` except caches.
@@ -748,14 +758,24 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
748758
Base.Fix1(broadcast, p_constructor)
749759
getu(srcsys, syms[3])
750760
end
751-
rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf
752-
if buf == ()
753-
return Returns(())
754-
else
755-
return Base.Fix1(broadcast, p_constructor) getu(srcsys, buf)
756-
end
761+
const_getter = if syms[4] == ()
762+
Returns(())
763+
else
764+
Base.Fix1(broadcast, p_constructor) getu(srcsys, syms[4])
757765
end
758-
getters = (tunable_getter, initials_getter, discs_getter, rest_getters...)
766+
nonnumeric_getter = if syms[5] == ()
767+
Returns(())
768+
else
769+
ic = get_index_cache(dstsys)
770+
buftypes = Tuple(map(ic.nonnumeric_buffer_sizes) do bufsize
771+
Vector{bufsize.type}
772+
end)
773+
# nonnumerics retain the assigned buffer type without narrowing
774+
Base.Fix1(broadcast, _p_constructor)
775+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) getu(srcsys, syms[5])
776+
end
777+
getters = (
778+
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter)
759779
getter = let getters = getters
760780
function _getter(valp, initprob)
761781
oldcache = parameter_values(initprob).caches
@@ -768,6 +788,10 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
768788
return getter
769789
end
770790

791+
function call(f, args...)
792+
f(args...)
793+
end
794+
771795
"""
772796
$(TYPEDSIGNATURES)
773797

0 commit comments

Comments
 (0)