Issue
Let's call the function I'm looking for "magic_combine
", which can combine the continuous dimensions of tensor I give to it. For more specific, I want it to do the following thing:
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
I know that torch.view()
can do the similar thing. But I'm just wondering if there is any more elegant way to achieve the goal?
Solution
I am not sure what you have in mind with "a more elegant way", but Tensor.view()
has the advantage not to re-allocate data for the view (original tensor and view share the same data), making this operation quite light-weight.
As mentioned by @UmangGupta, it is however rather straight-forward to wrap this function to achieve what you want, e.g.:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])
Answered By - benjaminplanche
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.