'python change namespace across modules?
Suppose I have created a module named my_mod, and within it there are two files __init__.py and my_func.py:
__init__.py:
import numpy
import cupy
xp = numpy # xp defaults to numpy
# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
def wrapper(x, *args, **kwargs):
global xp
xp = cupy if isinstance(x, cupy.ndarray) else numpy
return fcn(x, *args, **kwargs)
return wrapper
from .my_func import *
and my_func.py:
from my_mod import xp, array_dispatch
@array_dispatch
def print_xp(x):
print(xp.__name__)
basically I'd like print_xp to print out either "numpy" or "cupy" based on the class of the input x: if the input x to print_xp is a numpy array, then print out "numpy"; if x is a cupy array, then it should print out "cupy".
However, currently it always prints out "numpy", which is the default of xp. Can someone help me understand why, and what is the solution? Thanks!
Solution 1:[1]
To answer your specific question, don't do:
from my_mod import xp, array_dispatch
Instead, use
import my_mod
then refer to my_mod.xp in your function:
@my_mod.array_dispatch
def print_xp(x):
print(my_mod.xp.__name__)
Then you'll see the updates to my_mod's global namespace...
Although, you really should try to avoid using a global variable like this at all.
EDIT: Here's an approach I would take, if I understand what you want correctly.
import inspect
import cupy
import numpy
def array_dispatch(fcn):
sig = inspect.signature(fcn)
param = sig.parameters.get("xp")
if param is None:
raise ValueError("function must have an `xp` paramter")
if param.kind is not inspect.Parameter.KEYWORD_ONLY:
raise ValueError(f"`xp` parameter must be keyword only, got {param.kind}")
def wrapper(x, *args, **kwargs):
if isinstance(x, cupy.ndarray):
xp = cupy
elif isinstance(x, numpy.ndarray):
xp = numpy
else:
raise TypeError(f"expected either a numpy.ndarray or a cupy.ndarray, got {type(x)}")
return fcn(x, *args, xp=xp, **kwargs)
return wrapper
Then, an example user of this decorator:
from my_mod import xp, array_dispatch
@array_dispatch
def frobnicate(x, *, xp):
return xp.tanh(x) + 42
import numpy as np
print(frobnicate(np.arange(10)))
Solution 2:[2]
Define xp as an array
import numpy
import cupy
xp = [numpy] # xp defaults to numpy
# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
def wrapper(x, *args, **kwargs):
global xp
xp[0] = cupy if isinstance(x, cupy.ndarray) else numpy
return fcn(x, *args, **kwargs)
return wrapper
from .my_func import *
and
from my_mod import xp, array_dispatch
@array_dispatch
def print_xp(x):
print(xp[0].__name__)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 | Levi |
