gnn_tracking.models.edge_filter#
Models for edge filtering (same as edge classification, but without message passing, i.e., making decisions solely based on edge features of the edge under consideration)
Classes#
EdgeFilter based on the deep sets architecture |
|
EdgeFilter based on an MLP architecture. |
|
Edge filter with geometric cuts only (no learning required). |
Module Contents#
- class gnn_tracking.models.edge_filter.EFDeepSet(*, in_dim: int = 14, hidden_dim: int = 128, depth: int = 3)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
EdgeFilter based on the deep sets architecture
- node_encoder#
- aggregator#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor] #
- class gnn_tracking.models.edge_filter.EFMLP(*, node_indim: int, edge_indim: int = 0, hidden_dim: int, depth: int, beta: float = 0.4)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
EdgeFilter based on an MLP architecture.
- Parameters:
node_indim – dimension of the node features
edge_indim – dimension of the edge features. If set to 0: do not use edge features.
hidden_dim – dimension of the hidden layers
depth – number of hidden layers
beta – temperature parameter for the softmax
- encoder#
- decoder#
- layers#
- reset_parameters()#
- static _reset_layer_parameters(layer, var: float)#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor] #