'How can I write custom context manager to change the other function's behavior, like torch.no_grad()
I want to make my custom context manager to determine a generic function's behavior, like torch.no_grad() manager. For the torch.no_grad(), sample code is like:
import torch
from torchvision.models import resnet18
# gradient calculation enabled globally
model = resnet18(pretrained=True)
# with no_grad, 'backward' calculation disabled in context and is thread-locked.
with torch.no_grad():
model.forward(torch.randn((1,3,224,224)))
In this sample, torch.no_grad() context keeps all of the inline operations related with underlying tensor.backward disabled without referring to the context manager itself. I want to implement similar behavior for my custom function foo like this :
class MyCustomContextManager():
def __enter__(self):
print("enter behavior")
def __exit(self):
print("exit behavior")
def foo():
print("Without context manager")
foo()
>> "Without context manager"
with MyCustomContextManager() as f:
foo()
>> "With context manager"
To achieve this, Googling and searching showed me some solution, but the problem is that the all of them use a global variance as a flag to indicate a context manager is in or not, but that will produce interference when used in the flask or other similar library empowered with multithreading. How can achive either goal to 1. thread-locked 2. automatically detect context manager ?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
