gnn_tracking.metrics.losses.ec#

Classes#

FalsifyLowPtEdgeWeightLoss

Add an option to falsify edges with low pt to edge classification losses.

EdgeWeightBCELoss

Add an option to falsify edges with low pt to edge classification losses.

EdgeWeightFocalLoss

Loss function based on focal loss for edge classification.

HaughtyFocalLoss

Functions#

_binary_focal_loss(→ torch.Tensor)

Extracted function for JIT compilation.

binary_focal_loss(→ torch.Tensor)

Binary Focal Loss, following https://arxiv.org/abs/1708.02002.

falsify_low_pt_edges(→ torch.Tensor)

Modify the ground truth to-be-predicted by the edge classification

Module Contents#

gnn_tracking.metrics.losses.ec._binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float, pos_weight: torch.Tensor) torch.Tensor#

Extracted function for JIT compilation.

gnn_tracking.metrics.losses.ec.binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0, pos_weight: torch.Tensor | None = None) torch.Tensor#

Binary Focal Loss, following https://arxiv.org/abs/1708.02002.

Parameters:
  • inpt

  • target

  • alpha – Weight for positive/negative results

  • gamma – Focusing parameter

  • pos_weight – Can be used to balance precision/recall

gnn_tracking.metrics.losses.ec.falsify_low_pt_edges(*, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, pt_thld: float = 0.0) torch.Tensor#

Modify the ground truth to-be-predicted by the edge classification to consider edges that include a hit with pt < pt_thld as false.

Parameters:
  • y – True classification

  • edge_index

  • pt – Hit pt

  • pt_thld – Apply pt threshold

Returns:

True classification with additional criteria applied

class gnn_tracking.metrics.losses.ec.FalsifyLowPtEdgeWeightLoss(*, pt_thld: float = 0.0)#

Bases: torch.nn.Module, abc.ABC, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Add an option to falsify edges with low pt to edge classification losses.

forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, **kwargs) torch.Tensor#
abstract _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor#
class gnn_tracking.metrics.losses.ec.EdgeWeightBCELoss(*, pt_thld: float = 0.0)#

Bases: FalsifyLowPtEdgeWeightLoss

Add an option to falsify edges with low pt to edge classification losses.

static _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor#
class gnn_tracking.metrics.losses.ec.EdgeWeightFocalLoss(*, alpha=0.25, gamma=2.0, pos_weight=None, **kwargs)#

Bases: FalsifyLowPtEdgeWeightLoss

Loss function based on focal loss for edge classification. See binary_focal_loss for details.

_forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) torch.Tensor#
class gnn_tracking.metrics.losses.ec.HaughtyFocalLoss(*, alpha: float = 0.25, gamma: float = 2.0, pt_thld=0.0)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

_alpha#
_gamma#
_pt_thld#
forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) torch.Tensor#