gnn_tracking.training.base#

Base class used for all pytorch lightning modules.

Classes#

StandardError

A torch metric that computes the standard error.

ImprovedLogLM

This subclass of LightningModule adds some convenience to logging,

TrackingModule

Base class for all pytorch lightning modules in this project.

Functions#

obj_from_or_to_hparams(→ Any)

Used to support initializing python objects from hyperparameters:

get_logger([name, level])

Sets up global logger.

tolerate_some_oom_errors(fct)

Decorators to tolerate a couple of out of memory (OOM) errors.

Module Contents#

class gnn_tracking.training.base.StandardError#

Bases: torchmetrics.Metric

A torch metric that computes the standard error. This is necessary, because LightningModule.log doesn’t take custom reduce functions.

update(x: torch.Tensor)#
compute()#
gnn_tracking.training.base.obj_from_or_to_hparams(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: Any) Any#

Used to support initializing python objects from hyperparameters: If obj is a python object other than a dictionary, its hyperparameters are saved (its class path and init args) to self.hparams[key]. If obj is instead a dictionary, its assumed that we have to restore an object based on this information.

gnn_tracking.training.base.get_logger(name='gnn-tracking', level=LOG_DEFAULT_LEVEL)#

Sets up global logger.

gnn_tracking.training.base.tolerate_some_oom_errors(fct: Callable)#

Decorators to tolerate a couple of out of memory (OOM) errors.

class gnn_tracking.training.base.ImprovedLogLM(**kwargs)#

Bases: pytorch_lightning.LightningModule

This subclass of LightningModule adds some convenience to logging, e.g., logging of statistical uncertainties (batch-to-batch) and logging of the validation metrics to the console after each validation epoch.

log_dict_with_errors(dct: dict[str, float], batch_size=None) None#

Log a dictionary of values with their statistical uncertainties.

This method only starts calculating the uncertainties. To log them, _log_errors needs to be called at the end of the train/val/test epoch (done with the hooks configured in this class).

_log_errors() None#

Log the uncertainties calculated in log_dict_with_errors. Needs to be called at the end of the train/val/test epoch.

on_train_epoch_end(*args) None#
on_validation_epoch_end() None#
on_test_epoch_end() None#
class gnn_tracking.training.base.TrackingModule(model: torch.nn.Module, *, optimizer: pytorch_lightning.cli.OptimizerCallable = torch.optim.Adam, scheduler: pytorch_lightning.cli.LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, preproc: torch.nn.Module | None = None)#

Bases: ImprovedLogLM

Base class for all pytorch lightning modules in this project.

forward(data: torch_geometric.data.Data, _preprocessed=False) torch.Tensor | dict[str, torch.Tensor]#
data_preproc(data) torch_geometric.data.Data#
configure_optimizers() Any#
backward(*args: Any, **kwargs: Any) None#