gnn_tracking.training.ml#

Pytorch lightning module with training and validation step for the metric learning approach to graph construction.

Classes#

MLModule

Pytorch lightning module with training and validation step for the metric

Module Contents#

class gnn_tracking.training.ml.MLModule(*, loss_fct: gnn_tracking.metrics.losses.MultiLossFct, gc_scanner: gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner | None = None, **kwargs)#

Bases: gnn_tracking.training.base.TrackingModule

Pytorch lightning module with training and validation step for the metric learning approach to graph construction.

loss_fct: gnn_tracking.metrics.losses.metric_learning.GraphConstructionHingeEmbeddingLoss#
gc_scanner#
get_losses(out: dict[str, Any], data: torch_geometric.data.Data) tuple[torch.Tensor, dict[str, float]]#
training_step(batch: torch_geometric.data.Data, batch_idx: int) torch.Tensor | None#
validation_step(batch: torch_geometric.data.Data, batch_idx: int)#
on_validation_epoch_end() None#
highlight_metric(metric: str) bool#