'Numba Compatable Memoization
I have just discovered numba, and learnt that optimal performance requires adding @njit to most functions, such that numba rarely exits LLVM mode.
I still have a few expensive/lookup functions that could benefit from memoization, but so far none of my attempts have found a workable solution that compiles without error.
- Using common decorator functions, before
@njitresults in a numba not being able to do type inference. - Using decorators after
@njitfails to compile the decorator - Numba doesn't like the use of
globalvariables, even when usingnumba.typed.Dict - Numba doesn't like using closures to store mutable state
- Removing
@njitalso causes type errors when called from other@njitfunctions
What is the correct way to add memoization to functions when working inside numba?
import functools
import time
import fastcache
import numba
import numpy as np
import toolz
from numba import njit
from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize
# @fastcache.clru_cache(None) # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None) # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
@njit
# @fastcache.clru_cache(None) # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None) # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
return bitmasks
# @fastcache.clru_cache(None) # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None) # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
return bitmasks
# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
# Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
@njit
def expensive_global():
global _expensive_cache
if _expensive_cache is None:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
_expensive_cache = bitmasks
return _expensive_cache
# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
# is not supported as globals are considered compile-time constants and there is no known way to compile
# a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
key_type = numba.types.string,
value_type = numba.uint64[:]
)
@njit
def expensive_cache():
global cache
if "expensive" not in cache:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
cache["expensive"] = bitmasks
return cache["expensive"]
# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
@njit()
def _expensive_wrapped():
cache = []
def wrapper(bitmasks):
if len(cache) is None:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
cache.append(bitmasks)
return cache[0]
return wrapper
expensive_wrapped = _expensive_wrapped()
@njit
def loop(count):
for n in range(count):
expensive()
# expensive_nojit()
# expensive_cache()
# expensive_global)
# expensive_wrapped()
def main():
time_start = time.perf_counter()
count = 10000
loop(count)
time_taken = time.perf_counter() - time_start
print(f'{count} loops in {time_taken:.4f}s')
loop(1) # precache numba
main()
# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s
Solution 1:[1]
You already mentioned that your real code is more complex, but looking at your minimal example, I would recommend the following pattern:
@njit
def loop(count):
expensive_result = expensive()
for i in range(count):
do_something(count, expensive_result)
Instead of using a cache, you could pre-compute it outside of the loop and provide the result to the loop body. Instead of using globals, I would recommend you to pass every argument explicitly (always, but especially when using the numba jit).
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Carlos Horn |
