'PyTorch, could not broadcast shape

I trained an auto encoder in PyTorch with some satellite images. I'm trying to extract the embedding of the images using the code below and when I execute the get_all_embeddings function, I receive the following error:

ValueError: could not broadcast input array from shape (3,256,256) into shape (256,256).

The error doesn’t seem to happen when retrieving a single embedding with get_single_embedding. I'm not sure where I'm going wrong here.

This process has worked previously doing the same task using a ResNet50 model - the only difference was I was trying to work with an auto encoder this time.

The images I'm using are 256x256 png images with 3 bands (R,G,B).

import numpy as np
import matplotlib.image as mpimg
import torch
import torch.nn as nn
import torchvision.transforms as transforms

class AutoEncoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=128
        )
        self.encoder_output_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_output_layer = nn.Linear(
            in_features=128, out_features=kwargs["input_shape"]
        )

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)
        code = self.encoder_output_layer(activation)
        code = torch.relu(code)
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.relu(activation)
        return reconstructed

model = AutoEncoder(input_shape=256).to(device)

transform = transforms.Compose([
    transforms.ToPILImage(mode='RGB'),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

def get_single_embedding(image_path, model, transform):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image = mpimg.imread(image_path)
    image = transform(image)
    image = image.unsqueeze(0)
    image = image.to(device)
    model.eval()

    with torch.no_grad():
        embedding = model(image).to(device)

    return embedding.cpu().numpy()

def get_all_embeddings(image_paths, model, transform):
    image_embeddings = np.zeros(shape=(3, 256, 256))
    
    for idx, file in enumerate(image_paths):
        image_embeddings[idx] = get_single_embedding(file, model, transform)
    
    print(image_embeddings.shape)

    return image_embeddings

image_embeddings = get_all_embeddings(image_files, model, transform)
query_embedding = get_single_embedding(query_image, model, transform)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source