Issue
The tensor a
in shape of torch.Size([2, 2, 1, 2])
how to expand this tensor to the shape of torch.Size([2, 4, 1, 2])
without using for
loop which takes time
For example, I have tensor a
>>> print(a, a.shape)
tensor([[[[0.2955, 0.8836]],
[[0.7607, 0.6657]]],
[[[0.6779, 0.5109]],
[[0.0785, 0.6564]]]]) torch.Size([2, 2, 1, 2])
I want to expand it to become tensor b
>>> print(b, b.shape)
tensor([[[[0.2955, 0.8836]],
[[0.2955, 0.8836]],
[[0.7607, 0.6657]],
[[0.7607, 0.6657]]],
[[[0.6779, 0.5109]],
[[0.6779, 0.5109]],
[[0.0785, 0.6564]],
[[0.0785, 0.6564]]]]) torch.Size([2, 4, 1, 2])
I tried torch.expand()
but it only expands when the dimension=1.
How to achieve this format? Thank you
Solution
Pytorch provides multiple functions with different styles to repeat the tensors. The function that you are looking for is named repeat_interleave. Simple example:
>>> a.repeat_interleave(2,dim=1)
tensor([[[[0.2955, 0.8836]],
[[0.2955, 0.8836]],
[[0.7607, 0.6657]],
[[0.7607, 0.6657]]],
[[[0.6779, 0.5109]],
[[0.6779, 0.5109]],
[[0.0785, 0.6564]],
[[0.0785, 0.6564]]]])
Answered By - CuCaRot
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.