'Tensorflow: Create the torch.gather() equivalent in tensorflow
I want to replicate the torch.gather() function in TensorFlow 2.X.
I have a Tensor A (shape: [2, 4, 3]) and a corresponding Index-Tensor I (shape: [2,2,3]).
Using torch.gather() yields the following:
A = torch.tensor([[[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
I = torch.tensor([[[0,1,0], [1,2,1]],
[[2,1,2], [1,0,1]]])
torch.gather(A, 1, I)
>
tensor([[[10, 200, 30], [100, 2000, 300]],
[5000, 600, 7000], [500, 60, 700]]])
I have tried using tf.gather(), but this did not yield pytorch-like results. I also tried to play around with tf.gather_nd(), but I could not find a suitable solution.
I found this StackOverflow post, but this seems not to work for me.
Edit:
When using tf.gather_nd(A, I), I get the following result:
tf.gather_nd(A, I)
>
[[100, 6000],
[ 0, 60]]
The result for tf.gather(A, I) is rather lengthy. It has the shape of [2, 2, 3, 4, 3]
Solution 1:[1]
torch.gather and tf.gather_nd work differently and will therefore yield different results when using the same indices tensor (in some cases an error will also be returned). This is what the indices tensor would have to look like to get the same results:
import tensorflow as tf
A = tf.constant([[
[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
I = tf.constant([[[
[0,0,0],
[0,1,1],
[0,0,2],
],[
[0,1,0],
[0,2,1],
[0,1,2],
]],
[[
[1,2,0],
[1,1,1],
[1,2,2],
],
[
[1,1,0],
[1,0,1],
[1,1,2],
]]])
print(tf.gather_nd(A, I))
tf.Tensor(
[[[ 10 200 30]
[ 100 2000 300]]
[[5000 600 7000]
[ 500 60 700]]], shape=(2, 2, 3), dtype=int32)
So, the question is actually how are you calculating your indices or are they always hard-coded? Also, check out this post on the differences of the two operations.
As for the post you linked that didn't work for you, you just need to cast the indices and everything should be fine:
def torch_gather(x, indices, gather_axis):
all_indices = tf.where(tf.fill(indices.shape, True))
gather_locations = tf.reshape(indices, [indices.shape.num_elements()])
gather_indices = []
for axis in range(len(indices.shape)):
if axis == gather_axis:
gather_indices.append(tf.cast(gather_locations, dtype=tf.int64))
else:
gather_indices.append(tf.cast(all_indices[:, axis], dtype=tf.int64))
gather_indices = tf.stack(gather_indices, axis=-1)
gathered = tf.gather_nd(x, gather_indices)
reshaped = tf.reshape(gathered, indices.shape)
return reshaped
I = tf.constant([[[0,1,0], [1,2,1]],
[[2,1,2], [1,0,1]]])
A = tf.constant([[
[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
print(torch_gather(A, I, 1))
tf.Tensor(
[[[ 10 200 30]
[ 100 2000 300]]
[[5000 600 7000]
[ 500 60 700]]], shape=(2, 2, 3), dtype=int32)
Solution 2:[2]
You could also try this as an equivalent to torch.gather:
import random
import numpy as np
import tensorflow as tf
import torch
# torch.gather equivalent
def tf_gather(x: tf.Tensor, indices: tf.Tensor, axis: int) -> tf.Tensor:
complete_indices = np.array(np.where(indices > -1))
complete_indices[axis] = tf.reshape(indices, [-1])
flat_ind = np.ravel_multi_index(tuple(complete_indices), x.shape)
return tf.reshape(tf.gather(tf.reshape(x, [-1]), flat_ind), indices.shape)
# ======= test program ========
if __name__ == '__main__':
a = np.random.rand(2, 5, 3, 4)
dim = 2 # 0 <= dim < len(a.shape))
ind = np.expand_dims(np.argmax(a, axis=dim), axis=dim)
# ========== np: groundtruth ==========
np_max = np.expand_dims(np.max(a, axis=dim), axis=dim)
# ========= torch: gather =========
torch_max = torch.gather(torch.tensor(a), dim=dim, index=torch.tensor(ind))
# ========= tensorflow: torch-like gather =========
tf_max = tf_gather(tf.convert_to_tensor(a), axis=dim, indices=tf.convert_to_tensor(ind))
keepdim = False
if not keepdim:
np_max = np.squeeze(np_max, axis=dim)
torch_max = torch.squeeze(torch_max, dim=dim)
tf_max = tf.squeeze(tf_max, axis=dim)
# print('np_max:\n', np_max)
# print('torch_max:\n', torch_max)
# print('tf_max:\n', tf_max)
assert np.allclose(np_max, torch_max.numpy()), '\33[1m\33[31mError with torch\33[0m'
assert np.allclose(np_max, tf_max.numpy()), '\33[1m\33[31mError with tensorflow\33[0m'
print('\33[1m\33[32mSuccess!\33[0m')
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 | Slifer |
