gnn_tracking.metrics.losses.oc#
Classes#
Implementation of condensation loss that uses radius graph instead |
|
Implementation of condensation loss that directly calculates the n^2 |
|
Loss functions for predicted object properties. |
Functions#
|
Return the first occurrence of each unique element in a 1D array |
|
Returns squared distances between two sets of points |
|
|
|
|
|
|
|
Extracted function for condensation loss. See PotentialLoss for details. |
|
Extracted function for torch compilation. See condensation_loss_tiger for |
Module Contents#
- gnn_tracking.metrics.losses.oc._first_occurrences(x: torch.Tensor) torch.Tensor #
Return the first occurrence of each unique element in a 1D array
- gnn_tracking.metrics.losses.oc._square_distances(edges: torch.Tensor, positions: torch.Tensor) torch.Tensor #
Returns squared distances between two sets of points
- gnn_tracking.metrics.losses.oc._get_alphas_first_occurences(beta: torch.Tensor, particle_id: torch.Tensor, mask: torch.Tensor) tuple[torch.Tensor, torch.Tensor] #
- gnn_tracking.metrics.losses.oc._get_vr_rg(*, radius_edges: torch.Tensor, is_cp_j: torch.Tensor, particle_id: torch.Tensor, x: torch.Tensor, q_j: torch.Tensor, radius_threshold: float)#
- gnn_tracking.metrics.losses.oc._get_va(*, alphas_k: torch.Tensor, is_cp_j: torch.Tensor, particle_id: torch.Tensor, x: torch.Tensor, q_j: torch.Tensor, mask: torch.Tensor) torch.Tensor #
- gnn_tracking.metrics.losses.oc._radius_graph_condensation_loss(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, q_min: float, mask: torch.Tensor, radius_threshold: float, max_num_neighbors: int) tuple[dict[str, torch.Tensor], dict[str, Any]] #
Extracted function for condensation loss. See PotentialLoss for details.
- Parameters:
mask – Mask for objects cast to nodes
- class gnn_tracking.metrics.losses.oc.CondensationLossRG(*, lw_repulsive: float = 1.0, lw_noise: float = 0.0, lw_coward: float = 0.0, q_min: float = 0.01, pt_thld: float = 0.9, max_eta: float = 4.0, max_num_neighbors: int = 256, sample_pids: float = 1.0)#
Bases:
gnn_tracking.metrics.losses.MultiLossFct
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Implementation of condensation loss that uses radius graph instead calculating the whole n^2 distance matrix.
- Parameters:
lw_repulsive – Loss weight for repulsive part of potential loss
lw_noise – Loss weight for noise loss
lw_background – Loss weight for background loss
q_min (float, optional) – See OC paper. Defaults to 0.01.
pt_thld (float, optional) – pt thld for interesting particles. Defaults to 0.9.
max_eta (float, optional) – eta thld for interesting particles. Defaults to 4.0.
max_num_neighbors (int, optional) – Maximum number of neighbors to consider for radius graphs. Defaults to 256.
sample_pids (float, optional) – Further subsample particles to conserve memory. Defaults to 1.0 (no sampling)
- forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, eta: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn #
- gnn_tracking.metrics.losses.oc.condensation_loss_tiger(*, beta: torch.Tensor, x: torch.Tensor, object_id: torch.Tensor, object_mask: torch.Tensor, q_min: float, noise_threshold: int, max_n_rep: int) tuple[dict[str, torch.Tensor], dict[str, int | float]] #
Extracted function for torch compilation. See condensation_loss_tiger for docstring.
- Parameters:
object_mask – Mask for the particles that should be considered for the loss this is broadcased to n_hits
- Returns:
Dictionary of losses extra_dct: Dictionary of extra information
- Return type:
loss_dct
- class gnn_tracking.metrics.losses.oc.CondensationLossTiger(*, lw_repulsive: float = 1.0, lw_noise: float = 0.0, lw_coward: float = 0.0, q_min: float = 0.01, pt_thld: float = 0.9, max_eta: float = 4.0, max_n_rep: int = 0, sample_pids: float = 1.0)#
Bases:
gnn_tracking.metrics.losses.MultiLossFct
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Implementation of condensation loss that directly calculates the n^2 distance matrix.
- Parameters:
lw_repulsive – Loss weight for repulsive part of potential loss
lw_noise – Loss weight for noise loss
lw_background – Loss weight for background loss
q_min (float, optional) – See OC paper. Defaults to 0.01.
pt_thld (float, optional) – pt thld for interesting particles. Defaults to 0.9.
max_eta (float, optional) – eta thld for interesting particles. Defaults to 4.0.
max_n_rep (int, optional) – Maximum number of repulsive edges to consider. Defaults to 0 (all).
sample_pids (float, optional) – Further subsample particles to conserve memory. Defaults to 1.0 (no sampling)
- forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, eta: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn #
- class gnn_tracking.metrics.losses.oc.ObjectLoss(mode='efficiency')#
Bases:
torch.nn.Module
Loss functions for predicted object properties.
- mode#
- static _mse(*, pred: torch.Tensor, truth: torch.Tensor) torch.Tensor #
- object_loss(*, pred: torch.Tensor, beta: torch.Tensor, truth: torch.Tensor, particle_id: torch.Tensor) torch.Tensor #
- forward(*, beta: torch.Tensor, pred: torch.Tensor, particle_id: torch.Tensor, track_params: torch.Tensor, reconstructable: torch.Tensor, **kwargs) torch.Tensor #