'Pytorch NaN error when using custom Softmax Function
I wanna to modify the backward of Softmax to reweight the gradient, but the modified gradients turn to become NaN at some indexes and I do not know why, because every computing componets are not NaN, and it becomes NaN after several addition and multiplication operations.
Here is my code, and it will take the place of the last softmax layer of classification model during trainning, NaN appears after 4 epoches.
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd import gradcheck
import pandas as pd
class WeightedSoftmax(Function):
@staticmethod
def forward(ctx, x, weights):
max = torch.max(x, dim=-1, keepdim=True)[0]
x_exp = torch.exp(x - max)
x_exp_sum = torch.sum(x_exp, -1, keepdim=True)
probs = x_exp/x_exp_sum
ctx.save_for_backward(probs, weights)
return probs
@staticmethod
def backward(ctx, grad_output):
probs,weights, = ctx.saved_tensors
target_prob = grad_output * probs
gradient = (target_prob - probs * target_prob) * (2 - weights) + \
(probs * target_prob - probs * target_prob.sum(-1, keepdim=True) )*(weights+1)
if gradient.isnan().any():
print(gradient.isnan().any())
print(target_prob.isnan().any(), probs.isnan().any())
print(weights.isnan().any())
print(grad_output.isnan().any())
print((target_prob - probs * target_prob).isnan().any())
print((probs * target_prob - probs * target_prob.sum(-1, keepdim=True)).isnan().any())
print(((target_prob - probs * target_prob) * (2 - weights)).isnan().any())
print(((probs * target_prob - probs * target_prob.sum(-1, keepdim=True) )*(weights+1)).isnan().any())
return gradient, weights
And here is the error report
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
