Finds the right k for k-NN to optimize the graph construction.

Module Contents#



This object holds the results of scanning over ks. It performs


Scan over different values of k to build a graph and calculate the figures



class gnn_tracking.graph_construction.k_scanner.KScanResults(results: pandas.DataFrame, targets: Sequence[float])#

This object holds the results of scanning over ks. It performs interpolation to get the figures of merit (FOMs).

  • results – The results of the scan: (k, n_edges, frac50, …)

  • targets – The targets 50%-segment fractions that we’re interested in

_extra_metrics = ('k', 'frac75', 'frac100', 'efficiency', 'purity')#
get_foms() dict[str, float]#
plot() matplotlib.pyplot.Axes#

Plot interpolation

_eval_spline(k: float) dict[str, float]#
_get_target_k(target: float) float#

K at which the 50%-segment fraction = target

_get_foms_at_target(target: float) dict[str, float]#
class gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner(ks: list[int] = _DEFAULT_KS, *, targets=(0.8, 0.85, 0.88, 0.9, 0.93, 0.95, 0.97, 0.99), max_radius=1.0, pt_thld=0.9, max_eta=4.0, subsample_pids: int | None = None, max_edges=5000000)#

Bases: pytorch_lightning.core.mixins.HyperparametersMixin

Scan over different values of k to build a graph and calculate the figures of merit.

  • ks – List of ks to scan. Results will be interpolated between these values, so it’s a good idea to not make them too dense.

  • targets – Targets for the 50%-segment fraction that we aim for (we will find the k that gets us closest to these targets and report the number of edges for these places). Does not impact compute time.

  • max_radius – Maximum length of edges for the KNN graph.

  • pt_thld – pt threshold for evaluation of the 50%-segment fraction.

  • subsample_pids – Set to a number to subsample the number of pids in the evaluation of the 50%-segment fraction. This is useful for speeding up the evaluation of the 50%-segment fraction, but it will lead to a less accurate result/statistical fluctuations.

  • max_edges – Do not attempt to compute metrics for more than this number of edges in the knn graph

property results_raw: pandas.DataFrame#

DataFrame with raw results for all graphs and all k

get_results() KScanResults#

Get results object

get_foms() dict[str, float]#

Get figures of merit (convenience method that uses the appropriate method of KSCanResults).


Reset the results. Will be automatically called every time we run on a batch with i_batch == 0.

__call__(data:, i_batch: int, *, progress=False, latent: torch.Tensor | None = None) None#

Run on graph

  • data – Data object. data.x is the space used for clustering

  • i_batch – Batch number. Will reset saved data for i_batch == 0.

  • progress – Show progress bar

  • latent – Use this instead of data.x



_evaluate_graph(data:, k: int) dict[str, float] | None#

Evaluate metrics for single graphs

  • data

  • k


None if computation was aborted