Issue
So, I've read about half the original ResNet paper, and am trying to figure out how to make my version for tabular data.
I've read a few blog posts on how it works in PyTorch, and I see heavy use of nn.Identity()
. Now, the paper also frequently uses the term identity mapping. However, it just refers to adding the input for a stack of layers the output of that same stack in an element-wise fashion. If the in and out dimensions are different, then the paper talks about padding the input with zeros or using a matrix W_s
to project the input to a different dimension.
Here is an abstraction of a residual block I found in a blog post:
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, activation='relu'):
super().__init__()
self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
self.blocks = nn.Identity()
self.shortcut = nn.Identity()
def forward(self, x):
residual = x
if self.should_apply_shortcut: residual = self.shortcut(x)
x = self.blocks(x)
x += residual
return x
@property
def should_apply_shortcut(self):
return self.in_channels != self.out_channels
block1 = ResidualBlock(4, 4)
And my own application to a dummy tensor:
x = tensor([1, 1, 2, 2])
block1 = ResidualBlock(4, 4)
block2 = ResidualBlock(4, 6)
x = block1(x)
print(x)
x = block2(x)
print(x)
>>> tensor([2, 2, 4, 4])
>>> tensor([4, 4, 8, 8])
So at the end of it, x = nn.Identity(x)
and I'm not sure the point of its use except to mimic math lingo found in the original paper. I'm sure that's not the case though, and that it has some hidden use that I'm just not seeing yet. What could it be?
EDIT Here is another example of implementing residual learning, this time in Keras. It does just what I suggested above and just keeps a copy of the input for adding to the output:
def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
y = Conv2D(kernel_size=kernel_size,
strides= (1 if not downsample else 2),
filters=filters,
padding="same")(x)
y = relu_bn(y)
y = Conv2D(kernel_size=kernel_size,
strides=1,
filters=filters,
padding="same")(y)
if downsample:
x = Conv2D(kernel_size=1,
strides=2,
filters=filters,
padding="same")(x)
out = Add()([x, y])
out = relu_bn(out)
return out
Solution
What is the idea behind using nn.Identity for residual learning?
There is none (almost, see the end of the post), all nn.Identity
does is forwarding the input given to it (basically no-op
).
As shown in PyTorch repo issue you linked in comment this idea was first rejected, later merged into PyTorch, due to other use (see the rationale in this PR). This rationale is not connected to ResNet block itself, see end of the answer.
ResNet implementation
Easiest generic version I can think of with projection would be something along those lines:
class Residual(torch.nn.Module):
def __init__(self, module: torch.nn.Module, projection: torch.nn.Module = None):
super().__init__()
self.module = module
self.projection = projection
def forward(self, inputs):
output = self.module(inputs)
if self.projection is not None:
inputs = self.projection(inputs)
return output + inputs
You can pass as module
things like two stacked convolutions and add 1x1
convolution (with padding or with strides or something) as projection module.
For tabular
data you could use this as module
(assuming your input has 50
features):
torch.nn.Sequential(
torch.nn.Linear(50, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 50),
)
Basically, all you have to do is is add input
to some module to it's output and that is it.
Rationale behing nn.Identity
It might be easier to construct neural networks (and read them afterwards), example for batch norm (taken from aforementioned PR):
batch_norm = nn.BatchNorm2d
if dont_use_batch_norm:
batch_norm = Identity
Now you can use it with nn.Sequential
easily:
nn.Sequential(
...
batch_norm(N, momentum=0.05),
...
)
And when printing the network it always has the same number of submodules (with either BatchNorm
or Identity
) which also makes the whole thing a little smoother IMO.
Another use case, mentioned here might be removing parts of existing neural networks:
net = tv.models.alexnet(pretrained=True)
# Assume net has two parts
# features and classifier
net.classifier = Identity()
Now, instead of running net.features(input)
you can run net(input)
which might be easier for others to read as well.
Answered By - Szymon Maszke
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.