gnn_tracking.training.ml#

Pytorch lightning module with training and validation step for the metric learning approach to graph construction.

Classes#

GraphConstructionKNNScanner

Scan over different values of k to build a graph and calculate the figures

MultiLossFct

Base class for loss functions that return multiple losses.

GraphConstructionHingeEmbeddingLoss

Loss for graph construction using metric learning.

TrackingModule

Base class for all pytorch lightning modules in this project.

MLModule

Pytorch lightning module with training and validation step for the metric

Functions#

add_key_suffix(→ dict[str, _P])

Return a copy of the dictionary with the suffix added to all keys.

to_floats(→ Any)

Convert all tensors in a datastructure to floats.

obj_from_or_to_hparams(→ Any)

Used to support initializing python objects from hyperparameters:

tolerate_some_oom_errors(fct)

Decorators to tolerate a couple of out of memory (OOM) errors.

Module Contents#

class gnn_tracking.training.ml.GraphConstructionKNNScanner(ks: list[int] = _DEFAULT_KS, *, targets=(0.8, 0.85, 0.88, 0.9, 0.93, 0.95, 0.97, 0.99), max_radius=1.0, pt_thld=0.9, max_eta=4.0, subsample_pids: int | None = None, max_edges=5000000)#

Bases: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Scan over different values of k to build a graph and calculate the figures of merit.

Parameters:
  • ks – List of ks to scan. Results will be interpolated between these values, so it’s a good idea to not make them too dense.

  • targets – Targets for the 50%-segment fraction that we aim for (we will find the k that gets us closest to these targets and report the number of edges for these places). Does not impact compute time.

  • max_radius – Maximum length of edges for the KNN graph.

  • pt_thld – pt threshold for evaluation of the 50%-segment fraction.

  • subsample_pids – Set to a number to subsample the number of pids in the evaluation of the 50%-segment fraction. This is useful for speeding up the evaluation of the 50%-segment fraction, but it will lead to a less accurate result/statistical fluctuations.

  • max_edges – Do not attempt to compute metrics for more than this number of edges in the knn graph

property results_raw: pandas.DataFrame#

DataFrame with raw results for all graphs and all k

get_results() KScanResults#

Get results object

get_foms() dict[str, float]#

Get figures of merit (convenience method that uses the appropriate method of KSCanResults).

reset()#

Reset the results. Will be automatically called every time we run on a batch with i_batch == 0.

__call__(data: torch_geometric.data.Data, i_batch: int, *, progress=False, latent: torch.Tensor | None = None) None#

Run on graph

Parameters:
  • data – Data object. data.x is the space used for clustering

  • i_batch – Batch number. Will reset saved data for i_batch == 0.

  • progress – Show progress bar

  • latent – Use this instead of data.x

Returns:

None

_evaluate_tracking_metrics_upper_bounds(data: torch_geometric.data.Data) dict[str, float]#

Evaluate upper bounds of tracking metrics assuming a pipeline with perfect EC. See https://arxiv.org/abs/2309.16754

_evaluate_graph(data: torch_geometric.data.Data, k: int) dict[str, float] | None#

Evaluate metrics for single graphs

Parameters:
  • data

  • k

Returns:

None if computation was aborted

class gnn_tracking.training.ml.MultiLossFct#

Bases: torch.nn.Module

Base class for loss functions that return multiple losses.

forward(*args: Any, **kwargs: Any) MultiLossFctReturn#
class gnn_tracking.training.ml.GraphConstructionHingeEmbeddingLoss(*, lw_repulsive: float = 1.0, r_emb: float = 1.0, max_num_neighbors: int = 256, pt_thld: float = 0.9, max_eta: float = 4.0, p_attr: float = 1.0, p_rep: float = 1.0, rep_normalization: str = 'n_hits_oi', rep_oi_only: bool = True)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Loss for graph construction using metric learning.

Parameters:
  • lw_repulsive – Loss weight for repulsive part of potential loss

  • r_emb – Radius for edge construction

  • max_num_neighbors – Maximum number of neighbors in radius graph building. See rusty1s/pytorch_cluster

  • pt_thld – pt threshold for particles of interest

  • max_eta – maximum eta for particles of interest

  • p_attr – Power for the attraction term (default 1: linear loss)

  • p_rep – Power for the repulsion term (default 1: linear loss)

  • normalization – Normalization for the repulsive term. Can be either “n_rep_edges” (normalizes by the number of repulsive edges < r_emb) or “n_hits_oi” (normalizes by the number of hits of interest) or “n_att_edges” (normalizes by the number of attractive edges of interest)

  • rep_oi_only – Only consider repulsion between hits if at least one of the hits is of interest

_get_edges(*, x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, mask: torch.Tensor, particle_id: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#

Returns edge index for graph

forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, eta: torch.Tensor, reconstructable: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn#
class gnn_tracking.training.ml.TrackingModule(model: torch.nn.Module, *, optimizer: pytorch_lightning.cli.OptimizerCallable = torch.optim.Adam, scheduler: pytorch_lightning.cli.LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, preproc: torch.nn.Module | None = None)#

Bases: ImprovedLogLM

Base class for all pytorch lightning modules in this project.

forward(data: torch_geometric.data.Data, _preprocessed=False) torch.Tensor | dict[str, torch.Tensor]#
data_preproc(data) torch_geometric.data.Data#
configure_optimizers() Any#
backward(*args: Any, **kwargs: Any) None#
gnn_tracking.training.ml.add_key_suffix(dct: dict[str, _P], suffix: str = '') dict[str, _P]#

Return a copy of the dictionary with the suffix added to all keys.

gnn_tracking.training.ml.to_floats(inpt: Any) Any#

Convert all tensors in a datastructure to floats. Works on single tensors, lists, or dictionaries, nested or not.

gnn_tracking.training.ml.obj_from_or_to_hparams(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: Any) Any#

Used to support initializing python objects from hyperparameters: If obj is a python object other than a dictionary, its hyperparameters are saved (its class path and init args) to self.hparams[key]. If obj is instead a dictionary, its assumed that we have to restore an object based on this information.

gnn_tracking.training.ml.tolerate_some_oom_errors(fct: Callable)#

Decorators to tolerate a couple of out of memory (OOM) errors.

class gnn_tracking.training.ml.MLModule(*, loss_fct: gnn_tracking.metrics.losses.MultiLossFct, gc_scanner: gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner | None = None, **kwargs)#

Bases: gnn_tracking.training.base.TrackingModule

Pytorch lightning module with training and validation step for the metric learning approach to graph construction.

get_losses(out: dict[str, Any], data: torch_geometric.data.Data) tuple[torch.Tensor, dict[str, float]]#
training_step(batch: torch_geometric.data.Data, batch_idx: int) torch.Tensor | None#
validation_step(batch: torch_geometric.data.Data, batch_idx: int)#
on_validation_epoch_end() None#
highlight_metric(metric: str) bool#