gnn_tracking.utils.lightning#
Classes#
A torch metric that computes the standard error. |
|
Fallback progress bar that creates a new tqdm bar for each epoch. |
Functions#
|
Take hyperparameters from obj and save them to self under the |
|
Load object from hyperparameters. |
|
Used to support initializing python objects from hyperparameters: |
|
Get object from path (string) to its code location. |
|
Get model (specified by class_path, a string) and |
|
Get torch model (specified by class_path, a string) and load a checkpoint. |
|
Find latest lightning checkpoint |
Module Contents#
- gnn_tracking.utils.lightning.save_sub_hyperparameters(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin | dict, errors: Literal['warn', 'raise'] = 'warn') None #
Take hyperparameters from obj and save them to self under the key key.
- Parameters:
self – The object to save the hyperparameters to.
key – The key under which to save the hyperparameters.
obj – The object to take the hyperparameters from.
errors – Whether to raise an error or just warn
- gnn_tracking.utils.lightning.load_obj_from_hparams(hparams: dict[str, Any], key: str = '') Any #
Load object from hyperparameters.
- gnn_tracking.utils.lightning.obj_from_or_to_hparams(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: Any) Any #
Used to support initializing python objects from hyperparameters: If obj is a python object other than a dictionary, its hyperparameters are saved (its class path and init args) to self.hparams[key]. If obj is instead a dictionary, its assumed that we have to restore an object based on this information.
- gnn_tracking.utils.lightning.get_object_from_path(path: str, init_args: dict[str, Any] | None = None) Any #
Get object from path (string) to its code location.
- gnn_tracking.utils.lightning.get_lightning_module(class_path: str, chkpt_path: str = '', *, freeze: bool = True, device: str | None = None) pytorch_lightning.LightningModule | None #
Get model (specified by class_path, a string) and load a checkpoint.
- gnn_tracking.utils.lightning.get_model(class_path: str, chkpt_path: str = '', freeze: bool = False, whole_module: bool = False, device: None | str = None) torch.nn.Module | None #
Get torch model (specified by class_path, a string) and load a checkpoint. Uses get_lightning_module to get the model.
- Parameters:
class_path – The path to the lightning module class
chkpt_path – The path to the checkpoint. If no checkpoint is specified, we return None.
freeze – Whether to freeze the model
whole_module – Whether to return the whole lightning module or just the model
device
- class gnn_tracking.utils.lightning.StandardError#
Bases:
torchmetrics.Metric
A torch metric that computes the standard error. This is necessary, because LightningModule.log doesn’t take custom reduce functions.
- update(x: torch.Tensor)#
- compute()#
- class gnn_tracking.utils.lightning.SimpleTqdmProgressBar#
Bases:
pytorch_lightning.callbacks.ProgressBar
Fallback progress bar that creates a new tqdm bar for each epoch. Adapted from Lightning-AI/lightning#2189 , reply Lightning-AI/lightning#2189
- bar = None#
- enabled = True#
- property is_enabled#
- on_train_epoch_start(trainer, pl_module)#
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)#
- on_validation_epoch_end(trainer, pl_module) None #
- disable()#
- gnn_tracking.utils.lightning.find_latest_checkpoint(log_dir: os.PathLike, trial_name: str = '') pathlib.Path #
Find latest lightning checkpoint
- Parameters:
log_dir (os.PathLike) – Path to the directory of your trial or to the directory of the experiment if trial_name is specified.
trial_name (str, optional) – Name of the trial if log_dir is the directory of the experiment.