gnn_tracking.metrics.losses.ec#
Classes#
Add an option to falsify edges with low pt to edge classification losses. |
|
Add an option to falsify edges with low pt to edge classification losses. |
|
Loss function based on focal loss for edge classification. |
|
Functions#
|
Extracted function for JIT compilation. |
|
Binary Focal Loss, following https://arxiv.org/abs/1708.02002. |
|
Modify the ground truth to-be-predicted by the edge classification |
Module Contents#
- gnn_tracking.metrics.losses.ec._binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float, pos_weight: torch.Tensor) torch.Tensor #
Extracted function for JIT compilation.
- gnn_tracking.metrics.losses.ec.binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0, pos_weight: torch.Tensor | None = None) torch.Tensor #
Binary Focal Loss, following https://arxiv.org/abs/1708.02002.
- Parameters:
inpt
target
alpha – Weight for positive/negative results
gamma – Focusing parameter
pos_weight – Can be used to balance precision/recall
- gnn_tracking.metrics.losses.ec.falsify_low_pt_edges(*, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, pt_thld: float = 0.0) torch.Tensor #
Modify the ground truth to-be-predicted by the edge classification to consider edges that include a hit with pt < pt_thld as false.
- Parameters:
y – True classification
edge_index
pt – Hit pt
pt_thld – Apply pt threshold
- Returns:
True classification with additional criteria applied
- class gnn_tracking.metrics.losses.ec.FalsifyLowPtEdgeWeightLoss(*, pt_thld: float = 0.0)#
Bases:
torch.nn.Module
,abc.ABC
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Add an option to falsify edges with low pt to edge classification losses.
- forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, **kwargs) torch.Tensor #
- abstract _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor #
- class gnn_tracking.metrics.losses.ec.EdgeWeightBCELoss(*, pt_thld: float = 0.0)#
Bases:
FalsifyLowPtEdgeWeightLoss
Add an option to falsify edges with low pt to edge classification losses.
- static _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor #
- class gnn_tracking.metrics.losses.ec.EdgeWeightFocalLoss(*, alpha=0.25, gamma=2.0, pos_weight=None, **kwargs)#
Bases:
FalsifyLowPtEdgeWeightLoss
Loss function based on focal loss for edge classification. See binary_focal_loss for details.
- _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor #
- class gnn_tracking.metrics.losses.ec.HaughtyFocalLoss(*, alpha: float = 0.25, gamma: float = 2.0, pt_thld=0.0)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
- _alpha#
- _gamma#
- _pt_thld#
- forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) torch.Tensor #