gnn_tracking.models.mlp#

Fully connected neural network implementations

Classes#

MLP

Multi Layer Perceptron, using ReLu as activation function.

ResFCNN

Fully connected NN with residual connections.

HeterogeneousResFCNN

Separate FCNNs for pixel and strip data, with residual connections.

Functions#

get_pixel_mask(→ torch.Tensor)

Module Contents#

class gnn_tracking.models.mlp.MLP(input_size: int, output_size: int, hidden_dim: int | None, L=3, *, bias=True, include_last_activation=False)#

Bases: torch.nn.Module

Multi Layer Perceptron, using ReLu as activation function.

Parameters:
  • input_size – Input feature dimension

  • output_size – Output feature dimension

  • hidden_dim – Feature dimension of the hidden layers. If None: Choose maximum of input/output size

  • L – Total number of layers (1 initial layer, L-2 hidden layers, 1 output layer)

  • bias – Include bias in linear layer?

  • include_last_activation – Include activation function for the last layer?

layers: list[torch.nn.Module]#
reset_parameters()#
forward(x)#
class gnn_tracking.models.mlp.ResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6, bias: bool = True)#

Bases: torch.nn.Module

Fully connected NN with residual connections.

Parameters:
  • in_dim – Input dimension

  • hidden_dim – Hidden dimension

  • out_dim – Output dimension = embedding space

  • depth – 1 input encoder layer, depth-1 hidden layers, 1 output encoder layer

  • alpha – strength of the residual connection

_encoder#
_decoder#
_layers#
_alpha#
static _reset_layer_parameters(layer, var: float)#
forward(x: torch.Tensor, **ignore) torch.Tensor#
gnn_tracking.models.mlp.get_pixel_mask(layer: torch.Tensor) torch.Tensor#
class gnn_tracking.models.mlp.HeterogeneousResFCNN(*, in_dim: int, out_dim: int, hidden_dim: int, depth: int, alpha: float = 0.6, bias: bool = True)#

Bases: torch.nn.Module

Separate FCNNs for pixel and strip data, with residual connections. For parameters, see ResFCNN.

pixel_fcnn#
strip_fcnn#
forward(x: torch.Tensor, layer: torch.Tensor) torch.Tensor#