Issue
Can't handle that problem for several days I'm new to NLP and the solution is probably very simple
class QAModel(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
def forward(self, input_ids, attention_mask, labels=None):
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
return output.loss, output.logits
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('train_loss', loss, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('val_loss', loss, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('test_loss', loss, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.0001)
model = QAModel()
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath='/content/checkpoints',
filename='best-checkpoint',
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
)
trainer = pl.Trainer(
checkpoint_callback=checkpoint_callback,
max_epochs=N_EPOCHS,
gpus=1,
progress_bar_refresh_rate=30
)
trainer.fit(model, data_module)
Running this code gives me AttributeError: 'QAModel' object has no attribute 'automatic_optimization' after fit() function Probably, the problem is in MT5ForConditionalGeneration, as after passing it to funtion() we've got the same error
Solution
Try inheriting pl.LightingModule
instead of pl.LightningDataModule
. It is the right choice for defining a model class.
Answered By - Shiv
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.