'Implementing Multi-Label Margin-Loss in Tensorflow
I'm wanted to implement the Multi-Label Margin-Loss in Tensorflow, using as orientation the definition of pytorch, i.e.
https://pytorch.org/docs/stable/generated/torch.nn.MultiLabelMarginLoss.html
This is the naive solution I came up with:
def naive(y_true, y_pred, mu = 1.0):
pos = tf.ragged.boolean_mask(y_pred, tf.cast(y_true, dtype=tf.bool))
neg = tf.ragged.boolean_mask(y_pred, tf.cast(1 - y_true, dtype=tf.bool))
loss = 0
for i in range(y_true.shape[0]):
loss += tf.reduce_mean(tf.nn.relu(mu - (tf.transpose([pos[i]]) - neg[i])))
return loss
The implementation above yield correct results (see example below), but I'm having a hard time removing the loop from the function, i.e. expressing this in matrix/vector multiplication, etc.
Example:
y_pred = tf.constant([[0.1, 0.2, 0.4, 0.8]], dtype=tf.float32)
print(y_pred)
y_true = tf.constant([[1, 0, 0, 1]], dtype=tf.float32)
print(y_true)
naive(y_true, y_pred)
# 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
# 0.8500
# (see pytorch example)
Any ideas are very welcome.
Solution 1:[1]
You could try using tf.while_loop:
import tensorflow as tf
def naive(y_true, y_pred, mu = 1.0):
pos = tf.ragged.boolean_mask(y_pred, tf.cast(y_true, dtype=tf.bool))
neg = tf.ragged.boolean_mask(y_pred, tf.cast(1 - y_true, dtype=tf.bool))
loss = tf.Variable(0.0, trainable=False)
i = tf.constant(0)
while_condition = lambda i, loss, pos, neg: tf.math.less(i, tf.shape(y_true)[0])
def body(i, loss, p, n):
loss.assign_add(tf.reduce_mean(tf.nn.relu(1.0 - (tf.transpose([p[i]]) - n[i]))))
return tf.add(i, 1), loss, p, n
_, loss, _,_ = tf.while_loop(while_condition, body, loop_vars=(i, loss, pos, neg))
return loss
y_pred = tf.constant([[0.1, 0.2, 0.4, 0.8], [0.1, 0.2, 0.4, 0.8]], dtype=tf.float32)
y_true = tf.constant([[1, 0, 0, 1], [1, 0, 0, 1]], dtype=tf.float32)
naive(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=1.7>
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |

