'How to implement IoU metric in semantic segmentation multiclass problem - Pytorch?
So far I have the following variables, obviously with their respective values:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)
predictions, targets = learner.predict(val_loader)
Class number: 5
val_loader datatype: torch.utils.data.dataloader.DataLoader
predictions datatype: torch.Tensor
Size of predictions and targets: (torch.Size([10, 5, 224, 224]), torch.Size([10, 5, 224, 224]))
The predictions variable itself contains the predictions of the model, and the targets variable is the label.
Before calculating the IoU I must necessarily extract the tensors from val_loader, and the way I am doing it is using the iterator() and next(), but this does not convince me. Is there a cleaner way to extract the tensors from a val_loader? Below is the code using iterator() and next():
val_iter = iter(val_loader)
val_images_f, val_target_f = next(val_iter)
val_images_s, val_targets_s = next(val_iter)
val_images_t, val_targets_t = next(val_iter)
val_targets_all = torch.cat((val_target_f, val_targets_s, val_targets_t))
Subsequently, I am using val_targets_all as a true tag to calculate IoU. The code of the attempt to calculate IoU is as follows:
def mIOU(label, pred, num_classes=5):
pred = torch.nn.functional.softmax(pred, dim=1)
pred = torch.argmax(pred, dim=1).squeeze(1)
label = torch.nn.functional.softmax(label, dim=1)
label = torch.argmax(label, dim=1).squeeze(1)
iou_list = list()
present_iou_list = list()
pred = pred.view(-1)
label = label.view(-1)
for sem_class in range(num_classes):
pred_inds = (pred == sem_class)
target_inds = (label == sem_class)
if target_inds.long().sum().item() == 0:
iou_now = float('nan')
else:
intersection_now = (pred_inds[target_inds]).long().sum().item()
union_now = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection_now
iou_now = float(intersection_now) / float(union_now)
present_iou_list.append(iou_now)
iou_list.append(iou_now)
print(np.mean(iou_now))
return np.mean(present_iou_list)
Note: The test model has been trained with 6 epochs.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
