'Filters are not being learnt in sparse coding

Could you please take a look at the code below? I am trying to implement a simple sparse coding algorithm. I try to visualize the filters at the end, but it seems that filters are not learnt at all. in the code, phi and weights should be learnt independently. I tried ISTA algorithm to learn phi. I appreciate it if you could take a look. Thank you.

import torch 
import torch.nn.functional as F

from torchvision import datasets

from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device is:', device)

# dataset definition
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
mnist_trainset.data = mnist_trainset.data[:10000]
mnist_testset.data = mnist_testset.data[:5000]

from torch.utils.data import DataLoader
train_dl = DataLoader(mnist_trainset, batch_size=32, shuffle=True)
test_dl = DataLoader(mnist_testset, batch_size=1024, shuffle=False)


from numpy import vstack
from sklearn.metrics import accuracy_score
from torch.optim import SGD 
from torch.nn import Module
from torch.nn import Linear
from tqdm import tqdm

class MNIST_ISTA(Module):
  
    # define model elements
    def __init__(self, n_inputs):
        self.lambda_ = 0.5e-5
        super(MNIST_ISTA, self).__init__()
        # input to first hidden layer
        # self.sc = Scattering2D(shape=(28,28), J=2)
        # self.view = Vi
        self.neurons = 400
        self.receptive_field = 10
        self.output = Linear(self.neurons, 28*28)
        self.phi = None
       

    # forward propagate input


    def ista_(self, img_batch):

        self.phi = torch.zeros((img_batch.shape[0], 400), requires_grad=True)
        converged = False
        # optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
        optimizer = torch.optim.SGD([{'params': self.phi, "lr": 0.1e-3},{'params': self.parameters(), "lr": 0.1e-3}])
        while not converged:
            phi_old = self.phi.clone().detach()
            pred = self.output(self.phi)
            loss = ((img_batch-pred)**2).sum() + torch.norm(self.phi,p=1)
            loss.backward()
            optimizer.step()
            self.zero_grad()
            self.phi.data = self.soft_thresholding_(self.phi, self.lambda_ )
            converged = torch.norm(self.phi - phi_old)/torch.norm(phi_old)<1e-1

    def soft_thresholding_(self,x, alpha):
        with torch.no_grad():
            rtn = F.relu(x-alpha)- F.relu(-x-alpha)
        return rtn.data

    def zero_grad(self):
        self.phi.grad.zero_()
        self.output.zero_grad()

    def forward(self, img_batch):
        self.ista_(img_batch)
        pred = self.output(self.phi)
        return pred



ista_model = MNIST_ISTA(400)
optim = torch.optim.SGD([{'params': ista_model.output.weight, "lr": 0.01}])

for epoch in range(100):
    running_loss = 0
    c=0
    for img_batch in tqdm(train_dl, desc='training', total=len(train_dl)):
        img_batch = img_batch[0]
        img_batch = img_batch.reshape(img_batch.shape[0], -1)
        
        pred = ista_model(img_batch)

        loss = ((img_batch - pred) ** 2).sum()
        running_loss += loss.item()
        loss.backward()

        optim.step()
        # zero grad
        ista_model.zero_grad()

weight = ista_model.output.weight.data.numpy()
print(weight.shape)

import matplotlib.pyplot as plt
fig = plt.figure(figsize=(20,20))
for i in range(20):
  for j in range(20):
    ax=fig.add_subplot(20,20,i*20+j+1)
    ax.imshow(weight[:,1].reshape((28,28)))
    ax.axis('off')
# plt.close(fig)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source