gnn_tracking.models.interaction_network
#
Module Contents#
Classes#
Interaction Network, consisting of a relational model and an object model, |
- class gnn_tracking.models.interaction_network.InteractionNetwork(*, node_indim: int, edge_indim: int, node_outdim=3, edge_outdim=4, node_hidden_dim=40, edge_hidden_dim=40, aggr='add')#
Bases:
torch_geometric.nn.MessagePassing
,pytorch_lightning.core.mixins.HyperparametersMixin
Interaction Network, consisting of a relational model and an object model, both represented as MLPs.
- Parameters:
node_indim – Node feature dimension
edge_indim – Edge feature dimension
node_outdim – Output node feature dimension
edge_outdim – Output edge feature dimension
node_hidden_dim – Hidden dimension for the object model MLP
edge_hidden_dim – Hidden dimension for the relational model MLP
aggr – How to aggregate the messages
- forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor] #
Forward pass
- Parameters:
x – Input node features
edge_index –
edge_attr – Input edge features
- Returns:
Output node embedding, output edge embedding
- message(x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) torch.Tensor #
Calculate message of an edge
- Parameters:
x_i – Features of node 1 (node where the edge ends)
x_j – Features of node 2 (node where the edge starts)
edge_attr – Edge features
- Returns:
Message
- update(aggr_out: torch.Tensor, x: torch.Tensor) torch.Tensor #
Update for node embedding
- Parameters:
aggr_out – Aggregated messages of all edges
x – Node features for the node that receives all edges
- Returns:
Updated node features/embedding