gnn_tracking.training.ml#
Pytorch lightning module with training and validation step for the metric learning approach to graph construction.
Classes#
Pytorch lightning module with training and validation step for the metric |
Module Contents#
- class gnn_tracking.training.ml.MLModule(*, loss_fct: gnn_tracking.metrics.losses.MultiLossFct, gc_scanner: gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner | None = None, **kwargs)#
Bases:
gnn_tracking.training.base.TrackingModule
Pytorch lightning module with training and validation step for the metric learning approach to graph construction.
- gc_scanner#
- get_losses(out: dict[str, Any], data: torch_geometric.data.Data) tuple[torch.Tensor, dict[str, float]] #
- training_step(batch: torch_geometric.data.Data, batch_idx: int) torch.Tensor | None #
- validation_step(batch: torch_geometric.data.Data, batch_idx: int)#
- on_validation_epoch_end() None #
- highlight_metric(metric: str) bool #