Issue
I'm trying to use Dice metric from pytorch "torchmetrics". I found an example for using accuracy metric. like below :
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy()
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")
for x, y in valid_data:
y_hat = model(x)
valid_accuracy.update(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")
# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()
However, when I replaced "Accuracy()" with "Dice_score()". like below:
from torchmetrics.functional import dice_score
train_accuracy =dice_score()
valid_accuracy =dice_score()
I got below error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-43-726045592283> in <module>
3 from torchmetrics.functional import dice_score
4
----> 5 train_accuracy_2 =dice_score()# Accuracy()
6 valid_accuracy_2 =dice_score()# Accuracy()
7
TypeError: dice_score() missing 2 required positional arguments: 'preds' and 'target'
Is there an example of using "Dice" metric from "torchmetrics"
Solution
torchmetrics.classification.dice_score
is the functional interface to the Dice score. That means it is a stateless function that expects the ground truth and predictions. There doesn't seem to be a module interface to the Dice score, like there is with accuracy.
torchmetrics.classification.Accuracy
is a class that maintains state. Under the hood, it uses the functional interface, which is torchmetrics.functional.accuracy
.
This is not enforced in any way, but typically classes are named with CamelCase and functions are named with snake_case.
Answered By - jakub
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.