'How can I get the associated tensor from a Torch FX Graph Node?
I want to be able to get all the operations that occur within a torch module, along with how they are parameterized. To do this, I first made a torch.fx.Tracer that disables leaf nodes so that I can get the graph without call_modules:
class MyTracer(torch.fx.Tracer):
def is_leaf_module(self, m, module_qualified_name):
return False
I also have a basic module that I am working with:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3,3,3)
def forward(self, x):
y1 = self.conv(x)
y = torch.relu(y1)
y = y + y1
y = torch.relu(y)
return y
I construct an instance of the module like so and trace it:
m = MyModule()
graph = MyTracer().trace(m)
graph.print_tabular()
which gives:
opcode name target args kwargs
------------- ----------- --------------------------------------------------------- ------------------------------------------------------ --------
placeholder x x () {}
get_attr conv_weight conv.weight () {}
get_attr conv_bias conv.bias () {}
call_function conv2d <built-in method conv2d of type object at 0x7f99b6a0a1c0> (x, conv_weight, conv_bias, (1, 1), (0, 0), (1, 1), 1) {}
call_function relu <built-in method relu of type object at 0x7f99b6a0a1c0> (conv2d,) {}
call_function add <built-in function add> (relu, conv2d) {}
call_function relu_1 <built-in method relu of type object at 0x7f99b6a0a1c0> (add,) {}
output output output (relu_1,) {}
How do I actually get the associated parameters conv_weight and conv_bias without accessing them directly in the model (via m.conv.weight or m.conv.bias)?
Solution 1:[1]
After additional searching and outside assistance, I was shown the Interpreter pattern: https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern This pattern allows you to actually see the nodes while executing the graph. So, I built this small interpreter which prints out Conv2D information:
class MyInterpreter(fx.Interpreter):
def call_function(self, target, args, kwargs):
if target == torch.conv2d:
print('CONV2D')
print('kernel', args[1].shape)
print('bias', args[2].shape)
return super().call_function(target, args, kwargs)
gm = torch.fx.GraphModule(m, graph)
MyInterpreter(gm).run(torch.randn((3,3,3,3))
yields:
CONV2D
kernel torch.Size([3, 3, 3, 3])
bias torch.Size([3])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | iHowell |
