'WeightedRandomSampler for custom image dataloader pytorch

I am trying to solve class imbalance by using Weighted Random Sampler on a custom data loader. I can't seem to find the best way to implement this. The images are in a folder and labels are in a csv file. The dataloader code without the weighted random sampler is given below.

class CassavaDataset(Dataset):
    def __init__(self, df, data_root, transforms=None, output_label=True):
        super().__init__()
        self.df = df.reset_index(drop=True).copy() # data
        self.transforms = transforms
        self.data_root = data_root
        self.output_label = output_label
 
    def __len__(self):
        return self.df.shape[0] # or len(self.df)
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.df.iloc[index]['label']
          
        path = "{}/{}".format(self.data_root, self.df.iloc[index]['image_id'])
        
        img  = get_img(path)

        if self.transforms:
            img = self.transforms(image=img)['image']
           
        # do label smoothing
        if self.output_label == True:
            return img, target
        else:
            return img

What will be the best way to get weights of each class and feed it to the sampler before augmentation? Thanks in advance!



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source