Issue
I'm trying to calculate the rmse error of two torch tensors. I would like to ignore/mask the rows where the labels are 0 (missing values). How could I modify this line to take that restriction into account?
torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()
Thank you in advance.
Solution
This can be solved by defining a custom MSE loss function* that masks out the missing values, 0 in your case, from both the input and target tensors:
def mse_loss_with_nans(input, target):
# Missing data are nan's
# mask = torch.isnan(target)
# Missing data are 0's
mask = target == 0
out = (input[~mask]-target[~mask])**2
loss = out.mean()
return loss
(*) Computing MSE is equivalent to RMSE from an optimisation point of view -- with the advantage of being computationally faster.
Answered By - prl900
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.