'Implementing Multi-Label Margin-Loss in Tensorflow

I'm wanted to implement the Multi-Label Margin-Loss in Tensorflow, using as orientation the definition of pytorch, i.e.

Example

https://pytorch.org/docs/stable/generated/torch.nn.MultiLabelMarginLoss.html

This is the naive solution I came up with:

def naive(y_true, y_pred, mu = 1.0):
    pos = tf.ragged.boolean_mask(y_pred, tf.cast(y_true, dtype=tf.bool))
    neg = tf.ragged.boolean_mask(y_pred, tf.cast(1 - y_true, dtype=tf.bool))

    loss = 0
    for i in range(y_true.shape[0]):
        loss += tf.reduce_mean(tf.nn.relu(mu - (tf.transpose([pos[i]]) - neg[i])))
    return loss

The implementation above yield correct results (see example below), but I'm having a hard time removing the loop from the function, i.e. expressing this in matrix/vector multiplication, etc.

Example:

y_pred = tf.constant([[0.1, 0.2, 0.4, 0.8]], dtype=tf.float32)
print(y_pred)

y_true = tf.constant([[1, 0, 0, 1]], dtype=tf.float32)
print(y_true)

naive(y_true, y_pred)

# 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
# 0.8500

# (see pytorch example)

Any ideas are very welcome.



Solution 1:[1]

You could try using tf.while_loop:

import tensorflow as tf

def naive(y_true, y_pred, mu = 1.0):
    pos = tf.ragged.boolean_mask(y_pred, tf.cast(y_true, dtype=tf.bool))
    neg = tf.ragged.boolean_mask(y_pred, tf.cast(1 - y_true, dtype=tf.bool))
    
    loss = tf.Variable(0.0, trainable=False)
    i = tf.constant(0)    
    while_condition = lambda i, loss, pos, neg: tf.math.less(i, tf.shape(y_true)[0])

    def body(i, loss, p, n):
      loss.assign_add(tf.reduce_mean(tf.nn.relu(1.0 - (tf.transpose([p[i]]) - n[i]))))
      return tf.add(i, 1), loss, p, n

    _, loss, _,_ = tf.while_loop(while_condition, body, loop_vars=(i, loss, pos, neg))

    return loss

y_pred = tf.constant([[0.1, 0.2, 0.4, 0.8], [0.1, 0.2, 0.4, 0.8]], dtype=tf.float32)
y_true = tf.constant([[1, 0, 0, 1], [1, 0, 0, 1]], dtype=tf.float32)
naive(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=1.7>

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1