Issue
I want use DenseHashTable lookup string tensors, just like this answeranswer , keys' type is tf.string, value is embedding with tf.float32 dtype. But when keys is multi-dimensional, error occurs.
keys = ["Fritz", "Franz", "Fred"]
values = [[1, 2, 3, -1], [4, 5, -1, -1], [6, 7, 8, 9]]
table = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string, value_dtype=tf.float32, empty_key="0", deleted_key="-1", default_value=[-1,-1,-1,-1])
table.insert(keys, values)
table.lookup(['Franz', 'Emil']) # shape=(2,) its ok
table.lookup([['Franz', 'Emil'], ['Emil', 'Fred']]) # when lookup with 2-D tensor(shape like (batch_size, 2)), throws error.
How can i make it works just like tf.nn.embedding_lookup? Keys not array index but tf.string.
Solution
The problem is that TensorFlow expects a list of keys, not a nested list of keys. Granted, the Can be a tensor of any shape.
in the keys description in the docs is a bit confusing.
What you can do is flatten your list, hash it and reshape it afterwards:
keys = [['Franz', 'Emil'], ['Emil', 'Fred']]
keys = tf.convert_to_tensor(keys) # to get the shape
key_shape = keys.shape # shape: (2, 2)
x = table.lookup(tf.reshape(keys, -1)) # shape: (4, 4) after hashing
x = tf.reshape(x, key_shape+(x.shape[-1:])) # shape: (2, 2 ,4)
Answered By - mhenning
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.