gnn_tracking.metrics.binary_classification#
This class collects figures of merit for binary classification
Classes#
Calculator for binary classification metrics. |
Functions#
|
Normal division a/b but return 0 for x/0 |
|
Calculate the best possible binary classification stats for a given output and y. |
|
Wrapper that ignores exceptions |
|
Calculate ROC AUC scores for a given set of maximum FPRs. |
Module Contents#
- class gnn_tracking.metrics.binary_classification.BinaryClassificationStats(output: torch.Tensor, y: torch.Tensor, thld: torch.Tensor | float)#
Calculator for binary classification metrics. All properties are cached, so they are only calculated once.
- Parameters:
output – Output weights
y – True labels
thld – Threshold to consider something true
- Returns:
accuracy, TPR, TNR
- _output#
- _y#
- _thld#
- property _true: torch.Tensor#
- property n_true: int#
- property _false: torch.Tensor#
- property n_false: int#
- property _predicted_false: torch.Tensor#
- property n_predicted_false: int#
- property _predicted_true: torch.Tensor#
- property n_predicted_true: int#
- property TP: float#
- property TN: float#
- property FP: float#
- property FN: float#
- property acc: float#
- property TPR: float#
- property TNR: float#
- property FPR: float#
- property FNR: float#
- property balanced_acc: float#
- property F1: float#
- property MCC: float#
- get_all() dict[str, float] #
- gnn_tracking.metrics.binary_classification.zero_divide(a: float, b: float) float #
Normal division a/b but return 0 for x/0
- gnn_tracking.metrics.binary_classification.get_maximized_bcs(*, output: torch.Tensor, y: torch.Tensor, n_samples=200) dict[str, float] #
Calculate the best possible binary classification stats for a given output and y.
- Parameters:
output – Weights
y – True
n_samples – Number of thresholds to sample
- Returns:
Dictionary of metrics
- gnn_tracking.metrics.binary_classification.roc_auc_score(*, y_true: torch.Tensor, y_score: torch.Tensor, max_fpr: float | None = None, device=None) float #
Wrapper that ignores exceptions that can e.g., be raised if there’s only one label present.
- gnn_tracking.metrics.binary_classification.get_roc_auc_scores(true, predicted, max_fprs: Iterable[float | None])#
Calculate ROC AUC scores for a given set of maximum FPRs.