Issue
tf.math.argmax returns index of maximum value in a tensor.
a = tf.constant([1,2,3])
print(a)
print(tf.math.argmax(input = a))
output:
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
<tf.Tensor: shape=(), dtype=int64, numpy=2>
I want to apply tf.math.argmax function on a list of tensors. How can I do it.
input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
print(split_sequence)
tf.math.argmax(input = split_sequence)
output:
[<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>]
tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 1, 1])>
It is giving wrong indices -> numpy=array([1, 1, 1]
desired output:
numpy=array([[2],[2]]
Solution
You can use map
to apply any function on each value in the list
.
(It's better don't use built-in function
of python as a variable so I change input
to inp
)
import tensorflow as tf
inp = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(inp, num_or_size_splits=2, axis=-1)
print(split_sequence)
result = list(map(lambda x: [tf.math.argmax(x).numpy()] , split_sequence))
print(result)
Or by thanks @jkr, we can use List Comprehensions
too. (Which one is better, map
vs List comprehension
)
>>> [[tf.math.argmax(item).numpy()] for item in split_sequence]
[[2], [2]]
[
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>,
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>
]
[[2], [2]]
Benchmark (on colab):
import tensorflow as tf
input = tf.constant([1,2,3,4,5,6]*1_000_000)
split_sequence = tf.split(input, num_or_size_splits=20, axis=-1)
%timeit tf.math.top_k(split_sequence, k=1).indices
# 13.5 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit list(map(lambda x: [tf.math.argmax(x).numpy()] , split_sequence))
# 14 ms ± 2.39 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit [[tf.math.argmax(item).numpy()] for item in split_sequence]
# 8.77 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Answered By - I'mahdi
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.