'Weather prediction model outputs weird, square-shaped images

As the title suggests, I'm trying to make an AI model predict various features about the weather. I'm basing my model on the openclimatefix implementation of Google's MetNet model. Specifically, I'm trying to use MetNet2. However, there have been a few issues with implementing the model (such as typos in the code and needing to lower certain parameters to make sure it can run locally), but after some work, I got it done. However, after trying to visualize some predictions, I get really weird results, such as the following.

weird outputs

The top images are the ground truth and the bottom images are my model's predictions. I have no idea how to fix this. I've tried changing various parameter sizes and the learning rate of my model multiple times, but nothing has seemed to help. Below is the important part of the code for running this. I also have separate files for calculating loss (using MS-SSIM) and for loading data

    import matplotlib.pyplot as plt
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import xarray as xr
    from numpy import float32
    from torch.utils.data import DataLoader

    from loss import MS_SSIMLoss

plt.rcParams["figure.figsize"] = (20, 12)
    BATCH_SIZE = 1
    EPOCHS = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "CPU")

    from metnet import MetNet2

    model = MetNet2(
        forecast_steps=24,
        upsample_method = "interp",
        input_channels=1,
        sat_channels=1,
        input_size=1024,
        num_input_timesteps=12,
        upsampler_channels=64,
        lstm_channels=64,
        encoder_channels=64,
        output_channels=1,
        center_crop_size=16
    )
    optimiser = optim.Adam(model.parameters(), lr=.01)
    criterion = MS_SSIMLoss(channels=24) # produces less blurry images than nn.MSELoss()
    losses = []
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}")
        running_loss = 0
        i = 0
        count = 0
        for batch_coordinates, batch_features, batch_targets in ch_dataloader:
            optimiser.zero_grad()
            batch_predictions = model(batch_features.to(device).unsqueeze(dim=2))

            batch_loss = criterion(batch_predictions.squeeze(dim=2), batch_targets.to(device))
            batch_loss.backward()

            optimiser.step()

            running_loss += batch_loss.item() * batch_predictions.shape[0]
            count += batch_predictions.shape[0]
            i += 1

            print(f"Completed batch {i} of epoch {epoch + 1} with loss {batch_loss.item()} -- processed {count} image sequences ({12 * count} images)")
    
    losses.append(running_loss / count)
    print(f"Loss for epoch {epoch + 1}/{EPOCHS}: {losses[-1]}")
    for batch_coordinates, batch_features, batch_targets in ch_dataloader:
        print(batch_features.shape)
        p=model(batch_features.unsqueeze(dim=2)).squeeze(dim=2).detach().numpy()
        fig, (ax1, ax2) = plt.subplots(1, 12, figsize=(20,8))
        print(p.shape)
        for i, img in enumerate(p[0][:12]):
            ax2[i].imshow(img, cmap='viridis')
            ax2[i].get_xaxis().set_visible(False)
            ax2[i].get_yaxis().set_visible(False)
        for i, img in enumerate(batch_targets[0][:12].numpy()):
            ax1[i].imshow(img, cmap='viridis')
            ax1[i].get_xaxis().set_visible(False)
            ax1[i].get_yaxis().set_visible(False)
        fig.tight_layout()
        fig.subplots_adjust(wspace=0, hspace=0) 
        print(criterion(torch.from_numpy(p),batch_targets))
        break

How should I continue? Any help would be appreciated. Thanks!



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source