'Binary operations between two 0-dim ndarrays don't preserve type

The docs about ndarray state the following:

Arithmetic and comparison operations on ndarrays are defined as element-wise operations, and generally yield ndarray objects as results.

This makes sense as it would be very surprising if type(x) == type(y) but type(x {op} y) != type(x) (where {op} stands for an arbitrary binary operator).

However, np.ndarray seems to behave that way if ndim == 0 for both operands:

>>> x = np.array(1.)
>>> f'{x = }, {x.ndim = }'
'x = array(1.), x.ndim = 0'
>>> f'{x+x = }, {type(x+x) = }'
"x+x = 2.0, type(x+x) = <class 'numpy.float64'>"

This is inconvenient because my code assumes that binary operations do preserve the type (which is reasonable). Also, static type checking won't capture that:

from typing import TypeVar
import numpy as np

T = TypeVar('T', bound=np.ndarray)

def add(x: T, y: T) -> T:
    return x + y

a = np.array(1.)
assert isinstance(add(a, a), np.ndarray)

According to mypy this snippet contains no issues but the assertion fails when run.

Is there a way to configure Numpy such that it will preserve the ndarray type even for 0-dim arrays?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source