gnn_tracking.metrics.binary_classification
#
This class collects figures of merit for binary classification
Module Contents#
Classes#
Calculator for binary classification metrics. |
Functions#
|
Normal division a/b but return 0 for x/0 |
|
Calculate the best possible binary classification stats for a given output and y. |
|
Wrapper that ignores exceptions |
|
Calculate ROC AUC scores for a given set of maximum FPRs. |
- class gnn_tracking.metrics.binary_classification.BinaryClassificationStats(output: torch.Tensor, y: torch.Tensor, thld: torch.Tensor | float)#
Calculator for binary classification metrics. All properties are cached, so they are only calculated once.
- Parameters:
output – Output weights
y – True labels
thld – Threshold to consider something true
- Returns:
accuracy, TPR, TNR
- _true() torch.Tensor #
- n_true() int #
- _false() torch.Tensor #
- n_false() int #
- _predicted_false() torch.Tensor #
- n_predicted_false() int #
- _predicted_true() torch.Tensor #
- n_predicted_true() int #
- TP() float #
- TN() float #
- FP() float #
- FN() float #
- acc() float #
- TPR() float #
- TNR() float #
- FPR() float #
- FNR() float #
- balanced_acc() float #
- F1() float #
- MCC() float #
- get_all() dict[str, float] #
- gnn_tracking.metrics.binary_classification.zero_divide(a: float, b: float) float #
Normal division a/b but return 0 for x/0
- gnn_tracking.metrics.binary_classification.get_maximized_bcs(*, output: torch.Tensor, y: torch.Tensor, n_samples=200) dict[str, float] #
Calculate the best possible binary classification stats for a given output and y.
- Parameters:
output – Weights
y – True
n_samples – Number of thresholds to sample
- Returns:
Dictionary of metrics
- gnn_tracking.metrics.binary_classification.roc_auc_score(*, y_true: torch.Tensor, y_score: torch.Tensor, max_fpr: float | None = None, device=None) float #
Wrapper that ignores exceptions that can e.g., be raised if there’s only one label present.
- gnn_tracking.metrics.binary_classification.get_roc_auc_scores(true, predicted, max_fprs: Iterable[float | None])#
Calculate ROC AUC scores for a given set of maximum FPRs.