:py:mod:`gnn_tracking.training.ml`
==================================

.. py:module:: gnn_tracking.training.ml

.. autoapi-nested-parse::

   Pytorch lightning module with training and validation step for the metric learning
   approach to graph construction.



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.training.ml.MLModule




.. py:class:: MLModule(*, loss_fct: gnn_tracking.metrics.losses.MultiLossFct, gc_scanner: gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner | None = None, **kwargs)


   Bases: :py:obj:`gnn_tracking.training.base.TrackingModule`

   Pytorch lightning module with training and validation step for the metric
   learning approach to graph construction.

   .. py:method:: get_losses(out: dict[str, Any], data: torch_geometric.data.Data) -> tuple[torch.Tensor, dict[str, float]]


   .. py:method:: training_step(batch: torch_geometric.data.Data, batch_idx: int) -> torch.Tensor | None


   .. py:method:: validation_step(batch: torch_geometric.data.Data, batch_idx: int)


   .. py:method:: on_validation_epoch_end() -> None


   .. py:method:: highlight_metric(metric: str) -> bool



