gnn_tracking.models.graph_construction#

Models for embeddings used for graph construction.

Attributes#

logger

Classes#

MLP

Multi Layer Perceptron, using ReLu as activation function.

HeterogeneousResFCNN

Separate FCNNs for pixel and strip data, with residual connections.

ResFCNN

Fully connected NN with residual connections.

ResIN

Create a ResIN with identical layers of interaction networks.

GraphConstructionFCNN

Fully connected neural network for graph construction.

GraphConstructionHeteroResFCNN

Fully connected neural network for graph construction.

GraphConstructionHeteroEncResFCNN

Fully connected neural network for graph construction.

GraphConstructionResIN

Graph construction refinement with a stack of interaction network with

MLGraphConstruction

Builds graph from embedding space. If you want to start from a checkpoint,

MLGraphConstructionFromChkpt

Alias for MLGraphConstruction.from_chkpt for use in yaml files

MLPCTransformer

Transforms a point cloud (PC) using a metric learning (ML) model.

MLPCTransformerFromMLChkpt

Transforms a point cloud (PC) using a metric learning (ML) model.

Functions#

assert_feat_dim(→ None)

get_model(→ torch.nn.Module | None)

Get torch model (specified by class_path, a string) and load a checkpoint.

obj_from_or_to_hparams(→ Any)

Used to support initializing python objects from hyperparameters:

freeze_if(→ torch.nn.Module | None)

Freezes all parameters of a model if do_freeze is True. If model is None,

knn_with_max_radius(→ torch.Tensor)

A version of kNN that excludes edges with a distance larger than a given radius.

Module Contents#

class gnn_tracking.models.graph_construction.MLP(input_size: int, output_size: int, hidden_dim: int | None, L=3, *, bias=True, include_last_activation=False)#

Bases: torch.nn.Module

Multi Layer Perceptron, using ReLu as activation function.

Parameters:
  • input_size – Input feature dimension

  • output_size – Output feature dimension

  • hidden_dim – Feature dimension of the hidden layers. If None: Choose maximum of input/output size

  • L – Total number of layers (1 initial layer, L-2 hidden layers, 1 output layer)

  • bias – Include bias in linear layer?

  • include_last_activation – Include activation function for the last layer?

reset_parameters()#
forward(x)#
class gnn_tracking.models.graph_construction.HeterogeneousResFCNN(*, in_dim: int, out_dim: int, hidden_dim: int, depth: int, alpha: float = 0.6, bias: bool = True)#

Bases: torch.nn.Module

Separate FCNNs for pixel and strip data, with residual connections. For parameters, see ResFCNN.

forward(x: torch.Tensor, layer: torch.Tensor) torch.Tensor#
class gnn_tracking.models.graph_construction.ResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6, bias: bool = True)#

Bases: torch.nn.Module

Fully connected NN with residual connections.

Parameters:
  • in_dim – Input dimension

  • hidden_dim – Hidden dimension

  • out_dim – Output dimension = embedding space

  • depth – 1 input encoder layer, depth-1 hidden layers, 1 output encoder layer

  • alpha – strength of the residual connection

static _reset_layer_parameters(layer, var: float)#
forward(x: torch.Tensor, **ignore) torch.Tensor#
class gnn_tracking.models.graph_construction.ResIN(*, node_dim: int, edge_dim: int, object_hidden_dim=40, relational_hidden_dim=40, alpha: float = 0.5, n_layers=1, residual_type: str = 'skip1', residual_kwargs: dict | None = None)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Create a ResIN with identical layers of interaction networks.

Parameters:
  • node_dim – Node feature dimension

  • edge_dim – Edge feature dimension

  • object_hidden_dim – Hidden dimension for the object model MLP

  • relational_hidden_dim – Hidden dimension for the relational model MLP

  • alpha – Strength of the node residual connection

  • n_layers – Total number of layers

  • residual_type – Type of residual network. Options are ‘skip1’, ‘skip2’, ‘skip_top’.

  • residual_kwargs – Additional arguments to the residual network (can depend on the residual type)

property concat_edge_embeddings_length: int#

Length of the concatenated edge embeddings from all intermediate layers. Or in other words: self.forward()[3].shape[1]

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[torch.Tensor] | None]#
gnn_tracking.models.graph_construction.assert_feat_dim(feat_vec: torch.Tensor, dim: int) None#
gnn_tracking.models.graph_construction.get_model(class_path: str, chkpt_path: str = '', freeze: bool = False, whole_module: bool = False, device: None | str = None) torch.nn.Module | None#

Get torch model (specified by class_path, a string) and load a checkpoint. Uses get_lightning_module to get the model.

Parameters:
  • class_path – The path to the lightning module class

  • chkpt_path – The path to the checkpoint. If no checkpoint is specified, we return None.

  • freeze – Whether to freeze the model

  • whole_module – Whether to return the whole lightning module or just the model

  • device

gnn_tracking.models.graph_construction.obj_from_or_to_hparams(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: Any) Any#

Used to support initializing python objects from hyperparameters: If obj is a python object other than a dictionary, its hyperparameters are saved (its class path and init args) to self.hparams[key]. If obj is instead a dictionary, its assumed that we have to restore an object based on this information.

gnn_tracking.models.graph_construction.logger#
gnn_tracking.models.graph_construction.freeze_if(model: torch.nn.Module | None, do_freeze: bool = False) torch.nn.Module | None#

Freezes all parameters of a model if do_freeze is True. If model is None, None is returned. This is a trivial convenience function to avoid if-else statements.

Returns:

The model with all parameters frozen (but model is also modified in-place).

class gnn_tracking.models.graph_construction.GraphConstructionFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#

Bases: gnn_tracking.models.mlp.ResFCNN, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Contains additional normalization parameter for the latent space.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionHeteroResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#

Bases: gnn_tracking.models.mlp.HeterogeneousResFCNN, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Fully heterogeneous (i.e., two separate MLPs for node and edge features). Contains additional normalization parameter for the latent space.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionHeteroEncResFCNN(*, in_dim: int, hidden_dim_enc: int, hidden_dim: int, out_dim: int, depth_enc: int, depth: int, alpha: float = 0.6)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Heterogeneous encoding. Contains additional normalization parameter for the latent space.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionResIN(*, node_indim: int, edge_indim: int, h_outdim: int = 8, hidden_dim: int = 40, alpha: float = 0.5, n_layers: int = 1, alpha_fcnn: float = 0.5)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Graph construction refinement with a stack of interaction network with residual connections between them. Refinement means that we assume that we’re starting of the latent space and edges from GraphConstructionFCNN.

Parameters:
  • node_indim – Input dimension of the node features

  • edge_indim – Input dimension of the edge features

  • h_outdim – Output dimension of the node features

  • hidden_dim – All other dimensions

  • alpha – Strength of residual connections for the INs

  • n_layers – Number of INs

  • alpha_fcnn – Strength of residual connections connecting back to the initial latent space from the FCNN (assuming that the first h_outdim features are its latent space output)

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
gnn_tracking.models.graph_construction.knn_with_max_radius(x: torch.Tensor, k: int, max_radius: float | None = None) torch.Tensor#

A version of kNN that excludes edges with a distance larger than a given radius.

Parameters:
  • x

  • k – Number of neighbors

  • max_radius

Returns:

edge index

class gnn_tracking.models.graph_construction.MLGraphConstruction(ml: torch.nn.Module | None = None, *, ec: torch.nn.Module | None = None, max_radius: float = 1, max_num_neighbors: int = 256, use_embedding_features=False, ratio_of_false=None, build_edge_features=True, ec_threshold=None, ml_freeze: bool = True, ec_freeze: bool = True, embedding_slice: tuple[int | None, int | None] = (None, None))#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Builds graph from embedding space. If you want to start from a checkpoint, use MLGraphConstruction.from_chkpt.

Parameters:
  • ml – Metric learning embedding module. If not specified, it is assumed that the node features from the data object are already the embedding coordinates. To use a subset of the embedding coordinates, use embedding_slice.

  • ec – Directly apply edge filter

  • max_radius – Maximum radius for kNN

  • max_num_neighbors – Number of neighbors for kNN

  • use_embedding_features – Add embedding space features to node features (only if ml is not None)

  • ratio_of_false – Subsample false edges (using truth information)

  • build_edge_features – Build edge features (as difference and sum of node features).

  • ec_threshold – EC threshold for edge filter/classification

  • embedding_slice – Used if ml is None. If not None, all node features are used. If a tuple, the first element is the start index and the second element is the end index.

_validate_config() None#

Ensure that config makes sense.

classmethod from_chkpt(ml_chkpt_path: str = '', ec_chkpt_path: str = '', *, ml_class_name: str = 'gnn_tracking.training.ml.MLModule', ec_class_name: str = 'gnn_tracking.training.ec.ECModule', ml_model_only: bool = True, ec_model_only: bool = True, **kwargs) MLGraphConstruction#

Build MLGraphConstruction from checkpointed models.

Parameters:
  • ml_chkpt_path – Path to metric learning checkpoint

  • ec_chkpt_path – Path to edge filter checkpoint. If empty, no EC will be used.

  • ml_class_name – Class name of metric learning lightning module (default should almost always be fine)

  • ec_class_name – Class name of edge filter lightning module (default should almost always be fine)

  • ml_model_only – Only the torch model is loaded (excluding preprocessing steps from the lightning module)

  • ec_model_only – See ml_model_only

  • **kwargs – Additional arguments passed to MLGraphConstruction

property out_dim: tuple[int, int]#

Returns node, edge, output dims

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#
class gnn_tracking.models.graph_construction.MLGraphConstructionFromChkpt#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Alias for MLGraphConstruction.from_chkpt for use in yaml files

class gnn_tracking.models.graph_construction.MLPCTransformer(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Transforms a point cloud (PC) using a metric learning (ML) model. This is just a thin wrapper around the ML module with specification of what to do with the resulting latent space. In contrast to MLGraphConstructionFromChkpt, this class does not build a graph from the latent space but returns a transformed point cloud. Use MLPCTransformer.from_ml_chkpt to build from a checkpointed ML model.

Warning

In the current implementation, the original Data object is modified by forward.

Parameters:
  • model – Metric learning model. Should return latent space with key “H”

  • original_features – Include original node features as node features (after the transformed ones)

classmethod from_ml_chkpt(chkpt_path: str, *, class_name: str = 'gnn_tracking.training.ml.MLModule', **kwargs)#

Build MLPCTransformer from checkpointed ML model.

Parameters:
  • chkpt_path – Path to checkpoint

  • class_name – Lightning module class name that was used for training. Probably default covers most cases.

  • **kwargs – Additional kwargs passed to MLPCTransformer constructor

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#
class gnn_tracking.models.graph_construction.MLPCTransformerFromMLChkpt(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)#

Bases: MLPCTransformer

Transforms a point cloud (PC) using a metric learning (ML) model. This is just a thin wrapper around the ML module with specification of what to do with the resulting latent space. In contrast to MLGraphConstructionFromChkpt, this class does not build a graph from the latent space but returns a transformed point cloud. Use MLPCTransformer.from_ml_chkpt to build from a checkpointed ML model.

Warning

In the current implementation, the original Data object is modified by forward.

Parameters:
  • model – Metric learning model. Should return latent space with key “H”

  • original_features – Include original node features as node features (after the transformed ones)