:py:mod:`gnn_tracking.utils.graph_masks`
========================================

.. py:module:: gnn_tracking.utils.graph_masks


Module Contents
---------------


Functions
~~~~~~~~~

.. autoapisummary::

   gnn_tracking.utils.graph_masks.get_good_node_mask
   gnn_tracking.utils.graph_masks.get_good_node_mask_tensors
   gnn_tracking.utils.graph_masks.get_edge_mask_from_node_mask



.. py:function:: get_good_node_mask(data: torch_geometric.data.Data, *, pt_thld: float = 0.9, max_eta: float = 4.0) -> torch.Tensor

   Get a mask for nodes that are included in metrics and more.
   This includes lower limit on pt, not noise, reconstructable, cut on eta.


.. py:function:: get_good_node_mask_tensors(*, pt, particle_id, reconstructable, eta, pt_thld: float = 0.9, max_eta: float = 4.0) -> torch.Tensor

   See `get_good_node_mask`


.. py:function:: get_edge_mask_from_node_mask(node_mask: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor

   Get a mask for edges that are between two nodes that are both in the node
   mask.


