gnn_tracking.metrics.losses#

This module contains loss functions for the GNN tracking model.

Module Contents#

Classes#

FalsifyLowPtEdgeWeightLoss

Add an option to falsify edges with low pt to edge classification losses.

EdgeWeightBCELoss

Add an option to falsify edges with low pt to edge classification losses.

EdgeWeightFocalLoss

Loss function based on focal loss for edge classification.

HaughtyFocalLoss

PotentialLoss

Potential/condensation loss (specific to object condensation approach).

BackgroundLoss

ObjectLoss

Loss functions for predicted object properties.

LossFctType

Type of a loss function

LossClones

Wrapper for a loss function that evaluates it on multiple inputs.

GraphConstructionHingeEmbeddingLoss

Loss for graph construction using metric learning.

Functions#

_binary_focal_loss(→ torch.Tensor)

Extracted function for JIT compilation.

binary_focal_loss(→ torch.Tensor)

Binary Focal Loss, following https://arxiv.org/abs/1708.02002.

falsify_low_pt_edges(→ torch.Tensor)

Modify the ground truth to-be-predicted by the edge classification

_condensation_loss(→ dict[str, torch.Tensor])

Extracted function for JIT-compilation. See PotentialLoss for details.

_first_occurrences(→ torch.Tensor)

Return the first occurrence of each unique element in a 1D array

_square_distances(→ torch.Tensor)

Returns squared distances between two sets of points

_fast_condensation_loss(→ dict[str, torch.Tensor])

Extracted function for condensation loss. See PotentialLoss for details.

_background_loss(→ torch.Tensor)

Extracted function for JIT-compilation. See BackgroundLoss for details.

unpack_loss_returns(→ dict[str, float | torch.Tensor])

Some of our loss functions return a dictionary or a list of individual losses.

_hinge_loss_components(→ tuple[torch.Tensor, torch.Tensor])

Attributes#

loss_weight_type

gnn_tracking.metrics.losses._binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float, pos_weight: torch.Tensor) torch.Tensor#

Extracted function for JIT compilation.

gnn_tracking.metrics.losses.binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0, pos_weight: torch.Tensor = None) torch.Tensor#

Binary Focal Loss, following https://arxiv.org/abs/1708.02002.

Parameters:
  • inpt

  • target

  • alpha – Weight for positive/negative results

  • gamma – Focusing parameter

  • pos_weight – Can be used to balance precision/recall

gnn_tracking.metrics.losses.falsify_low_pt_edges(*, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, pt_thld: float = 0.0) torch.Tensor#

Modify the ground truth to-be-predicted by the edge classification to consider edges that include a hit with pt < pt_thld as false.

Parameters:
  • y – True classification

  • edge_index

  • pt – Hit pt

  • pt_thld – Apply pt threshold

Returns:

True classification with additional criteria applied

class gnn_tracking.metrics.losses.FalsifyLowPtEdgeWeightLoss(*, pt_thld: float = 0.0)#

Bases: torch.nn.Module, abc.ABC, pytorch_lightning.core.mixins.HyperparametersMixin

Add an option to falsify edges with low pt to edge classification losses.

forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, **kwargs) torch.Tensor#
abstract _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor#
class gnn_tracking.metrics.losses.EdgeWeightBCELoss(*, pt_thld: float = 0.0)#

Bases: FalsifyLowPtEdgeWeightLoss

Add an option to falsify edges with low pt to edge classification losses.

static _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor#
class gnn_tracking.metrics.losses.EdgeWeightFocalLoss(*, alpha=0.25, gamma=2.0, pos_weight=None, **kwargs)#

Bases: FalsifyLowPtEdgeWeightLoss

Loss function based on focal loss for edge classification. See binary_focal_loss for details.

_forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor#
class gnn_tracking.metrics.losses.HaughtyFocalLoss(*, alpha: float = 0.25, gamma: float = 2.0, pt_thld=0.0)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.HyperparametersMixin

forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) torch.Tensor#
gnn_tracking.metrics.losses._condensation_loss(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, mask: torch.Tensor, q_min: float, radius_threshold: float) dict[str, torch.Tensor]#

Extracted function for JIT-compilation. See PotentialLoss for details.

gnn_tracking.metrics.losses._first_occurrences(input_array: torch.Tensor) torch.Tensor#

Return the first occurrence of each unique element in a 1D array

gnn_tracking.metrics.losses._square_distances(edges: torch.Tensor, positions: torch.Tensor) torch.Tensor#

Returns squared distances between two sets of points

gnn_tracking.metrics.losses._fast_condensation_loss(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, q_min: float, mask: torch.Tensor, radius_threshold: float, max_num_neighbors: int) dict[str, torch.Tensor]#

Extracted function for condensation loss. See PotentialLoss for details.

class gnn_tracking.metrics.losses.PotentialLoss(q_min=0.01, radius_threshold=1.0, attr_pt_thld=0.9, max_neighbors=0)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.HyperparametersMixin

Potential/condensation loss (specific to object condensation approach).

Parameters:
  • q_min – Minimal charge q

  • radius_threshold – Parameter of repulsive potential

  • attr_pt_thld – Truth-level threshold for hits/tracks to consider in attractive loss [GeV]

  • max_neighbors – Parameter to determine maximum number of edges drawn from each node while calculating repulsive loss. If set to 0, non-approximate loss is calculated.

forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, **kwargs) dict[str, torch.Tensor]#
gnn_tracking.metrics.losses._background_loss(*, beta: torch.Tensor, particle_id: torch.Tensor, sb: float) torch.Tensor#

Extracted function for JIT-compilation. See BackgroundLoss for details.

class gnn_tracking.metrics.losses.BackgroundLoss(sb=0.1)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.HyperparametersMixin

forward(*, beta: torch.Tensor, particle_id: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, **kwargs) torch.Tensor#
class gnn_tracking.metrics.losses.ObjectLoss(mode='efficiency')#

Bases: torch.nn.Module

Loss functions for predicted object properties.

static _mse(*, pred: torch.Tensor, truth: torch.Tensor) torch.Tensor#
object_loss(*, pred: torch.Tensor, beta: torch.Tensor, truth: torch.Tensor, particle_id: torch.Tensor) torch.Tensor#
forward(*, beta: torch.Tensor, pred: torch.Tensor, particle_id: torch.Tensor, track_params: torch.Tensor, reconstructable: torch.Tensor, **kwargs) torch.Tensor#
class gnn_tracking.metrics.losses.LossFctType#

Bases: Protocol

Type of a loss function

__call__(*args: Any, **kwargs: Any) torch.Tensor#
to(device: torch.device) LossFctType#
gnn_tracking.metrics.losses.loss_weight_type#
gnn_tracking.metrics.losses.unpack_loss_returns(key: str, returns: Any) dict[str, float | torch.Tensor]#

Some of our loss functions return a dictionary or a list of individual losses. This function unpacks these into a dictionary of individual losses with appropriate keys.

Parameters:
  • key – str (name of the loss function)

  • returns – dict or list or single value

Returns:

dict of individual losses

class gnn_tracking.metrics.losses.LossClones(loss: torch.nn.Module, prefixes=('w', 'y'))#

Bases: torch.nn.Module

Wrapper for a loss function that evaluates it on multiple inputs. The forward method will look for all model outputs that start with w_ (or another specified prefix) and evaluate the loss function for each of them, returning a dictionary of losses (with keys equal to the suffixes).

Usage example 1:

losses = {
    "potential": (PotentialLoss(), 1.),
    "edge": (LossClones(EdgeWeightBCELoss()), [1.0, 2.0, 3.0])
}

will evaluate three clones of the BCE loss function, one for each EC layer.

Usage Example 2:

losses = {
    "potential": (PotentialLoss(), 1.),
    "edge": (LossClones(EdgeWeightBCELoss()), {}))
}

this works with a variable number of layers. The weights are all 1.

Under the hood, ECLossClones(EdgeWeightBCELoss())(model output) will output a dictionary of the individual losses, keyed by their suffixes (in a similar way to how PotentialLoss returns a dictionary of losses).

Parameters:
  • loss – Loss function to be evaluated on multiple inputs

  • prefixes – Prefixes of the model outputs that should be evaluated. An underscore is assumed (set prefix to w for w_0, w_1, etc.)

forward(**kwargs) dict[str, torch.Tensor]#
gnn_tracking.metrics.losses._hinge_loss_components(*, x: torch.Tensor, edge_index: torch.Tensor, particle_id: torch.Tensor, pt: torch.Tensor, r_emb_hinge: float, pt_thld: float, p_attr: float, p_rep: float) tuple[torch.Tensor, torch.Tensor]#
class gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss(*, r_emb=1, max_num_neighbors: int = 256, attr_pt_thld: float = 0.9, p_attr: float = 1, p_rep: float = 1)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.HyperparametersMixin

Loss for graph construction using metric learning.

Parameters:
  • r_emb – Radius for edge construction

  • max_num_neighbors – Maximum number of neighbors in radius graph building. See rusty1s/pytorch_cluster

  • p_attr – Power for the attraction term (default 1: linear loss)

  • p_rep – Power for the repulsion term (default 1: linear loss)

_build_graph(x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor) torch.Tensor#
forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) dict[str, torch.Tensor]#