Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ struct copyable_holder_caster<
}

if (parent) {
return type_caster_base<type>::cast(
return type_caster_generic::cast_non_owning(
srcs, return_value_policy::reference_internal, parent);
}

Expand Down
12 changes: 12 additions & 0 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,18 @@ class type_caster_generic {
return cast(srcs, policy, parent, copy_constructor, move_constructor, existing_holder);
}

static handle cast_non_owning(const cast_sources &srcs,
return_value_policy policy,
handle parent,
const void *existing_holder = nullptr) {
// Reference-like policies alias an existing C++ object instead of creating
// a new one, so copy/move constructor callbacks must remain null here.
assert(policy == return_value_policy::reference
|| policy == return_value_policy::reference_internal
|| policy == return_value_policy::automatic_reference);
return cast(srcs, policy, parent, nullptr, nullptr, existing_holder);
}

PYBIND11_NOINLINE static handle cast(const cast_sources &srcs,
return_value_policy policy,
handle parent,
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(PYBIND11_TEST_FILES
test_operator_overloading
test_pickling
test_potentially_slicing_weak_ptr
test_pytorch_shared_ptr_cast_regression
test_python_multiple_inheritance
test_pytypes
test_scoped_critical_section
Expand Down
12 changes: 12 additions & 0 deletions tests/test_class_sh_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,15 @@ def test_non_smart_holder_member_type_with_smart_holder_owner_aliases_member():
legacy = obj.legacy
legacy.value = 13
assert obj.legacy.value == 13


def test_non_smart_holder_member_type_with_smart_holder_owner_aliases_member_multiple_reads():
obj = m.ShWithSimpleStructMember()

a = obj.legacy
b = obj.legacy

a.value = 13

assert b.value == 13
assert obj.legacy.value == 13
62 changes: 62 additions & 0 deletions tests/test_pytorch_shared_ptr_cast_regression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "pybind11_tests.h"

#include <memory>
#include <string>

#if defined(__clang__)
# if __has_warning("-Wdeprecated-copy-with-user-provided-dtor")
# pragma clang diagnostic error "-Wdeprecated-copy-with-user-provided-dtor"
# endif
# if __has_warning("-Wdeprecated-copy-with-dtor")
# pragma clang diagnostic error "-Wdeprecated-copy-with-dtor"
# endif
#endif

namespace test_pytorch_regressions {

// Directly extracted from PyTorch patterns that regressed in CI.
struct TracingState : std::enable_shared_from_this<TracingState> {
TracingState() = default;
~TracingState() = default;
int value = 0;
};

const std::shared_ptr<TracingState> &get_tracing_state() {
static std::shared_ptr<TracingState> state = std::make_shared<TracingState>();
return state;
}

struct InterfaceType {
~InterfaceType() = default;
int value = 0;
};
using InterfaceTypePtr = std::shared_ptr<InterfaceType>;

struct CompilationUnit {
InterfaceTypePtr iface = std::make_shared<InterfaceType>();

InterfaceTypePtr get_interface(const std::string &) const { return iface; }
};

} // namespace test_pytorch_regressions

TEST_SUBMODULE(pybind11_pytorch_regressions, m) {
using namespace test_pytorch_regressions;

py::class_<TracingState, std::shared_ptr<TracingState>>(m, "TracingState")
.def(py::init<>())
.def_readwrite("value", &TracingState::value);

m.def("_get_tracing_state", []() { return get_tracing_state(); });

py::class_<InterfaceType, InterfaceTypePtr>(m, "InterfaceType")
.def(py::init<>())
.def_readwrite("value", &InterfaceType::value);

py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(m, "CompilationUnit")
.def(py::init<>())
.def("get_interface",
[](const std::shared_ptr<CompilationUnit> &self, const std::string &name) {
return self->get_interface(name);
});
}
25 changes: 25 additions & 0 deletions tests/test_pytorch_shared_ptr_cast_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from pybind11_tests import pybind11_pytorch_regressions as m


def test_pytorch_like_get_tracing_state_aliases_singleton_shared_ptr():
a = m._get_tracing_state()
b = m._get_tracing_state()

a.value = 17

assert b.value == 17
assert m._get_tracing_state().value == 17


def test_pytorch_like_compilation_unit_get_interface_aliases_member_shared_ptr():
cu = m.CompilationUnit()

a = cu.get_interface("iface")
b = cu.get_interface("iface")

a.value = 23

assert b.value == 23
assert cu.get_interface("iface").value == 23
Loading