Issue
I'm trying to create a cross-encoder model starting from "distilbert-base-uncased" using Huggingface transformers with PyTorch. The architecture is simple: get the CLS embedding from the concatenated input strings (this is handled by the Huggingface tokenizer), then pass it through a final FC linear layer to 1 output logit. The loss function is the built-in torch.nn.BCEWithLogitsLoss
function.
This model failed to learn correctly on my dataset, instead quickly converging the CLS embedding (prior the the final FC linear layer) to the same embedding for every input (in fact, the other tokens also converge to the same embedding). The last layer then maps this embedding to the ratio of positives in the training sample, which is the expected behavior assuming a constant embedding function.
For debugging purposes I simply fed it a dummy dataset consisting of the same 3 sentence pairs over and over (1 labelled positive, the other 2 negative). The same behavior persisted, but when I froze the transformer parameters (so that only the final FC was being trained), the model correctly overfit on the data point as expected.
My model architecture:
class CrossEncoderModel(nn.Module):
"""
Architecture:
- Transformer
- Final FC linear layer to one output for binary classification
"""
def __init__(
self, transformer_model: str, tokenizer: PreTrainedTokenizerFast
) -> None:
super(ParagraphCrossEncoderModel, self).__init__()
self.transformer = AutoModel.from_pretrained(transformer_model)
print(type(self.transformer))
self.transformer.resize_token_embeddings(len(tokenizer))
self.fc = nn.Linear(self.transformer.config.hidden_size, 1)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
)
cls_embedding = outputs.last_hidden_state[:, 0, :]
logits = self.fc(cls_embedding).squeeze(-1)
return logits
My loss/update (using Adam optimizer):
loss = torch.nn.BCEWithLogitsLoss(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
I tried varying the learning rate and batch size, neither of which changed the convergence to the same CLS embedding. I suspect something is wrong with my model architecture, but I'm having trouble finding what exactly. The behavior also persists when I replace the loss function with a manual target:
class TestLoss(nn.Module):
def __init__(self):
super(TestLoss, self).__init__()
def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
return torch.sum(torch.abs(logits - torch.tensor([10.0, -10.0, -10.0]).to(device)))
# still converges to all embeddings being the same
Solution
Okay, after a lot of debugging I tried changing my optimizer. I was using Adam which worked well when I was using a dual-encoder architecture. Changing to SGD fixed the issue and the model learns correctly now.
Not super sure why Adam wasn't working, will update if I figure it out.
Answered By - Ben Chen
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.