Shortcuts

Source code for torchtraining.functional.metrics.classification.binary

import typing

import torch

from ... import utils
from . import utils as binary_utils


[docs]@utils.docs def accuracy( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0, reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean, ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return reduction((output == target).float())
[docs]@utils.docs def jaccard( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0, reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean, ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) union = (output | target).sum(axis=-1) intersection = (target & output).sum(axis=-1) empty = union <= 0 union[empty] = 1 intersection[empty] = 1 return reduction(intersection.float() / union)
# Basic cases
[docs]@utils.docs def true_positive( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0, reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.sum, ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return reduction((output & target).float())
[docs]@utils.docs def false_positive( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0, reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.sum, ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return reduction((output & ~target).float())
[docs]@utils.docs def true_negative( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0, reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.sum, ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return reduction((~output & ~target).float())
[docs]@utils.docs def false_negative( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0, reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.sum, ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return reduction((~output & target).float())
# Confusion matrix
[docs]@utils.docs def confusion_matrix( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0, reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.sum, ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) tp = reduction((output & target).float()) fp = reduction((output & ~target).float()) tn = reduction((~output & ~target).float()) fn = reduction((~output & target).float()) return torch.tensor([tp, fn, fp, tn]).reshape(2, 2, -1).squeeze()
# Rate metrics
[docs]@utils.docs def recall( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return (output & target).sum().float() / target.sum()
[docs]@utils.docs def specificity( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) inverse_target = ~target return (~output & inverse_target).sum().float() / inverse_target.sum()
[docs]@utils.docs def precision( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return (output & target).sum().float() / output.sum()
[docs]@utils.docs def negative_predictive_value( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) inverse_output = ~output return (inverse_output & ~target).sum().float() / inverse_output.sum()
[docs]@utils.docs def false_negative_rate( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return (~output & target).sum().float() / target.sum()
[docs]@utils.docs def false_positive_rate( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) inverse_target = ~target return (output & inverse_target).sum().float() / inverse_target.sum()
[docs]@utils.docs def false_discovery_rate( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) return (output & ~target).sum().float() / output.sum()
[docs]@utils.docs def false_omission_rate( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) inverse_output = ~output return (inverse_output & target).sum().float() / inverse_output.sum()
[docs]@utils.docs def critical_success_index( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) tp = (output & target).sum().float() return tp / tp + (output != target).sum()
[docs]@utils.docs def balanced_accuracy( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) inverse_target = ~target return ( (output & target).sum().float() / target.sum() + (~output & inverse_target).sum().float() / inverse_target.sum() ) / 2
[docs]@utils.docs def f1( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) tp = 2 * (output & target).sum().float() return tp / (tp + (output != target).sum())
[docs]@utils.docs def f_beta( output: torch.Tensor, target: torch.Tensor, beta: float, threshold: float = 0.0, ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) tp = (1 + beta) ** 2 * (output & target).sum().float() return tp / (tp + (beta ** 2) * (output != target).sum())
[docs]@utils.docs def matthews_correlation_coefficient( output: torch.Tensor, target: torch.Tensor, threshold: float = 0.0 ) -> torch.Tensor: binary_utils.binary.check(output, target) output, target = binary_utils.binary.threshold(output, target, threshold) inverse_output = ~output inverse_target = ~target dimensions = tuple(range(len(inverse_output.shape))) tp = (output & target).float().sum(dim=dimensions[1:]) tn = (inverse_output & inverse_target).float().sum(dim=dimensions[1:]) fp = (output & inverse_target).float().sum(dim=dimensions[1:]) fn = (inverse_output & target).float().sum(dim=dimensions[1:]) numerator = torch.dot(tp, tn) - torch.dot(fp, fn) denominator = ( output.sum() * target.sum() * inverse_target.sum() * inverse_output.sum() ) if denominator == 0.0: return numerator return numerator / denominator