'For loops in a dictionary, pytorch
Hi guys I have a question, for the variable "image_datasets" there is a for loop for x in ['train', 'val']. I have never seen the implementation of a for loop in a dict before.
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]),
}
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=0)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
Solution 1:[1]
This is called dictionary comprehension and it is iterating over a list.
The code
dataloaders = {
x: torch.utils.data.DataLoader(
image_datasets[x],
batch_size=4,
shuffle=True,
num_workers=0
)
for x in ['train', 'val']
}
is equivalent to
dataloaders = {
'train': torch.utils.data.DataLoader(
image_datasets['train'],
batch_size=4,
shuffle=True,
num_workers=0
),
'val': torch.utils.data.DataLoader(
image_datasets['val'],
batch_size=4,
shuffle=True,
num_workers=0
)
}
Solution 2:[2]
It's called dict comprehension and there's also list comprehensions: https://book.pythontips.com/en/latest/comprehensions.html
They basically function as you'd expect it:
Do X with x for x in List
and the result is then used as the input for a dict or list and so on.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Zoom |
| Solution 2 | haxor789 |
