gnn_tracking.utils.loading
#
Module Contents#
Classes#
Dataset for tracking applications |
|
This subclass of LightningDataModule configures all data for the |
|
Version of TrackingDataLoader only used for testing purposes. |
- class gnn_tracking.utils.loading.TrackingDataset(in_dir: str | os.PathLike | list[str] | list[os.PathLike], *, start=0, stop=None, sector: int | None = None)#
Bases:
torch_geometric.data.Dataset
Dataset for tracking applications
- Parameters:
in_dir – Directory or list of directories containing the data files
start – Index of the first file to be considered (with files from the in dirs considered in order)
stop – Index of the last file to be considered
sector – If not None, only files with this sector number will be considered
- static _get_paths(in_dir: str | os.PathLike | list[str] | list[os.PathLike], *, start=0, stop: int | None = None, sector: int | None = None) list[pathlib.Path] #
Collect all paths that should be in this dataset.
- len() int #
- get(idx: int) torch_geometric.data.Data #
- class gnn_tracking.utils.loading.TrackingDataModule(*, identifier: str, train: dict | None = None, val: dict | None = None, test: dict | None = None, cpus: int = 1)#
Bases:
pytorch_lightning.LightningDataModule
This subclass of LightningDataModule configures all data for the ML pipeline.
- Parameters:
identifier – Identifier of the dataset (e.g., graph_v5)
train – Config dictionary for training data (see below)
val – Config dictionary for validation data (see below)
test – Config dictionary for test data (see below)
cpus – Number of CPUs to use for loading data.
The following keys are available for each config dictionary:
dirs: List of dirs to load from (required)
start=0: Index of first file to load
stop=None: Index of last file to load
sector=None: Sector to load from (if None, load all sectors)
batch_size=1: Batch size
Training has the following additional keys:
- sample_size=None: Number of samples to load for each epoch
(if None, load all samples)
- property datasets: dict[str, TrackingDataset]#
- static _fix_datatypes(dct: dict[str, Any] | None) dict[str, Any] | None #
Fix datatypes of config dictionary. This is necessary because when configuring values from the command line, all values might be strings.
- _get_dataset(key: str) TrackingDataset #
- setup(stage: str) None #
- _get_dataloader(key: str) torch_geometric.loader.DataLoader #
- train_dataloader()#
- val_dataloader()#
- test_dataloader()#