Issue
- I have a 3D tensor
X
, of shape[7, 240, 768]
. - I have another tensor
mask_idx
of shape[7, 240]
which contains0/False
and1/True
, where0/False
means I don't want to update the value inX[i][j]
and1/True
means I want to do thisX[i][j] = tf.zeros([768])
.
I have tried using tf.where(mask_idx, tf.zeros([7, 240, 768]), X)
but getting this error:
*** tensorflow.python.framework.errors_impl.InvalidArgumentError: condition [7,240], then [7,240,768], and else [7,240,768] must be broadcastable [Op:SelectV2]
Can anyone suggest the correct approach to it?
Solution
TF checks if dimensions are broadcastable from right to left, so one simple way is to expand your mask tensor in the last dimension, i.e., make its shape (7,240,1)
.
tf.where(mask_idx[...,None], 0, X)
Answered By - bui
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.