Issue
I have an encoder, which outputs a tensor with shape (bn, c * k, 32, 32)
. I now want produce k means with shape (bn, k, 1, 2)
. So the means are 2-dim coordinates. To do so, I want to use k FC Layers, while for each mean k_i I only want to use c channels.
So my idea is, that I reshape the encoder output out
to a 5d tensor with shape (bn, k, c, 32, 32)
. Then I can use the flattened out[:, 0]
... out[:, k]
as input for the k linear layers.
The trivial solution would be to define the linear layers manually:
self.fc0 = nn.Linear(c * 32 * 32, 2)
...
self.fck = nn.Linear(c * 32 * 32, 2)
Then I could define the forward pass for each mean as follows:
mean_0 = self.fc0(out[:, 0].reshape(bn, -1))
...
mean_k = self.fck(out[:, k].reshape(bn, -1))
Is there a more efficient way to do that?
Solution
I believe you are looking for a grouped convolution. You can let axis=1
have k*c
tensors, so the input shape is (bn, k*c, 32, 32)
. Then use a nn.Conv2d
convolution layer with 2*k
filters and set to receive k
groups so it's not a fully connected channel-wise (only k
groups of c
maps: convolves c
at a time.
>>> bn = 1; k = 5; c = 3
>>> x = torch.rand(bn, k*c, 32, 32)
>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32, groups=k)
>>> m(x).shape
torch.Size([4, 10, 1, 1])
Which you can then reshape to your liking.
In terms of number of parameters. A typical nn.Conv2d
usage would be:
>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32)
>>> sum(layer.numel() for layer in m.parameters())
153610
Which is exactly c*k*2*k*32*32
weights, plus 2*k
biases.
In your case, you would have
>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32, groups=k)
>>> sum(layer.numel() for layer in m.parameters())
30730
Which is exactly c*2*k*32*32
weights, plus 2*k
biases. i.e. k
times less than the previous layer. A given filter's has only c
layers (not k*c
) which means it will have an input with c
channels (i.e. one of the k
groups containing c
maps)
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.