gnn_tracking.metrics.losses#
This module contains loss functions for the GNN tracking model.
Submodules#
Attributes#
Classes#
Return type for loss functions that return multiple losses. |
|
Base class for loss functions that return multiple losses. |
|
Dummy loss function that returns the sum of the x input. |
|
Wrapper for a loss function that evaluates it on multiple inputs. |
Package Contents#
- gnn_tracking.metrics.losses.logger#
- class gnn_tracking.metrics.losses.MultiLossFctReturn#
Return type for loss functions that return multiple losses.
- loss_dct: dict[str, torch.Tensor]#
- weight_dct: dict[str, torch.Tensor] | dict[str, float]#
- extra_metrics: dict[str, Any]#
- __post_init__() None #
- property loss: torch.Tensor#
- property weighted_losses: dict[str, torch.Tensor]#
- class gnn_tracking.metrics.losses.MultiLossFct#
Bases:
torch.nn.Module
Base class for loss functions that return multiple losses.
- forward(*args: Any, **kwargs: Any) MultiLossFctReturn #
- class gnn_tracking.metrics.losses.DummyMultiLoss#
Bases:
MultiLossFct
Dummy loss function that returns the sum of the x input. This can be used to quickly test the speed of the training loop without any actual loss function.
- forward(x: torch.Tensor, **kwargs: Any) MultiLossFctReturn #
- class gnn_tracking.metrics.losses.LossClones(loss: torch.nn.Module, prefixes=('w', 'y'))#
Bases:
torch.nn.Module
Wrapper for a loss function that evaluates it on multiple inputs. The forward method will look for all model outputs that start with w_ (or another specified prefix) and evaluate the loss function for each of them, returning a dictionary of losses (with keys equal to the suffixes).
Usage example 1:
losses = { "potential": (PotentialLoss(), 1.), "edge": (LossClones(EdgeWeightBCELoss()), [1.0, 2.0, 3.0]) }
will evaluate three clones of the BCE loss function, one for each EC layer.
Usage Example 2:
losses = { "potential": (PotentialLoss(), 1.), "edge": (LossClones(EdgeWeightBCELoss()), {})) }
this works with a variable number of layers. The weights are all 1.
Under the hood,
ECLossClones(EdgeWeightBCELoss())(model output)
will output a dictionary of the individual losses, keyed by their suffixes (in a similar way to how PotentialLoss returns a dictionary of losses).- Parameters:
loss – Loss function to be evaluated on multiple inputs
prefixes – Prefixes of the model outputs that should be evaluated. An underscore is assumed (set prefix to w for w_0, w_1, etc.)
- _loss#
- _prefixes#
- forward(**kwargs) dict[str, torch.Tensor] #