gnn_tracking.models.resin#

Deep stacked interaction networks with residual connections

Attributes#

RESIDUAL_NETWORKS_BY_NAME

Classes#

InteractionNetwork

Interaction Network, consisting of a relational model and an object model,

ResidualNetwork

Apply a list of layers in sequence with residual connections for the nodes.

Skip1ResidualNetwork

A residual network in which any two successive layers are connected by a

Skip2ResidualNetwork

A residual network built from blocks of two layers. Each of these blocks

SkipTopResidualNetwork

Residual network with skip connections to the top layer.

ResIN

Create a ResIN with identical layers of interaction networks.

Functions#

_sqconvex_combination(→ torch.Tensor)

Helper function for JIT compilation. Use convext_combination instead.

sqconvex_combination(→ torch.Tensor)

Convex combination of delta and residue

Module Contents#

class gnn_tracking.models.resin.InteractionNetwork(*, node_indim: int, edge_indim: int, node_outdim=3, edge_outdim=4, node_hidden_dim=40, edge_hidden_dim=40, aggr='add')#

Bases: torch_geometric.nn.MessagePassing, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Interaction Network, consisting of a relational model and an object model, both represented as MLPs.

Parameters:
  • node_indim – Node feature dimension

  • edge_indim – Edge feature dimension

  • node_outdim – Output node feature dimension

  • edge_outdim – Output edge feature dimension

  • node_hidden_dim – Hidden dimension for the object model MLP

  • edge_hidden_dim – Hidden dimension for the relational model MLP

  • aggr – How to aggregate the messages

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#

Forward pass

Parameters:
  • x – Input node features

  • edge_index

  • edge_attr – Input edge features

Returns:

Output node embedding, output edge embedding

message(x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) torch.Tensor#

Calculate message of an edge

Parameters:
  • x_i – Features of node 1 (node where the edge ends)

  • x_j – Features of node 2 (node where the edge starts)

  • edge_attr – Edge features

Returns:

Message

update(aggr_out: torch.Tensor, x: torch.Tensor) torch.Tensor#

Update for node embedding

Parameters:
  • aggr_out – Aggregated messages of all edges

  • x – Node features for the node that receives all edges

Returns:

Updated node features/embedding

gnn_tracking.models.resin._sqconvex_combination(*, delta: torch.Tensor, residue: torch.Tensor, alpha_residue: float) torch.Tensor#

Helper function for JIT compilation. Use convext_combination instead.

gnn_tracking.models.resin.sqconvex_combination(*, delta: torch.Tensor, residue: torch.Tensor | None, alpha_residue: float) torch.Tensor#

Convex combination of delta and residue

class gnn_tracking.models.resin.ResidualNetwork(layers: list[torch.nn.Module], *, alpha: float = 0.5, collect_hidden_edge_embeds: bool = False)#

Bases: abc.ABC, torch.nn.Module

Apply a list of layers in sequence with residual connections for the nodes. This is an abstract base class that does not contain code for the type of residual connections.

Use one of the subclasses below or use ResIN (a convenience wrapper around the subclasses for layers of identical INs).

Parameters:
  • layers – List of layers

  • alpha – Strength of the node embedding residual connection

  • collect_hidden_edge_embeds – Whether to collect the edge embeddings from all layers (can be set to false to save memory)

forward(x, edge_index, edge_attr) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]#

Forward pass

Parameters:
  • x – Node features

  • edge_index

  • edge_attr – Edge features

Returns:

node embedding, edge_embedding, concatenated edge embeddings from all levels (including edge_attr, unless collect_hidden_edges is False)

abstract _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]#
class gnn_tracking.models.resin.Skip1ResidualNetwork(*args, **kwargs)#

Bases: ResidualNetwork

A residual network in which any two successive layers are connected by a residual connection.

_forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]#
class gnn_tracking.models.resin.Skip2ResidualNetwork(layers: list[torch.nn.Module], *, node_dim: int, edge_dim: int, add_bn: bool = False, **kwargs)#

Bases: ResidualNetwork

A residual network built from blocks of two layers. Each of these blocks is connected to its predecessor by a residual connection.

Parameters:
  • layers – List of layers

  • node_dim – Node feature dimension

  • edge_dim – Edge feature dimension

  • add_bn – Add batch norms

  • **kwargs – Arguments to ResidualNetwork

_forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]#
class gnn_tracking.models.resin.SkipTopResidualNetwork(layers: list[torch.nn.Module], connect_to=1, **kwargs)#

Bases: ResidualNetwork

Residual network with skip connections to the top layer.

Parameters:
  • layers – List of layers

  • connect_to – Layer to which to add the skip connection. 0 means to the input, 1 means to the output of the first layer, etc.

  • **kwargs – Arguments to ResidualNetwork

_forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]#
gnn_tracking.models.resin.RESIDUAL_NETWORKS_BY_NAME: dict[str, Any]#
class gnn_tracking.models.resin.ResIN(*, node_dim: int, edge_dim: int, object_hidden_dim=40, relational_hidden_dim=40, alpha: float = 0.5, n_layers=1, residual_type: str = 'skip1', residual_kwargs: dict | None = None)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Create a ResIN with identical layers of interaction networks.

Parameters:
  • node_dim – Node feature dimension

  • edge_dim – Edge feature dimension

  • object_hidden_dim – Hidden dimension for the object model MLP

  • relational_hidden_dim – Hidden dimension for the relational model MLP

  • alpha – Strength of the node residual connection

  • n_layers – Total number of layers

  • residual_type – Type of residual network. Options are ‘skip1’, ‘skip2’, ‘skip_top’.

  • residual_kwargs – Additional arguments to the residual network (can depend on the residual type)

property concat_edge_embeddings_length: int#

Length of the concatenated edge embeddings from all intermediate layers. Or in other words: self.forward()[3].shape[1]

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[torch.Tensor] | None]#