gnn_tracking.models.track_condensation_networks#

This module holds the main training models for GNN tracking.

Classes#

DynamicEdgeConv

ECForGraphTCN

Edge classification step to be used for Graph Track Condensor network

PerfectEdgeClassification

An edge classifier that is perfect because it uses the truth information.

IN

Interaction Network, consisting of a relational model and an object model,

MLP

Multi Layer Perceptron, using ReLu as activation function.

HeterogeneousResFCNN

Separate FCNNs for pixel and strip data, with residual connections.

ResFCNN

Fully connected NN with residual connections.

ResIN

Create a ResIN with identical layers of interaction networks.

INConvBlock

PointCloudTCN

Model to directly process point clouds rather than start with a graph.

ModularGraphTCN

Track condensation network based on preconstructed graphs. This module

GraphTCN

ModularTCN with ECForGraphTCN as

PerfectECGraphTCN

Similar to GraphTCN but with a "perfect" (i.e., truth based) edge

PreTrainedECGraphTCN

GraphTCN for the use with a pre-trained edge classifier

GraphTCNForMLGCPipeline

GraphTCN for use with a metric learning graph construction pipeline.

Functions#

obj_from_or_to_hparams(→ Any)

Used to support initializing python objects from hyperparameters:

Module Contents#

class gnn_tracking.models.track_condensation_networks.DynamicEdgeConv(nn: Callable, k: int, aggr: str = 'max', num_workers: int = 1, **kwargs)#

Bases: torch_geometric.nn.conv.MessagePassing

reset_parameters()#
get_edge_index()#
forward(x: torch.Tensor | torch_geometric.typing.PairTensor, batch: torch_geometric.typing.OptTensor | torch_geometric.typing.PairTensor | None = None) torch.Tensor#
message(x_i: torch.Tensor, x_j: torch.Tensor) torch.Tensor#
__repr__() str#
class gnn_tracking.models.track_condensation_networks.ECForGraphTCN(*, node_indim: int, edge_indim: int, interaction_node_dim: int = 5, interaction_edge_dim: int = 4, hidden_dim: int | float | None = None, L_ec: int = 3, alpha: float = 0.5, residual_type='skip1', use_intermediate_edge_embeddings: bool = True, use_node_embedding: bool = True, residual_kwargs: dict | None = None)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Edge classification step to be used for Graph Track Condensor network (Graph TCN)

Parameters:
  • node_indim – Node feature dim

  • edge_indim – Edge feature dim

  • interaction_node_dim – Node dimension for interaction networks. Defaults to 5 for backward compatibility, but this is probably not reasonable.

  • interaction_edge_dim – Edge dimension of interaction networks Defaults to 4 for backward compatibility, but this is probably not reasonable.

  • hidden_dim – width of hidden layers in all perceptrons (edge and node encoders, hidden dims for MLPs in object and relation networks). If None: choose as maximum of input/output dims for each MLP separately

  • L_ec – message passing depth for edge classifier

  • alpha – strength of residual connection for EC

  • residual_type – type of residual connection for EC

  • use_intermediate_edge_embeddings – If true, don’t only feed the final encoding of the stacked interaction networks to the final MLP, but all intermediate encodings

  • use_node_embedding – If true, feed node attributes to the final MLP for EC

  • residual_kwargs – Keyword arguments passed to ResIN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#

Returns dictionary of the following:

  • W: Edge weights

  • node_embedding: Last node embedding (result of last interaction network)

  • edge_embedding: Last edge embedding (result of last interaction network)

class gnn_tracking.models.track_condensation_networks.PerfectEdgeClassification(tpr=1.0, tnr=1.0, false_below_pt=0.0)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

An edge classifier that is perfect because it uses the truth information. If TPR or TNR is not 1.0, noise is added to the truth information.

This can be used to evaluate the maximal possible performance of a model that relies on edge classification as a first step (e.g., the object condensation approach).

Parameters:
  • tpr – True positive rate

  • tnr – False positive rate

  • false_below_pt – If not 0.0, all true edges between hits corresponding to particles with a pt lower than this threshold are set to false. This is not counted towards the TPR/TNR but applied afterwards.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.track_condensation_networks.IN(*, node_indim: int, edge_indim: int, node_outdim=3, edge_outdim=4, node_hidden_dim=40, edge_hidden_dim=40, aggr='add')#

Bases: torch_geometric.nn.MessagePassing, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Interaction Network, consisting of a relational model and an object model, both represented as MLPs.

Parameters:
  • node_indim – Node feature dimension

  • edge_indim – Edge feature dimension

  • node_outdim – Output node feature dimension

  • edge_outdim – Output edge feature dimension

  • node_hidden_dim – Hidden dimension for the object model MLP

  • edge_hidden_dim – Hidden dimension for the relational model MLP

  • aggr – How to aggregate the messages

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#

Forward pass

Parameters:
  • x – Input node features

  • edge_index

  • edge_attr – Input edge features

Returns:

Output node embedding, output edge embedding

message(x_i: torch.Tensor, x_j: torch.Tensor, edge_attr: torch.Tensor) torch.Tensor#

Calculate message of an edge

Parameters:
  • x_i – Features of node 1 (node where the edge ends)

  • x_j – Features of node 2 (node where the edge starts)

  • edge_attr – Edge features

Returns:

Message

update(aggr_out: torch.Tensor, x: torch.Tensor) torch.Tensor#

Update for node embedding

Parameters:
  • aggr_out – Aggregated messages of all edges

  • x – Node features for the node that receives all edges

Returns:

Updated node features/embedding

class gnn_tracking.models.track_condensation_networks.MLP(input_size: int, output_size: int, hidden_dim: int | None, L=3, *, bias=True, include_last_activation=False)#

Bases: torch.nn.Module

Multi Layer Perceptron, using ReLu as activation function.

Parameters:
  • input_size – Input feature dimension

  • output_size – Output feature dimension

  • hidden_dim – Feature dimension of the hidden layers. If None: Choose maximum of input/output size

  • L – Total number of layers (1 initial layer, L-2 hidden layers, 1 output layer)

  • bias – Include bias in linear layer?

  • include_last_activation – Include activation function for the last layer?

reset_parameters()#
forward(x)#
class gnn_tracking.models.track_condensation_networks.HeterogeneousResFCNN(*, in_dim: int, out_dim: int, hidden_dim: int, depth: int, alpha: float = 0.6, bias: bool = True)#

Bases: torch.nn.Module

Separate FCNNs for pixel and strip data, with residual connections. For parameters, see ResFCNN.

forward(x: torch.Tensor, layer: torch.Tensor) torch.Tensor#
class gnn_tracking.models.track_condensation_networks.ResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6, bias: bool = True)#

Bases: torch.nn.Module

Fully connected NN with residual connections.

Parameters:
  • in_dim – Input dimension

  • hidden_dim – Hidden dimension

  • out_dim – Output dimension = embedding space

  • depth – 1 input encoder layer, depth-1 hidden layers, 1 output encoder layer

  • alpha – strength of the residual connection

static _reset_layer_parameters(layer, var: float)#
forward(x: torch.Tensor, **ignore) torch.Tensor#
class gnn_tracking.models.track_condensation_networks.ResIN(*, node_dim: int, edge_dim: int, object_hidden_dim=40, relational_hidden_dim=40, alpha: float = 0.5, n_layers=1, residual_type: str = 'skip1', residual_kwargs: dict | None = None)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Create a ResIN with identical layers of interaction networks.

Parameters:
  • node_dim – Node feature dimension

  • edge_dim – Edge feature dimension

  • object_hidden_dim – Hidden dimension for the object model MLP

  • relational_hidden_dim – Hidden dimension for the relational model MLP

  • alpha – Strength of the node residual connection

  • n_layers – Total number of layers

  • residual_type – Type of residual network. Options are ‘skip1’, ‘skip2’, ‘skip_top’.

  • residual_kwargs – Additional arguments to the residual network (can depend on the residual type)

property concat_edge_embeddings_length: int#

Length of the concatenated edge embeddings from all intermediate layers. Or in other words: self.forward()[3].shape[1]

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[torch.Tensor] | None]#
gnn_tracking.models.track_condensation_networks.obj_from_or_to_hparams(self: pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin, key: str, obj: Any) Any#

Used to support initializing python objects from hyperparameters: If obj is a python object other than a dictionary, its hyperparameters are saved (its class path and init args) to self.hparams[key]. If obj is instead a dictionary, its assumed that we have to restore an object based on this information.

class gnn_tracking.models.track_condensation_networks.INConvBlock(indim, h_dim, e_dim, L, k, hidden_dim=100)#

Bases: torch.nn.Module

forward(x: torch.Tensor, alpha: float = 0.5) torch.Tensor#
class gnn_tracking.models.track_condensation_networks.PointCloudTCN(node_indim: int, h_dim=10, e_dim=10, h_outdim=5, hidden_dim=100, N_blocks=3, L=3)#

Bases: torch.nn.Module

Model to directly process point clouds rather than start with a graph.

Parameters:
  • node_indim

  • h_dim – node dimension in latent space

  • e_dim – edge dimension in latent space

  • h_outdim – output dimension in clustering space

  • hidden_dim – hidden with of all nn.Linear layers

  • N_blocks – number of edge_conv + IN blocks

  • L – message passing depth in each block

forward(data: torch_geometric.data.Data, alpha: float = 0.5) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.ModularGraphTCN(*, ec: torch.nn.Module | None = None, hc_in: torch.nn.Module, node_indim: int, edge_indim: int, h_dim: int = 5, e_dim: int = 4, h_outdim: int = 2, hidden_dim: int = 40, feed_edge_weights: bool = False, ec_threshold: float = 0.5, mask_orphan_nodes: bool = False, use_ec_embeddings_for_hc: bool = False, alpha_latent: float = 0.0, n_embedding_coords: int = 0, heterogeneous_node_encoder: bool = False)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Track condensation network based on preconstructed graphs. This module combines the following:

  • Node and edge encoders to get to (h_dim, e_dim)

  • a track condensation network hc_in

  • an optional edge classifier

Additional options configure how output from the edge classifier can be included in the track condensation network.

Parameters:
  • ec – Edge classifier

  • hc_in – Track condensor interaction network.

  • node_indim – Node feature dimension

  • edge_indim – Edge feature dimension

  • h_dim – node dimension in the condensation interaction networks

  • e_dim – edge dimension in the condensation interaction networks

  • h_outdim – output dimension in clustering space

  • hidden_dim – width of hidden layers in all perceptrons

  • feed_edge_weights – whether to feed edge weights to the track condenser

  • ec_threshold – threshold for edge classification

  • mask_orphan_nodes – Mask nodes with no connections after EC

  • use_ec_embeddings_for_hc – Use edge classifier embeddings as input to track condenser. This currently assumes that h_dim and e_dim are also the dimensions used in the EC.

  • alpha_latent – Assume that we’re already starting from a latent space given by the first h_outdim node features. In this case, this is the strength of the residual connection

  • n_embedding_coords – Number of embedding coordinates for which to add a residual connection. To be used with alpha_latent.

  • heterogeneous_node_encoder – Whether to use different encoders for pixel/strip

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.GraphTCN(node_indim: int, edge_indim: int, *, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_ec=3, L_hc=3, alpha_ec: float = 0.5, alpha_hc: float = 0.5, **kwargs)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

ModularTCN with ECForGraphTCN as edge classification step and several interaction networks as residual layers for the track condensor network.

This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.

Parameters:
  • node_indim – Node feature dim

  • edge_indim – Edge feature dim

  • h_dim – node dimension in latent space

  • e_dim – edge dimension in latent space

  • h_outdim – output dimension in clustering space

  • hidden_dim – width of hidden layers in all perceptrons

  • L_ec – message passing depth for edge classifier

  • L_hc – message passing depth for track condenser

  • alpha_ec – strength of residual connection for multi-layer interaction networks in edge classifier

  • alpha_hc – strength of residual connection for multi-layer interaction networks in track condenser

  • **kwargs – Additional keyword arguments passed to ModularGraphTCN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.PerfectECGraphTCN(*, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, ec_tpr=1.0, ec_tnr=1.0, **kwargs)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Similar to GraphTCN but with a “perfect” (i.e., truth based) edge classifier.

This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.

Parameters:
  • node_indim – Node feature dim. Determined by input data.

  • edge_indim – Edge feature dim. Determined by input data.

  • h_dim – node dimension after encoding

  • e_dim – edge dimension after encoding

  • h_outdim – output dimension in clustering space

  • hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks

  • L_hc – message passing depth for track condenser

  • alpha_hc – strength of residual connection for multi-layer interaction networks

  • ec_tpr – true positive rate of the perfect edge classifier

  • ec_tnr – true negative rate of the perfect edge classifier

  • **kwargs – Passed to ModularGraphTCN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.PreTrainedECGraphTCN(ec, *, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, **kwargs)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

GraphTCN for the use with a pre-trained edge classifier

This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.

Parameters:
  • ec – Pre-trained edge classifier

  • node_indim – Node feature dim. Determined by input data.

  • edge_indim – Edge feature dim. Determined by input data.

  • h_dim – node dimension after encoding

  • e_dim – edge dimension after encoding

  • h_outdim – output dimension in clustering space

  • hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks

  • L_hc – message passing depth for track condenser

  • alpha_hc – strength of residual connection for multi-layer interaction networks

  • **kwargs – Passed to ModularGraphTCN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#
class gnn_tracking.models.track_condensation_networks.GraphTCNForMLGCPipeline(*, node_indim: int, edge_indim: int, h_dim=5, e_dim=4, h_outdim=2, hidden_dim=40, L_hc=3, alpha_hc: float = 0.5, **kwargs)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

GraphTCN for use with a metric learning graph construction pipeline.

This is a small wrapper around ModularGraphTCN, mostly to make sure that we can change the underlying implementation without invalidating config files that reference this class.

Parameters:
  • node_indim – Node feature dim. Determined by input data.

  • edge_indim – Edge feature dim. Determined by input data.

  • h_dim – node dimension after encoding

  • e_dim – edge dimension after encoding

  • h_outdim – output dimension in clustering space

  • hidden_dim – dimension of hidden layers in all MLPs used in the interaction networks

  • L_hc – message passing depth for track condenser

  • alpha_hc – strength of residual connection for multi-layer interaction networks

  • **kwargs – Passed to ModularGraphTCN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#