Issue
I have some nets, such as the following (augmented) resnet18:
num_classes = 10
resnet = models.resnet18(pretrained=True)
for param in resnet.parameters():
param.requires_grad = True
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, num_classes)
And I want to use them inside a lightning module, and have it handle all optimizations, to_device, stages and so on. In other words, I want to register those modules for my lightning module. I also want to be able to access their public members.
class MyLightning(LightningModule):
def __init__(self, resnet):
super().__init__()
self._resnet = resnet
self._criterion = lambda x: 1.0
def forward(self, x):
resnet_out = self._resnet(x)
loss = self._criterion(resnet_out)
return loss
my_lightning = MyLightning(resnet)
The above doesn't optimize any parameters.
Trying
def __init__(self, resnet)
...
_layers = list(resnet.children())[:-1]
self._resnet = nn.Sequential(*_layers)
Doesn't take resnet.fc
into account. This also doesn't make sense to be the intended way of nesting models inside pytorch lightning.
How to nest models in pytorch lightning, and have them fully accessible and handled by the framework?
Solution
The training loop and optimization process is handles by the Trainer
class. You can do so by initializing a new instance:
>>> trainer = Trainer()
And wrapping your PyTorch Lightning module with it. This way you can perform fitting, tuning, validating, and testing on that instance provided a DataLoader
or LightningDataModule
:
>>> trainer.fit(my_lightning, train_dataloader, val_dataloader)
You will have to implement the following functions on your Lightning module (i.e. in your case MyLightning
):
Name | Description |
---|---|
init |
Define computations here |
forward |
Use for inference only (separate from training_step ) |
training_step |
the complete training loop |
validation_step |
the complete validation loop |
test_step |
the complete test loop |
predict_step |
the complete prediction loop |
configure_optimizers |
define optimizers and LR schedulers |
source LightningModule
documentation page.
Keep in mind a LightningModule
is a nn.Module
, so whenever you define a nn.Module
as attribute to a LightningModule
in the __init__
function, this module will end being registered as a sub-module to the parent pytorch lightning module.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.