'calling njit function in python numba jitclass fails
@njit
def cumutrapz(x:np.array, y:np.array):
return np.append(0, [
np.trapz(y=y[i-2:i], x=x[i-2:i]) for i in range(2, len(x) + 1)]).cumsum()
from numba import float64
@jitclass([
('a', float64[:]),
('b', float64[:]),
('c', float64[:]),
])
class Testaroo(object):
def __init__(self, a, b):
self.a = a
self.b = b
self.c = np.zeros(len(self.a), dtype=np.float64)
def set_c(self):
self.c = cumutrapz(self.a, self.b)
testaroo = Testaroo(
np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()
The above fails, but the following two very similar examples work:
cumutrapz(np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
and
from numba import float64
@jitclass([
('a', float64[:]),
('b', float64[:]),
('c', float64[:]),
])
class Testaroo(object):
def __init__(self, a, b):
self.a = a
self.b = b
self.c = np.zeros(len(self.a), dtype=np.float64)
def set_c(self):
self.c = (self.a * self.b).cumsum()
testaroo = Testaroo(
np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()
This latter example will work for me for now, but I'd like to know if there's a way to get the cumutrapz function working inside of a jitclass.
I'm using numba version '0.53.1'.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
