'Image preprocessing code error of target detection neural network based on U-Net network architecture
Recently I found a U-Net-based target detection network code, but there is always a missing "image_index" in the definition position of the image preprocessing part of the code. I tried many methods to solve it, but all failed. This is code:
import torch
from torch.utils.data import Dataset
import glob
from PIL import Image
from torchvision import transforms
from skimage.segmentation import mark_boundaries
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.transforms import Grayscale, RandomHorizontalFlip, Resize, ToTensor
import numpy as np
import matplotlib.pyplot as plt
import os
class InfraredDataset(Dataset):
def __init__(self, dataset_dir, image_index):
super(InfraredDataset, self).__init__()
self.dataset_dir = dataset_dir
self.image_inde x = image_index
self.transformer = transforms.Compose([
Resize((256, 256)),
Grayscale(),
ToTensor(),
RandomHorizontalFlip(0.5),
])
def __getitem__(self, index):
image_index = self.image_index[index].strip('\n')
image_path = os.path.join(self.dataset_dir, 'images', '%s.png' % image_index)
label_path = os.path.join(self.dataset_dir, 'masks', '%s_pixels0.png' % image_index)
image = Image.open(image_path)
label = Image.open(label_path)
torch.manual_seed(1024)
tensor_image = self.transformer(image)
torch.manual_seed(1024)
label = self.transformer(label)
label[label > 0] = 1
return tensor_image, label
def __len__(self):
return len(self.image_index)
if __name__ == "__main__":
f = open('../sirst/idx_427/trainval.txt').readlines()
ds = InfraredDataset(f)
# 数据集测试
for i, (image, label) in enumerate(ds):
image, label = to_pil_image(image), to_pil_image(label)
image, label = np.array(image), np.array(label)
print(image.shape, label.shape)
vis = mark_boundaries(image, label, color=(1, 1, 0))
image, label = np.stack([image] * 3, -1), np.stack([label] * 3, -1)
plt.imsave('image_%d.png' % i, vis)
This is error:
Traceback (most recent call last):
File "H:/ProgramData/Infrared-detect-by-segmentation-master/Infrared-detect-by-segmentation-master/utils/dataloader.py", line 55, in <module>
ds = InfraredDataset(f)
TypeError: __init__() missing 1 required positional argument: 'image_index'
I tried a lot of methods, but I didn't find a solution, I hope the big guys can help, thank you!
Solution 1:[1]
As you see, your code define two key words are "dataset_dir" and "image_index",but when you test the module of the dataste, you defint the "InfaredDataset" rename as "ds" but the the module's defintion have not change,so if you want test the module, you must define two key words like that:
if __name__ == "__main__":
dataset_dir = 'H:/ProgramData/Infrared-detect-by-segmentation-master/Infrared-detect-by-segmentation-master/sirst'
image_index = open('H:/ProgramData/Infrared-detect-by-segmentation-master/Infrared-detect-by-segmentation-master/sirst/idx_320/val.txt').readlines()
ds = InfraredDataset(dataset_dir, image_index)
then you can slove the problem.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Osvart |
