gnn_tracking.models.noise_classification#

Models for filtering out noise before we even build a graph

Module Contents#

Classes#

TruthNoiseClassifierModel

Remove all noise with truth information

WithNoiseClassification

Combine a noise filter with another model

class gnn_tracking.models.noise_classification.TruthNoiseClassifierModel#

Bases: torch.nn.Module

Remove all noise with truth information

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#
class gnn_tracking.models.noise_classification.WithNoiseClassification(noise_model: torch.nn.Module, model: torch.nn.Module)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Combine a noise filter with another model

forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#