gnn_tracking.training.tc#

Lightning module for object condensation training.

Module Contents#

Classes#

TCModule

Object condensation for tracks. This lightning module implements

class gnn_tracking.training.tc.TCModule(*, potential_loss: gnn_tracking.metrics.losses.PotentialLoss = PotentialLoss(), background_loss: gnn_tracking.metrics.losses.BackgroundLoss | None = BackgroundLoss(), cluster_scanner: gnn_tracking.postprocessing.clusterscanner.ClusterScanner | None = None, lw_repulsive: float = 1.0, lw_background: float = 1.0, **kwargs)#

Bases: gnn_tracking.training.base.TrackingModule

Object condensation for tracks. This lightning module implements losses, training, and validation steps. k:w

Parameters:
  • potential_loss

  • background_loss

  • cluster_scanner

  • lw_repulsive – Loss weight for repulsive part of potential loss

  • lw_background – Loss weight for background loss

  • **kwargs – Passed on to TrackingModule

_validate_settings()#

Check that settings make sense and warn/raise exceptions otherwise.

is_last_val_batch(batch_idx: int) bool#

Are we validating the last batch of the validation set?

get_losses(out: dict[str, Any], data: torch_geometric.data.Data) tuple[torch.Tensor, dict[str, float]]#
training_step(data: torch_geometric.data.Data, batch_idx: int) torch.Tensor#
validation_step(data: torch_geometric.data.Data, batch_idx: int) None#
_evaluate_cluster_metrics(out: dict[str, Any], data: torch_geometric.data.Data, batch_idx: int) dict[str, float]#

Evaluate cluster metrics.

highlight_metric(metric: str) bool#