'How to calculate log_softmax for list of tensors without breaking autograd in Pytorch

I'm trying to calculate the log_softmax function of a list of tensors, i.e., a list [t_1, t_2, ..., t_n] where each t_i is of type torch.tensor and each t_i can be of a different, arbitrary shape. I do not want to apply the log_softmax function to each t_i separately, but to all of them as if they were part of the same unique tensor. The output of this function should be a list of tensors with the same shape as the input. Lastly, as I will apply this function to the end layer of a neural network, I want to be able to differentiate this function, i.e., the gradients must flow through it.

Pytorch provides the class torch.nn.LogSoftmax, but I cannot use it as it expects a single tensor as input, instead of a list of tensors. Additionally, I want to calculate the log_softmax function efficiently and in a stable way. To achieve that, I want to use the log-sum-exp trick. Lastly, I want to ignore the last value of the first element of the list (see code snippet below), i.e., not apply log_softmax to it.

This is my current implementation:

def log_softmax(pred_tensors):
    minus_inf = -1000 # Constant that represents minus infinity

    # Calculate the max value
    c = max([preds.amax() if preds is not None else minus_inf for preds in pred_tensors])

    # Calculate log(sum(e^(x_i-c)))
    log_sum_exp = 0
    for r in range(len(pred_tensors)):
        if pred_tensors[r] is not None:
            
            # Arity 0 -> ignore nullary predicate corresponding to termination condition
            curr_sum =  torch.sum(torch.exp(pred_tensors[r][:-1] - c))   if r == 0 else \
                        torch.sum(torch.exp(pred_tensors[r] - c))
            log_sum_exp += curr_sum
            
    log_sum_exp = torch.log(log_sum_exp)
        
    # Calculate log_softmax (apply log_softmax to the original tensor) (except to the termination condition)
    for r in range(len(pred_tensors)):
        if pred_tensors[r] is not None:
            # Arity 0 -> ignore nullary predicate corresponding to termination condition
            if r == 0:
                pred_tensors[r][:-1] -= log_sum_exp + c 
            else:    
                pred_tensors[r] -= log_sum_exp + c


    return pred_tensors

I have tested it and it works. However, I think my implementation may be breaking the autograd of Pytorch, in lines c = max([preds.amax() if preds is not None else minus_inf for preds in pred_tensors]) and log_sum_exp += curr_sum.

So, my questions are: Is my implementation really breaking autograd? If it is, can you provide an alternative implementation that works with autograd?



Solution 1:[1]

I also posted this question on the Pytorch Forum and was solved there. I post the solution below:

def _log_softmax(self, pred_tensors):
    # Remove the nullary predicate associated with the termination condition, so that it does not
    # affect the log_softmax computation
    term_cond_value = pred_tensors[0][-1]
    pred_tensors[0] = pred_tensors[0][:-1]
    
    # Calculate log_sum_exp of all the values in the tensors of the list
    # 1) flatten each tensor in the list
    # 2) concatenate them as a unique tensor
    # 3) calculate log_sum_exp
    log_sum_exp = torch.logsumexp(torch.cat([preds.flatten() if preds is not None else torch.empty(0, dtype=torch.float32) for preds in pred_tensors]), dim=-1)

    # Use log_sum_exp to calculate the log_softmax of the tensors in the list
    for r in range(len(pred_tensors)):
        if pred_tensors[r] is not None:
            pred_tensors[r] -= log_sum_exp

    # Append the nullary predicate corresponding to the termination condition
    pred_tensors[0] = torch.cat([pred_tensors[0], term_cond_value.reshape(1)]) # We need reshape() to transform from tensor of dimension 0 to dimension 1
    
    return pred_tensors

Basically, I firstly removed from the list of tensors the element pred_tensors[0][-1], so that it did not affect the calculations, and appended it to the final list of tensors. Then, since I could not concatenate a list of tensors of different sizes, I first flattened them and then used torch.cat to concatenate them, before using torch.logsumexp to calculate the log_sum_exp with all the values in all the tensors of the list. Then, this value was finally used to calculate the log_softmax of each tensor value, obtaining a list of output tensors with the same shape as the input.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Aeryan