gnn_tracking.models.edge_classifier#

Models for edge classification

Classes#

MLP

Multi Layer Perceptron, using ReLu as activation function.

ResIN

Create a ResIN with identical layers of interaction networks.

ECForGraphTCN

Edge classification step to be used for Graph Track Condensor network

PerfectEdgeClassification

An edge classifier that is perfect because it uses the truth information.

ECFromChkpt

For easy loading of an pretrained EC from a lightning yaml config.

Functions#

assert_feat_dim(→ None)

get_model(→ torch.nn.Module | None)

Get torch model (specified by class_path, a string) and load a checkpoint.

Module Contents#

class gnn_tracking.models.edge_classifier.MLP(input_size: int, output_size: int, hidden_dim: int | None, L=3, *, bias=True, include_last_activation=False)#

Bases: torch.nn.Module

Multi Layer Perceptron, using ReLu as activation function.

Parameters:
  • input_size – Input feature dimension

  • output_size – Output feature dimension

  • hidden_dim – Feature dimension of the hidden layers. If None: Choose maximum of input/output size

  • L – Total number of layers (1 initial layer, L-2 hidden layers, 1 output layer)

  • bias – Include bias in linear layer?

  • include_last_activation – Include activation function for the last layer?

reset_parameters()#
forward(x)#
class gnn_tracking.models.edge_classifier.ResIN(*, node_dim: int, edge_dim: int, object_hidden_dim=40, relational_hidden_dim=40, alpha: float = 0.5, n_layers=1, residual_type: str = 'skip1', residual_kwargs: dict | None = None)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Create a ResIN with identical layers of interaction networks.

Parameters:
  • node_dim – Node feature dimension

  • edge_dim – Edge feature dimension

  • object_hidden_dim – Hidden dimension for the object model MLP

  • relational_hidden_dim – Hidden dimension for the relational model MLP

  • alpha – Strength of the node residual connection

  • n_layers – Total number of layers

  • residual_type – Type of residual network. Options are ‘skip1’, ‘skip2’, ‘skip_top’.

  • residual_kwargs – Additional arguments to the residual network (can depend on the residual type)

property concat_edge_embeddings_length: int#

Length of the concatenated edge embeddings from all intermediate layers. Or in other words: self.forward()[3].shape[1]

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[torch.Tensor] | None]#
gnn_tracking.models.edge_classifier.assert_feat_dim(feat_vec: torch.Tensor, dim: int) None#
gnn_tracking.models.edge_classifier.get_model(class_path: str, chkpt_path: str = '', freeze: bool = False, whole_module: bool = False, device: None | str = None) torch.nn.Module | None#

Get torch model (specified by class_path, a string) and load a checkpoint. Uses get_lightning_module to get the model.

Parameters:
  • class_path – The path to the lightning module class

  • chkpt_path – The path to the checkpoint. If no checkpoint is specified, we return None.

  • freeze – Whether to freeze the model

  • whole_module – Whether to return the whole lightning module or just the model

  • device

class gnn_tracking.models.edge_classifier.ECForGraphTCN(*, node_indim: int, edge_indim: int, interaction_node_dim: int = 5, interaction_edge_dim: int = 4, hidden_dim: int | float | None = None, L_ec: int = 3, alpha: float = 0.5, residual_type='skip1', use_intermediate_edge_embeddings: bool = True, use_node_embedding: bool = True, residual_kwargs: dict | None = None)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Edge classification step to be used for Graph Track Condensor network (Graph TCN)

Parameters:
  • node_indim – Node feature dim

  • edge_indim – Edge feature dim

  • interaction_node_dim – Node dimension for interaction networks. Defaults to 5 for backward compatibility, but this is probably not reasonable.

  • interaction_edge_dim – Edge dimension of interaction networks Defaults to 4 for backward compatibility, but this is probably not reasonable.

  • hidden_dim – width of hidden layers in all perceptrons (edge and node encoders, hidden dims for MLPs in object and relation networks). If None: choose as maximum of input/output dims for each MLP separately

  • L_ec – message passing depth for edge classifier

  • alpha – strength of residual connection for EC

  • residual_type – type of residual connection for EC

  • use_intermediate_edge_embeddings – If true, don’t only feed the final encoding of the stacked interaction networks to the final MLP, but all intermediate encodings

  • use_node_embedding – If true, feed node attributes to the final MLP for EC

  • residual_kwargs – Keyword arguments passed to ResIN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#

Returns dictionary of the following:

  • W: Edge weights

  • node_embedding: Last node embedding (result of last interaction network)

  • edge_embedding: Last edge embedding (result of last interaction network)

class gnn_tracking.models.edge_classifier.PerfectEdgeClassification(tpr=1.0, tnr=1.0, false_below_pt=0.0)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

An edge classifier that is perfect because it uses the truth information. If TPR or TNR is not 1.0, noise is added to the truth information.

This can be used to evaluate the maximal possible performance of a model that relies on edge classification as a first step (e.g., the object condensation approach).

Parameters:
  • tpr – True positive rate

  • tnr – False positive rate

  • false_below_pt – If not 0.0, all true edges between hits corresponding to particles with a pt lower than this threshold are set to false. This is not counted towards the TPR/TNR but applied afterwards.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.edge_classifier.ECFromChkpt#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

For easy loading of an pretrained EC from a lightning yaml config.

Parameters:
  • chkpt_path – Path to the checkpoint file.

  • class_name – Name of the lightning module that was used to train the EC. Default should work for most cases.

  • freeze – If True, the model is frozen (i.e., its parameters are not trained).