'Pytorch NaN error when using custom Softmax Function

I wanna to modify the backward of Softmax to reweight the gradient, but the modified gradients turn to become NaN at some indexes and I do not know why, because every computing componets are not NaN, and it becomes NaN after several addition and multiplication operations.

Here is my code, and it will take the place of the last softmax layer of classification model during trainning, NaN appears after 4 epoches.

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd import gradcheck
import pandas as pd

class WeightedSoftmax(Function):

    @staticmethod
    def forward(ctx, x, weights):
        max = torch.max(x, dim=-1, keepdim=True)[0]
        x_exp = torch.exp(x - max)
        x_exp_sum = torch.sum(x_exp, -1, keepdim=True)
        
        probs = x_exp/x_exp_sum
        ctx.save_for_backward(probs, weights)
        return probs

    @staticmethod
    def backward(ctx, grad_output):
        probs,weights, = ctx.saved_tensors
        target_prob = grad_output * probs
        
        gradient = (target_prob - probs * target_prob) * (2 - weights) + \
            (probs * target_prob - probs * target_prob.sum(-1, keepdim=True) )*(weights+1)
        
        if gradient.isnan().any():
            print(gradient.isnan().any())
            print(target_prob.isnan().any(), probs.isnan().any())
            print(weights.isnan().any())
            print(grad_output.isnan().any())
            print((target_prob - probs * target_prob).isnan().any())
            print((probs * target_prob - probs * target_prob.sum(-1, keepdim=True)).isnan().any())
            print(((target_prob - probs * target_prob) * (2 - weights)).isnan().any())
            print(((probs * target_prob - probs * target_prob.sum(-1, keepdim=True) )*(weights+1)).isnan().any())
        return gradient, weights

And here is the error report



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source