gnn_tracking.models.noise_classification#
Models for filtering out noise before we even build a graph
Classes#
Remove all noise with truth information |
|
Combine a noise filter with another model |
Module Contents#
- class gnn_tracking.models.noise_classification.TruthNoiseClassifierModel#
Bases:
torch.nn.Module
Remove all noise with truth information
- forward(data: torch_geometric.data.Data) torch_geometric.data.Data #
- class gnn_tracking.models.noise_classification.WithNoiseClassification(noise_model: torch.nn.Module, model: torch.nn.Module)#
Bases:
torch.nn.Module
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
Combine a noise filter with another model
- noise_model#
- model#
- forward(data: torch_geometric.data.Data) dict[str, torch.Tensor | None] #