Issue
Is there a batch way to replace several particular values in a pytorch tensor at once without a for loop?
Example:
old_values = torch.Tensor([1, 2, 3, 4, 5, 5, 2, 3, 3, 2])
old_new_value = [[2,22], [3,33], [6, 66]]
old_new_value = [[2,22], [3,33], [6, 66]]
, which means 2 should be replaced by 22, and 3 should be replaced by 33 and 6 to 66
Can I have an efficient way to achieve the following end_result?
end_result = torch.Tensor([1, 22, 33, 4, 5, 5, 22, 33, 33, 22])
Note that old_values is not unique. Also, it is possible that old_new_value has a pair here(6, 66) that does not exist in the old_values.
Also, the old_new_values
includes unique rows,
Solution
If you don't have any duplicate elements in your input tensor, here's one straightforward way using masking and value assignment using basic indexing. (I'll assume that the data type of the input tensor is int
. But, you can simply adapt this code in a straightforward manner to other dtype
s). Below is a reproducible illustration, with explanations interspersed in inline comments.
# input tensors to work with
In [75]: old_values
Out[75]: tensor([1, 2, 3, 4, 5], dtype=torch.int32)
In [77]: old_new_value
Out[77]:
tensor([[ 2, 22],
[ 3, 33]], dtype=torch.int32)
# generate a boolean mask using the values that need to be replaced (i.e. 2 & 3)
In [78]: boolean_mask = (old_values == old_new_value[:, :1]).sum(dim=0).bool()
In [79]: boolean_mask
Out[79]: tensor([False, True, True, False, False])
# assign the new values by basic indexing
In [80]: old_values[boolean_mask] = old_new_value[:, 1:].squeeze()
# sanity check!
In [81]: old_values
Out[81]: tensor([ 1, 22, 33, 4, 5], dtype=torch.int32)
A small note on efficiency: Throughout the whole process, we never made any copy of the data (i.e. we operate only on new views by massaging the shapes according to our needs). Therefore, the runtime would be blazing fast.
Answered By - kmario23
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.