gnn_tracking.models.track_condensation_networks#

This module holds the main training models for GNN tracking.

Module Contents#

Classes#

INConvBlock

PointCloudTCN

Model to directly process point clouds rather than start with a graph.

ModularGraphTCN

Track condensation network based on preconstructed graphs. This module

GraphTCN

ModularTCN with ECForGraphTCN as

PerfectECGraphTCN

Similar to GraphTCN but with a "perfect" (i.e., truth based) edge

PreTrainedECGraphTCN

GraphTCN for the use with a pre-trained edge classifier

GraphTCNForMLGCPipeline

GraphTCN for use with a metric learning graph construction pipeline.

class gnn_tracking.models.track_condensation_networks.INConvBlock(indim, h_dim, e_dim, L, k, hidden_dim=100)#

Bases: torch.nn.Module

forward(x: torch.Tensor, alpha: float = 0.5) torch.Tensor#
class gnn_tracking.models.track_condensation_networks.PointCloudTCN(node_indim: int, h_dim=10, e_dim=10, h_outdim=5, hidden_dim=100, N_blocks=3, L=3)#

Bases: torch.nn.Module

Model to directly process point clouds rather than start with a graph.

Parameters:
  • node_indim

  • h_dim – node dimension in latent space

  • e_dim – edge dimension in latent space

  • h_outdim – output dimension in clustering space

  • hidden_dim – hidden with of all nn.Linear layers

  • N_blocks – number of edge_conv + IN blocks

  • L – message passing depth in each block

forward(data: torch_geometric.data.Data, alpha: float = 0.5) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.ModularGraphTCN(*, ec: torch.nn.Module | None = None, hc_in: torch.nn.Module, node_indim: int, edge_indim: int, h_dim: int = 5, e_dim: int = 4, h_outdim: int = 2, hidden_dim: int = 40, feed_edge_weights: bool = False, ec_threshold: float = 0.5, mask_orphan_nodes: bool = False, use_ec_embeddings_for_hc: bool = False, alpha_latent: float = 0.0, n_embedding_coords: int = 0, heterogeneous_node_encoder: bool = False)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Track condensation network based on preconstructed graphs. This module combines the following:

  • Node and edge encoders to get to (h_dim, e_dim)

  • a track condensation network hc_in

  • an optional edge classifier

Additional options configure how output from the edge classifier can be included in the track condensation network.

Parameters:
  • ec – Edge classifier

  • hc_in – Track condensor interaction network.

  • node_indim – Node feature dimension

  • edge_indim – Edge feature dimension

  • h_dim – node dimension in the condensation interaction networks

  • e_dim – edge dimension in the condensation interaction networks

  • h_outdim – output dimension in clustering space

  • hidden_dim – width of hidden layers in all perceptrons

  • feed_edge_weights – whether to feed edge weights to the track condenser

  • ec_threshold – threshold for edge classification

  • mask_orphan_nodes – Mask nodes with no connections after EC

  • use_ec_embeddings_for_hc – Use edge classifier embeddings as input to track condenser. This currently assumes that h_dim and e_dim are also the dimensions used in the EC.

  • alpha_latent – Assume that we’re already starting from a latent space given by the first h_outdim node features. In this case, this is the strength of the residual connection

  • n_embedding_coords – Number of embedding coordinates for which to add a residual connection. To be used with alpha_latent.

  • heterogeneous_node_encoder – Whether to use different encoders for pixel/strip

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.GraphTCN(node_indim: int, edge_indim: int, *, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_ec=3, L_hc=3, alpha_ec: float = 0.5, alpha_hc: float = 0.5, **kwargs)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

ModularTCN with ECForGraphTCN as edge classification step and several interaction networks as residual layers for the track condensor network.

This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.

Parameters:
  • node_indim – Node feature dim

  • edge_indim – Edge feature dim

  • h_dim – node dimension in latent space

  • e_dim – edge dimension in latent space

  • h_outdim – output dimension in clustering space

  • hidden_dim – width of hidden layers in all perceptrons

  • L_ec – message passing depth for edge classifier

  • L_hc – message passing depth for track condenser

  • alpha_ec – strength of residual connection for multi-layer interaction networks in edge classifier

  • alpha_hc – strength of residual connection for multi-layer interaction networks in track condenser

  • **kwargs – Additional keyword arguments passed to ModularGraphTCN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.PerfectECGraphTCN(*, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, ec_tpr=1.0, ec_tnr=1.0, **kwargs)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Similar to GraphTCN but with a “perfect” (i.e., truth based) edge classifier.

This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.

Parameters:
  • node_indim – Node feature dim. Determined by input data.

  • edge_indim – Edge feature dim. Determined by input data.

  • h_dim – node dimension after encoding

  • e_dim – edge dimension after encoding

  • h_outdim – output dimension in clustering space

  • hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks

  • L_hc – message passing depth for track condenser

  • alpha_hc – strength of residual connection for multi-layer interaction networks

  • ec_tpr – true positive rate of the perfect edge classifier

  • ec_tnr – true negative rate of the perfect edge classifier

  • **kwargs – Passed to ModularGraphTCN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.PreTrainedECGraphTCN(ec, *, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, **kwargs)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

GraphTCN for the use with a pre-trained edge classifier

This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.

Parameters:
  • ec – Pre-trained edge classifier

  • node_indim – Node feature dim. Determined by input data.

  • edge_indim – Edge feature dim. Determined by input data.

  • h_dim – node dimension after encoding

  • e_dim – edge dimension after encoding

  • h_outdim – output dimension in clustering space

  • hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks

  • L_hc – message passing depth for track condenser

  • alpha_hc – strength of residual connection for multi-layer interaction networks

  • **kwargs – Passed to ModularGraphTCN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.GraphTCNForMLGCPipeline(*, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, **kwargs)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

GraphTCN for use with a metric learning graph construction pipeline.

This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.

Parameters:
  • node_indim – Node feature dim. Determined by input data.

  • edge_indim – Edge feature dim. Determined by input data.

  • h_dim – node dimension after encoding

  • e_dim – edge dimension after encoding

  • h_outdim – output dimension in clustering space

  • hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks

  • L_hc – message passing depth for track condenser

  • alpha_hc – strength of residual connection for multi-layer interaction networks

  • **kwargs – Passed to ModularGraphTCN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#