gnn_tracking.metrics.losses.oc#

Classes#

CondensationLossRG

Implementation of condensation loss that uses radius graph instead

CondensationLossTiger

Implementation of condensation loss that directly calculates the n^2

ObjectLoss

Loss functions for predicted object properties.

Functions#

_first_occurrences(→ torch.Tensor)

Return the first occurrence of each unique element in a 1D array

_square_distances(→ torch.Tensor)

Returns squared distances between two sets of points

_get_alphas_first_occurences(→ tuple[torch.Tensor, ...)

_get_vr_rg(*, radius_edges, is_cp_j, particle_id, x, ...)

_get_va(→ torch.Tensor)

_radius_graph_condensation_loss(→ tuple[dict[str, ...)

Extracted function for condensation loss. See PotentialLoss for details.

condensation_loss_tiger(→ tuple[dict[str, ...)

Extracted function for torch compilation. See condensation_loss_tiger for

Module Contents#

gnn_tracking.metrics.losses.oc._first_occurrences(x: torch.Tensor) torch.Tensor#

Return the first occurrence of each unique element in a 1D array

gnn_tracking.metrics.losses.oc._square_distances(edges: torch.Tensor, positions: torch.Tensor) torch.Tensor#

Returns squared distances between two sets of points

gnn_tracking.metrics.losses.oc._get_alphas_first_occurences(beta: torch.Tensor, particle_id: torch.Tensor, mask: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#
gnn_tracking.metrics.losses.oc._get_vr_rg(*, radius_edges: torch.Tensor, is_cp_j: torch.Tensor, particle_id: torch.Tensor, x: torch.Tensor, q_j: torch.Tensor, radius_threshold: float)#
gnn_tracking.metrics.losses.oc._get_va(*, alphas_k: torch.Tensor, is_cp_j: torch.Tensor, particle_id: torch.Tensor, x: torch.Tensor, q_j: torch.Tensor, mask: torch.Tensor) torch.Tensor#
gnn_tracking.metrics.losses.oc._radius_graph_condensation_loss(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, q_min: float, mask: torch.Tensor, radius_threshold: float, max_num_neighbors: int) tuple[dict[str, torch.Tensor], dict[str, Any]]#

Extracted function for condensation loss. See PotentialLoss for details.

Parameters:

mask – Mask for objects cast to nodes

class gnn_tracking.metrics.losses.oc.CondensationLossRG(*, lw_repulsive: float = 1.0, lw_noise: float = 0.0, lw_coward: float = 0.0, q_min: float = 0.01, pt_thld: float = 0.9, max_eta: float = 4.0, max_num_neighbors: int = 256, sample_pids: float = 1.0)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Implementation of condensation loss that uses radius graph instead calculating the whole n^2 distance matrix.

Parameters:
  • lw_repulsive – Loss weight for repulsive part of potential loss

  • lw_noise – Loss weight for noise loss

  • lw_background – Loss weight for background loss

  • q_min (float, optional) – See OC paper. Defaults to 0.01.

  • pt_thld (float, optional) – pt thld for interesting particles. Defaults to 0.9.

  • max_eta (float, optional) – eta thld for interesting particles. Defaults to 4.0.

  • max_num_neighbors (int, optional) – Maximum number of neighbors to consider for radius graphs. Defaults to 256.

  • sample_pids (float, optional) – Further subsample particles to conserve memory. Defaults to 1.0 (no sampling)

forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, eta: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn#
gnn_tracking.metrics.losses.oc.condensation_loss_tiger(*, beta: torch.Tensor, x: torch.Tensor, object_id: torch.Tensor, object_mask: torch.Tensor, q_min: float, noise_threshold: int, max_n_rep: int) tuple[dict[str, torch.Tensor], dict[str, int | float]]#

Extracted function for torch compilation. See condensation_loss_tiger for docstring.

Parameters:

object_mask – Mask for the particles that should be considered for the loss this is broadcased to n_hits

Returns:

Dictionary of losses extra_dct: Dictionary of extra information

Return type:

loss_dct

class gnn_tracking.metrics.losses.oc.CondensationLossTiger(*, lw_repulsive: float = 1.0, lw_noise: float = 0.0, lw_coward: float = 0.0, q_min: float = 0.01, pt_thld: float = 0.9, max_eta: float = 4.0, max_n_rep: int = 0, sample_pids: float = 1.0)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Implementation of condensation loss that directly calculates the n^2 distance matrix.

Parameters:
  • lw_repulsive – Loss weight for repulsive part of potential loss

  • lw_noise – Loss weight for noise loss

  • lw_background – Loss weight for background loss

  • q_min (float, optional) – See OC paper. Defaults to 0.01.

  • pt_thld (float, optional) – pt thld for interesting particles. Defaults to 0.9.

  • max_eta (float, optional) – eta thld for interesting particles. Defaults to 4.0.

  • max_n_rep (int, optional) – Maximum number of repulsive edges to consider. Defaults to 0 (all).

  • sample_pids (float, optional) – Further subsample particles to conserve memory. Defaults to 1.0 (no sampling)

forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, eta: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn#
class gnn_tracking.metrics.losses.oc.ObjectLoss(mode='efficiency')#

Bases: torch.nn.Module

Loss functions for predicted object properties.

mode#
static _mse(*, pred: torch.Tensor, truth: torch.Tensor) torch.Tensor#
object_loss(*, pred: torch.Tensor, beta: torch.Tensor, truth: torch.Tensor, particle_id: torch.Tensor) torch.Tensor#
forward(*, beta: torch.Tensor, pred: torch.Tensor, particle_id: torch.Tensor, track_params: torch.Tensor, reconstructable: torch.Tensor, **kwargs) torch.Tensor#