gnn_tracking.training.callbacks#
Classes#
This callback prints the validation metrics after every epoch. |
|
This callback adds more information to Weights & Biases hyperparameters: |
Functions#
|
Format a dictionary of results as a rich table. |
Module Contents#
- gnn_tracking.training.callbacks.format_results_table(results: dict[str, float], *, header: str = '', printed_results_filter: Callable[[str], bool] | None = None, highlight_metric: Callable[[str], bool] | None = None) rich.table.Table #
Format a dictionary of results as a rich table.
- Parameters:
results – Dictionary of results
header – Header to prepend to the log message
printed_results_filter – Function that takes a metric name and returns whether it should be printed in the log output. If None: Print everything
highlight_metric – Function that takes a metric name and returns whether it should be highlighted in the log output. If None: Don’t highlight anything
- Returns:
Rich table
- class gnn_tracking.training.callbacks.PrintValidationMetrics#
Bases:
pytorch_lightning.Callback
This callback prints the validation metrics after every epoch.
If the lightning module has a printed_results_filter attribute, only metrics for which this function returns True are printed. If the lightning module has a highlight_metric attribute, the metric returned by this function is highlighted in the output.
- on_validation_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None #
- class gnn_tracking.training.callbacks.ExpandWandbConfig#
Bases:
pytorch_lightning.Callback
This callback adds more information to Weights & Biases hyperparameters:
Information about the optimizer/scheduler.
Information from the datamodule
Information about version numbers and git hashes
SLURM job id (if set)
This will also avoid problems where hyperparameters are not synced at the beginning (but only at the end, in particular failing to save them if the run is interrupted).
- _wandb_logger = None#
- _tensorboard_logger = None#
- _find_loggers(trainer: pytorch_lightning.Trainer) None #
- _get_config() dict #
- on_train_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None #