gnn_tracking.analysis.graphs#
Classes#
Information about how well connected the hits of a track are in the graph. |
|
Stats about the number of orphan nodes in a graph |
Functions#
|
Same as nx.shortest_path_length but return inf if no path exists |
|
Shortest path for source to reach any of targets from any of the sources. |
|
Get the number of targets that are reachable from source. The source node itself |
|
Get information about how well connected the hits of a single particle are in the |
|
Get DataFrame of track graph information for every particle ID in the data. |
|
Summarize track graph information returned by |
|
Count unmber of orphan nodes in a graph. See OrphanCount for details. |
|
Get basic counts of edges and nodes |
|
Evaluate graph construction performance for a single batch. |
|
A fast way to get the fraction of hits in the largest segment for each track. |
|
Get labels for connected components of a graph. |
Module Contents#
- gnn_tracking.analysis.graphs.shortest_path_length_catch_no_path(graph: networkx.Graph, source, target) int | float #
Same as nx.shortest_path_length but return inf if no path exists
- gnn_tracking.analysis.graphs.shortest_path_length_multi(graph: networkx.Graph, sources: Iterable[int], targets: Iterable[int])#
Shortest path for source to reach any of targets from any of the sources. If no connection exists, returns inf. If only target is source itself, returns 0.
- gnn_tracking.analysis.graphs.get_n_reachable(graph: networkx.Graph, source: int, targets: Sequence[int]) int #
Get the number of targets that are reachable from source. The source node itself will not be counted!
- class gnn_tracking.analysis.graphs.TrackGraphInfo#
Bases:
NamedTuple
Information about how well connected the hits of a track are in the graph.
Here, “component” means connected component of the graph. “segment” means connected component of the graph that only contains hits of the track with the given particle ID.
- pid#
The particle ID of the track.
- n_hits#
The number of hits in the track.
- n_segments#
The number of segments of the track.
- n_hits_largest_segment#
The number of hits in the largest segment of the track.
- distance_largest_segments#
The shortest path length between the two largest segments
- n_hits_largest_component#
The number of hits of the track of the biggest component of the track.
- pid: int#
- n_hits: int#
- n_segments: int#
- n_hits_largest_segment: int#
- distance_largest_segments: int#
- n_hits_largest_component: int#
- gnn_tracking.analysis.graphs.get_track_graph_info(graph: networkx.Graph, particle_ids: numpy.ndarray, pid: int) TrackGraphInfo #
Get information about how well connected the hits of a single particle are in the graph.
- Parameters:
graph – networkx graph of the data
particle_ids – The particle IDs of the hits.
pid – Particle ID of the true track
- Returns:
TrackGraphInfo
- gnn_tracking.analysis.graphs.get_track_graph_info_from_data(data: torch_geometric.data.Data, *, w: torch.Tensor | None = None, pt_thld=0.9, threshold: float | None = None, max_eta: float = 4.0) pandas.DataFrame #
Get DataFrame of track graph information for every particle ID in the data. This function basically applies get_track_graph_info to every particle ID.
- Parameters:
data – Data
w – Edge weights. If None, no cut on edge weights
pt_thld – pt threshold for particle IDs to consider
threshold – Edge classification cutoff (if w is given)
- Returns:
DataFrame with columns as in TrackGraphInfo
- gnn_tracking.analysis.graphs.summarize_track_graph_info(tgi: pandas.DataFrame) dict[str, float] #
Summarize track graph information returned by get_track_graph_info_from_data.
- class gnn_tracking.analysis.graphs.OrphanCount#
Bases:
NamedTuple
Stats about the number of orphan nodes in a graph
- n_orphan_correct#
Number of orphan nodes that are actually bad nodes (low pt or noise)
- n_orphan_incorrect#
Number of orphan nodes that are actually good nodes
- n_orphan_total#
Total number of orphan nodes
- n_orphan_correct: int#
- n_orphan_incorrect: int#
- n_orphan_total: int#
- gnn_tracking.analysis.graphs.get_orphan_counts(data: torch_geometric.data.Data, *, pt_thld=0.9, max_eta: float = 4.0) OrphanCount #
Count unmber of orphan nodes in a graph. See OrphanCount for details.
- gnn_tracking.analysis.graphs.get_basic_counts(data: torch_geometric.data.Data, *, pt_thld: float = 0.9, max_eta: float = 4.0) dict[str, int] #
Get basic counts of edges and nodes
- gnn_tracking.analysis.graphs.get_all_graph_construction_stats(data: torch_geometric.data.Data, pt_thld=0.9, max_eta: float = 4.0) dict[str, float] #
Evaluate graph construction performance for a single batch.
- gnn_tracking.analysis.graphs.get_largest_segment_fracs(data: torch_geometric.data.Data, *, pt_thld=0.9, n_particles_sampled=None, max_eta=4) numpy.ndarray #
A fast way to get the fraction of hits in the largest segment for each track.
- Parameters:
data
pt_thld
n_particles_sampled – If not None, only consider a subsample of the particles. This speeds up calculation but introduces statistical fluctuations.
max_eta – Maximum pseudorapidity
- Returns:
Array of fractions.
- gnn_tracking.analysis.graphs.get_cc_labels(edge_index: torch.Tensor, num_nodes: int) torch.Tensor #
Get labels for connected components of a graph.