Issue
I have two tensors with the same size:
a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b = [0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
Tensor a has three regions which are demarked by consecutive values: region 1 is [1,2,3,4,5]
, region 2 is [10,11,12,13]
and region 3 is [20, 21, 22, 23, 24, 25, 26, 27, 28]
.
For each of those regions, I want to apply the following logic: if one of the values of b is 1, then the following i values are set to 0. If they are already 0, they continue as 0. After i values are changed, nothing happens until another value of b is 1. In that case, the next i values are forced to 0...
Some examples:
# i = 1
a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1]
# i = 2
a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]
# i = 4
a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
Not sure if this would help, but I was able to separate the regions into segments by doing:
a_shifted = tf.roll(a - 1, shift=-1, axis=0)
a_shifted_segs = tf.math.cumsum(tf.cast(a_shifted != a, dtype=tf.int64), exclusive=True)
# a_shifted_segs =
= [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]
Do you know any way of doing this efficiently?
Solution
Here is a pure Tensorflow
approach, which will work in Eager Execution
and Graph
mode:
# copy, paste, acknowledge
import tensorflow as tf
def split_regions_and_modify(a, b, i):
indices = tf.squeeze(tf.where(a[:-1] != a[1:] - 1), axis=-1) + 1
row_splits = tf.cast(tf.cond(tf.not_equal(tf.shape(indices)[0], 0),
lambda: tf.concat([indices, [indices[-1] + (tf.cast(tf.shape(a), dtype=tf.int64)[0] - indices[-1])]], axis=0),
lambda: tf.shape(a)[0][None]), dtype=tf.int32)
def body(i, j, k, tensor, row_splits):
k = tf.cond(tf.equal(row_splits[k], j), lambda: tf.add(k, 1), lambda: k)
current_indices = tf.range(j + 1, tf.minimum(j + 1 + i, row_splits[k]), dtype=tf.int32)
tensor = tf.cond(tf.logical_and(tf.equal(tensor[j], 1), tf.not_equal(j, row_splits[k])), lambda:
tf.tensor_scatter_nd_update(tensor, current_indices[..., None], tf.zeros_like(current_indices)), lambda: tensor)
return i, tf.add(j, 1), k, tensor, row_splits
j0 = tf.constant(0)
k0 = tf.constant(0)
c = lambda i, j0, k0, b, row_splits: tf.logical_and(tf.less(j0, tf.shape(b)[0]), tf.less(k0, tf.shape(row_splits)[0]))
_, _, _, output, _ = tf.while_loop(c, body, loop_vars=[i, j0, k0, b, row_splits])
return output
Usage:
a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1])
split_regions_and_modify(a, b, 1)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1], dtype=int32)>
split_regions_and_modify(a, b, 2)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int32)>
split_regions_and_modify(a, b, 4)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)>
Answered By - AloneTogether
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.