Issue
I'm revising a baseline method in pytorch. But when I add a custom function in the training phase, the cost time of backward increases 4x on a single V100. Here is an example of the custom function:
def batch_function(M, kernel_size=21, sf=2):
'''
Input:
M: b x (h*w) x 2 x 2 torch tensor
sf: scale factor
Output:
kernel: b x (h*w) x k x k torch tensor
'''
M_t = M.permute(0,1,3,2) # b x (h*w) x 2 x 2
INV_SIGMA = torch.matmul(M_t, M).unsqueeze(2).unsqueeze(2) # b x (h*w) x 1 x 1 x 2 x 2
X, Y = torch.meshgrid(torch.arange(kernel_size), torch.arange(kernel_size))
Z = torch.stack((Y, X), dim=2).unsqueeze(3).to(M.device) # k x k x 2 x 1
Z = Z.unsqueeze(0).unsqueeze(0) # 1 x 1 x k x k x 2 x 1
Z_t = Z.permute(0,1,2,3,5,4) # 1 x 1 x k x k x 1 x 2
raw_kernel = torch.exp(-0.5 * torch.squeeze(Z_t.matmul(INV_SIGMA).matmul(Z))) # b x (h*w) x k x k
# Normalize
kernel = raw_kernel / torch.sum(raw_kernel, dim=(2,3)).unsqueeze(-1).unsqueeze(-1) # b x (h*w) x k x k
return kernel
where b is the batch size, 16; h and w are the spatial dimensions, 100; k is equal to 21. I'm not sure if the large dimension of M causes the cost time longer. Why does the cost time longer? And are there other methods to rewrite this code to improve it? I'm new here, so if the problem is not clearly described, please let me know!
Solution
You might be able to get a performance boost on the double tensor multiplication by using torch.einsum
:
>>> o = torch.einsum('acdefg,bshigj,kldejm->bsdefm', ZZ_t, INV_SIGMA, ZZ)
The resulting tensor o
will be shaped (b, h*w, k, k, 1, 1)
For details on the subscript notation:
b
: batch dimension.s
: 's' for spatial, i.e. theh*w
dimension.d
ande
: the two k dimensions which are paired acrossZZ_t
andZZ
.
A simple 2D matrix multiplication applying matmul with ij,jk->ik
.
Keeping that in mind, we have in your case:
A first multiplication:
r = ZZ_t@INV_SIGMA
which does something like*fg,*gj->*fj
,
the asterisk sign*
refers to leading dimensions.A second matrix multiplication:
r@INV_SIGMA
which comes down to*fj,*jm->*fm
.
Overall, if we combine both, we get directly: *fg,*gj,*jm->*fm
.
Finally, I have assigned all other dimensions to random but different subscript letters:
a, c, f, h, i, k, l
Replacing the asterisk above with those notations, we get the following subscript input:
# * fg, * gj, * jm-> * fm
# acdefg,bshigj,kldejm->bsdefm
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.