gnn_tracking.models.interaction_network#

Classes#

MLP

Multi Layer Perceptron, using ReLu as activation function.

InteractionNetwork

Interaction Network, consisting of a relational model and an object model,

Functions#

assert_feat_dim(→ None)

Module Contents#

class gnn_tracking.models.interaction_network.MLP(input_size: int, output_size: int, hidden_dim: int | None, L=3, *, bias=True, include_last_activation=False)#

Bases: torch.nn.Module

Multi Layer Perceptron, using ReLu as activation function.

Parameters:
  • input_size – Input feature dimension

  • output_size – Output feature dimension

  • hidden_dim – Feature dimension of the hidden layers. If None: Choose maximum of input/output size

  • L – Total number of layers (1 initial layer, L-2 hidden layers, 1 output layer)

  • bias – Include bias in linear layer?

  • include_last_activation – Include activation function for the last layer?

reset_parameters()#
forward(x)#
gnn_tracking.models.interaction_network.assert_feat_dim(feat_vec: torch.Tensor, dim: int) None#
class gnn_tracking.models.interaction_network.InteractionNetwork(*, node_indim: int, edge_indim: int, node_outdim=3, edge_outdim=4, node_hidden_dim=40, edge_hidden_dim=40, aggr='add')#

Bases: torch_geometric.nn.MessagePassing, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Interaction Network, consisting of a relational model and an object model, both represented as MLPs.

Parameters:
  • node_indim – Node feature dimension

  • edge_indim – Edge feature dimension

  • node_outdim – Output node feature dimension

  • edge_outdim – Output edge feature dimension

  • node_hidden_dim – Hidden dimension for the object model MLP

  • edge_hidden_dim – Hidden dimension for the relational model MLP

  • aggr – How to aggregate the messages

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#

Forward pass

Parameters:
  • x – Input node features

  • edge_index

  • edge_attr – Input edge features

Returns:

Output node embedding, output edge embedding

message(x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) torch.Tensor#

Calculate message of an edge

Parameters:
  • x_i – Features of node 1 (node where the edge ends)

  • x_j – Features of node 2 (node where the edge starts)

  • edge_attr – Edge features

Returns:

Message

update(aggr_out: torch.Tensor, x: torch.Tensor) torch.Tensor#

Update for node embedding

Parameters:
  • aggr_out – Aggregated messages of all edges

  • x – Node features for the node that receives all edges

Returns:

Updated node features/embedding