'how do I solve the error: 'list' object has no attribute 'data'
I want to make inference on specific part of imagenet validation set here :
my validation dataset has 1000 classes and in each class, there are 50 images. Some images are selected with the json file above for inference. I am not interested in the entire validation set. Below is what I have tried. #class1 : dataset.py
import json
import os
import torchvision
class getImageNet(torchvision.datasets.ImageFolder):
def __init__(self, image_list="mypath/image_list.json", *args, **kwargs):
self.image_list = set(json.load(open(image_list, "r"))["images"])
super(getImageNet, self).__init__(is_valid_file=self.is_valid_file, *args, **kwargs)
def is_valid_file(self, x: str) -> bool:
return x[-38:] in self.image_list
if __name__ == '__main__':
data_path = "mypath/val"
transform = torchvision.transforms.ToTensor()
dataset = getImageNet(root=data_path, transform=transform)
#class 2:
from dataset import getImageNet
def test(image_loader, model, epsilon = .5, target_class = 6):
model.eval()
success = 0
total = 0
count =0
for batch_idx, (data, labels) in enumerate(test_loader):
data = data.to(device)
labels = labels.to(device)
data.requires_grad = True
img_output = model(data)
conf, img_pred = img_output.data.max(1, keepdim=True)
...
...
...
if __name__ == '__main__':
# Set train and test set
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 1
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
#transforms.Normalize(mean=mean, std=std)
])
test_dir = "mypath/val"
test_data = getImageNet(image_list="mypath/image_list_1k.json", root=test_dir, transform=data_transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4,
pin_memory=True)
model = create_model("deit_small_patch16_224", pretrained=True)
model.to(device)
test(test_loader, model)
I am getting the error: AttributeError: 'list' object has no attribute 'data' and if i set a breakpoint at img_output = model(data) to see what's happening, I get the following error:
FileNotFoundError: Found no valid file for the classes n01440764, n01443537, n01484850, n01491361, n01494475, n01496331, n01498041, n01514668, n01514859, n01518878, n01530575, n01531178, n01532829, n01534433...(all other classes), n15075141
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
