Source code for torchtraining.metrics.classification.binary
import abc
import collections
import typing
import torch
from ... import _base, functional
from . import utils
###############################################################################
#
# COMMON BASE CLASSES
#
###############################################################################
class _Threshold(_base.Operation):
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
@abc.abstractmethod
def forward(self, data):
pass
class _ThresholdReductionMean(_base.Operation):
def __init__(self, threshold: float = 0.0, reduction=torch.mean):
super().__init__()
self.threshold = threshold
self.reduction = reduction
@abc.abstractmethod
def forward(self, data):
pass
class _ThresholdReductionSum(_base.Operation):
def __init__(self, threshold: float = 0.0, reduction=torch.sum):
super().__init__()
self.threshold = threshold
self.reduction = reduction
@abc.abstractmethod
def forward(self, data):
pass
###############################################################################
#
# CONCRETE IMPLEMENTATIONS
#
###############################################################################
###############################################################################
#
# MEAN DEFAULT REDUCTION
#
###############################################################################
[docs]@utils.binary.docs(
header="""Calculate accuracy score between `output` and `target`.""",
reduction="mean",
)
class Accuracy(_ThresholdReductionMean):
[docs] def forward(self, data):
return functional.metrics.classification.binary.accuracy(
*data, self.threshold, self.reduction,
)
[docs]@utils.binary.docs(
header="""Calculate jaccard score between `output` and `target`.""",
reduction="mean",
)
class Jaccard(_ThresholdReductionMean):
[docs] def forward(self, data):
return functional.metrics.classification.binary.jaccard(
*data, self.threshold, self.reduction,
)
###############################################################################
#
# SUM DEFAULT REDUCTION
#
###############################################################################
[docs]@utils.binary.docs(
header="""Number of true positives between `output` and `target`.""",
reduction="sum",
)
class TruePositive(_ThresholdReductionSum):
[docs] def forward(self, data):
return functional.metrics.classification.binary.true_positive(
*data, self.threshold, self.reduction
)
[docs]@utils.binary.docs(
header="""Number of false positives between `output` and `target`.""",
reduction="sum",
)
class FalsePositive(_ThresholdReductionSum):
[docs] def forward(self, data):
return functional.metrics.classification.binary.false_positive(
*data, self.threshold, self.reduction
)
[docs]@utils.binary.docs(
header="""Number of true negatives between `output` and `target`.""",
reduction="sum",
)
class TrueNegative(_ThresholdReductionSum):
[docs] def forward(self, data):
return functional.metrics.classification.binary.true_negative(
*data, self.threshold, self.reduction
)
[docs]@utils.binary.docs(
header="""Number of false negatives between `output` and `target`.""",
reduction="sum",
)
class FalseNegative(_ThresholdReductionSum):
[docs] def forward(self, data):
return functional.metrics.classification.binary.false_negative(
*data, self.threshold, self.reduction
)
[docs]@utils.binary.docs(
header="""Confusion matrix between `output` and `target`.""", reduction="sum",
)
class ConfusionMatrix(_ThresholdReductionSum):
[docs] def forward(self, data):
return functional.metrics.classification.binary.confusion_matrix(
*data, self.threshold, self.reduction
)
###############################################################################
#
# NO REDUCTION
#
###############################################################################
[docs]@utils.binary.docs(header="""Recall between `output` and `target`.""",)
class Recall(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.recall(*data, self.threshold)
[docs]@utils.binary.docs(header="""Specificity between `output` and `target`.""",)
class Specificity(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.specificity(
*data, self.threshold
)
[docs]@utils.binary.docs(header="""Precision between `output` and `target`.""",)
class Precision(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.precision(*data, self.threshold)
[docs]@utils.binary.docs(
header="""Negative predictive value between `output` and `target`.""",
)
class NegativePredictiveValue(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.negative_predictive_value(
*data, self.threshold
)
[docs]@utils.binary.docs(header="""False negative rate between `output` and `target`.""",)
class FalseNegativeRate(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.false_negative_rate(
*data, self.threshold
)
[docs]@utils.binary.docs(header="""False positive rate between `output` and `target`.""",)
class FalsePositiveRate(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.false_positive_rate(
*data, self.threshold
)
[docs]@utils.binary.docs(header="""False discovery rate between `output` and `target`.""",)
class FalseDiscoveryRate(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.false_discovery_rate(
*data, self.threshold
)
[docs]@utils.binary.docs(header="""False omission rate between `output` and `target`.""",)
class FalseOmissionRate(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.false_omission_rate(
*data, self.threshold
)
[docs]@utils.binary.docs(header="""Critical success index between `output` and `target`.""",)
class CriticalSuccessIndex(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.critical_success_index(
*data, self.threshold
)
[docs]@utils.binary.docs(header="""Critical success index between `output` and `target`.""",)
class BalancedAccuracy(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.balanced_accuracy(
*data, self.threshold
)
[docs]@utils.binary.docs(header="""F1 score between `output` and `target`.""",)
class F1(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.f1(*data, self.threshold)
[docs]@utils.binary.docs(
header="""Matthews correlation coefficient between `output` and `target`.""",
)
class MatthewsCorrelationCoefficient(_Threshold):
[docs] def forward(self, data):
return functional.metrics.classification.binary.matthews_correlation_coefficient(
*data, self.threshold
)
###############################################################################
#
# OTHER METRICS
#
###############################################################################
[docs]class FBeta(_base.Operation):
r"""Get f-beta score between `outputs` and `targets`.
Works for both logits and probabilities of `output`.
If `output` is tensor after `sigmoid` activation
user should change `threshold` to `0.5` for
correct results (default `0.0` corresponds to unnormalized probability a.k.a logits).
Parameters
----------
beta: float
Beta coefficient of `f-beta` score.
threshold : float, optional
Threshold above which prediction is considered to be positive.
Default: `0.0`
Arguments
---------
data: Tuple[torch.Tensor, torch.Tensor]
Tuple containing `outputs` from neural network and `targets` (ground truths).
`outputs` should be of shape :math:`(N, *)` and contain `logits` or `probabilities`.
`targets` should be of shape :math:`(N, *)` as well and contain `boolean` values
(or integers from set :math:`{0, 1}`).
Returns
-------
torch.Tensor
Scalar `tensor`
"""
def __init__(self, beta: float, threshold: float = 0.0):
super().__init__()
self.beta = beta
self.threshold = threshold
[docs] def forward(self, data):
return functional.metrics.classification.binary.f_beta(
*data, self.beta, self.threshold
)