gnn_tracking.models.graph_construction#
Models for embeddings used for graph construction.
Classes#
Fully connected neural network for graph construction. |
|
Fully connected neural network for graph construction. |
|
Fully connected neural network for graph construction. |
|
Graph construction refinement with a stack of interaction network with |
|
Builds graph from embedding space. If you want to start from a checkpoint, |
|
Alias for MLGraphConstruction.from_chkpt for use in yaml files |
|
Transforms a point cloud (PC) using a metric learning (ML) model. |
|
Transforms a point cloud (PC) using a metric learning (ML) model. |
Functions#
|
A version of kNN that excludes edges with a distance larger than a given radius. |
Module Contents#
- class gnn_tracking.models.graph_construction.GraphConstructionFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#
Bases:
gnn_tracking.models.mlp.ResFCNN
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Fully connected neural network for graph construction. Contains additional normalization parameter for the latent space.
- _latent_normalization#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor] #
- class gnn_tracking.models.graph_construction.GraphConstructionHeteroResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)#
Bases:
gnn_tracking.models.mlp.HeterogeneousResFCNN
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Fully connected neural network for graph construction. Fully heterogeneous (i.e., two separate MLPs for node and edge features). Contains additional normalization parameter for the latent space.
- _latent_normalization#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor] #
- class gnn_tracking.models.graph_construction.GraphConstructionHeteroEncResFCNN(*, in_dim: int, hidden_dim_enc: int, hidden_dim: int, out_dim: int, depth_enc: int, depth: int, alpha: float = 0.6)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Fully connected neural network for graph construction. Heterogeneous encoding. Contains additional normalization parameter for the latent space.
- encoder#
- fcnn#
- _latent_normalization#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor] #
- class gnn_tracking.models.graph_construction.GraphConstructionResIN(*, node_indim: int, edge_indim: int, h_outdim: int = 8, hidden_dim: int = 40, alpha: float = 0.5, n_layers: int = 1, alpha_fcnn: float = 0.5)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Graph construction refinement with a stack of interaction network with residual connections between them. Refinement means that we assume that we’re starting of the latent space and edges from GraphConstructionFCNN.
- Parameters:
node_indim – Input dimension of the node features
edge_indim – Input dimension of the edge features
h_outdim – Output dimension of the node features
hidden_dim – All other dimensions
alpha – Strength of residual connections for the INs
n_layers – Number of INs
alpha_fcnn – Strength of residual connections connecting back to the initial latent space from the FCNN (assuming that the first h_outdim features are its latent space output)
- _node_encoder#
- _edge_encoder#
- _resin#
- _decoder: Callable#
- _latent_normalization#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor] #
- gnn_tracking.models.graph_construction.knn_with_max_radius(x: torch.Tensor, k: int, max_radius: float | None = None) torch.Tensor #
A version of kNN that excludes edges with a distance larger than a given radius.
- Parameters:
x
k – Number of neighbors
max_radius
- Returns:
edge index
- class gnn_tracking.models.graph_construction.MLGraphConstruction(ml: torch.nn.Module | None = None, *, ec: torch.nn.Module | None = None, max_radius: float = 1, max_num_neighbors: int = 256, use_embedding_features=False, ratio_of_false=None, build_edge_features=True, ec_threshold=None, ml_freeze: bool = True, ec_freeze: bool = True, embedding_slice: tuple[int | None, int | None] = (None, None))#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Builds graph from embedding space. If you want to start from a checkpoint, use MLGraphConstruction.from_chkpt.
- Parameters:
ml – Metric learning embedding module. If not specified, it is assumed that the node features from the data object are already the embedding coordinates. To use a subset of the embedding coordinates, use
embedding_slice
.ec – Directly apply edge filter
max_radius – Maximum radius for kNN
max_num_neighbors – Number of neighbors for kNN
use_embedding_features – Add embedding space features to node features (only if
ml
is not None)ratio_of_false – Subsample false edges (using truth information)
build_edge_features – Build edge features (as difference and sum of node features).
ec_threshold – EC threshold for edge filter/classification
embedding_slice – Used if
ml
is None. If not None, all node features are used. If a tuple, the first element is the start index and the second element is the end index.
- _ml#
- _ef#
- _validate_config() None #
Ensure that config makes sense.
- classmethod from_chkpt(ml_chkpt_path: str = '', ec_chkpt_path: str = '', *, ml_class_name: str = 'gnn_tracking.training.ml.MLModule', ec_class_name: str = 'gnn_tracking.training.ec.ECModule', ml_model_only: bool = True, ec_model_only: bool = True, **kwargs) MLGraphConstruction #
Build MLGraphConstruction from checkpointed models.
- Parameters:
ml_chkpt_path – Path to metric learning checkpoint
ec_chkpt_path – Path to edge filter checkpoint. If empty, no EC will be used.
ml_class_name – Class name of metric learning lightning module (default should almost always be fine)
ec_class_name – Class name of edge filter lightning module (default should almost always be fine)
ml_model_only – Only the torch model is loaded (excluding preprocessing steps from the lightning module)
ec_model_only – See
ml_model_only
**kwargs – Additional arguments passed to MLGraphConstruction
- property out_dim: tuple[int, int]#
Returns node, edge, output dims
- forward(data: torch_geometric.data.Data) torch_geometric.data.Data #
- class gnn_tracking.models.graph_construction.MLGraphConstructionFromChkpt#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Alias for MLGraphConstruction.from_chkpt for use in yaml files
- class gnn_tracking.models.graph_construction.MLPCTransformer(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Transforms a point cloud (PC) using a metric learning (ML) model. This is just a thin wrapper around the ML module with specification of what to do with the resulting latent space. In contrast to MLGraphConstructionFromChkpt, this class does not build a graph from the latent space but returns a transformed point cloud. Use MLPCTransformer.from_ml_chkpt to build from a checkpointed ML model.
Warning
In the current implementation, the original
Data
object is modified by forward.- Parameters:
model – Metric learning model. Should return latent space with key “H”
original_features – Include original node features as node features (after the transformed ones)
- _ml: torch.nn.Module#
- classmethod from_ml_chkpt(chkpt_path: str, *, class_name: str = 'gnn_tracking.training.ml.MLModule', **kwargs)#
Build MLPCTransformer from checkpointed ML model.
- Parameters:
chkpt_path – Path to checkpoint
class_name – Lightning module class name that was used for training. Probably default covers most cases.
**kwargs – Additional kwargs passed to MLPCTransformer constructor
- forward(data: torch_geometric.data.Data) torch_geometric.data.Data #
- class gnn_tracking.models.graph_construction.MLPCTransformerFromMLChkpt(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)#
Bases:
MLPCTransformer
Transforms a point cloud (PC) using a metric learning (ML) model. This is just a thin wrapper around the ML module with specification of what to do with the resulting latent space. In contrast to MLGraphConstructionFromChkpt, this class does not build a graph from the latent space but returns a transformed point cloud. Use MLPCTransformer.from_ml_chkpt to build from a checkpointed ML model.
Warning
In the current implementation, the original
Data
object is modified by forward.- Parameters:
model – Metric learning model. Should return latent space with key “H”
original_features – Include original node features as node features (after the transformed ones)