Issue
I am implementing the same (simple) model in tf.keras
and PyTorch. My Keras model does fine, but my Torch model always predicts zeros and does poorly as a result.
My Keras model is defined as:
model = Sequential([
Dense(5, activation='relu', name='layer1'),
Dense(5, activation='relu', name='layer2'),
Dense(1, activation='sigmoid', name='output')
])
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
history = model.fit(
data.x_train,
data.y_train,
batch_size=128,
epochs=30,
validation_data=(data.x_test, data.y_test)
)
Meanwhile, my Torch model is defined like so:
class Model(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(19, 5),
nn.ReLU(),
nn.Linear(5, 5),
nn.ReLU(),
nn.Linear(5, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
def train(train_ds, test_ds, model):
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=True)
opt = optim.Adam(model.parameters())
loss_func = F.binary_cross_entropy
# 30 epochs
for epoch in range(30):
for xb, yb in train_dl:
pred = model(xb)
loss = loss_func(pred.reshape((pred.shape[0])), yb)
loss.backward()
opt.step()
opt.zero_grad()
Both use default initialization, which is Glorot Uniform; the data passed to both is the same too. As far as I can tell, I've implemented the same model, with the same optimizer and hyper-parameters in both frameworks, but the PyTorch model returns all zeros when I print it out. What am I doing wrong?
Solution
Fixing the optimizer steps
For the optimizer to work in your Pytorch model, these are the steps you should take, it might be counter intuitive in the beginning, but you need to zero grad first:
opt.zero_grad()
loss.backward()
opt.step()
Evaluating different models
As a side note; If you want the models to be exactly the same you should validate that your basic Learning Rate and Optimizer parameters are the same. Same as the Weights initializations params, the method might be the same (Glorot) but it might be depended on some internal params, that you can fix in your initialization.
Answered By - elkbrs
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.