:py:mod:`gnn_tracking.postprocessing.dbscanscanner`
===================================================

.. py:module:: gnn_tracking.postprocessing.dbscanscanner


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking.postprocessing.dbscanscanner.OCScanResults
   gnn_tracking.postprocessing.dbscanscanner.DBSCANHyperParamScanner
   gnn_tracking.postprocessing.dbscanscanner.DBSCANHyperParamScannerFixed
   gnn_tracking.postprocessing.dbscanscanner.DBSCANPerformanceDetails



Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.postprocessing.dbscanscanner.dbscan



.. py:function:: dbscan(graphs: numpy.ndarray, eps=0.99, min_samples=1) -> numpy.ndarray

   Convenience wrapper around `sklearn`'s DBSCAN implementation.


.. py:class:: OCScanResults(df: pandas.DataFrame)


   Restults of `DBSCANHyperparamScanner` and friends.

   .. py:property:: df
      :type: pandas.DataFrame


   .. py:property:: df_mean
      :type: pandas.DataFrame

      Mean and std grouped by hyperparameters.

   .. py:method:: get_foms(guide='double_majority_pt0.9') -> dict[str, float]

      Get figures of merit


   .. py:method:: get_n_best_trials(n: int, guide='double_majority_pt0.9') -> list[dict[str, float]]



.. py:class:: DBSCANHyperParamScanner(*, eps_range=(0, 1), min_samples_range=(1, 4), n_trials=10, keep_best=0, n_jobs: int | None = None, guide: str = 'double_majority_pt0.9', pt_thlds=(0.0, 0.5, 0.9, 1.5), max_eta: float = 4.0)


   Bases: :py:obj:`gnn_tracking.postprocessing.clusterscanner.ClusterScanner`

   Scan for hyperparameters of DBSCAN. Use this scanner for validation.
   Even with few trials, it will eventually apply finer samples to the best
   region, because it will keep the best trials from the previous epoch
   (make sure th choose non-zero ``kep_best``).

   :param eps_range: Range of DBSCAN radii to scan
   :param min_samples_range: Range (INCLUSIVE!) of minimum number of samples for
                             DBSCAN
   :param n_trials: Total number of trials
   :param keep_best: Keep this number of the best `(eps, min_samples)` pairs from
                     the current epoch and make sure to scan over them again in the next
                     epoch.
   :param n_jobs: Number of jobs to use for parallelization
   :param guide: Report tracking metrics for parameters that maximize this metric
   :param pt_thlds: list of pT thresholds for the tracking metrics
   :param max_eta: Max eta for tracking metrics

   .. py:method:: get_results() -> OCScanResults


   .. py:method:: get_foms() -> dict[str, float]


   .. py:method:: _get_best_trials() -> list[dict[str, float]]


   .. py:method:: _reset_trials() -> None


   .. py:method:: reset()

      Reset the results. Will be automatically called every time we run on
      a batch with `i_batch == 0`.


   .. py:method:: __call__(data: torch_geometric.data.Data, out: dict[str, torch.Tensor], i_batch: int, *, progress=False)



.. py:class:: DBSCANHyperParamScannerFixed(trials: list[dict[str, float]], *, n_jobs: int | None = None, pt_thlds=(0.0, 0.5, 0.9, 1.5), max_eta: float = 4.0)


   Bases: :py:obj:`DBSCANHyperParamScanner`

   Scan grid for hyperparameters of DBSCAN. While `DBSCANHyperParamScanner`
   is for use in validation steps, this is for use in detailed testing.

   :param trials: List of trials to run
   :param n_jobs: Number of jobs to use for parallelization
   :param pt_thlds: list of pT thresholds for the tracking metrics
   :param max_eta: Max eta for tracking metrics

   .. py:method:: _reset_trials() -> None



.. py:class:: DBSCANPerformanceDetails(eps: float, min_samples: int)


   Bases: :py:obj:`DBSCANHyperParamScanner`

   Get information about detailed performance for fixed DBSCAN parameters.
   See `get_results` for outputs.

   :param eps: DBSCAN epsilon
   :param min_samples: DBSCAN min_samples

   .. py:method:: __call__(data: torch_geometric.data.Data, out: dict[str, torch.Tensor], i_batch: int) -> None


   .. py:method:: get_results() -> tuple[list[pandas.DataFrame], list[pandas.DataFrame]]

      Get results

      :returns: Tuple of (h_dfs, c_dfs), where h_dfs is a list of dataframes with
                information about all hits and c_dfs is a list of dataframes with
                information about all clusters.
                See `tracking_metric_df` for details about the information about both
                dataframes..


   .. py:method:: get_foms() -> dict[str, float]



