gnn_tracking.models.edge_classifier
#
Models for edge classification
Module Contents#
Classes#
Edge classification step to be used for Graph Track Condensor network |
|
An edge classifier that is perfect because it uses the truth information. |
|
For easy loading of an pretrained EC from a lightning yaml config. |
- class gnn_tracking.models.edge_classifier.ECForGraphTCN(*, node_indim: int, edge_indim: int, interaction_node_dim: int = 5, interaction_edge_dim: int = 4, hidden_dim: int | float | None = None, L_ec: int = 3, alpha: float = 0.5, residual_type='skip1', use_intermediate_edge_embeddings: bool = True, use_node_embedding: bool = True, residual_kwargs: dict | None = None)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
Edge classification step to be used for Graph Track Condensor network (Graph TCN)
- Parameters:
node_indim – Node feature dim
edge_indim – Edge feature dim
interaction_node_dim – Node dimension for interaction networks. Defaults to 5 for backward compatibility, but this is probably not reasonable.
interaction_edge_dim – Edge dimension of interaction networks Defaults to 4 for backward compatibility, but this is probably not reasonable.
hidden_dim – width of hidden layers in all perceptrons (edge and node encoders, hidden dims for MLPs in object and relation networks). If None: choose as maximum of input/output dims for each MLP separately
L_ec – message passing depth for edge classifier
alpha – strength of residual connection for EC
residual_type – type of residual connection for EC
use_intermediate_edge_embeddings – If true, don’t only feed the final encoding of the stacked interaction networks to the final MLP, but all intermediate encodings
use_node_embedding – If true, feed node attributes to the final MLP for EC
residual_kwargs – Keyword arguments passed to ResIN
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor] #
Returns dictionary of the following:
W
: Edge weightsnode_embedding
: Last node embedding (result of last interaction network)edge_embedding
: Last edge embedding (result of last interaction network)
- class gnn_tracking.models.edge_classifier.PerfectEdgeClassification(tpr=1.0, tnr=1.0, false_below_pt=0.0)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
An edge classifier that is perfect because it uses the truth information. If TPR or TNR is not 1.0, noise is added to the truth information.
This can be used to evaluate the maximal possible performance of a model that relies on edge classification as a first step (e.g., the object condensation approach).
- Parameters:
tpr – True positive rate
tnr – False positive rate
false_below_pt – If not 0.0, all true edges between hits corresponding to particles with a pt lower than this threshold are set to false. This is not counted towards the TPR/TNR but applied afterwards.
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor] #
- class gnn_tracking.models.edge_classifier.ECFromChkpt#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.HyperparametersMixin
For easy loading of an pretrained EC from a lightning yaml config.
- Parameters:
chkpt_path – Path to the checkpoint file.
class_name – Name of the lightning module that was used to train the EC. Default should work for most cases.
freeze – If True, the model is frozen (i.e., its parameters are not trained).