:py:mod:`gnn_tracking.models.edge_classifier`
=============================================

.. py:module:: gnn_tracking.models.edge_classifier

.. autoapi-nested-parse::

   Models for edge classification



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.models.edge_classifier.ECForGraphTCN
   gnn_tracking.models.edge_classifier.PerfectEdgeClassification
   gnn_tracking.models.edge_classifier.ECFromChkpt




.. py:class:: ECForGraphTCN(*, node_indim: int, edge_indim: int, interaction_node_dim: int = 5, interaction_edge_dim: int = 4, hidden_dim: int | float | None = None, L_ec: int = 3, alpha: float = 0.5, residual_type='skip1', use_intermediate_edge_embeddings: bool = True, use_node_embedding: bool = True, residual_kwargs: dict | None = None)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Edge classification step to be used for Graph Track Condensor network
   (Graph TCN)

   :param node_indim: Node feature dim
   :param edge_indim: Edge feature dim
   :param interaction_node_dim: Node dimension for interaction networks.
                                Defaults to 5 for backward compatibility, but this is probably
                                not reasonable.
   :param interaction_edge_dim: Edge dimension of interaction networks
                                Defaults to 4 for backward compatibility, but this is probably
                                not reasonable.
   :param hidden_dim: width of hidden layers in all perceptrons (edge and node
                      encoders, hidden dims for MLPs in object and relation networks). If
                      None: choose as maximum of input/output dims for each MLP separately
   :param L_ec: message passing depth for edge classifier
   :param alpha: strength of residual connection for EC
   :param residual_type: type of residual connection for EC
   :param use_intermediate_edge_embeddings: If true, don't only feed the final
                                            encoding of the stacked interaction networks to the final MLP, but all
                                            intermediate encodings
   :param use_node_embedding: If true, feed node attributes to the final MLP for
                              EC
   :param residual_kwargs: Keyword arguments passed to `ResIN`

   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor]

      Returns dictionary of the following:

      * ``W``: Edge weights
      * ``node_embedding``: Last node embedding (result of last interaction network)
      * ``edge_embedding``: Last edge embedding (result of last interaction network)



.. py:class:: PerfectEdgeClassification(tpr=1.0, tnr=1.0, false_below_pt=0.0)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   An edge classifier that is perfect because it uses the truth information.
   If TPR or TNR is not 1.0, noise is added to the truth information.

   This can be used to evaluate the maximal possible performance of a model
   that relies on edge classification as a first step (e.g., the object
   condensation approach).

   :param tpr: True positive rate
   :param tnr: False positive rate
   :param false_below_pt: If not 0.0, all true edges between hits corresponding to
                          particles with a pt lower than this threshold are set to false.
                          This is not counted towards the TPR/TNR but applied afterwards.

   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor]



.. py:class:: ECFromChkpt


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   For easy loading of an pretrained EC from a lightning yaml config.

   :param chkpt_path: Path to the checkpoint file.
   :param class_name: Name of the lightning module that was used to train the EC.
                      Default should work for most cases.
   :param freeze: If True, the model is frozen (i.e., its parameters are not trained).


