gnn_tracking.utils.graph_masks
#
Module Contents#
Functions#
|
Get a mask for nodes that are included in metrics and more. |
|
Get a mask for edges that are between two nodes that are both in the node |
- gnn_tracking.utils.graph_masks.get_good_node_mask(data: torch_geometric.data.Data, *, pt_thld: float = 0.9, max_eta: float = 4.0) torch.Tensor #
Get a mask for nodes that are included in metrics and more. This includes lower limit on pt, not noise, reconstructable, cut on eta.
- gnn_tracking.utils.graph_masks.get_edge_mask_from_node_mask(node_mask: torch.Tensor, edge_index: torch.Tensor) torch.Tensor #
Get a mask for edges that are between two nodes that are both in the node mask.