'Image preprocessing code error of target detection neural network based on U-Net network architecture

Recently I found a U-Net-based target detection network code, but there is always a missing "image_index" in the definition position of the image preprocessing part of the code. I tried many methods to solve it, but all failed. This is code:

import torch
from torch.utils.data import Dataset
import glob
from PIL import Image
from torchvision import transforms
from skimage.segmentation import mark_boundaries
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.transforms import Grayscale, RandomHorizontalFlip, Resize, ToTensor
import numpy as np
import matplotlib.pyplot as plt
import os


class InfraredDataset(Dataset):
    def __init__(self, dataset_dir, image_index):
        super(InfraredDataset, self).__init__()
        self.dataset_dir = dataset_dir
        self.image_inde x = image_index
        self.transformer = transforms.Compose([
            Resize((256, 256)),
            Grayscale(),
            ToTensor(),
            RandomHorizontalFlip(0.5),
        ])

    def __getitem__(self, index):
        image_index = self.image_index[index].strip('\n')
        image_path = os.path.join(self.dataset_dir, 'images', '%s.png' % image_index)
        label_path = os.path.join(self.dataset_dir, 'masks', '%s_pixels0.png' % image_index)
        image = Image.open(image_path)
        label = Image.open(label_path)
        torch.manual_seed(1024)
        tensor_image = self.transformer(image)
        torch.manual_seed(1024)
        label = self.transformer(label)
        label[label > 0] = 1
        return tensor_image, label

    def __len__(self):
        return len(self.image_index)


if __name__ == "__main__":
    f = open('../sirst/idx_427/trainval.txt').readlines()
    ds = InfraredDataset(f)
    # 数据集测试
    for i, (image, label) in enumerate(ds):
        image, label = to_pil_image(image), to_pil_image(label)
        image, label = np.array(image), np.array(label)
        print(image.shape, label.shape)
        vis = mark_boundaries(image, label, color=(1, 1, 0))
        image, label = np.stack([image] * 3, -1), np.stack([label] * 3, -1)
        plt.imsave('image_%d.png' % i, vis)

This is error:

Traceback (most recent call last):
  File "H:/ProgramData/Infrared-detect-by-segmentation-master/Infrared-detect-by-segmentation-master/utils/dataloader.py", line 55, in <module>
    ds = InfraredDataset(f)
TypeError: __init__() missing 1 required positional argument: 'image_index'

I tried a lot of methods, but I didn't find a solution, I hope the big guys can help, thank you!



Solution 1:[1]

As you see, your code define two key words are "dataset_dir" and "image_index",but when you test the module of the dataste, you defint the "InfaredDataset" rename as "ds" but the the module's defintion have not change,so if you want test the module, you must define two key words like that:

if __name__ == "__main__":
    dataset_dir = 'H:/ProgramData/Infrared-detect-by-segmentation-master/Infrared-detect-by-segmentation-master/sirst'
    image_index = open('H:/ProgramData/Infrared-detect-by-segmentation-master/Infrared-detect-by-segmentation-master/sirst/idx_320/val.txt').readlines()
    ds = InfraredDataset(dataset_dir, image_index)

then you can slove the problem.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Osvart