gnn_tracking.training.base#
Base class used for all pytorch lightning modules.
Classes#
This subclass of LightningModule adds some convenience to logging, |
|
Base class for all pytorch lightning modules in this project. |
Module Contents#
- class gnn_tracking.training.base.ImprovedLogLM(**kwargs)#
Bases:
pytorch_lightning.LightningModule
This subclass of LightningModule adds some convenience to logging, e.g., logging of statistical uncertainties (batch-to-batch) and logging of the validation metrics to the console after each validation epoch.
- _uncertainties#
- print_validation_results = True#
- log_dict_with_errors(dct: dict[str, float], batch_size=None) None #
Log a dictionary of values with their statistical uncertainties.
This method only starts calculating the uncertainties. To log them, _log_errors needs to be called at the end of the train/val/test epoch (done with the hooks configured in this class).
- _log_errors() None #
Log the uncertainties calculated in log_dict_with_errors. Needs to be called at the end of the train/val/test epoch.
- on_train_epoch_end(*args) None #
- on_validation_epoch_end() None #
- on_test_epoch_end() None #
- class gnn_tracking.training.base.TrackingModule(model: torch.nn.Module, *, optimizer: pytorch_lightning.cli.OptimizerCallable = torch.optim.Adam, scheduler: pytorch_lightning.cli.LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, preproc: torch.nn.Module | None = None)#
Bases:
ImprovedLogLM
Base class for all pytorch lightning modules in this project.
- model#
- logg#
- preproc#
- optimizer#
- scheduler#
- forward(data: torch_geometric.data.Data, _preprocessed=False) torch.Tensor | dict[str, torch.Tensor] #
- data_preproc(data) torch_geometric.data.Data #
- configure_optimizers() Any #
- backward(*args: Any, **kwargs: Any) None #