Issue
I would like to get a max/min value in tf.math.bincount
instead of the weight sum. Basically currently it works as:
values = tf.constant([1,1,2,3,2,4,4,5])
weights = tf.constant([1,5,0,1,0,5,4,5])
tf.math.bincount(values, weights=weights) #[0 6 0 1 9 5]
However, I would like to get max/min for the conflicting weights instead, e.g. for max it should return:
[0 5 0 1 5 5]
Solution
It requires some finessing, but you can accomplish this as follows:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
_range = tf.range(tf.reduce_max(values) + 1)
return tf.map_fn(lambda x: tf.maximum(
tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), 0), _range)
The output for the example case is:
[0 5 0 1 5 5]
Breaking it down, the first line computes the range of values in values
:
_range = tf.range(tf.reduce_max(values) + 1)
and in the second line, the maximum of weight
is computed per element in _range
using tf.map_fn
with tf.where
, which retrieves indices where the clause is true, and tf.gather
, which retrieves the values corresponding to supplied indices.
The tf.maximum
wraps the output to handle the case where the element does not exist in values
i.e; in the example case, 0
does not exist in values
so the output without tf.maximum
would be INT_MIN
for 0:
[-2147483648 5 0 1 5 5]
This could also be applied on the final result tensor instead of per element:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
_range = tf.range(tf.reduce_max(values) + 1)
result = tf.map_fn(lambda x:
tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), _range)
return tf.maximum(result, 0)
Note that this would not work if negative weights are utilized - in that case, tf.where
can be used for comparing against the minimum integer value (tf.int32.min
in the example, although this can be applied for any numeric dtype) instead of applying tf.maximum
:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
_range = tf.range(tf.reduce_max(values) + 1)
result = tf.map_fn(lambda x:
tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), _range)
return tf.where(tf.equal(result, tf.int32.min), 0, result)
Update
For handling the 2D Tensor case, we can use tf.map_fn
to apply the maximum weight function to each pair of values and weights in the batch:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor, axis: Optional[int] = None) -> tf.Tensor:
_range = tf.range(tf.reduce_max(values) + 1)
def mapping_function(x: int, _values: tf.Tensor, _weights: tf.Tensor) -> tf.Tensor:
return tf.reduce_max(tf.gather(_weights, tf.where(tf.equal(_values, x))))
if axis == -1:
result = tf.map_fn(lambda pair: tf.map_fn(lambda x: mapping_function(x, *pair), _range), (values, weights),
dtype=tf.int32)
else:
result = tf.map_fn(lambda x: mapping_function(x, values, weights), _range)
return tf.where(tf.equal(result, tf.int32.min), 0, result)
For the 2D example provided:
values = tf.constant([[1, 1, 2, 3], [2, 1, 4, 5]])
weights = tf.constant([[1, 5, 0, 1], [0, 5, 4, 5]])
print(bincount_with_max_weight(values, weights, axis=-1))
The output is:
tf.Tensor(
[[0 5 0 1 0 0]
[0 5 0 0 4 5]], shape=(2, 6), dtype=int32)
This implementation is a generalization of the approach originally described - if axis
is omitted, it will compute results for the 1D case.
Answered By - danielcahall
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.