gnn_tracking.metrics.binary_classification#

This class collects figures of merit for binary classification

Classes#

BinaryClassificationStats

Calculator for binary classification metrics.

Functions#

zero_divide(→ float)

Normal division a/b but return 0 for x/0

get_maximized_bcs(→ dict[str, float])

Calculate the best possible binary classification stats for a given output and y.

roc_auc_score(→ float)

Wrapper that ignores exceptions

get_roc_auc_scores(true, predicted, max_fprs)

Calculate ROC AUC scores for a given set of maximum FPRs.

Module Contents#

class gnn_tracking.metrics.binary_classification.BinaryClassificationStats(output: torch.Tensor, y: torch.Tensor, thld: torch.Tensor | float)#

Calculator for binary classification metrics. All properties are cached, so they are only calculated once.

Parameters:
  • output – Output weights

  • y – True labels

  • thld – Threshold to consider something true

Returns:

accuracy, TPR, TNR

_output#
_y#
_thld#
property _true: torch.Tensor#
property n_true: int#
property _false: torch.Tensor#
property n_false: int#
property _predicted_false: torch.Tensor#
property n_predicted_false: int#
property _predicted_true: torch.Tensor#
property n_predicted_true: int#
property TP: float#
property TN: float#
property FP: float#
property FN: float#
property acc: float#
property TPR: float#
property TNR: float#
property FPR: float#
property FNR: float#
property balanced_acc: float#
property F1: float#
property MCC: float#
get_all() dict[str, float]#
gnn_tracking.metrics.binary_classification.zero_divide(a: float, b: float) float#

Normal division a/b but return 0 for x/0

gnn_tracking.metrics.binary_classification.get_maximized_bcs(*, output: torch.Tensor, y: torch.Tensor, n_samples=200) dict[str, float]#

Calculate the best possible binary classification stats for a given output and y.

Parameters:
  • output – Weights

  • y – True

  • n_samples – Number of thresholds to sample

Returns:

Dictionary of metrics

gnn_tracking.metrics.binary_classification.roc_auc_score(*, y_true: torch.Tensor, y_score: torch.Tensor, max_fpr: float | None = None, device=None) float#

Wrapper that ignores exceptions that can e.g., be raised if there’s only one label present.

gnn_tracking.metrics.binary_classification.get_roc_auc_scores(true, predicted, max_fprs: Iterable[float | None])#

Calculate ROC AUC scores for a given set of maximum FPRs.