'Nested classes based on nn.module
I have several classes based on nn.module and I want to have data transferred between these classes. Specifically, I need to get data from the first class and add it to a variable in the second class. Afterward, the combined data in the second class is to be added to the third class. My question is that how it is possible to do this process based on parent and child class relationship? Also, other solutions will be welcome.
A simple example is shown below:
class A(nn.Module):
def __init__(self, in, out):
super(A, self).__init__()
self.a=function(.....)
def forward(self, x):
x=self.a(x)
return x
class B(nn.Module,A):
def __init__(self, in, out):
super(B, self).__init__()
self.b= nn.conv2d(....)
self.stacked = Function(self.b, A.a)
def forward(self, x):
x=self.b(x) + self.stacked(x)
return x
I do not know if this kind of parent and child classes is correct or not when it comes to the nn.module class?
Solution 1:[1]
You can use super().forward when B inherits from A:
import torch
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.conv2d(...)
def forward(self, x):
return self.a(x)
class B(A):
def __init__(self):
super().__init__()
self.b = torch.nn.conv2d(...)
self.stacked = Function(self.b, self.a)
def forward(self, x):
return self.b(x) + super().forward(x) + self.stacked(x)
And them similarly for another subclass C which inherits from B.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |
