gnn_tracking.models.track_condensation_networks#
This module holds the main training models for GNN tracking.
Classes#
Model to directly process point clouds rather than start with a graph. |
|
Track condensation network based on preconstructed graphs. This module |
|
ModularTCN with ECForGraphTCN as |
|
Similar to GraphTCN but with a "perfect" (i.e., truth based) edge |
|
GraphTCN for the use with a pre-trained edge classifier |
|
GraphTCN for use with a metric learning graph construction pipeline. |
Module Contents#
- class gnn_tracking.models.track_condensation_networks.INConvBlock(indim, h_dim, e_dim, L, k, hidden_dim=100)#
Bases:
torch.nn.Module
- relu#
- node_encoder#
- edge_conv#
- edge_encoder#
- layers = []#
- forward(x: torch.Tensor, alpha: float = 0.5) torch.Tensor #
- class gnn_tracking.models.track_condensation_networks.PointCloudTCN(node_indim: int, h_dim=10, e_dim=10, h_outdim=5, hidden_dim=100, N_blocks=3, L=3)#
Bases:
torch.nn.Module
Model to directly process point clouds rather than start with a graph.
- Parameters:
node_indim
h_dim – node dimension in latent space
e_dim – edge dimension in latent space
h_outdim – output dimension in clustering space
hidden_dim – hidden with of all nn.Linear layers
N_blocks – number of edge_conv + IN blocks
L – message passing depth in each block
- layers#
- B#
- X#
- forward(data: torch_geometric.data.Data, alpha: float = 0.5) dict[str, torch.Tensor | None] #
- class gnn_tracking.models.track_condensation_networks.ModularGraphTCN(*, ec: torch.nn.Module | None = None, hc_in: torch.nn.Module, node_indim: int, edge_indim: int, h_dim: int = 5, e_dim: int = 4, h_outdim: int = 2, hidden_dim: int = 40, feed_edge_weights: bool = False, ec_threshold: float = 0.5, mask_orphan_nodes: bool = False, use_ec_embeddings_for_hc: bool = False, alpha_latent: float = 0.0, n_embedding_coords: int = 0, heterogeneous_node_encoder: bool = False)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Track condensation network based on preconstructed graphs. This module combines the following:
Node and edge encoders to get to (h_dim, e_dim)
a track condensation network hc_in
an optional edge classifier
Additional options configure how output from the edge classifier can be included in the track condensation network.
- Parameters:
ec – Edge classifier
hc_in – Track condensor interaction network.
node_indim – Node feature dimension
edge_indim – Edge feature dimension
h_dim – node dimension in the condensation interaction networks
e_dim – edge dimension in the condensation interaction networks
h_outdim – output dimension in clustering space
hidden_dim – width of hidden layers in all perceptrons
feed_edge_weights – whether to feed edge weights to the track condenser
ec_threshold – threshold for edge classification
mask_orphan_nodes – Mask nodes with no connections after EC
use_ec_embeddings_for_hc – Use edge classifier embeddings as input to track condenser. This currently assumes that h_dim and e_dim are also the dimensions used in the EC.
alpha_latent – Assume that we’re already starting from a latent space given by the first
h_outdim
node features. In this case, this is the strength of the residual connectionn_embedding_coords – Number of embedding coordinates for which to add a residual connection. To be used with alpha_latent.
heterogeneous_node_encoder – Whether to use different encoders for pixel/strip
- relu#
- ec#
- hc_in#
- node_enc_indim#
- edge_enc_indim#
- hc_edge_encoder#
- p_beta#
- p_cluster#
- _latent_normalization#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None] #
- class gnn_tracking.models.track_condensation_networks.GraphTCN(node_indim: int, edge_indim: int, *, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_ec=3, L_hc=3, alpha_ec: float = 0.5, alpha_hc: float = 0.5, **kwargs)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
ModularTCN with ECForGraphTCN as edge classification step and several interaction networks as residual layers for the track condensor network.
This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.
- Parameters:
node_indim – Node feature dim
edge_indim – Edge feature dim
h_dim – node dimension in latent space
e_dim – edge dimension in latent space
h_outdim – output dimension in clustering space
hidden_dim – width of hidden layers in all perceptrons
L_ec – message passing depth for edge classifier
L_hc – message passing depth for track condenser
alpha_ec – strength of residual connection for multi-layer interaction networks in edge classifier
alpha_hc – strength of residual connection for multi-layer interaction networks in track condenser
**kwargs – Additional keyword arguments passed to ModularGraphTCN
- ec#
- hc_in#
- _gtcn#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None] #
- class gnn_tracking.models.track_condensation_networks.PerfectECGraphTCN(*, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, ec_tpr=1.0, ec_tnr=1.0, **kwargs)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Similar to GraphTCN but with a “perfect” (i.e., truth based) edge classifier.
This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.
- Parameters:
node_indim – Node feature dim. Determined by input data.
edge_indim – Edge feature dim. Determined by input data.
h_dim – node dimension after encoding
e_dim – edge dimension after encoding
h_outdim – output dimension in clustering space
hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks
L_hc – message passing depth for track condenser
alpha_hc – strength of residual connection for multi-layer interaction networks
ec_tpr – true positive rate of the perfect edge classifier
ec_tnr – true negative rate of the perfect edge classifier
**kwargs – Passed to ModularGraphTCN
- ec#
- hc_in#
- _gtcn#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None] #
- class gnn_tracking.models.track_condensation_networks.PreTrainedECGraphTCN(ec, *, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, **kwargs)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
GraphTCN for the use with a pre-trained edge classifier
This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.
- Parameters:
ec – Pre-trained edge classifier
node_indim – Node feature dim. Determined by input data.
edge_indim – Edge feature dim. Determined by input data.
h_dim – node dimension after encoding
e_dim – edge dimension after encoding
h_outdim – output dimension in clustering space
hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks
L_hc – message passing depth for track condenser
alpha_hc – strength of residual connection for multi-layer interaction networks
**kwargs – Passed to ModularGraphTCN
- ec#
- hc_in#
- _gtcn#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None] #
- class gnn_tracking.models.track_condensation_networks.GraphTCNForMLGCPipeline(*, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, **kwargs)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
GraphTCN for use with a metric learning graph construction pipeline.
This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.
- Parameters:
node_indim – Node feature dim. Determined by input data.
edge_indim – Edge feature dim. Determined by input data.
h_dim – node dimension after encoding
e_dim – edge dimension after encoding
h_outdim – output dimension in clustering space
hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks
L_hc – message passing depth for track condenser
alpha_hc – strength of residual connection for multi-layer interaction networks
**kwargs – Passed to ModularGraphTCN
- hc_in#
- _gtcn#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None] #