Issue
I have a flattened pytorch tensor that represents the indices for reverting a MaxPool2d operation, after getting the gradients for the backwards pass. The issue is that the size changes based on the kernel_size and the original input height/width before the MaxPool forward pass. For example, lets say I have an input of size 2x4 with a kernel_size of 2x2 (kernel always square, and stride always same as kernel):
input1 = tensor([[1,2,5,6], [3,4,7,8]])
For input1, after the forward pass, the indices will be [3,3](indexing from 0).
In the backwards pass, I then, after flattening and adding the gradient, end up with a tensor like this:
tensor([0,0,0,(gradient of 4),0,0,0,(gradient of 8)])
The issue is that I now need to transform the tensor back into the original shape while keeping the number order. Using something simple like tensor.view(input_height,input_width) doesn't work because the order of the numbers is messed up:
input1.view(2,4) = tensor([[1,2,3,4], [5,6,7,8]])
I've tried things like chunking it into groups of input_height/input_width but then I have issues making it work for different sizes. I think there is an easy solution and I just lack the pytorch or numpy skills to figure it out :(
Edit:
Ok so I've tried input1.view(2,4)
which as I've explained doesn't have the right order. And getting the right order isn't possible through permute or similar functions. I've tried chunking this way:
input1.view(-1, 2)
which splits the tensor into chunks of 2, but then I can't stack them back up properly into multiple columns. I've searched everywhere and I don't know how to progress. I've even (to no avail) asked chatGPT lol.
Solution
Here is one way to rearrange the flattened tensor back into the original shape while keeping the order:
import torch
flattened = torch.tensor([0,0,0,4,0,0,0,8]) # flattened tensor
kernel_size = 2
input_h = 2
input_w = 4
unflatten = []
for i in range(0, len(flattened), input_w):
unflatten.append(flattened[i:i+input_w])
result = torch.stack(unflatten)
print(result)
The key ideas:
- Iterate over the flattened tensor in chunks of size
input_w
(the original input width) - Append those chunks as separate tensors to create a list
- Stack that list back into a 2D tensor
This takes advantage of slicing the flattened tensor and stitching together the pieces.
The output with the example data is:
tensor([[0, 0, 0, 4],
[0, 0, 0, 8]])
Answered By - ops_sujan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.