Skip to content

Commit 3b62426

Browse files
authored
Fix static_pointer_cast build failure with virtual inheritance in holder_caster_foreign_helpers.h (#6014)
* Add regression test for #5989: static_pointer_cast fails with virtual inheritance When a class uses virtual inheritance and its holder type is shared_ptr, passing a shared_ptr of the derived type as a method argument triggers a compilation error because static_pointer_cast cannot downcast through a virtual base (dynamic_pointer_cast is needed instead). Made-with: Cursor * Fix #5989: use dynamic_pointer_cast for virtual inheritance in esft downcast Replace the unconditional static_pointer_cast in set_via_shared_from_this with a SFINAE-dispatched esft_downcast helper that falls back to dynamic_pointer_cast when static_cast through a virtual base is ill-formed. Also add a workaround in the test binding (.def("name") on SftVirtDerived2) for a separate pre-existing issue with inherited method dispatch through virtual bases. Made-with: Cursor
1 parent c0bbd8b commit 3b62426

3 files changed

Lines changed: 69 additions & 1 deletion

File tree

include/pybind11/detail/holder_caster_foreign_helpers.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,31 @@ struct holder_caster_foreign_helpers {
3131
PyObject *o;
3232
};
3333

34+
// Downcast shared_ptr from the enable_shared_from_this base to the target type.
35+
// SFINAE probe: use static_pointer_cast when the static downcast is valid (common case),
36+
// fall back to dynamic_pointer_cast when it isn't (virtual inheritance — issue #5989).
37+
// We can't use dynamic_pointer_cast unconditionally because it requires polymorphic types;
38+
// we can't use is_polymorphic to choose because that's orthogonal to virtual inheritance.
39+
// (The implementation uses the "tag dispatch via overload priority" trick.)
40+
template <typename type, typename esft_base>
41+
static auto esft_downcast(const std::shared_ptr<esft_base> &existing, int /*preferred*/)
42+
-> decltype(static_cast<type *>(std::declval<esft_base *>()), std::shared_ptr<type>()) {
43+
return std::static_pointer_cast<type>(existing);
44+
}
45+
46+
template <typename type, typename esft_base>
47+
static std::shared_ptr<type> esft_downcast(const std::shared_ptr<esft_base> &existing,
48+
... /*fallback*/) {
49+
return std::dynamic_pointer_cast<type>(existing);
50+
}
51+
3452
template <typename type>
3553
static auto set_via_shared_from_this(type *value, std::shared_ptr<type> *holder_out)
3654
-> decltype(value->shared_from_this(), bool()) {
3755
// object derives from enable_shared_from_this;
3856
// try to reuse an existing shared_ptr if one is known
3957
if (auto existing = try_get_shared_from_this(value)) {
40-
*holder_out = std::static_pointer_cast<type>(existing);
58+
*holder_out = esft_downcast<type>(existing, 0);
4159
return true;
4260
}
4361
return false;

tests/test_smart_ptr.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,28 @@ struct SharedFromThisVBase : std::enable_shared_from_this<SharedFromThisVBase> {
247247
};
248248
struct SharedFromThisVirt : virtual SharedFromThisVBase {};
249249

250+
// Issue #5989: static_pointer_cast where dynamic_pointer_cast is needed
251+
// (virtual inheritance with shared_ptr holder)
252+
struct SftVirtBase : std::enable_shared_from_this<SftVirtBase> {
253+
SftVirtBase() = default;
254+
virtual ~SftVirtBase() = default;
255+
static std::shared_ptr<SftVirtBase> create() { return std::make_shared<SftVirtBase>(); }
256+
virtual std::string name() { return "SftVirtBase"; }
257+
};
258+
struct SftVirtDerived : SftVirtBase {
259+
using SftVirtBase::SftVirtBase;
260+
static std::shared_ptr<SftVirtDerived> create() { return std::make_shared<SftVirtDerived>(); }
261+
std::string name() override { return "SftVirtDerived"; }
262+
};
263+
struct SftVirtDerived2 : virtual SftVirtDerived {
264+
using SftVirtDerived::SftVirtDerived;
265+
static std::shared_ptr<SftVirtDerived2> create() {
266+
return std::make_shared<SftVirtDerived2>();
267+
}
268+
std::string name() override { return "SftVirtDerived2"; }
269+
std::string call_name(const std::shared_ptr<SftVirtDerived2> &d2) { return d2->name(); }
270+
};
271+
250272
// test_move_only_holder
251273
struct C {
252274
C() { print_created(this); }
@@ -522,6 +544,21 @@ TEST_SUBMODULE(smart_ptr, m) {
522544
py::class_<SharedFromThisVirt, std::shared_ptr<SharedFromThisVirt>>(m, "SharedFromThisVirt")
523545
.def_static("get", []() { return sft.get(); });
524546

547+
// Issue #5989: static_pointer_cast where dynamic_pointer_cast is needed
548+
py::class_<SftVirtBase, std::shared_ptr<SftVirtBase>>(m, "SftVirtBase")
549+
.def(py::init<>(&SftVirtBase::create))
550+
.def("name", &SftVirtBase::name);
551+
py::class_<SftVirtDerived, SftVirtBase, std::shared_ptr<SftVirtDerived>>(m, "SftVirtDerived")
552+
.def(py::init<>(&SftVirtDerived::create));
553+
py::class_<SftVirtDerived2, SftVirtDerived, std::shared_ptr<SftVirtDerived2>>(
554+
m, "SftVirtDerived2")
555+
.def(py::init<>(&SftVirtDerived2::create))
556+
// TODO: Remove this once inherited methods work through virtual bases.
557+
// Without it, d2.name() segfaults because pybind11 uses an incorrect
558+
// pointer offset when dispatching through the virtual inheritance chain.
559+
.def("name", &SftVirtDerived2::name)
560+
.def("call_name", &SftVirtDerived2::call_name, py::arg("d2"));
561+
525562
// test_move_only_holder
526563
py::class_<C, custom_unique_ptr<C>>(m, "TypeWithMoveOnlyHolder")
527564
.def_static("make", []() { return custom_unique_ptr<C>(new C); })

tests/test_smart_ptr.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,19 @@ def test_shared_ptr_from_this_and_references():
251251
assert y is z
252252

253253

254+
def test_shared_from_this_virt_shared_ptr_arg():
255+
"""Issue #5989: static_pointer_cast fails with virtual inheritance."""
256+
b = m.SftVirtBase()
257+
assert b.name() == "SftVirtBase"
258+
259+
d = m.SftVirtDerived()
260+
assert d.name() == "SftVirtDerived"
261+
262+
d2 = m.SftVirtDerived2()
263+
assert d2.name() == "SftVirtDerived2"
264+
assert d2.call_name(d2) == "SftVirtDerived2"
265+
266+
254267
@pytest.mark.skipif("env.GRAALPY", reason="Cannot reliably trigger GC")
255268
def test_move_only_holder():
256269
a = m.TypeWithMoveOnlyHolder.make()

0 commit comments

Comments
 (0)