Issue
I am somewhat new to pytorch since I have been using Keras for some years. Now I want to run a network architecture search (NAS) based on DARTS: Differentiable Architecture Search (see https://nni.readthedocs.io/en/stable/NAS/DARTS.html) and it is based on pytorch.
All examples available use accuracy as a metric, but I would need to calculate MSE. This is one of the examples available:
DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
# where the accuracy is defined in a separate function:
def accuracy(output, target, topk=(1,)):
# Computes the precision@k for the specified values of k
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
As I see in pytorch it is more complicated to calculate metrics then in Keras. Can someone help please?
As a trial, I wrote this code:
def accuracy_mse(output, target):
batch_size = target.size(0)
diff = torch.square(output.t()-target)/batch_size
diff = diff.sum()
res = dict()
res["acc_mse"] = diff
return res
It seems to be working, but I am not 100% sure about it ...
Solution
Finally I figured out that the transpose (.t() ) wac causing the problem, so the final code is:
def accuracy_mse(output, target):
""" Computes the mse """
batch_size = target.size(0)
diff = torch.square(output-target)/batch_size
diff = diff.sum()
res = dict()
res["mse"] = diff
return res
Answered By - user898160
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.