gnn_tracking.training.tc#

Lightning module for object condensation training.

Classes#

MultiLossFct

Base class for loss functions that return multiple losses.

ClusterScanner

Base class for cluster scanners. Use any of its subclasses.

TrackingModule

Base class for all pytorch lightning modules in this project.

TCModule

Object condensation for tracks. This lightning module implements

Functions#

add_key_suffix(→ dict[str, _P])

Return a copy of the dictionary with the suffix added to all keys.

to_floats(→ Any)

Convert all tensors in a datastructure to floats.

obj_from_or_to_hparams(→ Any)

Used to support initializing python objects from hyperparameters:

tolerate_some_oom_errors(fct)

Decorators to tolerate a couple of out of memory (OOM) errors.

Module Contents#

class gnn_tracking.training.tc.MultiLossFct#

Bases: torch.nn.Module

Base class for loss functions that return multiple losses.

forward(*args: Any, **kwargs: Any) MultiLossFctReturn#
class gnn_tracking.training.tc.ClusterScanner(*args, **kwargs)#

Bases: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, abc.ABC

Base class for cluster scanners. Use any of its subclasses.

abstract __call__(data: torch_geometric.data.Data, out: dict[str, torch.Tensor], i_batch: int) None#
reset() None#
get_foms() dict[str, Any]#
class gnn_tracking.training.tc.TrackingModule(model: torch.nn.Module, *, optimizer: pytorch_lightning.cli.OptimizerCallable = torch.optim.Adam, scheduler: pytorch_lightning.cli.LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, preproc: torch.nn.Module | None = None)#

Bases: ImprovedLogLM

Base class for all pytorch lightning modules in this project.

forward(data: torch_geometric.data.Data, _preprocessed=False) torch.Tensor | dict[str, torch.Tensor]#
data_preproc(data) torch_geometric.data.Data#
configure_optimizers() Any#
backward(*args: Any, **kwargs: Any) None#
gnn_tracking.training.tc.add_key_suffix(dct: dict[str, _P], suffix: str = '') dict[str, _P]#

Return a copy of the dictionary with the suffix added to all keys.

gnn_tracking.training.tc.to_floats(inpt: Any) Any#

Convert all tensors in a datastructure to floats. Works on single tensors, lists, or dictionaries, nested or not.

gnn_tracking.training.tc.obj_from_or_to_hparams(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: Any) Any#

Used to support initializing python objects from hyperparameters: If obj is a python object other than a dictionary, its hyperparameters are saved (its class path and init args) to self.hparams[key]. If obj is instead a dictionary, its assumed that we have to restore an object based on this information.

gnn_tracking.training.tc.tolerate_some_oom_errors(fct: Callable)#

Decorators to tolerate a couple of out of memory (OOM) errors.

class gnn_tracking.training.tc.TCModule(*, loss_fct: gnn_tracking.metrics.losses.MultiLossFct, cluster_scanner: gnn_tracking.postprocessing.clusterscanner.ClusterScanner | None = None, **kwargs)#

Bases: gnn_tracking.training.base.TrackingModule

Object condensation for tracks. This lightning module implements losses, training, and validation steps. k:w

Parameters:
  • loss_fct

  • cluster_scanner

  • **kwargs – Passed on to TrackingModule

is_last_val_batch(batch_idx: int) bool#

Are we validating the last batch of the validation set?

get_losses(out: dict[str, Any], data: torch_geometric.data.Data) tuple[torch.Tensor, dict[str, float]]#
training_step(data: torch_geometric.data.Data, batch_idx: int) torch.Tensor#
validation_step(data: torch_geometric.data.Data, batch_idx: int) None#
_evaluate_cluster_metrics(out: dict[str, Any], data: torch_geometric.data.Data, batch_idx: int) dict[str, float]#

Evaluate cluster metrics.

highlight_metric(metric: str) bool#