gnn_tracking.models.interaction_network#

Module Contents#

Classes#

InteractionNetwork

Interaction Network, consisting of a relational model and an object model,

class gnn_tracking.models.interaction_network.InteractionNetwork(*, node_indim: int, edge_indim: int, node_outdim=3, edge_outdim=4, node_hidden_dim=40, edge_hidden_dim=40, aggr='add')#

Bases: torch_geometric.nn.MessagePassing, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Interaction Network, consisting of a relational model and an object model, both represented as MLPs.

Parameters:
  • node_indim – Node feature dimension

  • edge_indim – Edge feature dimension

  • node_outdim – Output node feature dimension

  • edge_outdim – Output edge feature dimension

  • node_hidden_dim – Hidden dimension for the object model MLP

  • edge_hidden_dim – Hidden dimension for the relational model MLP

  • aggr – How to aggregate the messages

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#

Forward pass

Parameters:
  • x – Input node features

  • edge_index

  • edge_attr – Input edge features

Returns:

Output node embedding, output edge embedding

message(x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) torch.Tensor#

Calculate message of an edge

Parameters:
  • x_i – Features of node 1 (node where the edge ends)

  • x_j – Features of node 2 (node where the edge starts)

  • edge_attr – Edge features

Returns:

Message

update(aggr_out: torch.Tensor, x: torch.Tensor) torch.Tensor#

Update for node embedding

Parameters:
  • aggr_out – Aggregated messages of all edges

  • x – Node features for the node that receives all edges

Returns:

Updated node features/embedding