:py:mod:`gnn_tracking.metrics.losses.oc`
========================================

.. py:module:: gnn_tracking.metrics.losses.oc


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.losses.oc.CondensationLossRG
   gnn_tracking.metrics.losses.oc.CondensationLossTiger
   gnn_tracking.metrics.losses.oc.ObjectLoss



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.metrics.losses.oc._first_occurrences
   gnn_tracking.metrics.losses.oc._square_distances
   gnn_tracking.metrics.losses.oc._get_alphas_first_occurences
   gnn_tracking.metrics.losses.oc._get_vr_rg
   gnn_tracking.metrics.losses.oc._get_va
   gnn_tracking.metrics.losses.oc._radius_graph_condensation_loss
   gnn_tracking.metrics.losses.oc.condensation_loss_tiger



.. py:function:: _first_occurrences(x: torch.Tensor) -> torch.Tensor

   Return the first occurrence of each unique element in a 1D array


.. py:function:: _square_distances(edges: torch.Tensor, positions: torch.Tensor) -> torch.Tensor

   Returns squared distances between two sets of points


.. py:function:: _get_alphas_first_occurences(beta: torch.Tensor, particle_id: torch.Tensor, mask: torch.Tensor, q_min: float) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]


.. py:function:: _get_vr_rg(*, radius_edges: torch.Tensor, is_cp_j: torch.Tensor, particle_id: torch.Tensor, x: torch.Tensor, q_j: torch.Tensor, radius_threshold: float)


.. py:function:: _get_va(*, alphas_k: torch.Tensor, is_cp_j: torch.Tensor, particle_id: torch.Tensor, x: torch.Tensor, q_j: torch.Tensor, mask: torch.Tensor) -> torch.Tensor


.. py:function:: _radius_graph_condensation_loss(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, q_min: float, mask: torch.Tensor, radius_threshold: float, max_num_neighbors: int) -> tuple[dict[str, torch.Tensor], dict[str, Any]]

   Extracted function for condensation loss. See `PotentialLoss` for details

   :param mask: Mask for objects cast to nodes


.. py:class:: CondensationLossRG(*, lw_repulsive: float = 1.0, lw_noise: float = 0.0, lw_coward: float = 0.0, q_min: float = 0.01, pt_thld: float = 0.9, max_eta: float = 4.0, max_num_neighbors: int = 256, sample_pids: float = 1.0)


   Bases: :py:obj:`gnn_tracking.metrics.losses.MultiLossFct`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Implementation of condensation loss that uses radius graph instead
   calculating the whole n^2 distance matrix.

   :param lw_repulsive: Loss weight for repulsive part of potential loss
   :param lw_noise: Loss weight for noise loss
   :param lw_background: Loss weight for background loss
   :param q_min: See OC paper. Defaults to 0.01.
   :type q_min: float, optional
   :param pt_thld: pt thld for interesting particles. Defaults to 0.9.
   :type pt_thld: float, optional
   :param max_eta: eta thld for interesting particles. Defaults to 4.0.
   :type max_eta: float, optional
   :param max_num_neighbors: Maximum number of neighbors to consider
                             for radius graphs. Defaults to 256.
   :type max_num_neighbors: int, optional
   :param sample_pids: Further subsample particles to conserve
                       memory. Defaults to 1.0 (no sampling)
   :type sample_pids: float, optional

   .. py:method:: forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, eta: torch.Tensor, **kwargs) -> gnn_tracking.metrics.losses.MultiLossFctReturn



.. py:function:: condensation_loss_tiger(*, beta: torch.Tensor, x: torch.Tensor, object_id: torch.Tensor, object_mask: torch.Tensor, q_min: float, noise_threshold: int, max_n_rep: int) -> tuple[dict[str, torch.Tensor], dict[str, int | float]]

   Extracted function for torch compilation. See `condensation_loss_tiger` for
   docstring.

   :param object_mask: Mask for the particles that should be considered for the loss
                       this is broadcased to n_hits

   :returns: Dictionary of losses
             extra_dct: Dictionary of extra information
   :rtype: loss_dct


.. py:class:: CondensationLossTiger(*, lw_repulsive: float = 1.0, lw_noise: float = 0.0, lw_coward: float = 0.0, q_min: float = 0.01, pt_thld: float = 0.9, max_eta: float = 4.0, max_n_rep: int = 0, sample_pids: float = 1.0)


   Bases: :py:obj:`gnn_tracking.metrics.losses.MultiLossFct`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Implementation of condensation loss that directly calculates the n^2
   distance matrix.

   :param lw_repulsive: Loss weight for repulsive part of potential loss
   :param lw_noise: Loss weight for noise loss
   :param lw_background: Loss weight for background loss
   :param q_min: See OC paper. Defaults to 0.01.
   :type q_min: float, optional
   :param pt_thld: pt thld for interesting particles. Defaults to 0.9.
   :type pt_thld: float, optional
   :param max_eta: eta thld for interesting particles. Defaults to 4.0.
   :type max_eta: float, optional
   :param max_n_rep: Maximum number of repulsive edges to consider.
                     Defaults to 0 (all).
   :type max_n_rep: int, optional
   :param sample_pids: Further subsample particles to conserve
                       memory. Defaults to 1.0 (no sampling)
   :type sample_pids: float, optional

   .. py:method:: forward(*, beta: torch.Tensor, x: torch.Tensor, particle_id: torch.Tensor, reconstructable: torch.Tensor, pt: torch.Tensor, ec_hit_mask: torch.Tensor | None = None, eta: torch.Tensor, **kwargs) -> gnn_tracking.metrics.losses.MultiLossFctReturn



.. py:class:: ObjectLoss(mode='efficiency')


   Bases: :py:obj:`torch.nn.Module`

   Loss functions for predicted object properties.

   .. py:method:: _mse(*, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor
      :staticmethod:


   .. py:method:: object_loss(*, pred: torch.Tensor, beta: torch.Tensor, truth: torch.Tensor, particle_id: torch.Tensor) -> torch.Tensor


   .. py:method:: forward(*, beta: torch.Tensor, pred: torch.Tensor, particle_id: torch.Tensor, track_params: torch.Tensor, reconstructable: torch.Tensor, **kwargs) -> torch.Tensor



