gnn_tracking.analysis.graphs#

Classes#

TrackGraphInfo

Information about how well connected the hits of a track are in the graph.

OrphanCount

Stats about the number of orphan nodes in a graph

Functions#

get_good_node_mask(→ torch.Tensor)

Get a mask for nodes that are included in metrics and more.

shortest_path_length_catch_no_path(→ int | float)

Same as nx.shortest_path_length but return inf if no path exists

shortest_path_length_multi(graph, sources, targets)

Shortest path for source to reach any of targets from any of the sources.

get_n_reachable(→ int)

Get the number of targets that are reachable from source. The source node itself

get_track_graph_info(→ TrackGraphInfo)

Get information about how well connected the hits of a single particle are in the

get_track_graph_info_from_data(→ pandas.DataFrame)

Get DataFrame of track graph information for every particle ID in the data.

summarize_track_graph_info(→ dict[str, float])

Summarize track graph information returned by

get_orphan_counts(→ OrphanCount)

Count unmber of orphan nodes in a graph. See OrphanCount for details.

get_basic_counts(→ dict[str, int])

Get basic counts of edges and nodes

get_all_graph_construction_stats(→ dict[str, float])

Evaluate graph construction performance for a single batch.

get_largest_segment_fracs(→ numpy.ndarray)

A fast way to get the fraction of hits in the largest segment for each track.

get_cc_labels(→ torch.Tensor)

Get labels for connected components of a graph.

Module Contents#

gnn_tracking.analysis.graphs.get_good_node_mask(data: torch_geometric.data.Data, *, pt_thld: float = 0.9, max_eta: float = 4.0) torch.Tensor#

Get a mask for nodes that are included in metrics and more. This includes lower limit on pt, not noise, reconstructable, cut on eta.

gnn_tracking.analysis.graphs.shortest_path_length_catch_no_path(graph: networkx.Graph, source, target) int | float#

Same as nx.shortest_path_length but return inf if no path exists

gnn_tracking.analysis.graphs.shortest_path_length_multi(graph: networkx.Graph, sources: Iterable[int], targets: Iterable[int])#

Shortest path for source to reach any of targets from any of the sources. If no connection exists, returns inf. If only target is source itself, returns 0.

gnn_tracking.analysis.graphs.get_n_reachable(graph: networkx.Graph, source: int, targets: Sequence[int]) int#

Get the number of targets that are reachable from source. The source node itself will not be counted!

class gnn_tracking.analysis.graphs.TrackGraphInfo#

Bases: NamedTuple

Information about how well connected the hits of a track are in the graph.

Here, “component” means connected component of the graph. “segment” means connected component of the graph that only contains hits of the track with the given particle ID.

pid#

The particle ID of the track.

n_hits#

The number of hits in the track.

n_segments#

The number of segments of the track.

n_hits_largest_segment#

The number of hits in the largest segment of the track.

distance_largest_segments#

The shortest path length between the two largest segments

n_hits_largest_component#

The number of hits of the track of the biggest component of the track.

pid: int#
n_hits: int#
n_segments: int#
n_hits_largest_segment: int#
distance_largest_segments: int#
n_hits_largest_component: int#
gnn_tracking.analysis.graphs.get_track_graph_info(graph: networkx.Graph, particle_ids: numpy.ndarray, pid: int) TrackGraphInfo#

Get information about how well connected the hits of a single particle are in the graph.

Parameters:
  • graph – networkx graph of the data

  • particle_ids – The particle IDs of the hits.

  • pid – Particle ID of the true track

Returns:

TrackGraphInfo

gnn_tracking.analysis.graphs.get_track_graph_info_from_data(data: torch_geometric.data.Data, *, w: torch.Tensor | None = None, pt_thld=0.9, threshold: float | None = None, max_eta: float = 4.0) pandas.DataFrame#

Get DataFrame of track graph information for every particle ID in the data. This function basically applies get_track_graph_info to every particle ID.

Parameters:
  • data – Data

  • w – Edge weights. If None, no cut on edge weights

  • pt_thld – pt threshold for particle IDs to consider

  • threshold – Edge classification cutoff (if w is given)

Returns:

DataFrame with columns as in TrackGraphInfo

gnn_tracking.analysis.graphs.summarize_track_graph_info(tgi: pandas.DataFrame) dict[str, float]#

Summarize track graph information returned by get_track_graph_info_from_data.

class gnn_tracking.analysis.graphs.OrphanCount#

Bases: NamedTuple

Stats about the number of orphan nodes in a graph

n_orphan_correct#

Number of orphan nodes that are actually bad nodes (low pt or noise)

n_orphan_incorrect#

Number of orphan nodes that are actually good nodes

n_orphan_total#

Total number of orphan nodes

n_orphan_correct: int#
n_orphan_incorrect: int#
n_orphan_total: int#
gnn_tracking.analysis.graphs.get_orphan_counts(data: torch_geometric.data.Data, *, pt_thld=0.9, max_eta: float = 4.0) OrphanCount#

Count unmber of orphan nodes in a graph. See OrphanCount for details.

gnn_tracking.analysis.graphs.get_basic_counts(data: torch_geometric.data.Data, *, pt_thld: float = 0.9, max_eta: float = 4.0) dict[str, int]#

Get basic counts of edges and nodes

gnn_tracking.analysis.graphs.get_all_graph_construction_stats(data: torch_geometric.data.Data, pt_thld=0.9, max_eta: float = 4.0) dict[str, float]#

Evaluate graph construction performance for a single batch.

gnn_tracking.analysis.graphs.get_largest_segment_fracs(data: torch_geometric.data.Data, *, pt_thld=0.9, n_particles_sampled=None, max_eta=4) numpy.ndarray#

A fast way to get the fraction of hits in the largest segment for each track.

Parameters:
  • data

  • pt_thld

  • n_particles_sampled – If not None, only consider a subsample of the particles. This speeds up calculation but introduces statistical fluctuations.

  • max_eta – Maximum pseudorapidity

Returns:

Array of fractions.

gnn_tracking.analysis.graphs.get_cc_labels(edge_index: torch.Tensor, num_nodes: int) torch.Tensor#

Get labels for connected components of a graph.