'PyTorch: Custom batch sampler exhausts after first epoch
I'm using a DataLoader with a custom batch_sampler to ensure each batch is class balanced. How do I prevent the iterator from exhausting itself on the first epoch?
import torch
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
self.x = torch.rand(10, 10)
self.y = torch.Tensor([0] * 5 + [1] * 5)
def __len__(self):
len(self.y)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return iter(batch_idx)
def train(loader):
for epoch in range(10):
for batch, (x, y) in enumerate(loader):
print('epoch:', epoch, 'batch:', batch) # stops after first epoch
if __name__=='__main__':
my_dataset = CustomDataset()
my_loader = torch.utils.data.DataLoader(
dataset=my_dataset,
batch_sampler=custom_batch_sampler()
)
train(my_loader)
Training stops after the first epoch and next(iter(loader)) gives a StopIteration error.
epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4
Solution 1:[1]
The custom batch sampler needs to be a Sampler or some iterable. In each epoch a new iterator is generated from this iterable. This means you don't actually need to manually make an iterator (which will run out and raise StopIteration after the first epoch), but you can just provide your list, so it should work if you remove the iter():
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return batch_idx
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | flawr |
