gnn_tracking.training.callbacks#

Module Contents#

Classes#

PrintValidationMetrics

This callback prints the validation metrics after every epoch.

ExpandWandbConfig

This callback adds more information to Weights & Biases hyperparameters:

Functions#

format_results_table(→ rich.table.Table)

Format a dictionary of results as a rich table.

gnn_tracking.training.callbacks.format_results_table(results: dict[str, float], *, header: str = '', printed_results_filter: Callable[[str], bool] | None = None, highlight_metric: Callable[[str], bool] | None = None) rich.table.Table#

Format a dictionary of results as a rich table.

Parameters:
  • results – Dictionary of results

  • header – Header to prepend to the log message

  • printed_results_filter – Function that takes a metric name and returns whether it should be printed in the log output. If None: Print everything

  • highlight_metric – Function that takes a metric name and returns whether it should be highlighted in the log output. If None: Don’t highlight anything

Returns:

Rich table

class gnn_tracking.training.callbacks.PrintValidationMetrics#

Bases: pytorch_lightning.Callback

This callback prints the validation metrics after every epoch.

If the lightning module has a printed_results_filter attribute, only metrics for which this function returns True are printed. If the lightning module has a highlight_metric attribute, the metric returned by this function is highlighted in the output.

on_validation_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None#
class gnn_tracking.training.callbacks.ExpandWandbConfig#

Bases: pytorch_lightning.Callback

This callback adds more information to Weights & Biases hyperparameters:

  • Information about the optimizer/scheduler.

  • Information from the datamodule

  • Information about version numbers and git hashes

  • SLURM job id (if set)

This will also avoid problems where hyperparameters are not synced at the beginning (but only at the end, in particular failing to save them if the run is interrupted).

_find_loggers(trainer: pytorch_lightning.Trainer) None#
_get_config() dict#
on_train_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None#