Issue
I am training a model in pytorch for which I have made a class like so:
from torch import nn
class myNN(nn.Module):
def __init__(self, dense1=128, dense2=64, dense3=32, ...):
self.MLP = nn.Sequential(
nn.Linear(dense1, dense2),
nn.ReLU(),
nn.Linear(dense2, dense3),
nn.ReLU(),
nn.Linear(dense3, 1)
)
...
In order to save it I am using:
torch.save(model.state_dict(), checkpoint_model_path)
and to load it I am using:
model = myNN() # or with specified parameters
model.load_state_dict(torch.load(model_file))
However, in order for this method to work I have to use the right values in myNN()'s constructor. That means that I would need to somehow remember or store which parameters (layer sizes) I have used in each case in order to properly load different models.
Is there a flexible way to save/load models in pytorch where I would also read the size of the layers?
E.g. by loading a myNN() object directly or somehow reading the layer sizes from the saved pickle file?
I am hesitant to try the second method in Best way to save a trained model in PyTorch? due to the warnings mentioned there. Is there a better way to achieve what I want?
Solution
Indeed serializing the whole Python is quite a drastic move. Instead, you can always add user-defined items in the saved file: you can save the model's state along with its class parameters. Something like this would work:
First save your arguments in the instance such that we can serialize them when saving the model:
class myNN(nn.Module): def __init__(self, dense1=128, dense2=64, dense3=32): super().__init__() self.kwargs = {'dense1': dense1, 'dense2': dense2, 'dense3': dense3} self.MLP = nn.Sequential( nn.Linear(dense1, dense2), nn.ReLU(), nn.Linear(dense2, dense3), nn.ReLU(), nn.Linear(dense3, 1))
We can save the parameters of the model along with its initializer arguments:
>>> torch.save([model.kwargs, model.state_dict()], path)
Then load it:
>>> kwargs, state = torch.load(path) >>> model = myNN(**kwargs) >>> model.load_state_dict(state) <All keys matched successfully>
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.