Issue
I am training a multi-label classification problem using Hugging face models. I am using Pytorch Lightning to train the model.
Here is the code:
And early stopping triggers when the loss hasn't improved for the last
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
We can start the training process:
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="best-checkpoint",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min"
)
trainer = pl.Trainer(
logger=logger,
callbacks=[early_stopping_callback],
max_epochs=N_EPOCHS,
checkpoint_callback=checkpoint_callback,
gpus=1,
progress_bar_refresh_rate=30
)
# checkpoint_callback=checkpoint_callback,
As soon as I run this, I get this error:
~/.local/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py in _configure_checkpoint_callbacks(self, checkpoint_callback)
75 if isinstance(checkpoint_callback, Callback):
76 error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
---> 77 raise MisconfigurationException(error_msg)
78 if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
79 raise MisconfigurationException(
MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the `callbacks` argument in the Trainer constructor instead.
How can I fix this issue?
Solution
You can look up the description of the checkpoint_callback
argument in the documentation page of pl.Trainer
:
checkpoint_callback
(bool) – IfTrue
, enable checkpointing. It will configure a defaultModelCheckpoint
callback if there is no user-definedModelCheckpoint
in callbacks.
You shouldn't pass your custom ModelCheckpoint
to this argument. I believe what you are looking to do is to pass both the EarlyStopping
and ModelCheckpoint
in the callbacks
list:
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="best-checkpoint",
save_top_k=1,
verbose=True,
monitor="val_loss",
mode="min")
trainer = pl.Trainer(
logger=logger,
callbacks=[checkpoint_callback, early_stopping_callback],
max_epochs=N_EPOCHS,
gpus=1,
progress_bar_refresh_rate=30)
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.