gnn_tracking.models.noise_classification#

Models for filtering out noise before we even build a graph

Classes#

TruthNoiseClassifierModel

Remove all noise with truth information

WithNoiseClassification

Combine a noise filter with another model

Module Contents#

class gnn_tracking.models.noise_classification.TruthNoiseClassifierModel#

Bases: torch.nn.Module

Remove all noise with truth information

forward(data: torch_geometric.data.Data) torch_geometric.data.Data#
class gnn_tracking.models.noise_classification.WithNoiseClassification(noise_model: torch.nn.Module, model: torch.nn.Module)#

Bases: torch.nn.Module, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Combine a noise filter with another model

noise_model#
model#
forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None]#