:py:mod:`gnn_tracking.metrics.losses.metric_learning`
=====================================================

.. py:module:: gnn_tracking.metrics.losses.metric_learning


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.losses.metric_learning.GraphConstructionHingeEmbeddingLoss



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.losses.metric_learning._hinge_loss_components



.. py:function:: _hinge_loss_components(*, x: torch.Tensor, att_edges: torch.Tensor, rep_edges: torch.Tensor, r_emb_hinge: float, p_attr: float, p_rep: float, n_hits_oi: int) -> tuple[torch.Tensor, torch.Tensor]


.. py:class:: GraphConstructionHingeEmbeddingLoss(*, lw_repulsive: float = 1.0, r_emb: float = 1.0, max_num_neighbors: int = 256, pt_thld: float = 0.9, max_eta: float = 4.0, p_attr: float = 1.0, p_rep: float = 1.0)


   Bases: :py:obj:`gnn_tracking.metrics.losses.MultiLossFct`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Loss for graph construction using metric learning.

   :param lw_repulsive: Loss weight for repulsive part of potential loss
   :param r_emb: Radius for edge construction
   :param max_num_neighbors: Maximum number of neighbors in radius graph building.
                             See https://github.com/rusty1s/pytorch_cluster#radius-graph
   :param pt_thld: pt threshold for particles of interest
   :param max_eta: maximum eta for particles of interest
   :param p_attr: Power for the attraction term (default 1: linear loss)
   :param p_rep: Power for the repulsion term (default 1: linear loss)

   .. py:method:: _get_edges(*, x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, mask: torch.Tensor, particle_id: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

      Returns edge index for graph


   .. py:method:: forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, eta: torch.Tensor, reconstructable: torch.Tensor, **kwargs) -> gnn_tracking.metrics.losses.MultiLossFctReturn



