'Tensorflow: argmin/argmax, then slice

This is a follow up question to Tensorflow : tf.argmax and slicing, which is 5 years old and only applies in the case of 2D tensors.

I have a rank D tensor x, and I wish to compute the argmax/argmin over one of its axes, then slice the tensor using the argmin/argmax. Something like:

x_min_indices = tf.mathm.argmin(x, axis=3)  # choosing an arbitrary axis
x_min = x[x_min_indices]

How can I accomplish this?



Solution 1:[1]

Same basic idea as the original question you mentioned. Anytime you are trying to do fancy indexing from numpy, know that you are going to need tf.gather or tf.gather_nd.

I'm not sure what you are trying to accomplish. Can you provide specific inputs and expected output.

In particular, when you run tf.math.argmin(x, axis=3), your result is a rank D-1 tensor. What did you expect when you called x[x_min_indices]?

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Yaoshiang