'Using Dice metric in pytorch "torchmetrics" : dice_score() missing 2 required positional arguments: 'preds' and 'target'
I'm trying to use Dice metric from pytorch "torchmetrics". I found an example for using accuracy metric. like below :
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy()
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")
for x, y in valid_data:
y_hat = model(x)
valid_accuracy.update(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")
# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()
However, when I replaced "Accuracy()" with "Dice_score()". like below:
from torchmetrics.functional import dice_score
train_accuracy =dice_score()
valid_accuracy =dice_score()
I got below error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-43-726045592283> in <module>
3 from torchmetrics.functional import dice_score
4
----> 5 train_accuracy_2 =dice_score()# Accuracy()
6 valid_accuracy_2 =dice_score()# Accuracy()
7
TypeError: dice_score() missing 2 required positional arguments: 'preds' and 'target'
Is there an example of using "Dice" metric from "torchmetrics"
Solution 1:[1]
torchmetrics.classification.dice_score is the functional interface to the Dice score. That means it is a stateless function that expects the ground truth and predictions. There doesn't seem to be a module interface to the Dice score, like there is with accuracy.
torchmetrics.classification.Accuracy is a class that maintains state. Under the hood, it uses the functional interface, which is torchmetrics.functional.accuracy.
This is not enforced in any way, but typically classes are named with CamelCase and functions are named with snake_case.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | jakub |
