gnn_tracking.graph_construction.data_transformer
#
Module Contents#
Classes#
Applies a transformation function to all data files and saves |
|
Applies a cut to the edge classifier output and saves the trimmed down |
|
Similar to ECCut, but assumes that the edge classifier output is in a |
- class gnn_tracking.graph_construction.data_transformer.DataTransformer(transform: torch.nn.Module)#
Applies a transformation function to all data files and saves them on disk.
- process(filename: str, *, input_dir: os.PathLike | str, output_dir: os.PathLike | str, redo: bool = True) None #
Process single file
- _save_hparams(input_dir: pathlib.Path, output_dir: pathlib.Path) None #
Save hyperparameters to disk
- process_directories(input_dirs: list[os.PathLike | str], output_dirs: list[os.PathLike | str], *, redo=True, max_processes=1, chunk_size=1, start=0, n_files=0, seed=None) None #
Process all files in the input directories and save them to the output directories.
- Parameters:
input_dirs –
output_dirs –
redo – If True, overwrite existing files
max_processes – Maximum number of processes to use
chunk_size – Number of files to process in one batch for multiprocessing
start – Index of first file to process
n_files – Number of files to process. If 0, process all files from start on
seed – Seed for shuffling of input files. If None, no shuffling. Shuffling with redo=False can help to submit more worker jobs later on for faster processing.
- Returns:
None
- class gnn_tracking.graph_construction.data_transformer.ECCut(ec: torch.nn.Module, thld: float)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
Applies a cut to the edge classifier output and saves the trimmed down graphs.
Args:
- forward(data) torch_geometric.data.Data #
- class gnn_tracking.graph_construction.data_transformer.ECCutRefine(thld: float, name='ec_score')#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
Similar to ECCut, but assumes that the edge classifier output is in a field named
- forward(data: torch_geometric.data.Data) torch_geometric.data.Data #