:py:mod:`gnn_tracking.utils.loading`
====================================

.. py:module:: gnn_tracking.utils.loading


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.utils.loading.TrackingDataset
   gnn_tracking.utils.loading.TrackingDataModule
   gnn_tracking.utils.loading.TestTrackingDataModule




.. py:class:: TrackingDataset(in_dir: str | os.PathLike | list[str] | list[os.PathLike], *, start=0, stop=None, sector: int | None = None)


   Bases: :py:obj:`torch_geometric.data.Dataset`

   Dataset for tracking applications

   :param in_dir: Directory or list of directories containing the data files
   :param start: Index of the first file to be considered (with files from the
                 in dirs considered in order)
   :param stop: Index of the last file to be considered
   :param sector: If not None, only files with this sector number will be considered

   .. py:method:: _get_paths(in_dir: str | os.PathLike | list[str] | list[os.PathLike], *, start=0, stop: int | None = None, sector: int | None = None) -> list[pathlib.Path]
      :staticmethod:

      Collect all paths that should be in this dataset.


   .. py:method:: len() -> int


   .. py:method:: get(idx: int) -> torch_geometric.data.Data



.. py:class:: TrackingDataModule(*, identifier: str, train: dict | None = None, val: dict | None = None, test: dict | None = None, cpus: int = 1)


   Bases: :py:obj:`pytorch_lightning.LightningDataModule`

   This subclass of `LightningDataModule` configures all data for the
   ML pipeline.

   :param identifier: Identifier of the dataset (e.g., `graph_v5`)
   :param train: Config dictionary for training data (see below)
   :param val: Config dictionary for validation data (see below)
   :param test: Config dictionary for test data (see below)
   :param cpus: Number of CPUs to use for loading data.

   The following keys are available for each config dictionary:

   - `dirs`: List of dirs to load from (required)
   - `start=0`: Index of first file to load
   - `stop=None`: Index of last file to load
   - `sector=None`: Sector to load from (if None, load all sectors)
   - `batch_size=1`: Batch size

   Training has the following additional keys:

   - `sample_size=None`: Number of samples to load for each epoch
       (if None, load all samples)

   .. py:property:: datasets
      :type: dict[str, TrackingDataset]


   .. py:method:: _fix_datatypes(dct: dict[str, Any] | None) -> dict[str, Any] | None
      :staticmethod:

      Fix datatypes of config dictionary.
      This is necessary because when configuring values from the command line,
      all values might be strings.


   .. py:method:: _get_dataset(key: str) -> TrackingDataset


   .. py:method:: setup(stage: str) -> None


   .. py:method:: _get_dataloader(key: str) -> torch_geometric.loader.DataLoader


   .. py:method:: train_dataloader()


   .. py:method:: val_dataloader()


   .. py:method:: test_dataloader()



.. py:class:: TestTrackingDataModule(graphs: list[torch_geometric.data.Data])


   Bases: :py:obj:`pytorch_lightning.LightningDataModule`

   Version of `TrackingDataLoader` only used for testing purposes.

   .. py:method:: setup(stage: str) -> None


   .. py:method:: train_dataloader()


   .. py:method:: val_dataloader()


   .. py:method:: test_dataloader()



