Issue
With 3 by n by k
tensor A
and 1 by k by m
tensor x
we can haveAx = B
where B
has shape of [3, n, m]
torch.linalg.lstsq(A, B)
returns a 3 x k x m
tensor as solution. Is there a way to find the 1 by k by m
tensor x
?
Solution
The difference between torch.lingalg.lstsq
and torch.matmul
is that torch.lingalg.lstsq
computes its answer based on batch-wise operation while torch.matmul
does not.
And your 1 by k by m
solution will be non-batch wise solution or some kind of global solution that can commonly be applied across whole batch. This case, you can simply reduce the batch dimension and obtain your least square solution.
A_re = A.reshape(1,-1, k)
B_re = B.reshape(1, -1, m) # or torch.matmul(A_re, x)
x = torch.linalg.lstsq(A_re, B_re)
x.size()
> torch.Size([1, k, m])
Answered By - won5830
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.