Issue
I want to do an operation similar to matrix multiplication, except instead of multiplying I want to check equality. The effect that I want to achieve is similar to the following:
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.uint8)
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).to(torch.uint8)
result = [[sum(a[i] == b [j]) for j in range(len(b))] for i in range(len(a))]
Is there a way that I can use einsum, or any other function in pytorch to achieve the above efficiently?
Solution
You can make use of the broadcasting to do the same, for instance with
result = (a[:, None, :] == b[None, :, :]).sum(dim=2)
Here None
just introduces a dummy dimensions - alternatively you can use the less visual .unsqueeze()
instead.
Answered By - flawr
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.