Lightning module for object condensation training.

Module Contents#



Object condensation for tracks. This lightning module implements

class*, potential_loss: gnn_tracking.metrics.losses.PotentialLoss = PotentialLoss(), background_loss: gnn_tracking.metrics.losses.BackgroundLoss | None = BackgroundLoss(), cluster_scanner: gnn_tracking.postprocessing.clusterscanner.ClusterScanner | None = None, lw_repulsive: float = 1.0, lw_background: float = 1.0, **kwargs)#


Object condensation for tracks. This lightning module implements losses, training, and validation steps. k:w

  • potential_loss

  • background_loss

  • cluster_scanner

  • lw_repulsive – Loss weight for repulsive part of potential loss

  • lw_background – Loss weight for background loss

  • **kwargs – Passed on to TrackingModule


Check that settings make sense and warn/raise exceptions otherwise.

is_last_val_batch(batch_idx: int) bool#

Are we validating the last batch of the validation set?

get_losses(out: dict[str, Any], data: tuple[torch.Tensor, dict[str, float]]#
training_step(data:, batch_idx: int) torch.Tensor#
validation_step(data:, batch_idx: int) None#
_evaluate_cluster_metrics(out: dict[str, Any], data:, batch_idx: int) dict[str, float]#

Evaluate cluster metrics.

highlight_metric(metric: str) bool#