gnn_tracking.training.run#

Classes#

PrintValidationMetrics

This callback prints the validation metrics after every epoch.

TrackingDataModule

This subclass of LightningDataModule configures all data for the

Functions#

cli_main()

Module Contents#

class gnn_tracking.training.run.PrintValidationMetrics#

Bases: pytorch_lightning.Callback

This callback prints the validation metrics after every epoch.

If the lightning module has a printed_results_filter attribute, only metrics for which this function returns True are printed. If the lightning module has a highlight_metric attribute, the metric returned by this function is highlighted in the output.

on_validation_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None#
class gnn_tracking.training.run.TrackingDataModule(*, identifier: str, train: dict | None = None, val: dict | None = None, test: dict | None = None, cpus: int = 1)#

Bases: pytorch_lightning.LightningDataModule

This subclass of LightningDataModule configures all data for the ML pipeline.

Parameters:
  • identifier – Identifier of the dataset (e.g., graph_v5)

  • train – Config dictionary for training data (see below)

  • val – Config dictionary for validation data (see below)

  • test – Config dictionary for test data (see below)

  • cpus – Number of CPUs to use for loading data.

The following keys are available for each config dictionary:

  • dirs: List of dirs to load from (required)

  • start=0: Index of first file to load

  • stop=None: Index of last file to load

  • sector=None: Sector to load from (if None, load all sectors)

  • batch_size=1: Batch size

Training has the following additional keys:

  • sample_size=None: Number of samples to load for each epoch

    (if None, load all samples)

property datasets: dict[str, TrackingDataset]#
static _fix_datatypes(dct: dict[str, Any] | None) dict[str, Any] | None#

Fix datatypes of config dictionary. This is necessary because when configuring values from the command line, all values might be strings.

_get_dataset(key: str) TrackingDataset#
setup(stage: str) None#
_get_dataloader(key: str) torch_geometric.loader.DataLoader#
train_dataloader()#
val_dataloader()#
test_dataloader()#
gnn_tracking.training.run.cli_main()#