gnn_tracking.graph_construction.graph_builder#

Module Contents#

Classes#

GraphBuilder

Build graphs out of the input data.

Functions#

get_two_hop_tuples(→ set[tuple[int, int]])

Given a list of tuples (a, b), returns the set of tuples (x, y)

gnn_tracking.graph_construction.graph_builder.get_two_hop_tuples(tuples: list[tuple[int, int]]) set[tuple[int, int]]#

Given a list of tuples (a, b), returns the set of tuples (x, y) where (x, t) and (t, y) are in the input list.

class gnn_tracking.graph_construction.graph_builder.GraphBuilder(indir: str | os.PathLike, outdir: str | os.PathLike, *, pixel_only=True, redo=True, phi_slope_max=0.005, z0_max=200, dR_max=1.7, remove_intersecting=True, directed=False, measurement_mode=False, write_output=True, log_level=0, collect_data=True, edge_augmentation: str | None = None)#

Build graphs out of the input data.

Parameters:
  • indir

  • outdir

  • pixel_only – Only consider pixel detector

  • redo

  • phi_slope_max

  • z0_max

  • dR_max

  • remove_intersecting – Remove “ambiguous” edges, see Fig. 3 in “Charged particle tracking via edge-classifying interaction networks” http://arxiv.org/abs/2103.16701 and mark the remaining ones as incorrect edges

  • directed – Build directed edges

  • measurement_mode

  • write_output – Save graphs?

  • log_level

  • collect_data – Deprecated: Directly load the data into memory

  • edge_augmentation – Add more edges (e.g., adding next-neighbor connections). Needs remove_intersecting to be false

property data_list#
get_measurements()#
calc_dphi(phi1: numpy.ndarray, phi2: numpy.ndarray) numpy.ndarray#

Computes phi2-phi1 given in range [-pi,pi]

calc_eta(r: numpy.ndarray, z: numpy.ndarray) numpy.ndarray#

Computes pseudorapidity (https://en.wikipedia.org/wiki/Pseudorapidity)

get_dataframe(evt: torch_geometric.data.Data, evtid: int) pandas.DataFrame#

Converts pytorch geometric data object to pandas dataframe

Parameters:
  • evt – pytorch geometric data object

  • evtid – event id

Returns:

pandas dataframe

select_edges(hits1: pandas.DataFrame, hits2: pandas.DataFrame, layer1: int, layer2: int) pandas.DataFrame#

Build edges between two layers

Parameters:
  • hits1 – Information about hit 1

  • hits2 – Information about hit 2

  • layer1 – Layer number for hit 1

  • layer2 – Layer number for hit 2

Returns:

Dictionary containing edge indices and extra information

correct_truth_labels(hits: pandas.DataFrame, edges: pandas.DataFrame, y: numpy.ndarray, particle_ids: numpy.ndarray) tuple[numpy.ndarray, int]#

Corrects for extra edges surviving the barrel intersection cut, i.e. for each particle counts the number of extra “transition edges” crossing from a barrel layer to an innermost endcap layer; the sum is n_incorrect - [edges] = n_edges x 2 - [y] = n_edges - [particle_ids] = n_edges

Returns:

corrected truth labels, number of incorrect edges

build_edges(hits: pandas.DataFrame) tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]#

Build edges between hits

Parameters:

hits – Point cloud dataframe

Returns:

edge_index (2 x num edges), edge_attr (edge features x num edges), y (truth label, shape = num edges), edge_pt (pt of track belong to first hit)

to_pyg_data(point_cloud, edge_index: numpy.ndarray, edge_attr: numpy.ndarray, y: numpy.ndarray, evtid: int = -1, s: int = -1) torch_geometric.data.Data#

Convert hit dataframe and edges to pytorch geometric data object

Parameters:
  • point_cloud – Hit dataframe, see get_dataframe

  • edge_index – See build_edges

  • edge_attr – See build_edges

  • y – See build_edges

  • evtid – Event ID

  • s – Sector

Returns:

Pytorch geometric data object

get_n_truth_edges(df: pandas.DataFrame) dict[float, int]#
static get_event_id_sector_from_str(name: str) tuple[int, int]#

Parses input file names.

Parameters:

name – Input file name

Returns:

Event id, sector Id

process(start=0, stop=1, *, only_sector: int = -1, progressbar=False)#

Main processing loop

Parameters:
  • start

  • stop

  • only_sector – Only process files for this sector. If < 0 (default): process all sectors.

Returns: