:py:mod:`gnn_tracking.training.base`
====================================

.. py:module:: gnn_tracking.training.base

.. autoapi-nested-parse::

   Base class used for all pytorch lightning modules.



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.training.base.ImprovedLogLM
   gnn_tracking.training.base.TrackingModule




.. py:class:: ImprovedLogLM(**kwargs)


   Bases: :py:obj:`pytorch_lightning.LightningModule`

   This subclass of `LightningModule` adds some convenience to logging,
   e.g., logging of statistical uncertainties (batch-to-batch) and logging
   of the validation metrics to the console after each validation epoch.

   .. py:method:: log_dict_with_errors(dct: dict[str, float], batch_size=None) -> None

      Log a dictionary of values with their statistical uncertainties.

      This method only starts calculating the uncertainties. To log them,
      `_log_errors` needs to be called at the end of the train/val/test epoch
      (done with the hooks configured in this class).


   .. py:method:: _log_errors() -> None

      Log the uncertainties calculated in `log_dict_with_errors`.
      Needs to be called at the end of the train/val/test epoch.


   .. py:method:: on_train_epoch_end(*args) -> None


   .. py:method:: on_validation_epoch_end() -> None


   .. py:method:: on_test_epoch_end() -> None



.. py:class:: TrackingModule(model: torch.nn.Module, *, optimizer: pytorch_lightning.cli.OptimizerCallable = torch.optim.Adam, scheduler: pytorch_lightning.cli.LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, preproc: torch.nn.Module | None = None)


   Bases: :py:obj:`ImprovedLogLM`

   Base class for all pytorch lightning modules in this project.

   .. py:method:: forward(data: torch_geometric.data.Data, _preprocessed=False) -> torch.Tensor | dict[str, torch.Tensor]


   .. py:method:: data_preproc(data) -> torch_geometric.data.Data


   .. py:method:: configure_optimizers() -> Any


   .. py:method:: backward(*args: Any, **kwargs: Any) -> None



