'Combine multiple DataLoaders sequentially

I'm interested in how I'd go about combining multiple DataLoaders sequentially for training. I understand I can use ConcatDataset to combine datasets first, but this does not work for my use case. I have a custom collate_fn that is passed to each dataloader, and this function depends on an attribute of the underlying Dataset. So, I'll have a set of custom DataLoaders like the following:

def custom_collate(sample, ref):
    data = clean_sample(torch.stack([x[0] for x in sample]), ref)
    labels = torch.tensor([x[1] for x in sample])
    return data, labels

class CollateLoader(torch.utils.data.DataLoader):
    def __init__(self, ref, *args, **kwargs):
        collate_fn = functools.partial(custom_collate, ref=ref)
        super().__init__(collate_fn = collate_fn, *args, **kwargs)

Where ref is a property of the custom Dataset class and is passed on initialization of a CollateLoader. Also, I know transforms can be applied in the Dataset, but in my case it must be done batch-wise.

So, how would I go about combining multiple DataLoaders? In the PyTorch-Lightning LightningDataModule, we can do something like

def train_dataloader(self):
    return [data_loader_1, data_loader_2]

But this will return a list of batches, not the batches sequentially.



Solution 1:[1]

I ran into the same problem and found a workaround. I overrided the epoch training loop using the Loops API from PytorchLightning, defining a class CustomLoop which inherits from pytorch_lightning.loops.TrainingEpochLoop, and overrided the advance() method. I copy pasted the source code from pytorch_lightning and replaced these lines with:

if not hasattr(self,'dataloader_idx'):
    self.dataloader_idx=0
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
    batch_idx = self.batch_idx + 1
    batch = next(data_fetcher.dataloader.loaders[self.dataloader_idx])
    self.dataloader_idx+=1
    if self.dataloader_idx == len(data_fetcher.dataloader.loaders):
        self.dataloader_idx = 0
else:
    batch_idx, batch = next(data_fetcher)

That way, instead of iterating over the CombinedLoader, i make it iterate over one dataloader at a time. Then, to make use of this custom loop you have to replace the default loop in the Trainer:

trainer.fit_loop.replace(epoch_loop=CustomLoop)
trainer.fit(my_model)

Solution 2:[2]

You can return [train_dataloader, train_2_dataloader] and then you take two batches, each dataloader, so, you can apply a for and sum losses

Solution 3:[3]

If you only need to move the preview panel to the right and remove another panels, you may write a custom UI for the color chooser. Here is the example:

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Component;

import javax.swing.JColorChooser;
import javax.swing.JComponent;
import javax.swing.SwingUtilities;
import javax.swing.UIManager;
import javax.swing.plaf.ComponentUI;
import javax.swing.plaf.basic.BasicColorChooserUI;

public class ColorChooserTest {

    public static void main(String[] args) {
        SwingUtilities.invokeLater(new ColorChooserTest()::initUI);
    }
    
    private void initUI() {
        // if you need this behavior for all color choosers you shouldn't restore the old UI
        Object oldUI = UIManager.get("ColorChooserUI");
        UIManager.put("ColorChooserUI", RightSidePreviewColorUI.class.getName());
        Color c = JColorChooser.showDialog(null, "Right aligned chooser", Color.BLACK);
        UIManager.put("ColorChooserUI", oldUI);
        if (c == null) {
            System.out.println("You've pressed cancel!");
        } else {
            System.out.println("You've chosen: " + c);
        }
    }
    
    // if you use Nimbus L&F you must extend SynthColorChooserUI
    public static class RightSidePreviewColorUI extends BasicColorChooserUI {
        public static ComponentUI createUI(JComponent c) {
            return new RightSidePreviewColorUI();
        }

        @Override
        public void installUI(JComponent c) {
            super.installUI(c);
            Component comp = chooser.getPreviewPanel().getParent();
            chooser.removeAll();
            chooser.add(defaultChoosers[3]);
            chooser.add(comp, BorderLayout.EAST);

        }
    }
}

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 Johnny
Solution 3