gnn_tracking.training.tc
#
Lightning module for object condensation training.
Module Contents#
Classes#
Object condensation for tracks. This lightning module implements |
- class gnn_tracking.training.tc.TCModule(*, potential_loss: gnn_tracking.metrics.losses.PotentialLoss = PotentialLoss(), background_loss: gnn_tracking.metrics.losses.BackgroundLoss | None = BackgroundLoss(), cluster_scanner: gnn_tracking.postprocessing.clusterscanner.ClusterScanner | None = None, lw_repulsive: float = 1.0, lw_background: float = 1.0, **kwargs)#
Bases:
gnn_tracking.training.base.TrackingModule
Object condensation for tracks. This lightning module implements losses, training, and validation steps. k:w
- Parameters:
potential_loss –
background_loss –
cluster_scanner –
lw_repulsive – Loss weight for repulsive part of potential loss
lw_background – Loss weight for background loss
**kwargs – Passed on to TrackingModule
- _validate_settings()#
Check that settings make sense and warn/raise exceptions otherwise.
- is_last_val_batch(batch_idx: int) bool #
Are we validating the last batch of the validation set?
- get_losses(out: dict[str, Any], data: torch_geometric.data.Data) tuple[torch.Tensor, dict[str, float]] #
- training_step(data: torch_geometric.data.Data, batch_idx: int) torch.Tensor #
- validation_step(data: torch_geometric.data.Data, batch_idx: int) None #
- _evaluate_cluster_metrics(out: dict[str, Any], data: torch_geometric.data.Data, batch_idx: int) dict[str, float] #
Evaluate cluster metrics.
- highlight_metric(metric: str) bool #