gnn_tracking.models.edge_filter#

Models for edge filtering (same as edge classification, but without message passing, i.e., making decisions solely based on edge features of the edge under consideration)

Classes#

EFDeepSet

EdgeFilter based on the deep sets architecture

EFMLP

EdgeFilter based on an MLP architecture.

GeometricEF

Edge filter with geometric cuts only (no learning required).

Module Contents#

class gnn_tracking.models.edge_filter.EFDeepSet(*, in_dim: int = 14, hidden_dim: int = 128, depth: int = 3)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

EdgeFilter based on the deep sets architecture

node_encoder#
aggregator#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.edge_filter.EFMLP(*, node_indim: int, edge_indim: int = 0, hidden_dim: int, depth: int, beta: float = 0.4)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

EdgeFilter based on an MLP architecture.

Parameters:
  • node_indim – dimension of the node features

  • edge_indim – dimension of the edge features. If set to 0: do not use edge features.

  • hidden_dim – dimension of the hidden layers

  • depth – number of hidden layers

  • beta – temperature parameter for the softmax

encoder#
decoder#
layers#
reset_parameters()#
static _reset_layer_parameters(layer, var: float)#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor]#
class gnn_tracking.models.edge_filter.GeometricEF(phi_slope_max, z0_max, dR_max)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Edge filter with geometric cuts only (no learning required).

forward(data: torch_geometric.data.Data)#