Issue
I am studying some code and I came across a usage of PyTorch's einsum function that I am not understanding. The docs are here.
The snippet looks like (slightly modified from the original):
import torch
x = torch.rand(64, 64, 25, 25)
y = torch.rand(64, 64, 64, 25)
result = torch.einsum('ncuv,nctv->nctu', x, y)
print(result.shape)
>> torch.Size([64, 64, 64, 25])
So the notation is such that n=64, c=64, u=25, v=25, t=64.
I'm not too sure what's happening. I think that for each 25 dimensional vector in t (64 of them), each one is being multiplied with each of the u=25 vectors of size 25 elementwise and then the results summed, or rather 25 dot products of 25 dimensional vectors?
Any insights appreciated.
Solution
Basically, you can think of it as taking dot products over certain dimensions, and reorganizing the rest.
For simplicity, let's ignore the batching dimensions n
and c
(since they are consistent before and after ncuv,nctv->nctu
), and discuss:
import torch
x = torch.rand(25, 25)
y = torch.rand(64, 25)
result = torch.einsum('uv,tv->tu', x, y)
print(result.shape)
>> torch.Size([64, 25])
Note that v
vanishes after einsum, meaning v
is the dimension being summed up, while t
and u
are not. You can interpret it this way: x
is a collection of 25
25-dimensional vectors; y
is a collection of 64
25-dimensional vectors. The dot product of the t
-th vector in y
and the u
-th vector in x
are computed and put in the t
-th row and u
-th column of result
.
You can also rewrite into a math equation:
result[n,c,t,u] = \sum_{v} x[n,c,u,v] * y[n,c,t,v], for each n, c, t, u
Note two things:
- the summation is over the indices that vanish in the summation pattern
nctu,ncuv->nctv
- indices appearing on the right of the pattern are the indices of the resulting tensor
Answered By - ihdv
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.