gnn_tracking.training.ec#

Lightning module for edge classifier training.

Module Contents#

Classes#

ECModule

Lightning module for edge classifier training.

class gnn_tracking.training.ec.ECModule(*, loss_fct: torch.nn.Module, **kwargs)#

Bases: gnn_tracking.training.base.TrackingModule

Lightning module for edge classifier training.

get_losses(out: dict[str, Any], data: torch_geometric.data.Data) torch.Tensor#
training_step(batch: torch_geometric.data.Data, batch_idx: int) torch.Tensor | None#
validation_step(batch: torch_geometric.data.Data, batch_idx: int)#
highlight_metric(metric: str) bool#