'Train Pytorch Autoencoder with custom dataset
I am new to Pytorch. I was able to build an autoencoder model and train it using the MINST dataset.
However, I need to train the model using a custom dataset.
I am getting the error 'ToTensor' object is not iterable when i try to train with the custom dataset.
Below is a code of my dataset class
class AutoEncoderDataSet(Dataset):
def __init__(self, in_dir, transform):
self._transforms = transform
self.img_paths = []
files = os.listdir(in_dir)
for file in files:
self.img_paths.append(os.path.join(in_dir, file))
def __getitem__(self, index):
img, img_trans = Image.open(self.img_paths[index]), Image.open(self.img_paths[index])
x, y = transform(img), transform(img_trans)
return x, y
def __len__(self):
return len(self.img_paths)
Here is how I am generating the dataloader
transform = transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = AutoEncoderDataSet('./datasets/train/', transform)
batch_size = 512
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
When I try to train using the data generated with the custom dataset class, I am getting the error mentioned above.
Below is code for training the model
epochs = 2048
for epoch in range(epochs):
loss = 0
for batch_features, _ in train_loader:
# reshape mini-batch data to [N, 784] matrix
# load it to the active device
batch_features = batch_features.view(-1, 250*250).to(device)
# reset the gradients back to zero
# PyTorch accumulates gradients on subsequent backward passes
optimizer.zero_grad()
# compute ecoder output
outputs = model(batch_features)
# compute training reconstruction loss
train_loss = criterion(outputs, batch_features)
# compute accumulated gradients
train_loss.backward()
# perform parameter update based on current gradients
optimizer.step()
# add the mini-batch training loss to epoch loss
loss += train_loss.item()
# compute the epoch training loss
loss = loss / len(train_loader)
# display the epoch training loss
print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
And this is the error I am getting
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_11164/1462449221.py in <module>
3 for epoch in range(epochs):
4 loss = 0
----> 5 for batch_features, _ in test_loader:
6 # reshape mini-batch data to [N, 784] matrix
7 # load it to the active device
~\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
519 if self._sampler_iter is None:
520 self._reset()
--> 521 data = self._next_data()
522 self._num_yielded += 1
523 if self._dataset_kind == _DatasetKind.Iterable and \
~\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
559 def _next_data(self):
560 index = self._next_index() # may raise StopIteration
--> 561 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
562 if self._pin_memory:
563 data = _utils.pin_memory.pin_memory(data)
~\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
47 def fetch(self, possibly_batched_index):
48 if self.auto_collation:
---> 49 data = [self.dataset[idx] for idx in possibly_batched_index]
50 else:
51 data = self.dataset[possibly_batched_index]
~\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
47 def fetch(self, possibly_batched_index):
48 if self.auto_collation:
---> 49 data = [self.dataset[idx] for idx in possibly_batched_index]
50 else:
51 data = self.dataset[possibly_batched_index]
~\AppData\Local\Programs\Python\Python38\lib\site-packages\torchvision\datasets\folder.py in __getitem__(self, index)
232 sample = self.loader(path)
233 if self.transform is not None:
--> 234 sample = self.transform(sample)
235 if self.target_transform is not None:
236 target = self.target_transform(target)
~\AppData\Local\Programs\Python\Python38\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, img)
58
59 def __call__(self, img):
---> 60 for t in self.transforms:
61 img = t(img)
62 return img
TypeError: 'ToTensor' object is not iterable
Any suggestions would be greatly appreciated.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
