:py:mod:`gnn_tracking.analysis.graphs`
======================================

.. py:module:: gnn_tracking.analysis.graphs


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.analysis.graphs.TrackGraphInfo
   gnn_tracking.analysis.graphs.OrphanCount



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.analysis.graphs.shortest_path_length_catch_no_path
   gnn_tracking.analysis.graphs.shortest_path_length_multi
   gnn_tracking.analysis.graphs.get_n_reachable
   gnn_tracking.analysis.graphs.get_track_graph_info
   gnn_tracking.analysis.graphs.get_track_graph_info_from_data
   gnn_tracking.analysis.graphs.summarize_track_graph_info
   gnn_tracking.analysis.graphs.get_orphan_counts
   gnn_tracking.analysis.graphs.get_basic_counts
   gnn_tracking.analysis.graphs.get_all_graph_construction_stats
   gnn_tracking.analysis.graphs.get_largest_segment_fracs
   gnn_tracking.analysis.graphs.get_cc_labels



.. py:function:: shortest_path_length_catch_no_path(graph: networkx.Graph, source, target) -> int | float

   Same as nx.shortest_path_length but return inf if no path exists


.. py:function:: shortest_path_length_multi(graph: networkx.Graph, sources: Iterable[int], targets: Iterable[int])

   Shortest path for source to reach any of targets from any of the sources.
   If no connection exists, returns inf. If only target is source itself, returns 0.


.. py:function:: get_n_reachable(graph: networkx.Graph, source: int, targets: Sequence[int]) -> int

   Get the number of targets that are reachable from source. The source node itself
   will not be counted!


.. py:class:: TrackGraphInfo


   Bases: :py:obj:`NamedTuple`

   Information about how well connected the hits of a track are in the graph.

   Here, "component" means connected component of the graph.
   "segment" means connected component of the graph that only contains hits of the
   track with the given particle ID.

   .. attribute:: pid

      The particle ID of the track.

   .. attribute:: n_hits

      The number of hits in the track.

   .. attribute:: n_segments

      The number of segments of the track.

   .. attribute:: n_hits_largest_segment

      The number of hits in the largest segment of the track.

   .. attribute:: distance_largest_segments

      The shortest path length between the two largest
      segments

   .. attribute:: n_hits_largest_component

      The number of hits of the track of the biggest
      component of the track.

   .. py:attribute:: pid
      :type: int

      

   .. py:attribute:: n_hits
      :type: int

      

   .. py:attribute:: n_segments
      :type: int

      

   .. py:attribute:: n_hits_largest_segment
      :type: int

      

   .. py:attribute:: distance_largest_segments
      :type: int

      

   .. py:attribute:: n_hits_largest_component
      :type: int

      


.. py:function:: get_track_graph_info(graph: networkx.Graph, particle_ids: numpy.ndarray, pid: int) -> TrackGraphInfo

   Get information about how well connected the hits of a single particle are in the
   graph.

   :param graph: networkx graph of the data
   :param particle_ids: The particle IDs of the hits.
   :param pid: Particle ID of the true track

   :returns: `TrackGraphInfo`


.. py:function:: get_track_graph_info_from_data(data: torch_geometric.data.Data, *, w: torch.Tensor | None = None, pt_thld=0.9, threshold: float | None = None, max_eta: float = 4.0) -> pandas.DataFrame

   Get DataFrame of track graph information for every particle ID in the data.
   This function basically applies `get_track_graph_info` to every particle ID.

   :param data: Data
   :param w: Edge weights. If None, no cut on edge weights
   :param pt_thld: pt threshold for particle IDs to consider
   :param threshold: Edge classification cutoff (if w is given)

   :returns: DataFrame with columns as in `TrackGraphInfo`


.. py:function:: summarize_track_graph_info(tgi: pandas.DataFrame) -> dict[str, float]

   Summarize track graph information returned by
   `get_track_graph_info_from_data`.


.. py:class:: OrphanCount


   Bases: :py:obj:`NamedTuple`

   Stats about the number of orphan nodes in a graph

   .. attribute:: n_orphan_correct

      Number of orphan nodes that are actually bad nodes (low pt or
      noise)

   .. attribute:: n_orphan_incorrect

      Number of orphan nodes that are actually good nodes

   .. attribute:: n_orphan_total

      Total number of orphan nodes

   .. py:attribute:: n_orphan_correct
      :type: int

      

   .. py:attribute:: n_orphan_incorrect
      :type: int

      

   .. py:attribute:: n_orphan_total
      :type: int

      


.. py:function:: get_orphan_counts(data: torch_geometric.data.Data, *, pt_thld=0.9, max_eta: float = 4.0) -> OrphanCount

   Count unmber of orphan nodes in a graph. See `OrphanCount` for details.


.. py:function:: get_basic_counts(data: torch_geometric.data.Data, *, pt_thld: float = 0.9, max_eta: float = 4.0) -> dict[str, int]

   Get basic counts of edges and nodes


.. py:function:: get_all_graph_construction_stats(data: torch_geometric.data.Data, pt_thld=0.9, max_eta: float = 4.0) -> dict[str, float]

   Evaluate graph construction performance for a single batch.


.. py:function:: get_largest_segment_fracs(data: torch_geometric.data.Data, *, pt_thld=0.9, n_particles_sampled=None, max_eta=4) -> numpy.ndarray

   A fast way to get the fraction of hits in the largest segment for each track.

   :param data:
   :param pt_thld:
   :param n_particles_sampled: If not None, only consider a subsample of the particles.
                               This speeds up calculation but introduces statistical fluctuations.
   :param max_eta: Maximum pseudorapidity

   :returns: Array of fractions.


.. py:function:: get_cc_labels(edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor

   Get labels for connected components of a graph.


