gnn_tracking.training.ec#
Lightning module for edge classifier training.
Classes#
Lightning module for edge classifier training. |
Module Contents#
- class gnn_tracking.training.ec.ECModule(*, loss_fct: torch.nn.Module, **kwargs)#
Bases:
gnn_tracking.training.base.TrackingModule
Lightning module for edge classifier training.
- loss_fct#
- get_losses(out: dict[str, Any], data: torch_geometric.data.Data) torch.Tensor #
- training_step(batch: torch_geometric.data.Data, batch_idx: int) torch.Tensor | None #
- validation_step(batch: torch_geometric.data.Data, batch_idx: int)#
- highlight_metric(metric: str) bool #