gnn_tracking.training.tc#
Lightning module for object condensation training.
Classes#
Object condensation for tracks. This lightning module implements |
Module Contents#
- class gnn_tracking.training.tc.TCModule(*, loss_fct: gnn_tracking.metrics.losses.MultiLossFct, cluster_scanner: gnn_tracking.postprocessing.clusterscanner.ClusterScanner | None = None, **kwargs)#
Bases:
gnn_tracking.training.base.TrackingModule
Object condensation for tracks. This lightning module implements losses, training, and validation steps. k:w
- Parameters:
loss_fct
cluster_scanner
**kwargs – Passed on to TrackingModule
- loss_fct#
- cluster_scanner#
- _cluster_scan_input#
- _best_cluster_params#
- is_last_val_batch(batch_idx: int) bool #
Are we validating the last batch of the validation set?
- get_losses(out: dict[str, Any], data: torch_geometric.data.Data) tuple[torch.Tensor, dict[str, float]] #
- training_step(data: torch_geometric.data.Data, batch_idx: int) torch.Tensor #
- validation_step(data: torch_geometric.data.Data, batch_idx: int) None #
- _evaluate_cluster_metrics(out: dict[str, Any], data: torch_geometric.data.Data, batch_idx: int) dict[str, float] #
Evaluate cluster metrics.
- highlight_metric(metric: str) bool #