gnn_tracking.utils.loading#
Attributes#
Classes#
Dataset for tracking applications |
|
This subclass of LightningDataModule configures all data for the |
|
Version of TrackingDataLoader only used for testing purposes. |
Module Contents#
- gnn_tracking.utils.loading.DEFAULT_FEATURES = ('r', 'phi', 'z', 'eta_rz', 'u', 'v', 'charge_frac', 'leta', 'lphi', 'lx', 'ly', 'lz', 'geta', 'gphi')#
- class gnn_tracking.utils.loading.TrackingDataset(in_dir: str | os.PathLike | list[str] | list[os.PathLike], *, start=0, stop=None, sector: int | None = None, point_cloud_builder: gnn_tracking.preprocessing.point_cloud_builder.TrackMLPointCloudBuilder | gnn_tracking.preprocessing.point_cloud_builder.CMSPointCloudBuilder | gnn_tracking.preprocessing.point_cloud_builder.MDPointCloudBuilder | None, feature_subset_names: list[str] | None = None, pt_cut: float | None = None)#
Bases:
torch_geometric.data.DatasetDataset for tracking applications
- Parameters:
in_dir – Directory or list of directories containing the data files
start – Index of the first file to be considered (with files from the in dirs considered in order)
stop – Index of the last file to be considered
sector – If not None, only files with this sector number will be considered
- point_cloud_builder#
- _processed_paths = []#
- file_number = 0#
- prev_file_number = -1#
- sector_results = []#
- pt_cut = None#
- _get_paths(in_dir: str | os.PathLike | list[str] | list[os.PathLike], *, start=0, stop: int | None = None, sector: int | None = None) list[pathlib.Path]#
Collect all paths that should be in this dataset.
- len() int#
- get(idx: int) torch_geometric.data.Data#
- _make_pt_cut(data: torch_geometric.data.Data) torch_geometric.data.Data#
- _cut_to_pixel_cms(data: torch_geometric.data.Data) torch_geometric.data.Data#
- class gnn_tracking.utils.loading.TrackingDataModule(*, identifier: str, train: dict | None = None, val: dict | None = None, test: dict | None = None, predict: dict | None = None, cpus: int = 1, builder_params: dict | None = None, feature_subset_names: list[str] | None = None, pt_cut: float | None = None)#
Bases:
pytorch_lightning.LightningDataModuleThis subclass of LightningDataModule configures all data for the ML pipeline.
- Parameters:
identifier – Identifier of the dataset (e.g., graph_v5)
train – Config dictionary for training data (see below)
val – Config dictionary for validation data (see below)
test – Config dictionary for test data (see below)
cpus – Number of CPUs to use for loading data.
The following keys are available for each config dictionary:
dirs: List of dirs to load from (required)
start=0: Index of first file to load
stop=None: Index of last file to load
sector=None: Sector to load from (if None, load all sectors)
batch_size=1: Batch size
Training has the following additional keys:
- sample_size=None: Number of samples to load for each epoch
(if None, load all samples)
- _configs#
- _datasets#
- _cpus = 1#
- builder_params = None#
- feature_subset_names = None#
- pt_cut = None#
- property datasets: dict[str, TrackingDataset]#
- static _fix_datatypes(dct: dict[str, Any] | None) dict[str, Any] | None#
Fix datatypes of config dictionary. This is necessary because when configuring values from the command line, all values might be strings.
- _get_dataset(key: str) TrackingDataset#
- setup(stage: str) None#
- _get_dataloader(key: str) torch_geometric.loader.DataLoader#
- train_dataloader()#
- val_dataloader()#
- test_dataloader()#
- predict_dataloader()#
- class gnn_tracking.utils.loading.TestTrackingDataModule(graphs: list[torch_geometric.data.Data])#
Bases:
pytorch_lightning.LightningDataModuleVersion of TrackingDataLoader only used for testing purposes.
- graphs#
- datasets#
- setup(stage: str) None#
- train_dataloader()#
- val_dataloader()#
- test_dataloader()#