'PyTorch, could not broadcast shape
I trained an auto encoder in PyTorch with some satellite images. I'm trying to extract the embedding of the images using the code below and when I execute the get_all_embeddings function, I receive the following error:
ValueError: could not broadcast input array from shape (3,256,256) into shape (256,256).
The error doesn’t seem to happen when retrieving a single embedding with get_single_embedding. I'm not sure where I'm going wrong here.
This process has worked previously doing the same task using a ResNet50 model - the only difference was I was trying to work with an auto encoder this time.
The images I'm using are 256x256 png images with 3 bands (R,G,B).
import numpy as np
import matplotlib.image as mpimg
import torch
import torch.nn as nn
import torchvision.transforms as transforms
class AutoEncoder(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.encoder_hidden_layer = nn.Linear(
in_features=kwargs["input_shape"], out_features=128
)
self.encoder_output_layer = nn.Linear(
in_features=128, out_features=128
)
self.decoder_hidden_layer = nn.Linear(
in_features=128, out_features=128
)
self.decoder_output_layer = nn.Linear(
in_features=128, out_features=kwargs["input_shape"]
)
def forward(self, features):
activation = self.encoder_hidden_layer(features)
activation = torch.relu(activation)
code = self.encoder_output_layer(activation)
code = torch.relu(code)
activation = self.decoder_hidden_layer(code)
activation = torch.relu(activation)
activation = self.decoder_output_layer(activation)
reconstructed = torch.relu(activation)
return reconstructed
model = AutoEncoder(input_shape=256).to(device)
transform = transforms.Compose([
transforms.ToPILImage(mode='RGB'),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
def get_single_embedding(image_path, model, transform):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image = mpimg.imread(image_path)
image = transform(image)
image = image.unsqueeze(0)
image = image.to(device)
model.eval()
with torch.no_grad():
embedding = model(image).to(device)
return embedding.cpu().numpy()
def get_all_embeddings(image_paths, model, transform):
image_embeddings = np.zeros(shape=(3, 256, 256))
for idx, file in enumerate(image_paths):
image_embeddings[idx] = get_single_embedding(file, model, transform)
print(image_embeddings.shape)
return image_embeddings
image_embeddings = get_all_embeddings(image_files, model, transform)
query_embedding = get_single_embedding(query_image, model, transform)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
