Issue
I am trying to multiply the following:
A batch of matrices N x M x D
A batch of vectors N x D x 1
To get a result: N x M x 1
as if I were doing N
dot products on M x D
D x 1
.
I cant seem to find the correct function in PyTorch.
torch.bmm
as far as I can tell only works for a batch of vectors and a single matrix. If I have to use torch.einsum
then so be it but id rather not!
Solution
It's pretty straightforward and intuitive with einsum
:
torch.einsum('ijk, ikl->ijl', mats, vecs)
But your operation is just:
mats @ vecs
Answered By - Quang Hoang
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.