'Inheritance and std::shared_ptr in Cython

Suppose I have the following simple example of C++ inheritance in file.h:

class Base {};
class Derived : public Base {};

Then, the following code compiles; that is, I can assign std::shared_ptr<Derived> to std::shared_ptr<Base>:

Derived* foo = new Derived();
std::shared_ptr<Derived> shared_foo = std::make_shared<Derived>(*foo);
std::shared_ptr<Base> bar = shared_foo;

Let's also say I've added the types to a decl.pxd:

cdef extern from "file.h":
    cdef cppclass Base:
        pass
    cdef cppclass Derived(Base):
        pass

Then, what I'm trying to do is mimic the above C++ assignment in Cython in a file.pyx:

cimport decl
from libcpp.memory cimport make_shared, shared_ptr

def do_stuff():
    cdef decl.Derived* foo = new decl.Derived()
    cdef shared_ptr[decl.Derived] shared_foo = make_shared[decl.Derived](foo)
    cdef shared_ptr[decl.Base] bar = shared_foo

Unlike the C++ case, this now fails with the following error (using Cython 3.0a6):

cdef shared_ptr[decl.Base] bar = shared_foo
                                ^
---------------------------------------------------------------

 Cannot assign type 'shared_ptr[Derived]' to 'shared_ptr[Base]'

Should I expect this behavior? Is there any way to mimic what the C++ examples does with Cython?

Edit: Cf. the comments to the accepted answer below, the relevant functionality has been added to Cython and is available as of version 3.0a7.



Solution 1:[1]

It should work for Cython>=3.0 as @fuglede made this PR fixing the issue described below (which is still present for Cython<3.0).


The issue is, that the the wrapper of std::shared_ptr misses

template <class U> shared_ptr& operator= (const shared_ptr<U>& x) noexcept;

of the std::shared_ptr-class.

If you patch the wrapper like that:

cdef extern from "<memory>" namespace "std" nogil:
cdef cppclass shared_ptr[T]:
    ...
    shared_ptr[T]& operator=[Y](const shared_ptr[Y]& ptr)
    #shared_ptr[Y](shared_ptr[Y]&)  isn't accepted

your code will compile.

You might ask, why operator= and not constructor shared_ptr[Y] is needed, because:

...
cdef shared_ptr[decl.Base] bar = shared_foo

looks like constructor (template <class U> shared_ptr (const shared_ptr<U>& x) noexcept;) is not explicit. But it is one of Cython's quirks with C++. The above code will be translated to

std::shared_ptr<Base> __pyx_v_bar;
...
__pyx_v_bar = __pyx_v_shared_foo;

and not

std::shared_ptr<Base> __pyx_v_bar = __pyx_v_shared_foo;

thus Cython will check the existence of operator= (lucky for us, because Cython seems not to support constructor with templates, but does so for operators).


If you want to distribute your module also on systems without patched memory.pxd you have two option:

  1. to wrap std::shared_ptr correctly by yourself
  2. write a small utility function, for example
%%cython
...
cdef extern from *:
    """
    template<typename T1, typename T2>
    void assign_shared_ptr(std::shared_ptr<T1>& lhs, const std::shared_ptr<T2>& rhs){
         lhs = rhs;
    }
    """
    void assign_shared_ptr[T1, T2](shared_ptr[T1]& lhs, shared_ptr[T2]& rhs)
    
...
cdef shared_ptr[Derived] shared_foo
# cdef shared_ptr[decl.Base] bar = shared_foo
# must be replaced through:
cdef shared_ptr[Base] bar 
assign_shared_ptr(bar, shared_foo)
...

Both options have drawbacks, so depending on your scenario you might prefer one or another.

Solution 2:[2]

I have not tried Cyton, but std::shared_ptr has an static cast function std::static_pointer_cast. I think this will work

std::shared_ptr<Base> bar = std::static_pointer_cast<Base>(shared_foo);

.

def do_stuff():
    cdef decl.Derived* foo = new decl.Derived()
    cdef shared_ptr[decl.Derived] shared_foo = make_shared[decl.Derived](foo)
    cdef shared_ptr[decl.Base] bar = static_pointer_cast[decl.Base] shared_foo

As a side note

The way you create shared_foo is probably not what you want. Here you are first creating a dynamically allocated Derived. Then you are creating a new dynamically allocated shared derived that is a copy of the original.

// this allocates one Derived
Derived* foo = new Derived(); 
// This allocates a new copy, it does not take ownership of foo
std::shared_ptr<Derived> shared_foo = std::make_shared<Derived>(*foo); 

What you probably want is either:

Derived* foo = new Derived();
std::shared_ptr<Derived> shared_foo(foo); // This now takes ownership of foo

Or just:

// This makes a default constructed shared Derived
auto shared_foo = std::make_shared<Derived>(); 

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 darthbith
Solution 2 jo-art