Issue
Assume an input data set contains CT scans of 100 patients, each scan containing 16 layers and each layer containing 512 x 512 pixels. I want to apply eight 3x3 convolution filters to each layer in every CT scan. So, the input array has shape [100, 16, 512, 512] and the kernels array has shape [8, 3, 3]. After the convolutions are applied, the goal is an output array with a shape [100, 16, 8, 512, 512]. The following code uses Pytorch Conv2d function to achieve this; however, I want to know if the groups parameter (and/or other means) can somehow eliminate the need for the loop.
for layer_index in range(0, number_of_layers):
# Getting current ct scan layer for all patients
# ct_scans dimensions are: [patient, scan layer, pixel row, pixel column]
# ct_scans shape: [100, 16, 512, 512]
image_stack = ct_scans[:, layer_index, :, :]
# Converting from numpy to tensor format
image_stack_t = torch.from_numpy(image[:, None, :, :])
# Applying convolution to create 8 filtered versions of current scan layer across all patients
# shape of kernels is: [8, 3, 3]
filtered_image_stack_t = conv2d(image_stack_t, kernels, padding=1, groups=1)
# Converting from tensor format back to numpy format
filtered_image_stack = filtered_image_stack_t.numpy()
# Amassing filtered ct scans for all patients back into one array
# filtered_ct_scans dimensions are: [patient, ct scan layer, filter number, pixel row, pixel column]
# filtered_ct_scans shape is: [100, 16, 8, 512, 512]
filtered_ct_scans[:, layer_index, :, :, :] = filtered_image_stack
So far, my attempts to use anything other than groups=1
leads to errors. I also found the following similar posts; however, they don't address my specific question.
How to use groups parameter in PyTorch conv2d function with batch?
How to use groups parameter in PyTorch conv2d function
Solution
You do not need to use grouped convolutions. Resizing you input appropriately is all that is needed.
import torch
import torch.nn.functional as F
ct_scans = torch.randn((100,16,512,512))
kernels = torch.randn((8,1,3,3))
B,L,H,W = ct_scans.shape #(batch,layers,height,width)
ct_scans = ct_scans.view(-1,H,W)
ct_scans.unsqueeze_(1)
out = F.conv2d(ct_scans, kernels)
out = out.view(B,L,*out.shape[1:])
print(out)
Answered By - gateway2745
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.