gnn_tracking.models.mlp#
Fully connected neural network implementations
Classes#
Multi Layer Perceptron, using ReLu as activation function. |
|
Fully connected NN with residual connections. |
|
Separate FCNNs for pixel and strip data, with residual connections. |
|
Fully connected NN w/ residual connections and Gaussian init |
Functions#
|
Module Contents#
- class gnn_tracking.models.mlp.MLP(input_size: int, output_size: int, hidden_dim: int | None, L=3, *, bias=True, include_last_activation=False)#
Bases:
torch.nn.ModuleMulti Layer Perceptron, using ReLu as activation function.
- Parameters:
input_size – Input feature dimension
output_size – Output feature dimension
hidden_dim – Feature dimension of the hidden layers. If None: Choose maximum of input/output size
L – Total number of layers (1 initial layer, L-2 hidden layers, 1 output layer)
bias – Include bias in linear layer?
include_last_activation – Include activation function for the last layer?
- layers#
- reset_parameters()#
- forward(x)#
- class gnn_tracking.models.mlp.ResFCNN(*, in_dim: int, hidden_dim: int, out_dim: int, depth: int, alpha: float = 0.6, bias: bool = True)#
Bases:
torch.nn.ModuleFully connected NN with residual connections.
- Parameters:
in_dim – Input dimension
hidden_dim – Hidden dimension
out_dim – Output dimension = embedding space
depth – 1 input encoder layer, depth-1 hidden layers, 1 output encoder layer
alpha – strength of the residual connection
- _encoder#
- _decoder#
- _layers#
- _alpha = 0.6#
- static _reset_layer_parameters(layer, var: float)#
- forward(x: torch.Tensor, **ignore) torch.Tensor#
- gnn_tracking.models.mlp.get_pixel_mask(layer: torch.Tensor) torch.Tensor#
- class gnn_tracking.models.mlp.HeterogeneousResFCNN(*, in_dim: int, out_dim: int, hidden_dim: int, depth: int, alpha: float = 0.6, bias: bool = True)#
Bases:
torch.nn.ModuleSeparate FCNNs for pixel and strip data, with residual connections. For parameters, see ResFCNN.
- pixel_fcnn#
- strip_fcnn#
- forward(x: torch.Tensor, layer: torch.Tensor) torch.Tensor#
- class gnn_tracking.models.mlp.ResMLP(in_dim: int, out_dim: int, hidden_dim: int, depth: int = 4, beta: float = 1.0, gamma_0: float = 1.0, eta_0: float = 0.01, activation: torch.nn.Module = Tanh, optimizer: str = 'adam', bias: bool = True, **kwargs)#
Bases:
torch.nn.ModuleFully connected NN w/ residual connections and Gaussian init Args: in_dim: input dimension out_dim: output dimension width: # neurons per internal layer beta: strength of the residual connection gamma_0: tuning of final layer output normalisation depth: number of hidden layers
- layers#
- in_dim#
- out_dim#
- width#
- beta = 1.0#
- gamma_0 = 1.0#
- eta_0 = 0.01#
- gamma#
- depth = 4#
- act#
- lr#
- reset_parameters()#
- get_lr(optimizer)#
- forward(x)#