gnn_tracking.models.graph_construction#

Models for embeddings used for graph construction.

Module Contents#

Classes#

GraphConstructionFCNN

Fully connected neural network for graph construction.

GraphConstructionHeteroResFCNN

Fully connected neural network for graph construction.

GraphConstructionHeteroEncResFCNN

Fully connected neural network for graph construction.

GraphConstructionResIN

Graph construction refinement with a stack of interaction network with

MLGraphConstruction

Builds graph from embedding space. If you want to start from a checkpoint,

MLGraphConstructionFromChkpt

Alias for MLGraphConstruction.from_chkpt for use in yaml files

MLPCTransformer

Transforms a point cloud (PC) using a metric learning (ML) model.

MLPCTransformerFromMLChkpt

Transforms a point cloud (PC) using a metric learning (ML) model.

Functions#

knn_with_max_radius(→ torch.Tensor)

A version of kNN that excludes edges with a distance larger than a given radius.

class gnn_tracking.models.graph_construction.GraphConstructionFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#

Bases: gnn_tracking.models.mlp.ResFCNN, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Contains additional normalization parameter for the latent space.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionHeteroResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#

Bases: gnn_tracking.models.mlp.HeterogeneousResFCNN, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Fully heterogeneous (i.e., two separate MLPs for node and edge features). Contains additional normalization parameter for the latent space.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionHeteroEncResFCNN(*, in_dim: int, hidden_dim_enc: int, hidden_dim: int, out_dim: int, depth_enc: int, depth: int, alpha: float = 0.6)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Heterogeneous encoding. Contains additional normalization parameter for the latent space.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionResIN(*, node_indim: int, edge_indim: int, h_outdim: int = 8, hidden_dim: int = 40, alpha: float = 0.5, n_layers: int = 1, alpha_fcnn: float = 0.5)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Graph construction refinement with a stack of interaction network with residual connections between them. Refinement means that we assume that we’re starting of the latent space and edges from GraphConstructionFCNN.

Parameters:
  • node_indim – Input dimension of the node features

  • edge_indim – Input dimension of the edge features

  • h_outdim – Output dimension of the node features

  • hidden_dim – All other dimensions

  • alpha – Strength of residual connections for the INs

  • n_layers – Number of INs

  • alpha_fcnn – Strength of residual connections connecting back to the initial latent space from the FCNN (assuming that the first h_outdim features are its latent space output)

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
gnn_tracking.models.graph_construction.knn_with_max_radius(x: torch.Tensor, k: int, max_radius: float | None = None) torch.Tensor#

A version of kNN that excludes edges with a distance larger than a given radius.

Parameters:
  • x

  • k – Number of neighbors

  • max_radius

Returns:

edge index

class gnn_tracking.models.graph_construction.MLGraphConstruction(ml: torch.nn.Module | None = None, *, ec: torch.nn.Module | None = None, max_radius: float = 1, max_num_neighbors: int = 256, use_embedding_features=False, ratio_of_false=None, build_edge_features=True, ec_threshold=None, ml_freeze: bool = True, ec_freeze: bool = True, embedding_slice: tuple[int | None, int | None] = (None, None))#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Builds graph from embedding space. If you want to start from a checkpoint, use MLGraphConstruction.from_chkpt.

Parameters:
  • ml – Metric learning embedding module. If not specified, it is assumed that the node features from the data object are already the embedding coordinates. To use a subset of the embedding coordinates, use embedding_slice.

  • ec – Directly apply edge filter

  • max_radius – Maximum radius for kNN

  • max_num_neighbors – Number of neighbors for kNN

  • use_embedding_features – Add embedding space features to node features (only if ml is not None)

  • ratio_of_false – Subsample false edges (using truth information)

  • build_edge_features – Build edge features (as difference and sum of node features).

  • ec_threshold – EC threshold for edge filter/classification

  • embedding_slice – Used if ml is None. If not None, all node features are used. If a tuple, the first element is the start index and the second element is the end index.

property out_dim: tuple[int, int]#

Returns node, edge, output dims

_validate_config() None#

Ensure that config makes sense.

classmethod from_chkpt(ml_chkpt_path: str = '', ec_chkpt_path: str = '', *, ml_class_name: str = 'gnn_tracking.training.ml.MLModule', ec_class_name: str = 'gnn_tracking.training.ec.ECModule', ml_model_only: bool = True, ec_model_only: bool = True, **kwargs) MLGraphConstruction#

Build MLGraphConstruction from checkpointed models.

Parameters:
  • ml_chkpt_path – Path to metric learning checkpoint

  • ec_chkpt_path – Path to edge filter checkpoint. If empty, no EC will be used.

  • ml_class_name – Class name of metric learning lightning module (default should almost always be fine)

  • ec_class_name – Class name of edge filter lightning module (default should almost always be fine)

  • ml_model_only – Only the torch model is loaded (excluding preprocessing steps from the lightning module)

  • ec_model_only – See ml_model_only

  • **kwargs – Additional arguments passed to MLGraphConstruction

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#
class gnn_tracking.models.graph_construction.MLGraphConstructionFromChkpt#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Alias for MLGraphConstruction.from_chkpt for use in yaml files

class gnn_tracking.models.graph_construction.MLPCTransformer(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Transforms a point cloud (PC) using a metric learning (ML) model. This is just a thin wrapper around the ML module with specification of what to do with the resulting latent space. In contrast to MLGraphConstructionFromChkpt, this class does not build a graph from the latent space but returns a transformed point cloud. Use MLPCTransformer.from_ml_chkpt to build from a checkpointed ML model.

Warning

In the current implementation, the original Data object is modified by forward.

Parameters:
  • model – Metric learning model. Should return latent space with key “H”

  • original_features – Include original node features as node features (after the transformed ones)

classmethod from_ml_chkpt(chkpt_path: str, *, class_name: str = 'gnn_tracking.training.ml.MLModule', **kwargs)#

Build MLPCTransformer from checkpointed ML model.

Parameters:
  • chkpt_path – Path to checkpoint

  • class_name – Lightning module class name that was used for training. Probably default covers most cases.

  • **kwargs – Additional kwargs passed to MLPCTransformer constructor

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#
class gnn_tracking.models.graph_construction.MLPCTransformerFromMLChkpt(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)#

Bases: MLPCTransformer

Transforms a point cloud (PC) using a metric learning (ML) model. This is just a thin wrapper around the ML module with specification of what to do with the resulting latent space. In contrast to MLGraphConstructionFromChkpt, this class does not build a graph from the latent space but returns a transformed point cloud. Use MLPCTransformer.from_ml_chkpt to build from a checkpointed ML model.

Warning

In the current implementation, the original Data object is modified by forward.

Parameters:
  • model – Metric learning model. Should return latent space with key “H”

  • original_features – Include original node features as node features (after the transformed ones)