Issue
I'm trying to extract all the possible permutations from a Tensor along a specific axis. My input is a [B, S, L]
tensor (B batches of S vectors of length L) and I want to extract all the possible permutations among these vectors (the S! permutations) namely a [B, S!, S, L]
Tensor as output.
That's what I tried for now but I'm struggling getting the right output shape. I think my mistake might be that I'm creating a batch_range, but I should create a permutation_range as well.
import tensorflow as tf
import numpy as np
from itertools import permutations
S = 3
B = 5
L = 10
input = tf.constant(np.random.randn(B, S, L))
perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])
batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
indicies = tf.concat([batch_range, perms], axis=3)
permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) #
# I get a [ B, P, S, S, L] instead of the desired [B, P, S, L]
I posted one possible 'solution' just below, but I think there is still a problem with this one. I tested it, and if B>1 it's not going very well.
Solution
I just found an answer I think, please correct me if you think I'm wrong or if there is an easier way to do this:
import tensorflow as tf
import numpy as np
from itertools import permutations
S = 3
B = 5
L = 10
input = tf.constant(np.random.randn(B, S, L))
perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])
batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
perm_range = tf.tile(tf.reshape(tf.range(length_perm, dtype=tf.int32), shape=[1, length_perm, 1, 1]), [B, 1, S, 1])
indicies = tf.concat([batch_range, perm_range, perms], axis=3)
permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) #
print permutations
Answered By - Anthony D'Amato
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.