gnn_tracking.models.edge_filter#

Models for edge filtering (same as edge classification, but without message passing, i.e., making decisions solely based on edge features of the edge under consideration)

Classes#

MLP

Multi Layer Perceptron, using ReLu as activation function.

EFDeepSet

EdgeFilter based on the deep sets architecture

EFMLP

EdgeFilter based on an MLP architecture.

GeometricEF

Edge filter with geometric cuts only (no learning required).

Functions#

assert_feat_dim(→ None)

Module Contents#

class gnn_tracking.models.edge_filter.MLP(input_size: int, output_size: int, hidden_dim: int | None, L=3, *, bias=True, include_last_activation=False)#

Bases: torch.nn.Module

Multi Layer Perceptron, using ReLu as activation function.

Parameters:
  • input_size – Input feature dimension

  • output_size – Output feature dimension

  • hidden_dim – Feature dimension of the hidden layers. If None: Choose maximum of input/output size

  • L – Total number of layers (1 initial layer, L-2 hidden layers, 1 output layer)

  • bias – Include bias in linear layer?

  • include_last_activation – Include activation function for the last layer?

reset_parameters()#
forward(x)#
gnn_tracking.models.edge_filter.assert_feat_dim(feat_vec: torch.Tensor, dim: int) None#
class gnn_tracking.models.edge_filter.EFDeepSet(*, in_dim: int = 14, hidden_dim: int = 128, depth: int = 3)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

EdgeFilter based on the deep sets architecture

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.edge_filter.EFMLP(*, node_indim: int, edge_indim: int = 0, hidden_dim: int, depth: int, beta: float = 0.4)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

EdgeFilter based on an MLP architecture.

Parameters:
  • node_indim – dimension of the node features

  • edge_indim – dimension of the edge features. If set to 0: do not use edge features.

  • hidden_dim – dimension of the hidden layers

  • depth – number of hidden layers

  • beta – temperature parameter for the softmax

reset_parameters()#
static _reset_layer_parameters(layer, var: float)#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.edge_filter.GeometricEF(phi_slope_max, z0_max, dR_max)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Edge filter with geometric cuts only (no learning required).

forward(data: torch_geometric.data.Data)#