:py:mod:`gnn_tracking.metrics.losses`
=====================================

.. py:module:: gnn_tracking.metrics.losses

.. autoapi-nested-parse::

   This module contains loss functions for the GNN tracking model.



Submodules
----------
.. toctree::
   :titlesonly:
   :maxdepth: 1

   ec/index.rst
   metric_learning/index.rst
   oc/index.rst


Package Contents
----------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.losses.MultiLossFctReturn
   gnn_tracking.metrics.losses.MultiLossFct
   gnn_tracking.metrics.losses.LossClones




Attributes
~~~~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.losses.logger


.. py:data:: logger

   

.. py:class:: MultiLossFctReturn


   Return type for loss functions that return multiple losses.

   .. py:property:: loss
      :type: torch.Tensor


   .. py:property:: weighted_losses
      :type: dict[str, torch.Tensor]


   .. py:attribute:: loss_dct
      :type: dict[str, torch.Tensor]

      

   .. py:attribute:: weight_dct
      :type: dict[str, torch.Tensor] | dict[str, float]

      

   .. py:attribute:: extra_metrics
      :type: dict[str, Any]

      

   .. py:method:: __post_init__() -> None



.. py:class:: MultiLossFct


   Bases: :py:obj:`torch.nn.Module`

   Base class for loss functions that return multiple losses.

   .. py:method:: forward(*args: Any, **kwargs: Any) -> MultiLossFctReturn



.. py:class:: LossClones(loss: torch.nn.Module, prefixes=('w', 'y'))


   Bases: :py:obj:`torch.nn.Module`

   Wrapper for a loss function that evaluates it on multiple inputs.
   The forward method will look for all model outputs that start with `w_`
   (or another specified prefix) and evaluate the loss function for each of them,
   returning a dictionary of losses (with keys equal to the suffixes).

   Usage example 1:

   .. code-block:: python

       losses = {
           "potential": (PotentialLoss(), 1.),
           "edge": (LossClones(EdgeWeightBCELoss()), [1.0, 2.0, 3.0])
       }

   will evaluate three clones of the BCE loss function, one for each EC layer.

   Usage Example 2:


   .. code-block:: python

       losses = {
           "potential": (PotentialLoss(), 1.),
           "edge": (LossClones(EdgeWeightBCELoss()), {}))
       }

   this works with a variable number of layers. The weights are all 1.

   Under the hood, ``ECLossClones(EdgeWeightBCELoss())(model output)`` will output
   a dictionary of the individual losses, keyed by their suffixes (in a similar
   way to how `PotentialLoss` returns a dictionary of losses).

   :param loss: Loss function to be evaluated on multiple inputs
   :param prefixes: Prefixes of the model outputs that should be evaluated.
                    An underscore is assumed (set prefix to `w` for `w_0`, `w_1`, etc.)

   .. py:method:: forward(**kwargs) -> dict[str, torch.Tensor]



