gnn_tracking.metrics.losses.oc#

Attributes#

logger

Classes#

MultiLossFct

Base class for loss functions that return multiple losses.

MultiLossFctReturn

Return type for loss functions that return multiple losses.

CondensationLossRG

Implementation of condensation loss that uses radius graph instead

CondensationLossTiger

Implementation of condensation loss that directly calculates the n^2

ObjectLoss

Loss functions for predicted object properties.

Functions#

get_good_node_mask_tensors(→ torch.Tensor)

See get_good_node_mask

_first_occurrences(→ torch.Tensor)

Return the first occurrence of each unique element in a 1D array

_square_distances(→ torch.Tensor)

Returns squared distances between two sets of points

_get_alphas_first_occurences(→ tuple[torch.Tensor, ...)

_get_vr_rg(*, radius_edges, is_cp_j, particle_id, x, ...)

_get_va(→ torch.Tensor)

_radius_graph_condensation_loss(→ tuple[dict[str, ...)

Extracted function for condensation loss. See PotentialLoss for details.

condensation_loss_tiger(→ tuple[dict[str, ...)

Extracted function for torch compilation. See condensation_loss_tiger for

Module Contents#

class gnn_tracking.metrics.losses.oc.MultiLossFct#

Bases: torch.nn.Module

Base class for loss functions that return multiple losses.

forward(*args: Any, **kwargs: Any) MultiLossFctReturn#
class gnn_tracking.metrics.losses.oc.MultiLossFctReturn#

Return type for loss functions that return multiple losses.

loss_dct: dict[str, torch.Tensor]#
weight_dct: dict[str, torch.Tensor] | dict[str, float]#
extra_metrics: dict[str, Any]#
__post_init__() None#
property loss: torch.Tensor#
property weighted_losses: dict[str, torch.Tensor]#
gnn_tracking.metrics.losses.oc.get_good_node_mask_tensors(*, pt, particle_id, reconstructable, eta, pt_thld: float = 0.9, max_eta: float = 4.0) torch.Tensor#

See get_good_node_mask

gnn_tracking.metrics.losses.oc.logger#
gnn_tracking.metrics.losses.oc._first_occurrences(x: torch.Tensor) torch.Tensor#

Return the first occurrence of each unique element in a 1D array

gnn_tracking.metrics.losses.oc._square_distances(edges: torch.Tensor, positions: torch.Tensor) torch.Tensor#

Returns squared distances between two sets of points

gnn_tracking.metrics.losses.oc._get_alphas_first_occurences(beta: torch.Tensor, particle_id: torch.Tensor, mask: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#
gnn_tracking.metrics.losses.oc._get_vr_rg(*, radius_edges: torch.Tensor, is_cp_j: torch.Tensor, particle_id: torch.Tensor, x: torch.Tensor, q_j: torch.Tensor, radius_threshold: float)#
gnn_tracking.metrics.losses.oc._get_va(*, alphas_k: torch.Tensor, is_cp_j: torch.Tensor, particle_id: torch.Tensor, x: torch.Tensor, q_j: torch.Tensor, mask: torch.Tensor) torch.Tensor#
gnn_tracking.metrics.losses.oc._radius_graph_condensation_loss(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, q_min: float, mask: torch.Tensor, radius_threshold: float, max_num_neighbors: int) tuple[dict[str, torch.Tensor], dict[str, Any]]#

Extracted function for condensation loss. See PotentialLoss for details.

Parameters:

mask – Mask for objects cast to nodes

class gnn_tracking.metrics.losses.oc.CondensationLossRG(*, lw_repulsive: float = 1.0, lw_noise: float = 0.0, lw_coward: float = 0.0, q_min: float = 0.01, pt_thld: float = 0.9, max_eta: float = 4.0, max_num_neighbors: int = 256, sample_pids: float = 1.0)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Implementation of condensation loss that uses radius graph instead calculating the whole n^2 distance matrix.

Parameters:
  • lw_repulsive – Loss weight for repulsive part of potential loss

  • lw_noise – Loss weight for noise loss

  • lw_background – Loss weight for background loss

  • q_min (float, optional) – See OC paper. Defaults to 0.01.

  • pt_thld (float, optional) – pt thld for interesting particles. Defaults to 0.9.

  • max_eta (float, optional) – eta thld for interesting particles. Defaults to 4.0.

  • max_num_neighbors (int, optional) – Maximum number of neighbors to consider for radius graphs. Defaults to 256.

  • sample_pids (float, optional) – Further subsample particles to conserve memory. Defaults to 1.0 (no sampling)

forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, eta: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn#
gnn_tracking.metrics.losses.oc.condensation_loss_tiger(*, beta: torch.Tensor, x: torch.Tensor, object_id: torch.Tensor, object_mask: torch.Tensor, q_min: float, noise_threshold: int, max_n_rep: int) tuple[dict[str, torch.Tensor], dict[str, int | float]]#

Extracted function for torch compilation. See condensation_loss_tiger for docstring.

Parameters:

object_mask – Mask for the particles that should be considered for the loss this is broadcased to n_hits

Returns:

Dictionary of losses extra_dct: Dictionary of extra information

Return type:

loss_dct

class gnn_tracking.metrics.losses.oc.CondensationLossTiger(*, lw_repulsive: float = 1.0, lw_noise: float = 0.0, lw_coward: float = 0.0, q_min: float = 0.01, pt_thld: float = 0.9, max_eta: float = 4.0, max_n_rep: int = 0, sample_pids: float = 1.0)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Implementation of condensation loss that directly calculates the n^2 distance matrix.

Parameters:
  • lw_repulsive – Loss weight for repulsive part of potential loss

  • lw_noise – Loss weight for noise loss

  • lw_background – Loss weight for background loss

  • q_min (float, optional) – See OC paper. Defaults to 0.01.

  • pt_thld (float, optional) – pt thld for interesting particles. Defaults to 0.9.

  • max_eta (float, optional) – eta thld for interesting particles. Defaults to 4.0.

  • max_n_rep (int, optional) – Maximum number of repulsive edges to consider. Defaults to 0 (all).

  • sample_pids (float, optional) – Further subsample particles to conserve memory. Defaults to 1.0 (no sampling)

forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, eta: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn#
class gnn_tracking.metrics.losses.oc.ObjectLoss(mode='efficiency')#

Bases: torch.nn.Module

Loss functions for predicted object properties.

static _mse(*, pred: torch.Tensor, truth: torch.Tensor) torch.Tensor#
object_loss(*, pred: torch.Tensor, beta: torch.Tensor, truth: torch.Tensor, particle_id: torch.Tensor) torch.Tensor#
forward(*, beta: torch.Tensor, pred: torch.Tensor, particle_id: torch.Tensor, track_params: torch.Tensor, reconstructable: torch.Tensor, **kwargs) torch.Tensor#