'Numba messes up dtype when broadcasting
I want to safe storage by using small dtypes. However when I add or multiply a number to an array numba changes the dtype to int64:
Pure Numpy
In:
def f():
a=np.ones(10, dtype=np.uint8)
return a+1
f()
Out:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
Now with numba:
In:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a+1
f()
Out:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)
One solution is to replace a+1 with a+np.ones(a.shape, dtype=a.dtype) but I cannot imagine something uglier.
Thanks a lot for help!
Solution 1:[1]
You can use np.ones_like:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.ones_like(a)
f()
Output:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
...or np.full_like:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.full_like(a, 100)
f()
Output:
array([101, 101, 101, 101, 101, 101, 101, 101, 101, 101], dtype=uint8)
Solution 2:[2]
As you mentioned in the comments, this is probably because numba's default type is int64, and the smaller dtype uint8 gets converted to the larger int64.
Why not just convert it?
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return (a+1).astype('uint8')
f()
Output:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
That's less ugly than a+np.ones(a.shape, dtype=a.dtype). ;)
Solution 3:[3]
You can fix this if you are willing to make your function accept inputs. I rewrote your function using the signature_or_function argument of njit:
@numba.njit(signature_or_function='uint8[:](uint8)')
def f(x):
a = np.ones(10, dtype=np.uint8)
return a+x
f(1)
# array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
Some documentation on numba signatures. If you define signatures, numba will compile a specialized function for each unique signature and try to use a compatible pre-compiled signature for anything for which a signature isn't explicitly defined. The signature there tells it that it will return an array if unsigned 8-bit integers ('uint8[:]') and take an input of an unsigned 8-bit integer value.
Note that in this case, I had to make the function accept an input because numba seems to default to treating integer literals (e.g., the 1 of a + 1) as int64 values, but if you specify that the input to the function is a uint8 and you don't make a more permissive signature, then when you compile and run the function, it will treat your input as uint8 and not up-convert since it doesn't need to.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 | |
| Solution 3 |
