:py:mod:`gnn_tracking.training.callbacks`
=========================================

.. py:module:: gnn_tracking.training.callbacks


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.training.callbacks.PrintValidationMetrics
   gnn_tracking.training.callbacks.ExpandWandbConfig



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.training.callbacks.format_results_table



.. py:function:: format_results_table(results: dict[str, float], *, header: str = '', printed_results_filter: Callable[[str], bool] | None = None, highlight_metric: Callable[[str], bool] | None = None) -> rich.table.Table

   Format a dictionary of results as a rich table.

   :param results: Dictionary of results
   :param header: Header to prepend to the log message
   :param printed_results_filter: Function that takes a metric name and returns
                                  whether it should be printed in the log output.
                                  If None: Print everything
   :param highlight_metric: Function that takes a metric name and returns
                            whether it should be highlighted in the log output.
                            If None: Don't highlight anything

   :returns: Rich table


.. py:class:: PrintValidationMetrics


   Bases: :py:obj:`pytorch_lightning.Callback`

   This callback prints the validation metrics after every epoch.

   If the lightning module has a `printed_results_filter` attribute, only
   metrics for which this function returns True are printed.
   If the lightning module has a `highlight_metric` attribute, the metric
   returned by this function is highlighted in the output.

   .. py:method:: on_validation_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) -> None



.. py:class:: ExpandWandbConfig


   Bases: :py:obj:`pytorch_lightning.Callback`

   This callback adds more information to Weights & Biases hyperparameters:

   * Information about the optimizer/scheduler.
   * Information from the datamodule
   * Information about version numbers and git hashes
   * SLURM job id (if set)

   This will also avoid problems where hyperparameters are not synced at the beginning
   (but only at the end, in particular failing to save them if the run is interrupted).

   .. py:method:: _find_loggers(trainer: pytorch_lightning.Trainer) -> None


   .. py:method:: _get_config() -> dict


   .. py:method:: on_train_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) -> None



