:py:mod:`gnn_tracking.metrics.binary_classification`
====================================================

.. py:module:: gnn_tracking.metrics.binary_classification

.. autoapi-nested-parse::

   This class collects figures of merit for binary classification



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.binary_classification.BinaryClassificationStats



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.binary_classification.zero_divide
   gnn_tracking.metrics.binary_classification.get_maximized_bcs
   gnn_tracking.metrics.binary_classification.roc_auc_score
   gnn_tracking.metrics.binary_classification.get_roc_auc_scores



.. py:class:: BinaryClassificationStats(output: torch.Tensor, y: torch.Tensor, thld: torch.Tensor | float)


   Calculator for binary classification metrics.
   All properties are cached, so they are only calculated once.

   :param output: Output weights
   :param y: True labels
   :param thld: Threshold to consider something true

   :returns: accuracy, TPR, TNR

   .. py:method:: _true() -> torch.Tensor


   .. py:method:: n_true() -> int


   .. py:method:: _false() -> torch.Tensor


   .. py:method:: n_false() -> int


   .. py:method:: _predicted_false() -> torch.Tensor


   .. py:method:: n_predicted_false() -> int


   .. py:method:: _predicted_true() -> torch.Tensor


   .. py:method:: n_predicted_true() -> int


   .. py:method:: TP() -> float


   .. py:method:: TN() -> float


   .. py:method:: FP() -> float


   .. py:method:: FN() -> float


   .. py:method:: acc() -> float


   .. py:method:: TPR() -> float


   .. py:method:: TNR() -> float


   .. py:method:: FPR() -> float


   .. py:method:: FNR() -> float


   .. py:method:: balanced_acc() -> float


   .. py:method:: F1() -> float


   .. py:method:: MCC() -> float


   .. py:method:: get_all() -> dict[str, float]



.. py:function:: zero_divide(a: float, b: float) -> float

   Normal division a/b but return 0 for x/0


.. py:function:: get_maximized_bcs(*, output: torch.Tensor, y: torch.Tensor, n_samples=200) -> dict[str, float]

   Calculate the best possible binary classification stats for a given output and y.

   :param output: Weights
   :param y: True
   :param n_samples: Number of thresholds to sample

   :returns: Dictionary of metrics


.. py:function:: roc_auc_score(*, y_true: torch.Tensor, y_score: torch.Tensor, max_fpr: float | None = None, device=None) -> float

   Wrapper that ignores exceptions
   that can e.g., be raised if there's only one label present.


.. py:function:: get_roc_auc_scores(true, predicted, max_fprs: Iterable[float | None])

   Calculate ROC AUC scores for a given set of maximum FPRs.


