'how do I solve the error: 'list' object has no attribute 'data'

I want to make inference on specific part of imagenet validation set here :

my validation dataset has 1000 classes and in each class, there are 50 images. Some images are selected with the json file above for inference. I am not interested in the entire validation set. Below is what I have tried. #class1 : dataset.py

import json
import os
import torchvision

class getImageNet(torchvision.datasets.ImageFolder):

    def __init__(self, image_list="mypath/image_list.json", *args, **kwargs):
        self.image_list = set(json.load(open(image_list, "r"))["images"])
        super(getImageNet, self).__init__(is_valid_file=self.is_valid_file, *args, **kwargs)

    def is_valid_file(self, x: str) -> bool:
        return x[-38:] in self.image_list

if __name__ == '__main__':
    data_path = "mypath/val"
    transform = torchvision.transforms.ToTensor()
    dataset = getImageNet(root=data_path, transform=transform)

#class 2:

    from dataset import getImageNet

    def test(image_loader, model, epsilon = .5, target_class = 6):
    model.eval()
    success = 0
    total = 0
    count =0
    for batch_idx, (data, labels) in enumerate(test_loader):
        data = data.to(device)
        labels = labels.to(device)
        data.requires_grad = True
        img_output = model(data)
        conf, img_pred = img_output.data.max(1, keepdim=True)
        ...
        ...
        ...

if __name__ == '__main__':
    # Set train and test set
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    batch_size = 1

    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        #transforms.Normalize(mean=mean, std=std)
    ])
    
    test_dir = "mypath/val"
    test_data = getImageNet(image_list="mypath/image_list_1k.json", root=test_dir, transform=data_transform)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4,
                                              pin_memory=True)
    model = create_model("deit_small_patch16_224", pretrained=True)
    model.to(device)

    test(test_loader, model)

I am getting the error: AttributeError: 'list' object has no attribute 'data' and if i set a breakpoint at img_output = model(data) to see what's happening, I get the following error:

FileNotFoundError: Found no valid file for the classes n01440764, n01443537, n01484850, n01491361, n01494475, n01496331, n01498041, n01514668, n01514859, n01518878, n01530575, n01531178, n01532829, n01534433...(all other classes), n15075141


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source