'Generating pap smear test cervical cancer images with WGAN-GP

I am a late starter in deep learning. I want to generate pap smear test cervical cancer images to augment a small dataset of 74 images. My model code often start training normally but freezes after several few epoch.I believe the problem is with me not getting the right hyperparameter. Can someone help me with the right tuning?

https://drive.google.com/drive/folders/1k0e_SwTWzjGuR33mNdeEiEdfmTYTgWNh?usp=sharing

from __future__ import print_function
# import argparse
# import os
# import random
import torch
import torch.nn as nn
import torch.nn.parallel
# import torch.nn.fuctional as F
# import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
# from torch.autograd import variable
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
# from torchvision.utils import save_image
from torchvision.utils import make_grid
# import matplotlib.animation as animation
# from IPython.display import HTML


def display_images(image_tensor, num_images=37, size=(3, 64, 64)):
    '''
    Function for visualizing images: Given a tensor of images,
    number of images, and size per image, plots and prints the
    images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=8)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

dataroot = "train_data"
workers = 2
batch_size = 37
image_size = 64
nc = 3
nz = 100
ngf = 64
ndf = 64
ngpu = 0
n_samples = 74

dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5,
                                                                      0.5,
                                                                      0.5)),
                           ]))
                           

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0)
                      else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))

# Display training set Batch image
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:37],
           padding=2, normalize=True).cpu(), (1, 2, 0)))
plt.show()


# initialize the weight to normal distribution with mean 0 and Std 0.02
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        torch.nn.init.constant_(m.bias, val=0)

# Generator random noise
def get_noise(n_samples, nz, device='cpu'):

    return torch.randn(n_samples, nz, 1, 1, device=device)

# Create the Generator

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64

             )

    def forward(self, input):
        output = self.main(input)
        return output

# Create the Generator Instance

netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
netG.apply(weights_init)
# Print the model
print(netG)


# Create the Critic

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 8, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),

        )

    def forward(self, input):
        output = self.main(input)
        return output


# Create the Critic Instance
netCritic = Discriminator(ngpu).to(device)


# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netCritic = nn.DataParallel(netCritic, list(range(ngpu)))
netCritic.apply(weights_init)

# Print the model
print(netCritic)


# Define the loss and optimizer for generator and critic

# Setup RMSprop optimizers for both Gen and Critic
lr = 5e-5
criterion = nn.BCEWithLogitsLoss()
optimizerCritic = optim.RMSprop(netCritic.parameters(), lr=lr)
optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

# Setup Adam optimizers for both G and Critic
# lr = 1e-4
# optimizerCritic = optim.Adam(netCritic.parameters(), lr=lr)
# optimizerG = optim.Adam(netG.parameters(), lr=lr)


# Declare Gradient penalty

def gradient_penalty(netCritic, real_image, fake_image, device="cpu"):

    batch_size, channel, height, width = real_image.shape

    # alpha is selected randomly between 0 and 1
    alpha = torch.rand(batch_size, 1, 1, 1).repeat(1, channel, height, width)
    # interpolated image=randomly weighted average between a real and fake
    # image
    # interpolated image ← alpha *real image  + (1 − alpha) fake image
    interpolated_image = (alpha*real_image) + (1-alpha) * fake_image

    # calculate the critic score on the interpolated image
    interpolated_score = netCritic(interpolated_image)

    # take the gradient of the score wrt to the interpolated image
    gradient = torch.autograd.grad(inputs=interpolated_image,
                                   outputs=interpolated_score,
                                   retain_graph=True,
                                   create_graph=True,
                                   grad_outputs=torch.ones_like
                                   (interpolated_score)
                                   )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1)**2)
    return gradient_penalty

# Training the WGAN with Gradient penalty

n_epochs = 2000
cur_step = 0
LAMBDA_GP = 10
display_step = 50
CRITIC_ITERATIONS = 5
nz = 100

for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real_image, _ in tqdm(dataloader):
        cur_batch_size = real_image.shape[0]
        real_image = real_image.to(device)
        for _ in range(CRITIC_ITERATIONS):
            fake_noise = get_noise(cur_batch_size, nz, device=device)
            fake = netG(fake_noise)
            critic_fake_pred = netCritic(fake).reshape(-1)
            critic_real_pred = netCritic(real_image).reshape(-1)

            # Calculate gradient penalty on real and fake images
            # generated by generator
            gp = gradient_penalty(netCritic, real_image, fake, device)
            critic_loss = -(torch.mean(critic_real_pred)
                            - torch.mean(critic_fake_pred)) + LAMBDA_GP * gp
            netCritic.zero_grad()
            # To make a backward pass and retain the intermediary results
            critic_loss.backward(retain_graph=True)
            optimizerCritic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = netCritic(fake).reshape(-1)
        gen_loss = -torch.mean(gen_fake)
        netG.zero_grad()
        gen_loss.backward()
        # Update optimizer
        optimizerG.step()

        # Visualization code 
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step{cur_step}: GenLoss: {gen_loss}: CLoss: {critic_loss}")
            display_images(fake)
            display_images(real_image)
            gen_loss = 0
            critic_loss = 0
        cur_step += 1

        # Save generated fake images into a folder
        img_list2 = []
        img_list2.append(fake)
        for i in range(len(fake)):
            vutils.save_image(img_list2[-1][i],
                              "GenPy_Image/gpi%d.jpg" % i, normalize=True)

Similar sample code was used on MNIST dataset with epoch = 50 and display_step = 500. But did not work for me. I tried changing optimizer to Adams with the recommended learning rate, epoch 2000 and display_step 100 and later 500. It often stop training and freezes after several epochs. I want it to train without interruption to generate the desired images.

enter image description here

enter image description here

Above are sample of real images and fake images generated before it freezes.

enter image description here

The training image above was generated using batch-size=74, epoch = 2000 and display_steps = 50 and it freezes at some steps above 250

I am presently training on CPU.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source