'How does pytorch Module collect learnable parameters from modules in its attributes?
When I define a class as a submodule of torch.nn.Module and then I define some class attributes, such as
class Vgg16(torch.nn.Module):
def __init__(self):
super().__init__()
self.feature_1 = nn.Sequential()
self.classifier = nn.Sequential()
...
my_weight = self.state_dict()
Does the my_weight variable contain state_dict including the state of the nn.Sequential() modules? I believe the state_dict contains all the parameters required for module reconstruction, but I have no ide how does the module register them when they are being created.
The constructor of the Sequential module has no way of knowing that it is instantiated inside of another module, or has it?
I would understand if it was done through the torch.nn.Module.add_module(...) but here it is not. I know that the Module has some private dict of modules and overwrites the __getattr__() method so that I can access layers (submodules) as attributes, but how does it work when calling for the state_dict()?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
