:py:mod:`gnn_tracking.models.graph_construction`
================================================

.. py:module:: gnn_tracking.models.graph_construction

.. autoapi-nested-parse::

   Models for embeddings used for graph construction.



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.models.graph_construction.GraphConstructionFCNN
   gnn_tracking.models.graph_construction.GraphConstructionResIN
   gnn_tracking.models.graph_construction.MLGraphConstruction
   gnn_tracking.models.graph_construction.MLGraphConstructionFromChkpt
   gnn_tracking.models.graph_construction.MLPCTransformer
   gnn_tracking.models.graph_construction.MLPCTransformerFromMLChkpt



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.models.graph_construction.knn_with_max_radius



.. py:class:: GraphConstructionFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Metric learning embedding fully connected NN.

   :param in_dim: Input dimension
   :param hidden_dim: Hidden dimension
   :param out_dim: Output dimension = embedding space
   :param depth: Number of layers
   :param beta: Strength of residual connection in layer-to-layer connections

   .. py:method:: reset_parameters()


   .. py:method:: _reset_layer_parameters(layer, var: float)
      :staticmethod:


   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor]



.. py:class:: GraphConstructionResIN(*, node_indim: int, edge_indim: int, h_outdim: int = 8, hidden_dim: int = 40, alpha: float = 0.5, n_layers: int = 1, alpha_fcnn: float = 0.5)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Graph construction refinement with a stack of interaction network with
   residual connections between them. Refinement means that we assume that we're
   starting of the latent space and edges from `GraphConstructionFCNN`.

   :param node_indim: Input dimension of the node features
   :param edge_indim: Input dimension of the edge features
   :param h_outdim: Output dimension of the node features
   :param hidden_dim: All other dimensions
   :param alpha: Strength of residual connections for the INs
   :param n_layers: Number of INs
   :param alpha_fcnn: Strength of residual connections connecting back to the initial
                      latent space from the FCNN (assuming that the first `h_outdim` features
                      are its latent space output)

   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor]



.. py:function:: knn_with_max_radius(x: torch.Tensor, k: int, max_radius: float | None = None) -> torch.Tensor

   A version of kNN that excludes edges with a distance larger than a given radius.

   :param x:
   :param k: Number of neighbors
   :param max_radius:

   :returns: edge index


.. py:class:: MLGraphConstruction(ml: torch.nn.Module | None = None, *, ec: torch.nn.Module | None = None, max_radius: float = 1, max_num_neighbors: int = 256, use_embedding_features=False, ratio_of_false=None, build_edge_features=True, ec_threshold=None, ml_freeze: bool = True, ec_freeze: bool = True, embedding_slice: tuple[int | None, int | None] = (None, None))


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Builds graph from embedding space. If you want to start from a checkpoint,
   use `MLGraphConstruction.from_chkpt`.

   :param ml: Metric learning embedding module. If not specified, it is assumed that
              the node features from the data object are already the embedding
              coordinates. To use a subset of the embedding coordinates, use
              ``embedding_slice``.
   :param ec: Directly apply edge filter
   :param max_radius: Maximum radius for kNN
   :param max_num_neighbors: Number of neighbors for kNN
   :param use_embedding_features: Add embedding space features to node features
                                  (only if ``ml`` is not None)
   :param ratio_of_false: Subsample false edges (using truth information)
   :param build_edge_features: Build edge features (as difference and sum of node
                               features).
   :param ec_threshold: EC threshold for edge filter/classification
   :param embedding_slice: Used if ``ml`` is None. If not None, all node features
                           are used. If a tuple, the first element is the start index and the
                           second element is the end index.

   .. py:property:: out_dim
      :type: tuple[int, int]

      Returns node, edge, output dims

   .. py:method:: _validate_config() -> None

      Ensure that config makes sense.


   .. py:method:: from_chkpt(ml_chkpt_path: str = '', ec_chkpt_path: str = '', *, ml_class_name: str = 'gnn_tracking.training.ml.MLModule', ec_class_name: str = 'gnn_tracking.training.ec.ECModule', ml_model_only: bool = True, ec_model_only: bool = True, **kwargs) -> MLGraphConstruction
      :classmethod:

      Build `MLGraphConstruction` from checkpointed models.

      :param ml_chkpt_path: Path to metric learning checkpoint
      :param ec_chkpt_path: Path to edge filter checkpoint. If empty, no EC will be
                            used.
      :param ml_class_name: Class name of metric learning lightning module
                            (default should almost always be fine)
      :param ec_class_name: Class name of edge filter lightning module
                            (default should almost always be fine)
      :param ml_model_only: Only the torch model is loaded (excluding preprocessing
                            steps from the lightning module)
      :param ec_model_only: See ``ml_model_only``
      :param \*\*kwargs: Additional arguments passed to `MLGraphConstruction`


   .. py:method:: forward(data: torch_geometric.data.Data) -> torch_geometric.data.Data



.. py:class:: MLGraphConstructionFromChkpt


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Alias for `MLGraphConstruction.from_chkpt` for use in yaml files


.. py:class:: MLPCTransformer(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Transforms a point cloud (PC) using a metric learning (ML) model.
   This is just a thin wrapper around the ML module with specification of what
   to do with the resulting latent space.
   In contrast to `MLGraphConstructionFromChkpt`, this class does not build a
   graph from the latent space but returns a transformed point cloud.
   Use `MLPCTransformer.from_ml_chkpt` to build from a checkpointed ML model.

   .. warning::
       In the current implementation, the original ``Data`` object is modified
       by `forward`.

   :param model: Metric learning model. Should return latent space with key "H"
   :param original_features: Include original node features as node features (after
                             the transformed ones)

   .. py:method:: from_ml_chkpt(chkpt_path: str, *, class_name: str = 'gnn_tracking.training.ml.MLModule', **kwargs)
      :classmethod:

      Build `MLPCTransformer` from checkpointed ML model.

      :param chkpt_path: Path to checkpoint
      :param class_name: Lightning module class name that was used for training.
                         Probably default covers most cases.
      :param \*\*kwargs: Additional kwargs passed to `MLPCTransformer` constructor


   .. py:method:: forward(data: torch_geometric.data.Data) -> torch_geometric.data.Data



.. py:class:: MLPCTransformerFromMLChkpt(model: torch.nn.Module, *, original_features: bool = False, freeze: bool = True)


   Bases: :py:obj:`MLPCTransformer`

   Transforms a point cloud (PC) using a metric learning (ML) model.
   This is just a thin wrapper around the ML module with specification of what
   to do with the resulting latent space.
   In contrast to `MLGraphConstructionFromChkpt`, this class does not build a
   graph from the latent space but returns a transformed point cloud.
   Use `MLPCTransformer.from_ml_chkpt` to build from a checkpointed ML model.

   .. warning::
       In the current implementation, the original ``Data`` object is modified
       by `forward`.

   :param model: Metric learning model. Should return latent space with key "H"
   :param original_features: Include original node features as node features (after
                             the transformed ones)


