gnn_tracking.models.resin
#
Deep stacked interaction networks with residual connections
Module Contents#
Classes#
Apply a list of layers in sequence with residual connections for the nodes. |
|
A residual network in which any two successive layers are connected by a |
|
A residual network built from blocks of two layers. Each of these blocks |
|
Residual network with skip connections to the top layer. |
|
Create a ResIN with identical layers of interaction networks. |
Functions#
|
Helper function for JIT compilation. Use convext_combination instead. |
|
Convex combination of delta and residue |
Attributes#
- gnn_tracking.models.resin._sqconvex_combination(*, delta: torch.Tensor, residue: torch.Tensor, alpha_residue: float) torch.Tensor #
Helper function for JIT compilation. Use convext_combination instead.
- gnn_tracking.models.resin.sqconvex_combination(*, delta: torch.Tensor, residue: torch.Tensor | None, alpha_residue: float) torch.Tensor #
Convex combination of delta and residue
- class gnn_tracking.models.resin.ResidualNetwork(layers: list[torch.nn.Module], *, alpha: float = 0.5, collect_hidden_edge_embeds: bool = False)#
Bases:
abc.ABC
,torch.nn.Module
Apply a list of layers in sequence with residual connections for the nodes. This is an abstract base class that does not contain code for the type of residual connections.
Use one of the subclasses below or use ResIN (a convenience wrapper around the subclasses for layers of identical INs).
- Parameters:
layers – List of layers
alpha – Strength of the node embedding residual connection
collect_hidden_edge_embeds – Whether to collect the edge embeddings from all layers (can be set to false to save memory)
- forward(x, edge_index, edge_attr) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] #
Forward pass
- Parameters:
x – Node features
edge_index –
edge_attr – Edge features
- Returns:
node embedding, edge_embedding, concatenated edge embeddings from all levels (including
edge_attr
, unless collect_hidden_edges is False)
- abstract _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None] #
- class gnn_tracking.models.resin.Skip1ResidualNetwork(*args, **kwargs)#
Bases:
ResidualNetwork
A residual network in which any two successive layers are connected by a residual connection.
- _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None] #
- class gnn_tracking.models.resin.Skip2ResidualNetwork(layers: list[torch.nn.Module], *, node_dim: int, edge_dim: int, add_bn: bool = False, **kwargs)#
Bases:
ResidualNetwork
A residual network built from blocks of two layers. Each of these blocks is connected to its predecessor by a residual connection.
- Parameters:
layers – List of layers
node_dim – Node feature dimension
edge_dim – Edge feature dimension
add_bn – Add batch norms
**kwargs – Arguments to ResidualNetwork
- _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None] #
- class gnn_tracking.models.resin.SkipTopResidualNetwork(layers: list[torch.nn.Module], connect_to=1, **kwargs)#
Bases:
ResidualNetwork
Residual network with skip connections to the top layer.
- Parameters:
layers – List of layers
connect_to – Layer to which to add the skip connection. 0 means to the input, 1 means to the output of the first layer, etc.
**kwargs – Arguments to ResidualNetwork
- _forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] #
- gnn_tracking.models.resin.RESIDUAL_NETWORKS_BY_NAME: dict[str, Any]#
- class gnn_tracking.models.resin.ResIN(*, node_dim: int, edge_dim: int, object_hidden_dim=40, relational_hidden_dim=40, alpha: float = 0.5, n_layers=1, residual_type: str = 'skip1', residual_kwargs: dict | None = None)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
Create a ResIN with identical layers of interaction networks.
- Parameters:
node_dim – Node feature dimension
edge_dim – Edge feature dimension
object_hidden_dim – Hidden dimension for the object model MLP
relational_hidden_dim – Hidden dimension for the relational model MLP
alpha – Strength of the node residual connection
n_layers – Total number of layers
residual_type – Type of residual network. Options are ‘skip1’, ‘skip2’, ‘skip_top’.
residual_kwargs – Additional arguments to the residual network (can depend on the residual type)
- property concat_edge_embeddings_length: int#
Length of the concatenated edge embeddings from all intermediate layers. Or in other words: self.forward()[3].shape[1]
- forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[torch.Tensor] | None] #