gnn_tracking.training.base#

Base class used for all pytorch lightning modules.

Module Contents#

Classes#

ImprovedLogLM

This subclass of LightningModule adds some convenience to logging,

TrackingModule

Base class for all pytorch lightning modules in this project.

class gnn_tracking.training.base.ImprovedLogLM(**kwargs)#

Bases: pytorch_lightning.LightningModule

This subclass of LightningModule adds some convenience to logging, e.g., logging of statistical uncertainties (batch-to-batch) and logging of the validation metrics to the console after each validation epoch.

log_dict_with_errors(dct: dict[str, float], batch_size=None) None#

Log a dictionary of values with their statistical uncertainties.

This method only starts calculating the uncertainties. To log them, _log_errors needs to be called at the end of the train/val/test epoch (done with the hooks configured in this class).

_log_errors() None#

Log the uncertainties calculated in log_dict_with_errors. Needs to be called at the end of the train/val/test epoch.

on_train_epoch_end(*args) None#
on_validation_epoch_end() None#
on_test_epoch_end() None#
class gnn_tracking.training.base.TrackingModule(model: torch.nn.Module, *, optimizer: pytorch_lightning.cli.OptimizerCallable = torch.optim.Adam, scheduler: pytorch_lightning.cli.LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, preproc: torch.nn.Module | None = None)#

Bases: ImprovedLogLM

Base class for all pytorch lightning modules in this project.

forward(data: torch_geometric.data.Data, _preprocessed=False) torch.Tensor | dict[str, torch.Tensor]#
data_preproc(data) torch_geometric.data.Data#
configure_optimizers() Any#
backward(*args: Any, **kwargs: Any) None#