Issue
I have a Pytorch tensor t
of shape (n, x, y)
, and I'd like to apply a mask such that, for all y > x + k
(with k
being a constant), t[n, x, y] = -inf
.
I believe I can do this with advanced indexing, but can't figure out how.
If not, a simple way to do this is to construct a mask like that with loops (slow, but do it one time and cache it), and then t += mask
, since -inf + z == -inf
for all z
.
Is there a better way to do this?
Solution
Notice that the condition y ≥ x
corresponds to the upper triangle, while y > x
is the strict upper triangle. Therefore y > x + k
is the upper triangle part with a shift equal to 1 + k
.
You can construct a triangle mask using torch.triu
which actually allows for a shift argument named diagonal
, referring to the position of the diagonal. Assign the desired value, here -torch.inf
, using this mask will allow you to obtain the desired result.
Overall it comes down to:
>>> m = torch.ones_like(t, dtype=bool).triu(1+k)
>>> t[m] = -torch.inf
Alternatively, a one-liner is possible using torch.where
:
>>> torch.where(torch.ones_like(t).bool().triu(1+k), -torch.inf, t)
Since the mask is equal for all batch elements, you can get away with creating a single 2D mask and masking t
on its 2nd and 3rd axes:
>>> m = torch.ones_like(t[0], dtype=bool).triu(1+k)
>>> t[:,m] = -torch.inf
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.