gnn_tracking.utils.lightning#

Module Contents#

Classes#

StandardError

A torch metric that computes the standard error.

SimpleTqdmProgressBar

Fallback progress bar that creates a new tqdm bar for each epoch.

Functions#

save_sub_hyperparameters(→ None)

Take hyperparameters from obj and save them to self under the

load_obj_from_hparams(→ Any)

Load object from hyperparameters.

obj_from_or_to_hparams(→ Any)

Used to support initializing python objects from hyperparameters:

get_object_from_path(→ Any)

Get object from path (string) to its code location.

get_lightning_module(...)

Get model (specified by class_path, a string) and

get_model(→ torch.nn.Module | None)

Get torch model (specified by class_path, a string) and load a checkpoint.

find_latest_checkpoint(→ pathlib.Path)

Find latest lightning checkpoint

gnn_tracking.utils.lightning.save_sub_hyperparameters(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin | dict, errors: Literal[warn, raise] = 'warn') None#

Take hyperparameters from obj and save them to self under the key key.

Parameters:
  • self – The object to save the hyperparameters to.

  • key – The key under which to save the hyperparameters.

  • obj – The object to take the hyperparameters from.

  • errors – Whether to raise an error or just warn

gnn_tracking.utils.lightning.load_obj_from_hparams(hparams: dict[str, Any], key: str = '') Any#

Load object from hyperparameters.

gnn_tracking.utils.lightning.obj_from_or_to_hparams(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: Any) Any#

Used to support initializing python objects from hyperparameters: If obj is a python object other than a dictionary, its hyperparameters are saved (its class path and init args) to self.hparams[key]. If obj is instead a dictionary, its assumed that we have to restore an object based on this information.

gnn_tracking.utils.lightning.get_object_from_path(path: str, init_args: dict[str, Any] | None = None) Any#

Get object from path (string) to its code location.

gnn_tracking.utils.lightning.get_lightning_module(class_path: str, chkpt_path: str = '', *, freeze: bool = True, device: str | None = None) pytorch_lightning.LightningModule | None#

Get model (specified by class_path, a string) and load a checkpoint.

gnn_tracking.utils.lightning.get_model(class_path: str, chkpt_path: str = '', freeze: bool = False, whole_module: bool = False, device: None | str = None) torch.nn.Module | None#

Get torch model (specified by class_path, a string) and load a checkpoint. Uses get_lightning_module to get the model.

Parameters:
  • class_path – The path to the lightning module class

  • chkpt_path – The path to the checkpoint. If no checkpoint is specified, we return None.

  • freeze – Whether to freeze the model

  • whole_module – Whether to return the whole lightning module or just the model

  • device

class gnn_tracking.utils.lightning.StandardError#

Bases: torchmetrics.Metric

A torch metric that computes the standard error. This is necessary, because LightningModule.log doesn’t take custom reduce functions.

update(x: torch.Tensor)#
compute()#
class gnn_tracking.utils.lightning.SimpleTqdmProgressBar#

Bases: pytorch_lightning.callbacks.ProgressBar

Fallback progress bar that creates a new tqdm bar for each epoch. Adapted from Lightning-AI/lightning#2189 , reply Lightning-AI/lightning#2189

property is_enabled#
on_train_epoch_start(trainer, pl_module)#
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)#
on_validation_epoch_end(trainer, pl_module) None#
disable()#
gnn_tracking.utils.lightning.find_latest_checkpoint(log_dir: os.PathLike, trial_name: str = '') pathlib.Path#

Find latest lightning checkpoint

Parameters:
  • log_dir (os.PathLike) – Path to the directory of your trial or to the directory of the experiment if trial_name is specified.

  • trial_name (str, optional) – Name of the trial if log_dir is the directory of the experiment.