:py:mod:`gnn_tracking.graph_construction.k_scanner`
===================================================

.. py:module:: gnn_tracking.graph_construction.k_scanner

.. autoapi-nested-parse::

   Finds the right k for k-NN to optimize the graph construction.



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.graph_construction.k_scanner.KScanResults
   gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner




Attributes
~~~~~~~~~~

.. autoapisummary::

   gnn_tracking.graph_construction.k_scanner._DEFAULT_KS


.. py:class:: KScanResults(results: pandas.DataFrame, targets: Sequence[float])


   This object holds the results of scanning over ks. It performs
   interpolation to get the figures of merit (FOMs).

   :param results: The results of the scan: (k, n_edges, frac50, ...)
   :param targets: The targets 50%-segment fractions that we're interested in

   .. py:attribute:: _extra_metrics
      :value: ('k', 'frac75', 'frac100', 'efficiency', 'purity')

      

   .. py:method:: get_foms() -> dict[str, float]

      Get figures of merit


   .. py:method:: plot() -> matplotlib.pyplot.Axes

      Plot interpolation


   .. py:method:: _spline() -> tuple[scipy.interpolate.CubicSpline, list[str], list[str]]

      Spline object. Do not use this object directly but rather
      only via `_eval_spline`.

      :returns:

                Spline object, list of columns that are nan, list of columns that are not
                    nan


   .. py:method:: _eval_spline(k: float) -> dict[str, float]

      Get figures of merit at k, using a spline for evaluation
      between data points.


   .. py:method:: _get_target_k(target: float) -> float

      K at which the 50%-segment fraction = target


   .. py:method:: _get_foms_at_target(target: float) -> dict[str, float]

      Get figures of merit at k given by `self._get_target_k`



.. py:data:: _DEFAULT_KS

   

.. py:class:: GraphConstructionKNNScanner(ks: list[int] = _DEFAULT_KS, *, targets=(0.8, 0.85, 0.88, 0.9, 0.93, 0.95, 0.97, 0.99), max_radius=1.0, pt_thld=0.9, max_eta=4.0, subsample_pids: int | None = None, max_edges=5000000)


   Bases: :py:obj:`pytorch_lightning.core.mixins.HyperparametersMixin`

   Scan over different values of k to build a graph and calculate the figures
   of merit.

   :param ks: List of ks to scan. Results will be interpolated between these values,
              so it's a good idea to not make them too dense.
   :param targets: Targets for the 50%-segment fraction that we aim for (we will find
                   the k that gets us closest to these targets and report the number
                   of edges for these places). Does not impact compute time.
   :param max_radius: Maximum length of edges for the KNN graph.
   :param pt_thld: pt threshold for evaluation of the 50%-segment fraction.
   :param subsample_pids: Set to a number to subsample the number of pids in the
                          evaluation of the 50%-segment fraction. This is useful for speeding
                          up the evaluation of the 50%-segment fraction, but it will lead to
                          a less accurate result/statistical fluctuations.
   :param max_edges: Do not attempt to compute metrics for more than this number of
                     edges in the knn graph

   .. py:property:: results_raw
      :type: pandas.DataFrame

      DataFrame with raw results for all graphs and all k

   .. py:method:: get_results() -> KScanResults

      Get results object


   .. py:method:: get_foms() -> dict[str, float]

      Get figures of merit (convenience method that uses the appropriate method
      of `KSCanResults`).


   .. py:method:: reset()

      Reset the results. Will be automatically called every time we run on
      a batch with `i_batch == 0`.


   .. py:method:: __call__(data: torch_geometric.data.Data, i_batch: int, *, progress=False, latent: torch.Tensor | None = None) -> None

      Run on graph

      :param data: Data object. `data.x` is the space used for clustering
      :param i_batch: Batch number. Will reset saved data for `i_batch == 0`.
      :param progress: Show progress bar
      :param latent: Use this instead of `data.x`

      :returns: None


   .. py:method:: _evaluate_tracking_metrics_upper_bounds(data: torch_geometric.data.Data) -> dict[str, float]

      Evaluate upper bounds of tracking metrics assuming a pipeline with
      perfect EC. See https://arxiv.org/abs/2309.16754


   .. py:method:: _evaluate_graph(data: torch_geometric.data.Data, k: int) -> dict[str, float] | None

      Evaluate metrics for single graphs

      :param data:
      :param k:

      :returns: None if computation was aborted



