'PYTORCH DISTRIBUTED: Can't save the last checkpoint

I am running a distributed code with PyTorch and SLURM. I am using 8 GPUs. Here is part of the training code:

dist.init_process_group(
        backend='nccl', 
        init_method='env://', 
        world_size=sdenv.size, 
        rank=sdenv.rank)

...

torch.cuda.set_device(sdenv.local_rank)
    gpu = torch.device("cuda")
    model = SimpleCNN().to(gpu)
    ddp_model = DistributedDataParallel(
        model, 
        device_ids=[sdenv.local_rank])
batch_size_per_gpu = batch_size // sdenv.size #batch_size is the "total" batch_size

...

image_datasets = {x: FileNames(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train']}

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        image_datasets['train'],
        num_replicas=sdenv.size,
        rank=sdenv.rank)

    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size_per_gpu,shuffle=False, num_workers=4, pin_memory=True,sampler=train_sampler)
            for x in ['train']}

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train']}

...

for epoch in range(args.epochs):
       phase = 'train'
       for i, (images, labels, paths) in enumerate(dataloaders[phase]):  
           ... steps ... training

       if (sdenv.rank == 0):
            torch.save(
                ddp_model.state_dict(), 
                './checkpoint/{}GPU_{}epoch.checkpoint'.format(sdenv.size, epoch+1))

Following IDRIS, "Since the model is replicated on each GPU, the saving of checkpoints can be effectuated on just one GPU to limit the writing operations. By convention, we use the GPU rank 0". But the problem is that, in the last epoch, the checkpoint is not saved and the system remains frozen. In order to work, I had to change the saving to (I ran for 10 epochs; hence I saved up to epoch 9):

if (sdenv.rank == 0) and ((epoch+1) < 10):
            print('------- Before check {}'.format(epoch+1))
            torch.save(
                ddp_model.state_dict(), 
                './checkpoint/{}GPU_{}epoch.checkpoint'.format(sdenv.size, epoch+1))
            print('------- After check {}'.format(epoch+1))

Any ideas about this issue?

Thank you.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source