gnn_tracking.graph_construction.k_scanner#

Finds the right k for k-NN to optimize the graph construction.

Attributes#

logger

_DEFAULT_KS

Classes#

KScanResults

This object holds the results of scanning over ks. It performs

GraphConstructionKNNScanner

Scan over different values of k to build a graph and calculate the figures

Functions#

get_cc_labels(→ torch.Tensor)

Get labels for connected components of a graph.

get_largest_segment_fracs(→ numpy.ndarray)

A fast way to get the fraction of hits in the largest segment for each track.

flatten_track_metrics(→ dict[str, float])

Flatten the result of custom_metrics by using pt suffixes to arrive at a

tracking_metrics_data(→ dict[float, TrackingMetrics])

Convenience function to apply tracking_metrics to a Data object.

get_efficiency_purity_edges(→ dict[str, float])

Calculate efficiency and purity for edges based on data.true_edge_index.

knn_with_max_radius(→ torch.Tensor)

A version of kNN that excludes edges with a distance larger than a given radius.

add_key_prefix(→ dict[str, _P])

Return a copy of the dictionary with the prefix added to all keys.

pivot_record_list(→ dict)

Transform list of key value pairs into dict of lists.

Module Contents#

gnn_tracking.graph_construction.k_scanner.get_cc_labels(edge_index: torch.Tensor, num_nodes: int) torch.Tensor#

Get labels for connected components of a graph.

gnn_tracking.graph_construction.k_scanner.get_largest_segment_fracs(data: torch_geometric.data.Data, *, pt_thld=0.9, n_particles_sampled=None, max_eta=4) numpy.ndarray#

A fast way to get the fraction of hits in the largest segment for each track.

Parameters:
  • data

  • pt_thld

  • n_particles_sampled – If not None, only consider a subsample of the particles. This speeds up calculation but introduces statistical fluctuations.

  • max_eta – Maximum pseudorapidity

Returns:

Array of fractions.

gnn_tracking.graph_construction.k_scanner.flatten_track_metrics(custom_metrics_result: dict[float, dict[str, float]]) dict[str, float]#

Flatten the result of custom_metrics by using pt suffixes to arrive at a flat dictionary, rather than a nested one.

gnn_tracking.graph_construction.k_scanner.tracking_metrics_data(data: torch_geometric.data.Data, labels, pt_thlds: Iterable[float], predicted_count_thld=3, max_eta=4) dict[float, TrackingMetrics]#

Convenience function to apply tracking_metrics to a Data object.

Parameters:
  • data – Data object

  • labels – Predicted labels/cluster index for each hit. Negative labels are treated as noise

  • pt_thlds – pt thresholds to calculate the metrics for

  • predicted_count_thld – Minimal number of hits in a cluster for it to not be rejected.

  • max_eta – Maximum eta value to count

gnn_tracking.graph_construction.k_scanner.get_efficiency_purity_edges(data: torch_geometric.data.Data, pt_thld: float = 0.9, max_eta: float = 4.0) dict[str, float]#

Calculate efficiency and purity for edges based on data.true_edge_index.

Only edges where at least one of the two nodes is accepted by the pt threshold (and reconstructable etc.) are considered.

gnn_tracking.graph_construction.k_scanner.knn_with_max_radius(x: torch.Tensor, k: int, max_radius: float | None = None) torch.Tensor#

A version of kNN that excludes edges with a distance larger than a given radius.

Parameters:
  • x

  • k – Number of neighbors

  • max_radius

Returns:

edge index

gnn_tracking.graph_construction.k_scanner.add_key_prefix(dct: dict[str, _P], prefix: str = '') dict[str, _P]#

Return a copy of the dictionary with the prefix added to all keys.

gnn_tracking.graph_construction.k_scanner.pivot_record_list(records: list[dict]) dict#

Transform list of key value pairs into dict of lists.

gnn_tracking.graph_construction.k_scanner.logger#
class gnn_tracking.graph_construction.k_scanner.KScanResults(results: pandas.DataFrame, targets: Sequence[float])#

This object holds the results of scanning over ks. It performs interpolation to get the figures of merit (FOMs).

Parameters:
  • results – The results of the scan: (k, n_edges, frac50, …)

  • targets – The targets 50%-segment fractions that we’re interested in

_extra_metrics = ('k', 'frac75', 'frac100', 'efficiency', 'purity')#
get_foms() dict[str, float]#

Get figures of merit

plot() matplotlib.pyplot.Axes#

Plot interpolation

_spline() tuple[scipy.interpolate.CubicSpline, list[str], list[str]]#

Spline object. Do not use this object directly but rather only via _eval_spline.

Returns:

Spline object, list of columns that are nan, list of columns that are not

nan

_eval_spline(k: float) dict[str, float]#

Get figures of merit at k, using a spline for evaluation between data points.

_get_target_k(target: float) float#

K at which the 50%-segment fraction = target

_get_foms_at_target(target: float) dict[str, float]#

Get figures of merit at k given by self._get_target_k

gnn_tracking.graph_construction.k_scanner._DEFAULT_KS#
class gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner(ks: list[int] = _DEFAULT_KS, *, targets=(0.8, 0.85, 0.88, 0.9, 0.93, 0.95, 0.97, 0.99), max_radius=1.0, pt_thld=0.9, max_eta=4.0, subsample_pids: int | None = None, max_edges=5000000)#

Bases: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Scan over different values of k to build a graph and calculate the figures of merit.

Parameters:
  • ks – List of ks to scan. Results will be interpolated between these values, so it’s a good idea to not make them too dense.

  • targets – Targets for the 50%-segment fraction that we aim for (we will find the k that gets us closest to these targets and report the number of edges for these places). Does not impact compute time.

  • max_radius – Maximum length of edges for the KNN graph.

  • pt_thld – pt threshold for evaluation of the 50%-segment fraction.

  • subsample_pids – Set to a number to subsample the number of pids in the evaluation of the 50%-segment fraction. This is useful for speeding up the evaluation of the 50%-segment fraction, but it will lead to a less accurate result/statistical fluctuations.

  • max_edges – Do not attempt to compute metrics for more than this number of edges in the knn graph

property results_raw: pandas.DataFrame#

DataFrame with raw results for all graphs and all k

get_results() KScanResults#

Get results object

get_foms() dict[str, float]#

Get figures of merit (convenience method that uses the appropriate method of KSCanResults).

reset()#

Reset the results. Will be automatically called every time we run on a batch with i_batch == 0.

__call__(data: torch_geometric.data.Data, i_batch: int, *, progress=False, latent: torch.Tensor | None = None) None#

Run on graph

Parameters:
  • data – Data object. data.x is the space used for clustering

  • i_batch – Batch number. Will reset saved data for i_batch == 0.

  • progress – Show progress bar

  • latent – Use this instead of data.x

Returns:

None

_evaluate_tracking_metrics_upper_bounds(data: torch_geometric.data.Data) dict[str, float]#

Evaluate upper bounds of tracking metrics assuming a pipeline with perfect EC. See https://arxiv.org/abs/2309.16754

_evaluate_graph(data: torch_geometric.data.Data, k: int) dict[str, float] | None#

Evaluate metrics for single graphs

Parameters:
  • data

  • k

Returns:

None if computation was aborted