Issue
For a 2D tensor with even and odd numbers, I would like to do different arithmetic operation depending on if the number in the tensor is even or odd. I've created the tensor and created a corresponding true false (even or odd) tensor, but am not sure how to proceed.
import torch
list1 = [
[10, 25, 75, 10, 50],
[25, 30, 35, 40, 30],
[45, 50, 55, 60, 20],
[50, 20, 15, 20, 10],
[10, 25, 40, 50, 35]]
tensor2 = torch.tensor(list1)
tensor3=tensor2 % 2
print(tensor3)
print(torch.eq(tensor3, 0)) #even numbers
print(torch.eq(tensor3, 1)) #odd numbers
#do 3x+1 for odd numbers, ie the tensor where indexes are false for even
#do x/2 for even numbers, ie the tensor where indexes are true for even
Solution
You're looking for torch.where
In your case the code would be:
list1 = [
[10, 25, 75, 10, 50],
[25, 30, 35, 40, 30],
[45, 50, 55, 60, 20],
[50, 20, 15, 20, 10],
[10, 25, 40, 50, 35]]
tensor2 = torch.tensor(list1, type = torch.float64)
res = torch.where(tensor2 % 2 == 0, tensor2 / 2, tensor2 * 3 + 1)
>>> res
tensor([[ 5., 76., 226., 5., 25.],
[ 76., 15., 106., 20., 15.],
[136., 25., 166., 30., 10.],
[ 25., 10., 46., 10., 5.],
[ 5., 76., 20., 25., 106.]], dtype=torch.float64)
Answered By - user15270287
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.