'How to share weights between modules in Pytorch?

What is the correct way of sharing weights between two layers(modules) in Pytorch?
Based on my findings in the Pytorch discussion forum, there are several ways for doing this.
As an example, based on this discussion, I thought simply assigning the transposed weights would do it. That is doing :

 self.decoder[0].weight = self.encoder[0].weight.t()

This however, proved to be wrong and causes an error. I then tried wrapping the above line in a nn.Parameter():

self.decoder[0].weight = nn.Parameter(self.encoder[0].weight.t())

This eliminates the error, but then again, there is no sharing happening here. by this I just initialized a new tensor with the same values as the encoder[0].weight.t().

I then found this link which provides different ways for sharing weights. however, I'm skeptical if all methods given there are actually correct.
For example, one way is demonstrated like this :

# tied autoencoder using off the shelf nn modules
class TiedAutoEncoderOffTheShelf(nn.Module):
    def __init__(self, inp, out, weight):
        super().__init__()
        self.encoder = nn.Linear(inp, out, bias=False)
        self.decoder = nn.Linear(out, inp, bias=False)

        # tie the weights
        self.encoder.weight.data = weight.clone()
        self.decoder.weight.data = self.encoder.weight.data.transpose(0,1)

    def forward(self, input):
        encoded_feats = self.encoder(input)
        reconstructed_output = self.decoder(encoded_feats)
        return encoded_feats, reconstructed_output

Basically it creates a new weight tensor using nn.Parameter() and assigns it to each layer/module like this :

weights = nn.Parameter(torch.randn_like(self.encoder[0].weight))
self.encoder[0].weight.data = weights.clone()
self.decoder[0].weight.data = self.encoder[0].weight.data.transpose(0, 1)

This really confuses me, how is this sharing the same variable between these two layers? Is it not just cloning the 'raw' data?
When I used this approach, and visualized the weights, I noticed the visualizations were different and that make me even more certain something is not right.
I'm not sure if the different visualizations were solely due to one being the transpose of the other one, or as I just already suspected, they are optimized independently (i.e. the weights are not shared between layers)

example weight initialization : enter image description here enter image description here



Solution 1:[1]

As it turns out, after further investigation, which was simply retransposing the decoder's weight and visualized it, they were indeed shared.
Below is the visualization for encoder and decoders weights : enter image description here enter image description here

Solution 2:[2]

AI questions in general have the tendency to be wrongly understood, including this one in particular. I will rephrase your question as:

Can layer A from module M1 and layer B from module M2 share the weights WA = WB, or possibly even WA = WB.transpose?

This is possible via PyTorch hooks where you would update forward hook of A to alter the WB and possible you would freeze WB in M2 autograd.

So just use hooks.


from time import sleep

import torch
import torch.nn as nn
class M(nn.Module):
    def __init__(self):        
        super().__init__()        
        self.l1 = nn.Linear(1,2)
        
    def forward(self, x):                      
        x = self.l1(x)
        return x

model = M()
model.train()

def printh(module, inp, outp):
    sleep(1)    
    print("update other model parameter in here...")   


h = model.register_forward_hook(printh)
for i in range(1,4):
    
    x = torch.randn(1)
    output = model(x)

h.remove()

Solution 3:[3]

Interestingly enough, you were right about your first intuition @Rika:

This really confuses me, how is this sharing the same variable between these two layers? Is it not just cloning the 'raw' data?

A lot of people actually got this wrong in blog posts or their own repos.

Also self.decoder[0].weight = nn.Parameter(self.encoder[0].weight.t()) will simply create a new weight matrix, as you wrote.

The only viable course of actions seems to be to use the linear function called by nn.Linear (torch.nn.functional.linear()):

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

# Real off-the-shelf tied linear module
class TiedLinear(nn.Module):
    def __init__(self, tied_to: nn.Linear, bias: bool = True):
        super().__init__()
        self.tied_to = tied_to
        if bias:
            self.bias = nn.Parameter(torch.Tensor(tied_to.in_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    # copied from nn.Linear
    def reset_parameters(self):
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.tied_to.weight.t())
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.tied_to.weight.t(), self.bias)

    # To keep module properties intuitive
    @property
    def weight(self) -> torch.Tensor:
        return self.tied_to.weight.t()

# Shared weights, different biases
encoder = nn.Linear(in, out)
decoder = TiedLinear(encoder)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Hossein
Solution 2 Mateen Ulhaq
Solution 3 alihwe