'Filters are not being learnt in sparse coding
Could you please take a look at the code below? I am trying to implement a simple sparse coding algorithm. I try to visualize the filters at the end, but it seems that filters are not learnt at all. in the code, phi and weights should be learnt independently. I tried ISTA algorithm to learn phi. I appreciate it if you could take a look. Thank you.
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device is:', device)
# dataset definition
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
mnist_trainset.data = mnist_trainset.data[:10000]
mnist_testset.data = mnist_testset.data[:5000]
from torch.utils.data import DataLoader
train_dl = DataLoader(mnist_trainset, batch_size=32, shuffle=True)
test_dl = DataLoader(mnist_testset, batch_size=1024, shuffle=False)
from numpy import vstack
from sklearn.metrics import accuracy_score
from torch.optim import SGD
from torch.nn import Module
from torch.nn import Linear
from tqdm import tqdm
class MNIST_ISTA(Module):
# define model elements
def __init__(self, n_inputs):
self.lambda_ = 0.5e-5
super(MNIST_ISTA, self).__init__()
# input to first hidden layer
# self.sc = Scattering2D(shape=(28,28), J=2)
# self.view = Vi
self.neurons = 400
self.receptive_field = 10
self.output = Linear(self.neurons, 28*28)
self.phi = None
# forward propagate input
def ista_(self, img_batch):
self.phi = torch.zeros((img_batch.shape[0], 400), requires_grad=True)
converged = False
# optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.SGD([{'params': self.phi, "lr": 0.1e-3},{'params': self.parameters(), "lr": 0.1e-3}])
while not converged:
phi_old = self.phi.clone().detach()
pred = self.output(self.phi)
loss = ((img_batch-pred)**2).sum() + torch.norm(self.phi,p=1)
loss.backward()
optimizer.step()
self.zero_grad()
self.phi.data = self.soft_thresholding_(self.phi, self.lambda_ )
converged = torch.norm(self.phi - phi_old)/torch.norm(phi_old)<1e-1
def soft_thresholding_(self,x, alpha):
with torch.no_grad():
rtn = F.relu(x-alpha)- F.relu(-x-alpha)
return rtn.data
def zero_grad(self):
self.phi.grad.zero_()
self.output.zero_grad()
def forward(self, img_batch):
self.ista_(img_batch)
pred = self.output(self.phi)
return pred
ista_model = MNIST_ISTA(400)
optim = torch.optim.SGD([{'params': ista_model.output.weight, "lr": 0.01}])
for epoch in range(100):
running_loss = 0
c=0
for img_batch in tqdm(train_dl, desc='training', total=len(train_dl)):
img_batch = img_batch[0]
img_batch = img_batch.reshape(img_batch.shape[0], -1)
pred = ista_model(img_batch)
loss = ((img_batch - pred) ** 2).sum()
running_loss += loss.item()
loss.backward()
optim.step()
# zero grad
ista_model.zero_grad()
weight = ista_model.output.weight.data.numpy()
print(weight.shape)
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(20,20))
for i in range(20):
for j in range(20):
ax=fig.add_subplot(20,20,i*20+j+1)
ax.imshow(weight[:,1].reshape((28,28)))
ax.axis('off')
# plt.close(fig)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
