:py:mod:`gnn_tracking.analysis.edge_classification`
===================================================

.. py:module:: gnn_tracking.analysis.edge_classification


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.analysis.edge_classification.ThresholdTrackInfoPlot



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.analysis.edge_classification.get_all_ec_stats
   gnn_tracking.analysis.edge_classification.collect_all_ec_stats



.. py:function:: get_all_ec_stats(threshold: float, w: torch.Tensor, data: torch_geometric.data.Data, *, pt_thld=0.9, max_eta=4) -> dict[str, float]

   Evaluate edge classification/graph construction performance for a single batch.
   Similar to `get_all_graph_construction_stats`, but also includes edge classification
   information.

   :param threshold: Edge classification threshold
   :param w: Edge classification output
   :param data: Data
   :param pt_thld: pt threshold for particle IDs to consider. For edge classification
                   stats (TPR, etc.), two versions are calculated: The stats ending with
                   `_thld` are calculated for all edges with pt > pt_thld

   :returns: Dictionary of metrics


.. py:function:: collect_all_ec_stats(model: torch.nn.Module, data_loader: torch_geometric.data.DataLoader, thresholds: Sequence[float], n_batches: int | None = None, max_workers=6, pt_thld=0.9) -> pandas.DataFrame

   Collect edge classification statistics for a model and a data loader, basically
   mapping `get_all_ec_stats` over the data loader with multiprocessing.

   :param model: Edge classifier model
   :param data_loader: Data loader
   :param thresholds: List of EC thresholds to evaluate
   :param n_batches: Number of batches to evaluate
   :param max_workers: Number of workers for multiprocessing

   :returns: DataFrame with columns as in `get_all_ec_stats`


.. py:class:: ThresholdTrackInfoPlot(df: pandas.DataFrame)


   Plot track info as a function of EC threshold.

   To get the plot in one go, simply call the `plot` method. Alternatively,
   use the individual methods to plot the different components separately.

   :param df: DataFrame with columns as in `get_all_ec_stats`

   .. py:method:: plot()

      Plot all the things.


   .. py:method:: setup_axes()


   .. py:method:: plot_line(key, **kwargs)


   .. py:method:: plot_errorline(key, **kwargs)


   .. py:method:: plot_100()


   .. py:method:: plot_50()


   .. py:method:: plot_75()


   .. py:method:: plot_tpr_fpr()


   .. py:method:: plot_mcc()


   .. py:method:: plot_hlines()


   .. py:method:: add_legend()



