:py:mod:`gnn_tracking.models.track_condensation_networks`
=========================================================

.. py:module:: gnn_tracking.models.track_condensation_networks

.. autoapi-nested-parse::

   This module holds the main training models for GNN tracking.



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.models.track_condensation_networks.INConvBlock
   gnn_tracking.models.track_condensation_networks.PointCloudTCN
   gnn_tracking.models.track_condensation_networks.ModularGraphTCN
   gnn_tracking.models.track_condensation_networks.GraphTCN
   gnn_tracking.models.track_condensation_networks.PerfectECGraphTCN
   gnn_tracking.models.track_condensation_networks.PreTrainedECGraphTCN
   gnn_tracking.models.track_condensation_networks.GraphTCNForMLGCPipeline




.. py:class:: INConvBlock(indim, h_dim, e_dim, L, k, hidden_dim=100)


   Bases: :py:obj:`torch.nn.Module`

   .. py:method:: forward(x: torch.Tensor, alpha: float = 0.5) -> torch.Tensor



.. py:class:: PointCloudTCN(node_indim: int, h_dim=10, e_dim=10, h_outdim=5, hidden_dim=100, N_blocks=3, L=3)


   Bases: :py:obj:`torch.nn.Module`

   Model to directly process point clouds rather than start with a graph.

   :param node_indim:
   :param h_dim: node dimension in latent space
   :param e_dim: edge dimension in latent space
   :param h_outdim: output dimension in clustering space
   :param hidden_dim: hidden with of all nn.Linear layers
   :param N_blocks: number of edge_conv + IN blocks
   :param L: message passing depth in each block

   .. py:method:: forward(data: torch_geometric.data.Data, alpha: float = 0.5) -> dict[str, torch.Tensor | None]



.. py:class:: ModularGraphTCN(*, ec: torch.nn.Module | None = None, hc_in: torch.nn.Module, node_indim: int, edge_indim: int, h_dim: int = 5, e_dim: int = 4, h_outdim: int = 2, hidden_dim: int = 40, feed_edge_weights: bool = False, ec_threshold: float = 0.5, mask_orphan_nodes: bool = False, use_ec_embeddings_for_hc: bool = False, alpha_latent: float = 0.0, n_embedding_coords: int = 0)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Track condensation network based on preconstructed graphs. This module
   combines the following:

   * Node and edge encoders to get to `(h_dim, e_dim)`
   * a track condensation network `hc_in`
   * an optional edge classifier

   Additional options configure how output from the edge classifier can be included
   in the track condensation network.

   :param ec: Edge classifier
   :param hc_in: Track condensor interaction network.
   :param node_indim: Node feature dimension
   :param edge_indim: Edge feature dimension
   :param h_dim: node dimension in the condensation interaction networks
   :param e_dim: edge dimension in the condensation interaction networks
   :param h_outdim: output dimension in clustering space
   :param hidden_dim: width of hidden layers in all perceptrons
   :param feed_edge_weights: whether to feed edge weights to the track condenser
   :param ec_threshold: threshold for edge classification
   :param mask_orphan_nodes: Mask nodes with no connections after EC
   :param use_ec_embeddings_for_hc: Use edge classifier embeddings as input to
                                    track condenser. This currently assumes that h_dim and e_dim are
                                    also the dimensions used in the EC.
   :param alpha_latent: Assume that we're already starting from a latent space given
                        by the first ``h_outdim`` node features. In this case, this is the
                        strength of the residual connection
   :param n_embedding_coords: Number of embedding coordinates for which to add a
                              residual connection. To be used with `alpha_latent`.

   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor | None]



.. py:class:: GraphTCN(node_indim: int, edge_indim: int, *, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_ec=3, L_hc=3, alpha_ec: float = 0.5, alpha_hc: float = 0.5, **kwargs)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   `ModularTCN` with `ECForGraphTCN` as
   edge classification step and several interaction networks as residual layers
   for the track condensor network.

   This is a small wrapper around `ModularGraphTCN`, mostly to make sure that
   we can change the underlying implementation without invalidating config
   files that reference this class.

   :param node_indim: Node feature dim
   :param edge_indim: Edge feature dim
   :param h_dim: node dimension in latent space
   :param e_dim: edge dimension in latent space
   :param h_outdim: output dimension in clustering space
   :param hidden_dim: width of hidden layers in all perceptrons
   :param L_ec: message passing depth for edge classifier
   :param L_hc: message passing depth for track condenser
   :param alpha_ec: strength of residual connection for multi-layer interaction
                    networks in edge classifier
   :param alpha_hc: strength of residual connection for multi-layer interaction
                    networks in track condenser
   :param \*\*kwargs: Additional keyword arguments passed to `ModularGraphTCN`

   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor | None]



.. py:class:: PerfectECGraphTCN(*, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, ec_tpr=1.0, ec_tnr=1.0, **kwargs)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Similar to `GraphTCN` but with a "perfect" (i.e., truth based) edge
   classifier.

   This is a small wrapper around `ModularGraphTCN`, mostly to make sure that
   we can change the underlying implementation without invalidating config
   files that reference this class.

   :param node_indim: Node feature dim. Determined by input data.
   :param edge_indim: Edge feature dim. Determined by input data.
   :param h_dim: node dimension after encoding
   :param e_dim: edge dimension after encoding
   :param h_outdim: output dimension in clustering space
   :param hidden_dim: dimension of hidden layers in all MLPs used in the interaction
                      networks
   :param L_hc: message passing depth for track condenser
   :param alpha_hc: strength of residual connection for multi-layer interaction
                    networks
   :param ec_tpr: true positive rate of the perfect edge classifier
   :param ec_tnr: true negative rate of the perfect edge classifier
   :param \*\*kwargs: Passed to `ModularGraphTCN`

   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor | None]



.. py:class:: PreTrainedECGraphTCN(ec, *, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, **kwargs)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   GraphTCN for the use with a pre-trained edge classifier

   This is a small wrapper around `ModularGraphTCN`, mostly to make sure that
   we can change the underlying implementation without invalidating config
   files that reference this class.

   :param ec: Pre-trained edge classifier
   :param node_indim: Node feature dim. Determined by input data.
   :param edge_indim: Edge feature dim. Determined by input data.
   :param h_dim: node dimension after encoding
   :param e_dim: edge dimension after encoding
   :param h_outdim: output dimension in clustering space
   :param hidden_dim: dimension of hidden layers in all MLPs used in the interaction
                      networks
   :param L_hc: message passing depth for track condenser
   :param alpha_hc: strength of residual connection for multi-layer interaction
                    networks
   :param \*\*kwargs: Passed to `ModularGraphTCN`

   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor | None]



.. py:class:: GraphTCNForMLGCPipeline(*, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, **kwargs)


   Bases: :py:obj:`torch.nn.Module`, :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   GraphTCN for use with a metric learning graph construction pipeline.

   This is a small wrapper around `ModularGraphTCN`, mostly to make sure that
   we can change the underlying implementation without invalidating config
   files that reference this class.

   :param node_indim: Node feature dim. Determined by input data.
   :param edge_indim: Edge feature dim. Determined by input data.
   :param h_dim: node dimension after encoding
   :param e_dim: edge dimension after encoding
   :param h_outdim: output dimension in clustering space
   :param hidden_dim: dimension of hidden layers in all MLPs used in the interaction
                      networks
   :param L_hc: message passing depth for track condenser
   :param alpha_hc: strength of residual connection for multi-layer interaction
                    networks
   :param \*\*kwargs: Passed to `ModularGraphTCN`

   .. py:method:: forward(data: torch_geometric.data.Data) -> dict[str, torch.Tensor | None]



