:py:mod:`gnn_tracking.graph_construction.data_transformer`
==========================================================

.. py:module:: gnn_tracking.graph_construction.data_transformer


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.graph_construction.data_transformer.DataTransformer
   gnn_tracking.graph_construction.data_transformer.ECCut
   gnn_tracking.graph_construction.data_transformer.ECCutRefine




.. py:class:: DataTransformer(transform: torch.nn.Module)


   Applies a transformation function to all data files and saves
   them on disk.

   .. py:method:: process(filename: str, *, input_dir: os.PathLike | str, output_dir: os.PathLike | str, redo: bool = True) -> None

      Process single file


   .. py:method:: _save_hparams(input_dir: pathlib.Path, output_dir: pathlib.Path) -> None

      Save hyperparameters to disk


   .. py:method:: process_directories(input_dirs: list[os.PathLike | str], output_dirs: list[os.PathLike | str], *, redo=True, max_processes=1, chunk_size=1, start=0, n_files=0, seed=None) -> None

      Process all files in the input directories and save them to the output
      directories.

      :param input_dirs:
      :param output_dirs:
      :param redo: If True, overwrite existing files
      :param max_processes: Maximum number of processes to use
      :param chunk_size: Number of files to process in one batch for multiprocessing
      :param start: Index of first file to process
      :param n_files: Number of files to process. If 0, process all files from `start`
                      on
      :param seed: Seed for shuffling of input files. If None, no shuffling. Shuffling
                   with `redo=False` can help to submit more worker jobs later on for
                   faster processing.

      :returns: None



.. py:class:: ECCut(ec: torch.nn.Module, thld: float)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Applies a cut to the edge classifier output and saves the trimmed down
   graphs.

   Args:

   .. py:method:: forward(data) -> torch_geometric.data.Data



.. py:class:: ECCutRefine(thld: float, name='ec_score')


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Similar to `ECCut`, but assumes that the edge classifier output is in a
   field named

   .. py:method:: forward(data: torch_geometric.data.Data) -> torch_geometric.data.Data



