gnn_tracking.metrics.binary_classification#

This class collects figures of merit for binary classification

Module Contents#

Classes#

BinaryClassificationStats

Calculator for binary classification metrics.

Functions#

zero_divide(→ float)

Normal division a/b but return 0 for x/0

get_maximized_bcs(→ dict[str, float])

Calculate the best possible binary classification stats for a given output and y.

roc_auc_score(→ float)

Wrapper that ignores exceptions

get_roc_auc_scores(true, predicted, max_fprs)

Calculate ROC AUC scores for a given set of maximum FPRs.

class gnn_tracking.metrics.binary_classification.BinaryClassificationStats(output: torch.Tensor, y: torch.Tensor, thld: torch.Tensor | float)#

Calculator for binary classification metrics. All properties are cached, so they are only calculated once.

Parameters:
  • output – Output weights

  • y – True labels

  • thld – Threshold to consider something true

Returns:

accuracy, TPR, TNR

_true() torch.Tensor#
n_true() int#
_false() torch.Tensor#
n_false() int#
_predicted_false() torch.Tensor#
n_predicted_false() int#
_predicted_true() torch.Tensor#
n_predicted_true() int#
TP() float#
TN() float#
FP() float#
FN() float#
acc() float#
TPR() float#
TNR() float#
FPR() float#
FNR() float#
balanced_acc() float#
F1() float#
MCC() float#
get_all() dict[str, float]#
gnn_tracking.metrics.binary_classification.zero_divide(a: float, b: float) float#

Normal division a/b but return 0 for x/0

gnn_tracking.metrics.binary_classification.get_maximized_bcs(*, output: torch.Tensor, y: torch.Tensor, n_samples=200) dict[str, float]#

Calculate the best possible binary classification stats for a given output and y.

Parameters:
  • output – Weights

  • y – True

  • n_samples – Number of thresholds to sample

Returns:

Dictionary of metrics

gnn_tracking.metrics.binary_classification.roc_auc_score(*, y_true: torch.Tensor, y_score: torch.Tensor, max_fpr: float | None = None, device=None) float#

Wrapper that ignores exceptions that can e.g., be raised if there’s only one label present.

gnn_tracking.metrics.binary_classification.get_roc_auc_scores(true, predicted, max_fprs: Iterable[float | None])#

Calculate ROC AUC scores for a given set of maximum FPRs.