Issue
I have a batch tensor of size (4, 100, 56, 56)
, where some channels have a certain values in it, and some only have all zeros. I wanted to make each elements in the channels has any value greater than 0 to be of 100, whereas if it has all zeros, it should be made to has 1 in each element. Any idea how to achieve this without looping?
t = torch.zeros((4, 100, 56, 56))
t[:, 5, 15:20, 15:20] = 0.07
new_t = torch.ones((4, 100, 56, 56))
for b in range(t.size(0))
for c in range(t.size(1)):
if t[b, c, :,:].max() > 0:
new_t[b, c, :, :] = 100
My code above is inefficient for large batches and channels, and it create memory overhead due to new_t
, is there a way to use view()
or similar functions to achieve this?
Solution
You can perform the following:
mask = torch.any(t.flatten(2, 3) > 0., dim=2)
t[mask] = 100. # or t[mask] *= 100. for differentiability
Answered By - aretor
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.