'Using Dice metric in pytorch "torchmetrics" : dice_score() missing 2 required positional arguments: 'preds' and 'target'

I'm trying to use Dice metric from pytorch "torchmetrics". I found an example for using accuracy metric. like below :

from torchmetrics.classification import Accuracy

train_accuracy = Accuracy()
valid_accuracy = Accuracy()

for epoch in range(epochs):
    for x, y in train_data:
        y_hat = model(x)

        # training step accuracy
        batch_acc = train_accuracy(y_hat, y)
        print(f"Accuracy of batch{i} is {batch_acc}")

    for x, y in valid_data:
        y_hat = model(x)
        valid_accuracy.update(y_hat, y)

    # total accuracy over all training batches
    total_train_accuracy = train_accuracy.compute()

    # total accuracy over all validation batches
    total_valid_accuracy = valid_accuracy.compute()

    print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
    print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")

    # Reset metric states after each epoch
    train_accuracy.reset()
    valid_accuracy.reset() 

However, when I replaced "Accuracy()" with "Dice_score()". like below:

from torchmetrics.functional import dice_score

train_accuracy =dice_score()
valid_accuracy =dice_score()

I got below error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-43-726045592283> in <module>
      3 from torchmetrics.functional import dice_score
      4 
----> 5 train_accuracy_2 =dice_score()# Accuracy()
      6 valid_accuracy_2 =dice_score()# Accuracy()
      7 

TypeError: dice_score() missing 2 required positional arguments: 'preds' and 'target' 

Is there an example of using "Dice" metric from "torchmetrics"



Solution 1:[1]

torchmetrics.classification.dice_score is the functional interface to the Dice score. That means it is a stateless function that expects the ground truth and predictions. There doesn't seem to be a module interface to the Dice score, like there is with accuracy.

torchmetrics.classification.Accuracy is a class that maintains state. Under the hood, it uses the functional interface, which is torchmetrics.functional.accuracy.

This is not enforced in any way, but typically classes are named with CamelCase and functions are named with snake_case.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 jakub