gnn_tracking.utils.loading#

Module Contents#

Classes#

TrackingDataset

Dataset for tracking applications

TrackingDataModule

This subclass of LightningDataModule configures all data for the

TestTrackingDataModule

Version of TrackingDataLoader only used for testing purposes.

class gnn_tracking.utils.loading.TrackingDataset(in_dir: str | os.PathLike | list[str] | list[os.PathLike], *, start=0, stop=None, sector: int | None = None)#

Bases: torch_geometric.data.Dataset

Dataset for tracking applications

Parameters:
  • in_dir – Directory or list of directories containing the data files

  • start – Index of the first file to be considered (with files from the in dirs considered in order)

  • stop – Index of the last file to be considered

  • sector – If not None, only files with this sector number will be considered

static _get_paths(in_dir: str | os.PathLike | list[str] | list[os.PathLike], *, start=0, stop: int | None = None, sector: int | None = None) list[pathlib.Path]#

Collect all paths that should be in this dataset.

len() int#
get(idx: int) torch_geometric.data.Data#
class gnn_tracking.utils.loading.TrackingDataModule(*, identifier: str, train: dict | None = None, val: dict | None = None, test: dict | None = None, cpus: int = 1)#

Bases: pytorch_lightning.LightningDataModule

This subclass of LightningDataModule configures all data for the ML pipeline.

Parameters:
  • identifier – Identifier of the dataset (e.g., graph_v5)

  • train – Config dictionary for training data (see below)

  • val – Config dictionary for validation data (see below)

  • test – Config dictionary for test data (see below)

  • cpus – Number of CPUs to use for loading data.

The following keys are available for each config dictionary:

  • dirs: List of dirs to load from (required)

  • start=0: Index of first file to load

  • stop=None: Index of last file to load

  • sector=None: Sector to load from (if None, load all sectors)

  • batch_size=1: Batch size

Training has the following additional keys:

  • sample_size=None: Number of samples to load for each epoch

    (if None, load all samples)

property datasets: dict[str, TrackingDataset]#
static _fix_datatypes(dct: dict[str, Any] | None) dict[str, Any] | None#

Fix datatypes of config dictionary. This is necessary because when configuring values from the command line, all values might be strings.

_get_dataset(key: str) TrackingDataset#
setup(stage: str) None#
_get_dataloader(key: str) torch_geometric.loader.DataLoader#
train_dataloader()#
val_dataloader()#
test_dataloader()#
class gnn_tracking.utils.loading.TestTrackingDataModule(graphs: list[torch_geometric.data.Data])#

Bases: pytorch_lightning.LightningDataModule

Version of TrackingDataLoader only used for testing purposes.

setup(stage: str) None#
train_dataloader()#
val_dataloader()#
test_dataloader()#