'Weather prediction model outputs weird, square-shaped images
As the title suggests, I'm trying to make an AI model predict various features about the weather. I'm basing my model on the openclimatefix implementation of Google's MetNet model. Specifically, I'm trying to use MetNet2. However, there have been a few issues with implementing the model (such as typos in the code and needing to lower certain parameters to make sure it can run locally), but after some work, I got it done. However, after trying to visualize some predictions, I get really weird results, such as the following.
The top images are the ground truth and the bottom images are my model's predictions. I have no idea how to fix this. I've tried changing various parameter sizes and the learning rate of my model multiple times, but nothing has seemed to help. Below is the important part of the code for running this. I also have separate files for calculating loss (using MS-SSIM) and for loading data
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import xarray as xr
from numpy import float32
from torch.utils.data import DataLoader
from loss import MS_SSIMLoss
plt.rcParams["figure.figsize"] = (20, 12)
BATCH_SIZE = 1
EPOCHS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
from metnet import MetNet2
model = MetNet2(
forecast_steps=24,
upsample_method = "interp",
input_channels=1,
sat_channels=1,
input_size=1024,
num_input_timesteps=12,
upsampler_channels=64,
lstm_channels=64,
encoder_channels=64,
output_channels=1,
center_crop_size=16
)
optimiser = optim.Adam(model.parameters(), lr=.01)
criterion = MS_SSIMLoss(channels=24) # produces less blurry images than nn.MSELoss()
losses = []
for epoch in range(EPOCHS):
print(f"Epoch {epoch + 1}")
running_loss = 0
i = 0
count = 0
for batch_coordinates, batch_features, batch_targets in ch_dataloader:
optimiser.zero_grad()
batch_predictions = model(batch_features.to(device).unsqueeze(dim=2))
batch_loss = criterion(batch_predictions.squeeze(dim=2), batch_targets.to(device))
batch_loss.backward()
optimiser.step()
running_loss += batch_loss.item() * batch_predictions.shape[0]
count += batch_predictions.shape[0]
i += 1
print(f"Completed batch {i} of epoch {epoch + 1} with loss {batch_loss.item()} -- processed {count} image sequences ({12 * count} images)")
losses.append(running_loss / count)
print(f"Loss for epoch {epoch + 1}/{EPOCHS}: {losses[-1]}")
for batch_coordinates, batch_features, batch_targets in ch_dataloader:
print(batch_features.shape)
p=model(batch_features.unsqueeze(dim=2)).squeeze(dim=2).detach().numpy()
fig, (ax1, ax2) = plt.subplots(1, 12, figsize=(20,8))
print(p.shape)
for i, img in enumerate(p[0][:12]):
ax2[i].imshow(img, cmap='viridis')
ax2[i].get_xaxis().set_visible(False)
ax2[i].get_yaxis().set_visible(False)
for i, img in enumerate(batch_targets[0][:12].numpy()):
ax1[i].imshow(img, cmap='viridis')
ax1[i].get_xaxis().set_visible(False)
ax1[i].get_yaxis().set_visible(False)
fig.tight_layout()
fig.subplots_adjust(wspace=0, hspace=0)
print(criterion(torch.from_numpy(p),batch_targets))
break
How should I continue? Any help would be appreciated. Thanks!
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|

