gnn_tracking.utils.graph_masks#

Functions#

get_good_node_mask(→ torch.Tensor)

Get a mask for nodes that are included in metrics and more.

get_good_node_mask_tensors(→ torch.Tensor)

See get_good_node_mask

get_edge_mask_from_node_mask(→ torch.Tensor)

Get a mask for edges that are between two nodes that are both in the node

Module Contents#

gnn_tracking.utils.graph_masks.get_good_node_mask(data: torch_geometric.data.Data, *, pt_thld: float = 0.9, max_eta: float = 4.0) torch.Tensor#

Get a mask for nodes that are included in metrics and more. This includes lower limit on pt, not noise, reconstructable, cut on eta.

gnn_tracking.utils.graph_masks.get_good_node_mask_tensors(*, pt, particle_id, reconstructable, eta, pt_thld: float = 0.9, max_eta: float = 4.0) torch.Tensor#

See get_good_node_mask

gnn_tracking.utils.graph_masks.get_edge_mask_from_node_mask(node_mask: torch.Tensor, edge_index: torch.Tensor) torch.Tensor#

Get a mask for edges that are between two nodes that are both in the node mask.