gnn_tracking.graph_construction.data_transformer#

Module Contents#

Classes#

DataTransformer

Applies a transformation function to all data files and saves

ECCut

Applies a cut to the edge classifier output and saves the trimmed down

ECCutRefine

Similar to ECCut, but assumes that the edge classifier output is in a

class gnn_tracking.graph_construction.data_transformer.DataTransformer(transform: torch.nn.Module)#

Applies a transformation function to all data files and saves them on disk.

process(filename: str, *, input_dir: os.PathLike | str, output_dir: os.PathLike | str, redo: bool = True) None#

Process single file

_save_hparams(input_dir: pathlib.Path, output_dir: pathlib.Path) None#

Save hyperparameters to disk

process_directories(input_dirs: list[os.PathLike | str], output_dirs: list[os.PathLike | str], *, redo=True, max_processes=1, chunk_size=1, start=0, n_files=0, seed=None) None#

Process all files in the input directories and save them to the output directories.

Parameters:
  • input_dirs

  • output_dirs

  • redo – If True, overwrite existing files

  • max_processes – Maximum number of processes to use

  • chunk_size – Number of files to process in one batch for multiprocessing

  • start – Index of first file to process

  • n_files – Number of files to process. If 0, process all files from start on

  • seed – Seed for shuffling of input files. If None, no shuffling. Shuffling with redo=False can help to submit more worker jobs later on for faster processing.

Returns:

None

class gnn_tracking.graph_construction.data_transformer.ECCut(ec: torch.nn.Module, thld: float)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Applies a cut to the edge classifier output and saves the trimmed down graphs.

Args:

forward(data) torch_geometric.data.Data#
class gnn_tracking.graph_construction.data_transformer.ECCutRefine(thld: float, name='ec_score')#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Similar to ECCut, but assumes that the edge classifier output is in a field named

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#