:py:mod:`gnn_tracking.models.resin`
===================================

.. py:module:: gnn_tracking.models.resin

.. autoapi-nested-parse::

   Deep stacked interaction networks with residual connections



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.models.resin.ResidualNetwork
   gnn_tracking.models.resin.Skip1ResidualNetwork
   gnn_tracking.models.resin.Skip2ResidualNetwork
   gnn_tracking.models.resin.SkipTopResidualNetwork
   gnn_tracking.models.resin.ResIN



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.models.resin._sqconvex_combination
   gnn_tracking.models.resin.sqconvex_combination



Attributes
~~~~~~~~~~

.. autoapisummary::

   gnn_tracking.models.resin.RESIDUAL_NETWORKS_BY_NAME


.. py:function:: _sqconvex_combination(*, delta: torch.Tensor, residue: torch.Tensor, alpha_residue: float) -> torch.Tensor

   Helper function for JIT compilation. Use `convext_combination` instead.


.. py:function:: sqconvex_combination(*, delta: torch.Tensor, residue: torch.Tensor | None, alpha_residue: float) -> torch.Tensor

   Convex combination of delta and residue


.. py:class:: ResidualNetwork(layers: list[torch.nn.Module], *, alpha: float = 0.5, collect_hidden_edge_embeds: bool = False)


   Bases: :py:obj:`abc.ABC`, :py:obj:`torch.nn.Module`

   Apply a list of layers in sequence with residual connections for the nodes.
   This is an abstract base class that does not contain code for the type of
   residual connections.

   Use one of the subclasses below or use `ResIN` (a convenience wrapper around
   the subclasses for layers of identical INs).

   :param layers: List of layers
   :param alpha: Strength of the node embedding residual connection
   :param collect_hidden_edge_embeds: Whether to collect the edge embeddings from all
                                      layers (can be set to false to save memory)

   .. py:method:: forward(x, edge_index, edge_attr) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]

      Forward pass

      :param x: Node features
      :param edge_index:
      :param edge_attr: Edge features

      :returns: node embedding, edge_embedding, concatenated edge embeddings from all
                levels (including ``edge_attr``, unless collect_hidden_edges is False)


   .. py:method:: _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]
      :abstractmethod:



.. py:class:: Skip1ResidualNetwork(*args, **kwargs)


   Bases: :py:obj:`ResidualNetwork`

   A residual network in which any two successive layers are connected by a
   residual connection.

   .. py:method:: _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]



.. py:class:: Skip2ResidualNetwork(layers: list[torch.nn.Module], *, node_dim: int, edge_dim: int, add_bn: bool = False, **kwargs)


   Bases: :py:obj:`ResidualNetwork`

   A residual network built from blocks of two layers. Each of these blocks
   is connected to its predecessor by a residual connection.

   :param layers: List of layers
   :param node_dim: Node feature dimension
   :param edge_dim: Edge feature dimension
   :param add_bn: Add batch norms
   :param \*\*kwargs: Arguments to `ResidualNetwork`

   .. py:method:: _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]



.. py:class:: SkipTopResidualNetwork(layers: list[torch.nn.Module], connect_to=1, **kwargs)


   Bases: :py:obj:`ResidualNetwork`

   Residual network with skip connections to the top layer.

   :param layers: List of layers
   :param connect_to: Layer to which to add the skip connection. 0 means to the
                      input, 1 means to the output of the first layer, etc.
   :param \*\*kwargs: Arguments to `ResidualNetwork`

   .. py:method:: _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]



.. py:data:: RESIDUAL_NETWORKS_BY_NAME
   :type: dict[str, Any]

   

.. py:class:: ResIN(*, node_dim: int, edge_dim: int, object_hidden_dim=40, relational_hidden_dim=40, alpha: float = 0.5, n_layers=1, residual_type: str = 'skip1', residual_kwargs: dict | None = None)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Create a ResIN with identical layers of interaction networks.

   :param node_dim: Node feature dimension
   :param edge_dim: Edge feature dimension
   :param object_hidden_dim: Hidden dimension for the object model MLP
   :param relational_hidden_dim: Hidden dimension for the relational model MLP
   :param alpha: Strength of the node residual connection
   :param n_layers: Total number of layers
   :param residual_type: Type of residual network. Options are 'skip1', 'skip2',
                         'skip_top'.
   :param residual_kwargs: Additional arguments to the residual network (can depend on
                           the residual type)

   .. py:property:: concat_edge_embeddings_length
      :type: int

      Length of the concatenated edge embeddings from all intermediate layers.
      Or in other words: `self.forward()[3].shape[1]`

   .. py:method:: forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[torch.Tensor] | None]



