'Numba messes up dtype when broadcasting

I want to safe storage by using small dtypes. However when I add or multiply a number to an array numba changes the dtype to int64:

Pure Numpy

In:

def f():
    a=np.ones(10, dtype=np.uint8)
    return a+1
f()

Out:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

Now with numba:

In:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a+1
f()

Out:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)

One solution is to replace a+1 with a+np.ones(a.shape, dtype=a.dtype) but I cannot imagine something uglier.

Thanks a lot for help!



Solution 1:[1]

You can use np.ones_like:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a + np.ones_like(a)
f()

Output:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

...or np.full_like:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a + np.full_like(a, 100)
f()

Output:

array([101, 101, 101, 101, 101, 101, 101, 101, 101, 101], dtype=uint8)

Solution 2:[2]

As you mentioned in the comments, this is probably because numba's default type is int64, and the smaller dtype uint8 gets converted to the larger int64.

Why not just convert it?

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return (a+1).astype('uint8')
f()

Output:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

That's less ugly than a+np.ones(a.shape, dtype=a.dtype). ;)

Solution 3:[3]

You can fix this if you are willing to make your function accept inputs. I rewrote your function using the signature_or_function argument of njit:

@numba.njit(signature_or_function='uint8[:](uint8)')
def f(x):
    a = np.ones(10, dtype=np.uint8)
    return a+x

f(1)
# array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

Some documentation on numba signatures. If you define signatures, numba will compile a specialized function for each unique signature and try to use a compatible pre-compiled signature for anything for which a signature isn't explicitly defined. The signature there tells it that it will return an array if unsigned 8-bit integers ('uint8[:]') and take an input of an unsigned 8-bit integer value.

Note that in this case, I had to make the function accept an input because numba seems to default to treating integer literals (e.g., the 1 of a + 1) as int64 values, but if you specify that the input to the function is a uint8 and you don't make a more permissive signature, then when you compile and run the function, it will treat your input as uint8 and not up-convert since it doesn't need to.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2
Solution 3