gnn_tracking.metrics.losses#

This module contains loss functions for the GNN tracking model.

Submodules#

Attributes#

logger

Classes#

MultiLossFctReturn

Return type for loss functions that return multiple losses.

MultiLossFct

Base class for loss functions that return multiple losses.

DummyMultiLoss

Dummy loss function that returns the sum of the x input.

LossClones

Wrapper for a loss function that evaluates it on multiple inputs.

Package Contents#

gnn_tracking.metrics.losses.logger#
class gnn_tracking.metrics.losses.MultiLossFctReturn#

Return type for loss functions that return multiple losses.

loss_dct: dict[str, torch.Tensor]#
weight_dct: dict[str, torch.Tensor] | dict[str, float]#
extra_metrics: dict[str, Any]#
__post_init__() None#
property loss: torch.Tensor#
property weighted_losses: dict[str, torch.Tensor]#
class gnn_tracking.metrics.losses.MultiLossFct#

Bases: torch.nn.Module

Base class for loss functions that return multiple losses.

forward(*args: Any, **kwargs: Any) MultiLossFctReturn#
class gnn_tracking.metrics.losses.DummyMultiLoss#

Bases: MultiLossFct

Dummy loss function that returns the sum of the x input. This can be used to quickly test the speed of the training loop without any actual loss function.

forward(x: torch.Tensor, **kwargs: Any) MultiLossFctReturn#
class gnn_tracking.metrics.losses.LossClones(loss: torch.nn.Module, prefixes=('w', 'y'))#

Bases: torch.nn.Module

Wrapper for a loss function that evaluates it on multiple inputs. The forward method will look for all model outputs that start with w_ (or another specified prefix) and evaluate the loss function for each of them, returning a dictionary of losses (with keys equal to the suffixes).

Usage example 1:

losses = {
    "potential": (PotentialLoss(), 1.),
    "edge": (LossClones(EdgeWeightBCELoss()), [1.0, 2.0, 3.0])
}

will evaluate three clones of the BCE loss function, one for each EC layer.

Usage Example 2:

losses = {
    "potential": (PotentialLoss(), 1.),
    "edge": (LossClones(EdgeWeightBCELoss()), {}))
}

this works with a variable number of layers. The weights are all 1.

Under the hood, ECLossClones(EdgeWeightBCELoss())(model output) will output a dictionary of the individual losses, keyed by their suffixes (in a similar way to how PotentialLoss returns a dictionary of losses).

Parameters:
  • loss – Loss function to be evaluated on multiple inputs

  • prefixes – Prefixes of the model outputs that should be evaluated. An underscore is assumed (set prefix to w for w_0, w_1, etc.)

forward(**kwargs) dict[str, torch.Tensor]#