:py:mod:`gnn_tracking.metrics.cluster_metrics`
==============================================

.. py:module:: gnn_tracking.metrics.cluster_metrics

.. autoapi-nested-parse::

   Metrics evaluating the quality of clustering/i.e., the usefulness of the
   algorithm for tracking.



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.cluster_metrics.ClusterMetricType
   gnn_tracking.metrics.cluster_metrics.TrackingMetrics



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.cluster_metrics.tracking_metric_df
   gnn_tracking.metrics.cluster_metrics.count_tracking_metrics
   gnn_tracking.metrics.cluster_metrics.tracking_metrics
   gnn_tracking.metrics.cluster_metrics.tracking_metrics_data
   gnn_tracking.metrics.cluster_metrics.tracking_metrics_vs_pt
   gnn_tracking.metrics.cluster_metrics.tracking_metrics_vs_eta
   gnn_tracking.metrics.cluster_metrics.flatten_track_metrics
   gnn_tracking.metrics.cluster_metrics.count_hits_per_cluster
   gnn_tracking.metrics.cluster_metrics.hits_per_cluster_count_to_flat_dict
   gnn_tracking.metrics.cluster_metrics._sklearn_signature_wrap



Attributes
~~~~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.cluster_metrics._tracking_metrics_nan_results
   gnn_tracking.metrics.cluster_metrics.common_metrics


.. py:class:: ClusterMetricType


   Bases: :py:obj:`Protocol`

   Function type that calculates a clustering metric.

   .. py:method:: __call__(*, truth: numpy.ndarray, predicted: numpy.ndarray, pts: numpy.ndarray, reconstructable: numpy.ndarray, pt_thlds: list[float]) -> float | dict[str, float]



.. py:class:: TrackingMetrics


   Bases: :py:obj:`TypedDict`

   Initialize self.  See help(type(self)) for accurate signature.

   .. py:attribute:: n_particles
      :type: int

      

   .. py:attribute:: n_cleaned_clusters
      :type: int

      

   .. py:attribute:: perfect
      :type: float

      

   .. py:attribute:: double_majority
      :type: float

      

   .. py:attribute:: lhc
      :type: float

      

   .. py:attribute:: fake_perfect
      :type: float

      

   .. py:attribute:: fake_double_majority
      :type: float

      

   .. py:attribute:: fake_lhc
      :type: float

      


.. py:data:: _tracking_metrics_nan_results
   :type: TrackingMetrics

   

.. py:function:: tracking_metric_df(h_df: pandas.DataFrame, predicted_count_thld=3) -> pandas.DataFrame

   Label clusters as double majority/perfect/LHC.

   :param h_df: Hit information dataframe
   :param predicted_count_thld: Number of hits a cluster must have to be considered a
                                valid cluster

   :returns: cluster dataframe with columns such as "double_majority" etc.


.. py:function:: count_tracking_metrics(c_df: pandas.DataFrame, h_df: pandas.DataFrame, c_mask: numpy.ndarray, h_mask: numpy.ndarray) -> TrackingMetrics

   Calculate TrackingMetrics from cluster and hit information.

   :param c_df: Output dataframe from `tracking_metric_dfs`
   :param h_df: Hit information dataframe
   :param c_mask: Cluster mask
   :param h_mask: Hit mask

   :returns: TrackingMetrics namedtuple.


.. py:function:: tracking_metrics(*, truth: numpy.ndarray, predicted: numpy.ndarray, pts: numpy.ndarray, reconstructable: numpy.ndarray, eta: numpy.ndarray, pt_thlds: Iterable[float], predicted_count_thld=3, max_eta=4) -> dict[float, TrackingMetrics]

   Calculate 'custom' metrics for matching tracks and hits.

   :param truth: Truth labels/PIDs for each hit
   :param predicted: Predicted labels/cluster index for each hit. Negative labels are
                     interpreted as noise (because this is how DBSCAN outputs it) and are
                     ignored
   :param pts: true pt value of particle belonging to each hit
   :param reconstructable: Whether the hit belongs to a "reconstructable tracks" (this
                           usually implies a cut on the number of layers that are being hit
                           etc.)
   :param eta: true pseudorapidity of particle belong to each hit
   :param pt_thlds: pt thresholds to calculate the metrics for
   :param predicted_count_thld: Minimal number of hits in a cluster for it to not be
                                rejected.
   :param max_eta: Maximum eta value to count

   :returns: See `TrackingMetrics`


.. py:function:: tracking_metrics_data(data: torch_geometric.data.Data, labels, pt_thlds: Iterable[float], predicted_count_thld=3, max_eta=4) -> dict[float, TrackingMetrics]

   Convenience function to apply `tracking_metrics` to a Data object.

   :param data: Data object
   :param labels: Predicted labels/cluster index for each hit. Negative labels are
                  treated as noise
   :param pt_thlds: pt thresholds to calculate the metrics for
   :param predicted_count_thld: Minimal number of hits in a cluster for it to not be
                                rejected.
   :param max_eta: Maximum eta value to count


.. py:function:: tracking_metrics_vs_pt(h_dfs: list[pandas.DataFrame], c_dfs: list[pandas.DataFrame], pts: list[float], *, max_eta: float = 4.0) -> pandas.DataFrame

   Calculate tracking metrics for pt slices.

   :param h_dfs: List of hit dataframes for different batches (see `tracking_metrics_df`)
   :param c_dfs: List of cluster dataframes for different batches (see
                 `tracking_metrics_df`)
   :param pts: List of pt points to calculate the metrics for
   :param max_eta: Maximum eta value to count

   :returns: Dataframe with tracking metrics for each pt slice


.. py:function:: tracking_metrics_vs_eta(h_dfs: list[pandas.DataFrame], c_dfs: list[pandas.DataFrame], etas: list[float], pt_thld: float = 0.9) -> pandas.DataFrame

   :param h_dfs: List of hit dataframes for different batches (see `tracking_metrics_df`)
   :param c_dfs: List of cluster dataframes for different batches (see
                 `tracking_metrics_df`)
   :param etas: Eta points to calculate metrics for
   :param pt_thld:

   :returns: Dataframe with tracking metrics for each pt slice


.. py:function:: flatten_track_metrics(custom_metrics_result: dict[float, dict[str, float]]) -> dict[str, float]

   Flatten the result of `custom_metrics` by using pt suffixes to arrive at a
   flat dictionary, rather than a nested one.


.. py:function:: count_hits_per_cluster(predicted: numpy.ndarray) -> numpy.ndarray

   Count number of hits per cluster


.. py:function:: hits_per_cluster_count_to_flat_dict(counts: numpy.ndarray, min_max=10) -> dict[str, float]

   Turn result array from `count_hits_per_cluster` into a dictionary
   with cumulative counts.

   :param counts: Result from `count_hits_per_cluster`
   :param min_max: Pad the counts with zeros to at least this length


.. py:function:: _sklearn_signature_wrap(func: Callable) -> ClusterMetricType

   A decorator to make an sklearn cluster metric function accept/take the
   arguments from ``ClusterMetricType``.


.. py:data:: common_metrics
   :type: dict[str, ClusterMetricType]

   

