gnn_tracking.training.ec#

Lightning module for edge classifier training.

Classes#

TrackingModule

Base class for all pytorch lightning modules in this project.

ECModule

Lightning module for edge classifier training.

Functions#

get_maximized_bcs(→ dict[str, float])

Calculate the best possible binary classification stats for a given output and y.

get_roc_auc_scores(true, predicted, max_fprs)

Calculate ROC AUC scores for a given set of maximum FPRs.

obj_from_or_to_hparams(→ Any)

Used to support initializing python objects from hyperparameters:

denote_pt(→ Any)

Append suffix to designate pt threshold.

tolerate_some_oom_errors(fct)

Decorators to tolerate a couple of out of memory (OOM) errors.

Module Contents#

gnn_tracking.training.ec.get_maximized_bcs(*, output: torch.Tensor, y: torch.Tensor, n_samples=200) dict[str, float]#

Calculate the best possible binary classification stats for a given output and y.

Parameters:
  • output – Weights

  • y – True

  • n_samples – Number of thresholds to sample

Returns:

Dictionary of metrics

gnn_tracking.training.ec.get_roc_auc_scores(true, predicted, max_fprs: Iterable[float | None])#

Calculate ROC AUC scores for a given set of maximum FPRs.

class gnn_tracking.training.ec.TrackingModule(model: torch.nn.Module, *, optimizer: pytorch_lightning.cli.OptimizerCallable = torch.optim.Adam, scheduler: pytorch_lightning.cli.LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, preproc: torch.nn.Module | None = None)#

Bases: ImprovedLogLM

Base class for all pytorch lightning modules in this project.

forward(data: torch_geometric.data.Data, _preprocessed=False) torch.Tensor | dict[str, torch.Tensor]#
data_preproc(data) torch_geometric.data.Data#
configure_optimizers() Any#
backward(*args: Any, **kwargs: Any) None#
gnn_tracking.training.ec.obj_from_or_to_hparams(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: Any) Any#

Used to support initializing python objects from hyperparameters: If obj is a python object other than a dictionary, its hyperparameters are saved (its class path and init args) to self.hparams[key]. If obj is instead a dictionary, its assumed that we have to restore an object based on this information.

gnn_tracking.training.ec.denote_pt(inpt, pt_min=0.0) Any#

Append suffix to designate pt threshold. If string is given, return string. If dict is given, modify all keys.

gnn_tracking.training.ec.tolerate_some_oom_errors(fct: Callable)#

Decorators to tolerate a couple of out of memory (OOM) errors.

class gnn_tracking.training.ec.ECModule(*, loss_fct: torch.nn.Module, **kwargs)#

Bases: gnn_tracking.training.base.TrackingModule

Lightning module for edge classifier training.

get_losses(out: dict[str, Any], data: torch_geometric.data.Data) torch.Tensor#
training_step(batch: torch_geometric.data.Data, batch_idx: int) torch.Tensor | None#
validation_step(batch: torch_geometric.data.Data, batch_idx: int)#
highlight_metric(metric: str) bool#