gnn_tracking.metrics.losses
#
This module contains loss functions for the GNN tracking model.
Module Contents#
Classes#
Add an option to falsify edges with low pt to edge classification losses. |
|
Add an option to falsify edges with low pt to edge classification losses. |
|
Loss function based on focal loss for edge classification. |
|
Potential/condensation loss (specific to object condensation approach). |
|
Loss functions for predicted object properties. |
|
Type of a loss function |
|
Wrapper for a loss function that evaluates it on multiple inputs. |
|
Loss for graph construction using metric learning. |
Functions#
|
Extracted function for JIT compilation. |
|
Binary Focal Loss, following https://arxiv.org/abs/1708.02002. |
|
Modify the ground truth to-be-predicted by the edge classification |
|
Extracted function for JIT-compilation. See PotentialLoss for details. |
|
Return the first occurrence of each unique element in a 1D array |
|
Returns squared distances between two sets of points |
|
Extracted function for condensation loss. See PotentialLoss for details. |
|
Extracted function for JIT-compilation. See BackgroundLoss for details. |
|
Some of our loss functions return a dictionary or a list of individual losses. |
|
Attributes#
- gnn_tracking.metrics.losses._binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float, pos_weight: torch.Tensor) torch.Tensor #
Extracted function for JIT compilation.
- gnn_tracking.metrics.losses.binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0, pos_weight: torch.Tensor = None) torch.Tensor #
Binary Focal Loss, following https://arxiv.org/abs/1708.02002.
- Parameters:
inpt –
target –
alpha – Weight for positive/negative results
gamma – Focusing parameter
pos_weight – Can be used to balance precision/recall
- gnn_tracking.metrics.losses.falsify_low_pt_edges(*, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, pt_thld: float = 0.0) torch.Tensor #
Modify the ground truth to-be-predicted by the edge classification to consider edges that include a hit with pt < pt_thld as false.
- Parameters:
y – True classification
edge_index –
pt – Hit pt
pt_thld – Apply pt threshold
- Returns:
True classification with additional criteria applied
- class gnn_tracking.metrics.losses.FalsifyLowPtEdgeWeightLoss(*, pt_thld: float = 0.0)#
Bases:
torch.nn.Module
,abc.ABC
,pytorch_lightning.core.mixins.HyperparametersMixin
Add an option to falsify edges with low pt to edge classification losses.
- forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, **kwargs) torch.Tensor #
- abstract _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor #
- class gnn_tracking.metrics.losses.EdgeWeightBCELoss(*, pt_thld: float = 0.0)#
Bases:
FalsifyLowPtEdgeWeightLoss
Add an option to falsify edges with low pt to edge classification losses.
- static _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor #
- class gnn_tracking.metrics.losses.EdgeWeightFocalLoss(*, alpha=0.25, gamma=2.0, pos_weight=None, **kwargs)#
Bases:
FalsifyLowPtEdgeWeightLoss
Loss function based on focal loss for edge classification. See binary_focal_loss for details.
- _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor #
- class gnn_tracking.metrics.losses.HaughtyFocalLoss(*, alpha: float = 0.25, gamma: float = 2.0, pt_thld=0.0)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
- forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) torch.Tensor #
- gnn_tracking.metrics.losses._condensation_loss(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, mask: torch.Tensor, q_min: float, radius_threshold: float) dict[str, torch.Tensor] #
Extracted function for JIT-compilation. See PotentialLoss for details.
- gnn_tracking.metrics.losses._first_occurrences(input_array: torch.Tensor) torch.Tensor #
Return the first occurrence of each unique element in a 1D array
- gnn_tracking.metrics.losses._square_distances(edges: torch.Tensor, positions: torch.Tensor) torch.Tensor #
Returns squared distances between two sets of points
- gnn_tracking.metrics.losses._fast_condensation_loss(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, q_min: float, mask: torch.Tensor, radius_threshold: float, max_num_neighbors: int) dict[str, torch.Tensor] #
Extracted function for condensation loss. See PotentialLoss for details.
- class gnn_tracking.metrics.losses.PotentialLoss(q_min=0.01, radius_threshold=1.0, attr_pt_thld=0.9, max_neighbors=0)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
Potential/condensation loss (specific to object condensation approach).
- Parameters:
q_min – Minimal charge
q
radius_threshold – Parameter of repulsive potential
attr_pt_thld – Truth-level threshold for hits/tracks to consider in attractive loss [GeV]
max_neighbors – Parameter to determine maximum number of edges drawn from each node while calculating repulsive loss. If set to 0, non-approximate loss is calculated.
- forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, **kwargs) dict[str, torch.Tensor] #
- gnn_tracking.metrics.losses._background_loss(*, beta: torch.Tensor, particle_id: torch.Tensor, sb: float) torch.Tensor #
Extracted function for JIT-compilation. See BackgroundLoss for details.
- class gnn_tracking.metrics.losses.BackgroundLoss(sb=0.1)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
- forward(*, beta: torch.Tensor, particle_id: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, **kwargs) torch.Tensor #
- class gnn_tracking.metrics.losses.ObjectLoss(mode='efficiency')#
Bases:
torch.nn.Module
Loss functions for predicted object properties.
- static _mse(*, pred: torch.Tensor, truth: torch.Tensor) torch.Tensor #
- object_loss(*, pred: torch.Tensor, beta: torch.Tensor, truth: torch.Tensor, particle_id: torch.Tensor) torch.Tensor #
- forward(*, beta: torch.Tensor, pred: torch.Tensor, particle_id: torch.Tensor, track_params: torch.Tensor, reconstructable: torch.Tensor, **kwargs) torch.Tensor #
- class gnn_tracking.metrics.losses.LossFctType#
Bases:
Protocol
Type of a loss function
- __call__(*args: Any, **kwargs: Any) torch.Tensor #
- to(device: torch.device) LossFctType #
- gnn_tracking.metrics.losses.loss_weight_type#
- gnn_tracking.metrics.losses.unpack_loss_returns(key: str, returns: Any) dict[str, float | torch.Tensor] #
Some of our loss functions return a dictionary or a list of individual losses. This function unpacks these into a dictionary of individual losses with appropriate keys.
- Parameters:
key – str (name of the loss function)
returns – dict or list or single value
- Returns:
dict of individual losses
- class gnn_tracking.metrics.losses.LossClones(loss: torch.nn.Module, prefixes=('w', 'y'))#
Bases:
torch.nn.Module
Wrapper for a loss function that evaluates it on multiple inputs. The forward method will look for all model outputs that start with w_ (or another specified prefix) and evaluate the loss function for each of them, returning a dictionary of losses (with keys equal to the suffixes).
Usage example 1:
losses = { "potential": (PotentialLoss(), 1.), "edge": (LossClones(EdgeWeightBCELoss()), [1.0, 2.0, 3.0]) }
will evaluate three clones of the BCE loss function, one for each EC layer.
Usage Example 2:
losses = { "potential": (PotentialLoss(), 1.), "edge": (LossClones(EdgeWeightBCELoss()), {})) }
this works with a variable number of layers. The weights are all 1.
Under the hood,
ECLossClones(EdgeWeightBCELoss())(model output)
will output a dictionary of the individual losses, keyed by their suffixes (in a similar way to how PotentialLoss returns a dictionary of losses).- Parameters:
loss – Loss function to be evaluated on multiple inputs
prefixes – Prefixes of the model outputs that should be evaluated. An underscore is assumed (set prefix to w for w_0, w_1, etc.)
- forward(**kwargs) dict[str, torch.Tensor] #
- gnn_tracking.metrics.losses._hinge_loss_components(*, x: torch.Tensor, edge_index: torch.Tensor, particle_id: torch.Tensor, pt: torch.Tensor, r_emb_hinge: float, pt_thld: float, p_attr: float, p_rep: float) tuple[torch.Tensor, torch.Tensor] #
- class gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss(*, r_emb=1, max_num_neighbors: int = 256, attr_pt_thld: float = 0.9, p_attr: float = 1, p_rep: float = 1)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
Loss for graph construction using metric learning.
- Parameters:
r_emb – Radius for edge construction
max_num_neighbors – Maximum number of neighbors in radius graph building. See rusty1s/pytorch_cluster
p_attr – Power for the attraction term (default 1: linear loss)
p_rep – Power for the repulsion term (default 1: linear loss)
- _build_graph(x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor) torch.Tensor #
- forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) dict[str, torch.Tensor] #