gnn_tracking.models.graph_construction#

Models for embeddings used for graph construction.

Classes#

GraphConstructionFCNN

Fully connected neural network for graph construction.

GraphConstructionResFCNN

Fully connected neural network for graph construction.

GraphConstructionHeteroResFCNN

Fully connected neural network for graph construction.

GraphConstructionHeteroEncResFCNN

Fully connected neural network for graph construction.

GraphConstructionResIN

Graph construction refinement with a stack of interaction network with

MLGraphConstruction

Builds graph from embedding space. If you want to start from a checkpoint,

MLGraphConstructionFromChkpt

Alias for MLGraphConstruction.from_chkpt for use in yaml files

MLPCTransformer

Transforms a point cloud (PC) using a metric learning (ML) model.

MLPCTransformerFromMLChkpt

Transforms a point cloud (PC) using a metric learning (ML) model.

Functions#

knn_with_max_radius(→ torch.Tensor)

A version of kNN that excludes edges with a distance larger than a given radius.

Module Contents#

class gnn_tracking.models.graph_construction.GraphConstructionFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#

Bases: gnn_tracking.models.mlp.ResFCNN, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Contains additional normalization parameter for the latent space.

_latent_normalization#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#

Bases: gnn_tracking.models.mlp.ResMLP, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Contains additional normalization parameter for the latent space.

_latent_normalization#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionHeteroResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#

Bases: gnn_tracking.models.mlp.HeterogeneousResFCNN, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Fully heterogeneous (i.e., two separate MLPs for node and edge features). Contains additional normalization parameter for the latent space.

_latent_normalization#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionHeteroEncResFCNN(*, in_dim: int, hidden_dim_enc: int, hidden_dim: int, out_dim: int, depth_enc: int, depth: int, alpha: float = 0.6)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Fully connected neural network for graph construction. Heterogeneous encoding. Contains additional normalization parameter for the latent space.

encoder#
fcnn#
_latent_normalization#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.graph_construction.GraphConstructionResIN(*, node_indim: int, edge_indim: int, h_outdim: int = 8, hidden_dim: int = 40, alpha: float = 0.5, n_layers: int = 1, alpha_fcnn: float = 0.5)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Graph construction refinement with a stack of interaction network with residual connections between them. Refinement means that we assume that we’re starting of the latent space and edges from GraphConstructionFCNN.

Parameters:
  • node_indim – Input dimension of the node features

  • edge_indim – Input dimension of the edge features

  • h_outdim – Output dimension of the node features

  • hidden_dim – All other dimensions

  • alpha – Strength of residual connections for the INs

  • n_layers – Number of INs

  • alpha_fcnn – Strength of residual connections connecting back to the initial latent space from the FCNN (assuming that the first h_outdim features are its latent space output)

_node_encoder#
_edge_encoder#
_resin#
_decoder: Callable#
_latent_normalization#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
gnn_tracking.models.graph_construction.knn_with_max_radius(x: torch.Tensor, k: int, max_radius: float | None = None) torch.Tensor#

A version of kNN that excludes edges with a distance larger than a given radius.

Parameters:
  • x

  • k – Number of neighbors

  • max_radius

Returns:

edge index

class gnn_tracking.models.graph_construction.MLGraphConstruction(ml: torch.nn.Module | None = None, *, ec: torch.nn.Module | None = None, max_radius: float = 1, max_num_neighbors: int = 256, use_embedding_features=False, ratio_of_false=None, build_edge_features=True, ec_threshold=None, ml_freeze: bool = True, ec_freeze: bool = True, embedding_slice: tuple[int | None, int | None] = (None, None))#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Builds graph from embedding space. If you want to start from a checkpoint, use MLGraphConstruction.from_chkpt.

Parameters:
  • ml – Metric learning embedding module. If not specified, it is assumed that the node features from the data object are already the embedding coordinates. To use a subset of the embedding coordinates, use embedding_slice.

  • ec – Directly apply edge filter

  • max_radius – Maximum radius for kNN

  • max_num_neighbors – Number of neighbors for kNN

  • use_embedding_features – Add embedding space features to node features (only if ml is not None)

  • ratio_of_false – Subsample false edges (using truth information)

  • build_edge_features – Build edge features (as difference and sum of node features).

  • ec_threshold – EC threshold for edge filter/classification

  • embedding_slice – Used if ml is None. If not None, all node features are used. If a tuple, the first element is the start index and the second element is the end index.

_ml = None#
_ef = None#
_validate_config() None#

Ensure that config makes sense.

classmethod from_chkpt(ml_chkpt_path: str = '', ec_chkpt_path: str = '', *, ml_class_name: str = 'gnn_tracking.training.ml.MLModule', ec_class_name: str = 'gnn_tracking.training.ec.ECModule', ml_model_only: bool = True, ec_model_only: bool = True, **kwargs) MLGraphConstruction#

Build MLGraphConstruction from checkpointed models.

Parameters:
  • ml_chkpt_path – Path to metric learning checkpoint

  • ec_chkpt_path – Path to edge filter checkpoint. If empty, no EC will be used.

  • ml_class_name – Class name of metric learning lightning module (default should almost always be fine)

  • ec_class_name – Class name of edge filter lightning module (default should almost always be fine)

  • ml_model_only – Only the torch model is loaded (excluding preprocessing steps from the lightning module)

  • ec_model_only – See ml_model_only

  • **kwargs – Additional arguments passed to MLGraphConstruction

property out_dim: tuple[int, int]#

Returns node, edge, output dims

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#
class gnn_tracking.models.graph_construction.MLGraphConstructionFromChkpt#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Alias for MLGraphConstruction.from_chkpt for use in yaml files

class gnn_tracking.models.graph_construction.MLPCTransformer(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Transforms a point cloud (PC) using a metric learning (ML) model. This is just a thin wrapper around the ML module with specification of what to do with the resulting latent space. In contrast to MLGraphConstructionFromChkpt, this class does not build a graph from the latent space but returns a transformed point cloud. Use MLPCTransformer.from_ml_chkpt to build from a checkpointed ML model.

Warning

In the current implementation, the original Data object is modified by forward.

Parameters:
  • model – Metric learning model. Should return latent space with key “H”

  • original_features – Include original node features as node features (after the transformed ones)

_ml: torch.nn.Module = None#
classmethod from_ml_chkpt(chkpt_path: str, *, class_name: str = 'gnn_tracking.training.ml.MLModule', **kwargs)#

Build MLPCTransformer from checkpointed ML model.

Parameters:
  • chkpt_path – Path to checkpoint

  • class_name – Lightning module class name that was used for training. Probably default covers most cases.

  • **kwargs – Additional kwargs passed to MLPCTransformer constructor

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#
class gnn_tracking.models.graph_construction.MLPCTransformerFromMLChkpt(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)#

Bases: MLPCTransformer

Transforms a point cloud (PC) using a metric learning (ML) model. This is just a thin wrapper around the ML module with specification of what to do with the resulting latent space. In contrast to MLGraphConstructionFromChkpt, this class does not build a graph from the latent space but returns a transformed point cloud. Use MLPCTransformer.from_ml_chkpt to build from a checkpointed ML model.

Warning

In the current implementation, the original Data object is modified by forward.

Parameters:
  • model – Metric learning model. Should return latent space with key “H”

  • original_features – Include original node features as node features (after the transformed ones)