:py:mod:`gnn_tracking.metrics.losses.ec`
========================================

.. py:module:: gnn_tracking.metrics.losses.ec


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.losses.ec.FalsifyLowPtEdgeWeightLoss
   gnn_tracking.metrics.losses.ec.EdgeWeightBCELoss
   gnn_tracking.metrics.losses.ec.EdgeWeightFocalLoss
   gnn_tracking.metrics.losses.ec.HaughtyFocalLoss



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.losses.ec._binary_focal_loss
   gnn_tracking.metrics.losses.ec.binary_focal_loss
   gnn_tracking.metrics.losses.ec.falsify_low_pt_edges



.. py:function:: _binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float, pos_weight: torch.Tensor) -> torch.Tensor

   Extracted function for JIT compilation.


.. py:function:: binary_focal_loss(*, inpt: torch.Tensor, target: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0, pos_weight: torch.Tensor | None = None) -> torch.Tensor

   Binary Focal Loss, following https://arxiv.org/abs/1708.02002.

   :param inpt:
   :param target:
   :param alpha: Weight for positive/negative results
   :param gamma: Focusing parameter
   :param pos_weight: Can be used to balance precision/recall


.. py:function:: falsify_low_pt_edges(*, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, pt_thld: float = 0.0) -> torch.Tensor

   Modify the ground truth to-be-predicted by the edge classification
   to consider edges that include a hit with pt < pt_thld as false.

   :param y: True classification
   :param edge_index:
   :param pt: Hit pt
   :param pt_thld: Apply pt threshold

   :returns: True classification with additional criteria applied


.. py:class:: FalsifyLowPtEdgeWeightLoss(*, pt_thld: float = 0.0)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`abc.ABC`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Add an option to falsify edges with low pt to edge classification losses.

   .. py:method:: forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor | None = None, pt: torch.Tensor | None = None, **kwargs) -> torch.Tensor


   .. py:method:: _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor
      :abstractmethod:



.. py:class:: EdgeWeightBCELoss(*, pt_thld: float = 0.0)


   Bases: :py:obj:`FalsifyLowPtEdgeWeightLoss`

   Add an option to falsify edges with low pt to edge classification losses.

   .. py:method:: _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor
      :staticmethod:



.. py:class:: EdgeWeightFocalLoss(*, alpha=0.25, gamma=2.0, pos_weight=None, **kwargs)


   Bases: :py:obj:`FalsifyLowPtEdgeWeightLoss`

   Loss function based on focal loss for edge classification.
   See `binary_focal_loss` for details.

   .. py:method:: _forward(*, w: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor



.. py:class:: HaughtyFocalLoss(*, alpha: float = 0.25, gamma: float = 2.0, pt_thld=0.0)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   .. py:method:: forward(*, w: torch.Tensor, y: torch.Tensor, edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) -> torch.Tensor



