gnn_tracking.training.tc#

Lightning module for object condensation training.

Module Contents#

Classes#

TCModule

Object condensation for tracks. This lightning module implements

class gnn_tracking.training.tc.TCModule(*, loss_fct: gnn_tracking.metrics.losses.MultiLossFct, cluster_scanner: gnn_tracking.postprocessing.clusterscanner.ClusterScanner | None = None, **kwargs)#

Bases: gnn_tracking.training.base.TrackingModule

Object condensation for tracks. This lightning module implements losses, training, and validation steps. k:w

Parameters:
  • loss_fct

  • cluster_scanner

  • **kwargs – Passed on to TrackingModule

is_last_val_batch(batch_idx: int) bool#

Are we validating the last batch of the validation set?

get_losses(out: dict[str, Any], data: torch_geometric.data.Data) tuple[torch.Tensor, dict[str, float]]#
training_step(data: torch_geometric.data.Data, batch_idx: int) torch.Tensor#
validation_step(data: torch_geometric.data.Data, batch_idx: int) None#
_evaluate_cluster_metrics(out: dict[str, Any], data: torch_geometric.data.Data, batch_idx: int) dict[str, float]#

Evaluate cluster metrics.

highlight_metric(metric: str) bool#