Issue
I would like to figure out a way to apply a function which calculates pairwise distances, let's call it dists(A, B), row-wise for every input element in a batch, meaning:
(100, 16, 3) -- input, 100 is the batch size so 100 instances, 16 is let's say image size, and 3 filters (asking for Conv2D)
(5, 3) -- tensor for which I want to calculate the row-wise distance (assume it's A in dists(A, B) and is fixed)
Now, for every instance I am supposed to get back a matrix of shape (5, 16). Naturally, I could use a for to span the batch and get my final (100,5,16) result. However, I would love to know if there is an easier way to apply my function row-wise, in parallel, using GPU.
Thank you very much for your time.
Solution
Suppose we are using the L1 distance:
import torch
# data and target
a = torch.randn(100, 16, 3)
b = torch.randn(5, 3)
# Reshape the tensors
a = a.unsqueeze(1)
b = b.unsqueeze(0).unsqueeze(2)
print(a.shape, b.shape)
# Compute distance
dist = (a-b).abs().sum(3)
print(dist.shape)
Answered By - hkchengrex
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.