Issue
Question:
I need a quick and simple method to transform a PyTorch tensor with dimensions (D, M, M) into a tensor with dimensions (D*4, M//2, M//2) manually without convolutions. I want to use a pooling-like approach but with a flattening and concatenation operation where the kernel size is always 2 and the stride is also 2 to downsample to half. It's essential to keep the gradients.
Example Input:
Transform (3, 4, 4) to (12, 2, 2):
[[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]],
[[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]],
[[32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]]
Desired Output:
[[[ 0, 1, 4, 5, 16, 17, 20, 21, 32, 33, 36, 37],
[ 2, 3, 6, 7, 18, 19, 22, 23, 34, 35, 38, 39]],
[[ 8, 9, 12, 13, 24, 25, 28, 29, 40, 41, 44, 45],
[10, 11, 14, 15, 26, 27, 30, 31, 42, 43, 46, 47]]]
Test Code:
# Generate the input tensor
input_tensor = torch.arange(48).reshape(3, 4, 4)
# Get Shape
n, m, _ = input_tensor.shape
# DO CODE operation
#check output output_tensor[:,0,0] == [ 0, 1, 4, 5, 16, 17, 20, 21, 32, 33, 36, 37]
.....
My attemp i try to create a middle step in order to get the desired output:
patches = input_tensor.unfold(1, 2, 2).unfold(2, 2, 2).reshape(n, m//2,m//2, 4)
Output:
output: tensor([[[[ 0, 1, 4, 5], [ 2, 3, 6, 7]],
[[ 8, 9, 12, 13],
[10, 11, 14, 15]]],
[[[16, 17, 20, 21],
[18, 19, 22, 23]],
[[24, 25, 28, 29],
[26, 27, 30, 31]]],
[[[32, 33, 36, 37],
[34, 35, 38, 39]],
[[40, 41, 44, 45],
[42, 43, 46, 47]]]])
but i still need to convert this patches to a vector of 12,2,2 and keep the correct order.
Edit (D2, M//2, M//2) > (D4, M//2, M//2)
Solution
You almost already did it. After getting to tthe patches with the shape (n, m//2, m//2, 4) you have to flatten the last dimension and permute the tensor to the correct order and torch.permute
is the way to go here. Here's the complete code:
import torch
# Generate the input tensor
input_tensor = torch.arange(48).reshape(3, 4, 4)
# Get Shape
n, m, _ = input_tensor.shape
# Create patches
patches = input_tensor.unfold(1, 2, 2).unfold(2, 2, 2).reshape(n, m//2, m//2, 4)
# Flatten the last dimension and permute the tensor to the correct order
output_tensor = patches.permute(1,2,0,3).reshape(m//2, m//2, n*4)
print(output_tensor)
Answered By - Giovanni Amorim
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.