gnn_tracking.models.edge_classifier#

Models for edge classification

Module Contents#

Classes#

ECForGraphTCN

Edge classification step to be used for Graph Track Condensor network

PerfectEdgeClassification

An edge classifier that is perfect because it uses the truth information.

ECFromChkpt

For easy loading of an pretrained EC from a lightning yaml config.

class gnn_tracking.models.edge_classifier.ECForGraphTCN(*, node_indim: int, edge_indim: int, interaction_node_dim: int = 5, interaction_edge_dim: int = 4, hidden_dim: int | float | None = None, L_ec: int = 3, alpha: float = 0.5, residual_type='skip1', use_intermediate_edge_embeddings: bool = True, use_node_embedding: bool = True, residual_kwargs: dict | None = None)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Edge classification step to be used for Graph Track Condensor network (Graph TCN)

Parameters:
  • node_indim – Node feature dim

  • edge_indim – Edge feature dim

  • interaction_node_dim – Node dimension for interaction networks. Defaults to 5 for backward compatibility, but this is probably not reasonable.

  • interaction_edge_dim – Edge dimension of interaction networks Defaults to 4 for backward compatibility, but this is probably not reasonable.

  • hidden_dim – width of hidden layers in all perceptrons (edge and node encoders, hidden dims for MLPs in object and relation networks). If None: choose as maximum of input/output dims for each MLP separately

  • L_ec – message passing depth for edge classifier

  • alpha – strength of residual connection for EC

  • residual_type – type of residual connection for EC

  • use_intermediate_edge_embeddings – If true, don’t only feed the final encoding of the stacked interaction networks to the final MLP, but all intermediate encodings

  • use_node_embedding – If true, feed node attributes to the final MLP for EC

  • residual_kwargs – Keyword arguments passed to ResIN

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#

Returns dictionary of the following:

  • W: Edge weights

  • node_embedding: Last node embedding (result of last interaction network)

  • edge_embedding: Last edge embedding (result of last interaction network)

class gnn_tracking.models.edge_classifier.PerfectEdgeClassification(tpr=1.0, tnr=1.0, false_below_pt=0.0)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

An edge classifier that is perfect because it uses the truth information. If TPR or TNR is not 1.0, noise is added to the truth information.

This can be used to evaluate the maximal possible performance of a model that relies on edge classification as a first step (e.g., the object condensation approach).

Parameters:
  • tpr – True positive rate

  • tnr – False positive rate

  • false_below_pt – If not 0.0, all true edges between hits corresponding to particles with a pt lower than this threshold are set to false. This is not counted towards the TPR/TNR but applied afterwards.

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.edge_classifier.ECFromChkpt#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

For easy loading of an pretrained EC from a lightning yaml config.

Parameters:
  • chkpt_path – Path to the checkpoint file.

  • class_name – Name of the lightning module that was used to train the EC. Default should work for most cases.

  • freeze – If True, the model is frozen (i.e., its parameters are not trained).