Issue
I am trying to perform matrix multiplication of multiple matrices in PyTorch and was wondering what is the equivalent of numpy.linalg.multi_dot()
in PyTorch?
If there isn't one, what is the next best way (in terms of speed and memory) I can do this in PyTorch?
Code:
import numpy as np
import torch
A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)
results = np.linalg.multi_dot(A, B, C)
A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)
# What is the PyTorch equivalent of np.linalg.multi_dot()?
Many thanks!
Solution
~~Looks like one can send tensors into multi_dot~~
Looks like the numpy implementation casts everything into numpy arrays. If your tensors are on the cpu and detached this should work. Otherwise, the conversion to numpy would fail.
So in general - likely there isn't an alternative. I think your best shot is to take the multi_dot
implementation, e.g. from here for numpy v1.19.0 and adjust it to handle tensors / skip the cast to numpy. Given the similar interface and the code simplicity I think that this should be pretty straightforward.
Answered By - Yuri Feldman
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.