Issue
I tried to translate pytorch code to tensorflow. So I wanna know is there an equivalent function of pytorch named index_select
in tensorflow
Solution
I haven't found a similar api can directly achieve it, but we can use tf.slice
to implement it.
def tf_index_select(input_, dim, indices):
"""
input_(tensor): input tensor
dim(int): dimension
indices(list): selected indices list
"""
shape = input_.get_shape().as_list()
if dim == -1:
dim = len(shape)-1
shape[dim] = 1
tmp = []
for idx in indices:
begin = [0]*len(shape)
begin[dim] = idx
tmp.append(tf.slice(input_, begin, shape))
res = tf.concat(tmp, axis=dim)
return res
Here is an example to show the equivalence.
import tensorflow as tf
import torch
import numpy as np
a = np.arange(2*3*4).reshape(2,3,4)
dim = 1
indices = [0,2]
# array([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
# pytorch
res = torch.tensor(a).index_select(dim, torch.tensor(indices))
# tensor([[[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [20, 21, 22, 23]]])
# tensorflow
res = tf_index_select(tf.constant(a), dim, indices)
# tensor([[[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [20, 21, 22, 23]]])
Answered By - zihaozhihao
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.