:py:mod:`gnn_tracking.models.interaction_network`
=================================================

.. py:module:: gnn_tracking.models.interaction_network


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.models.interaction_network.InteractionNetwork




.. py:class:: InteractionNetwork(*, node_indim: int, edge_indim: int, node_outdim=3, edge_outdim=4, node_hidden_dim=40, edge_hidden_dim=40, aggr='add')


   Bases: :py:obj:`torch_geometric.nn.MessagePassing`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Interaction Network, consisting of a relational model and an object model,
   both represented as MLPs.

   :param node_indim: Node feature dimension
   :param edge_indim: Edge feature dimension
   :param node_outdim: Output node feature dimension
   :param edge_outdim: Output edge feature dimension
   :param node_hidden_dim: Hidden dimension for the object model MLP
   :param edge_hidden_dim: Hidden dimension for the relational model MLP
   :param aggr: How to aggregate the messages

   .. py:method:: forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

      Forward pass

      :param x: Input node features
      :param edge_index:
      :param edge_attr: Input edge features

      :returns: Output node embedding, output edge embedding


   .. py:method:: message(x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor

      Calculate message of an edge

      :param x_i: Features of node 1 (node where the edge ends)
      :param x_j: Features of node 2 (node where the edge starts)
      :param edge_attr: Edge features

      :returns: Message


   .. py:method:: update(aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor

      Update for node embedding

      :param aggr_out: Aggregated messages of all edges
      :param x: Node features for the node that receives all edges

      :returns: Updated node features/embedding



