gnn_tracking.analysis.edge_classification#

Module Contents#

Classes#

ThresholdTrackInfoPlot

Plot track info as a function of EC threshold.

Functions#

get_all_ec_stats(→ dict[str, float])

Evaluate edge classification/graph construction performance for a single batch.

collect_all_ec_stats(→ pandas.DataFrame)

Collect edge classification statistics for a model and a data loader, basically

gnn_tracking.analysis.edge_classification.get_all_ec_stats(threshold: float, w: torch.Tensor, data: torch_geometric.data.Data, *, pt_thld=0.9, max_eta=4) dict[str, float]#

Evaluate edge classification/graph construction performance for a single batch. Similar to get_all_graph_construction_stats, but also includes edge classification information.

Parameters:
  • threshold – Edge classification threshold

  • w – Edge classification output

  • data – Data

  • pt_thld – pt threshold for particle IDs to consider. For edge classification stats (TPR, etc.), two versions are calculated: The stats ending with _thld are calculated for all edges with pt > pt_thld

Returns:

Dictionary of metrics

gnn_tracking.analysis.edge_classification.collect_all_ec_stats(model: torch.nn.Module, data_loader: torch_geometric.data.DataLoader, thresholds: Sequence[float], n_batches: int | None = None, max_workers=6, pt_thld=0.9) pandas.DataFrame#

Collect edge classification statistics for a model and a data loader, basically mapping get_all_ec_stats over the data loader with multiprocessing.

Parameters:
  • model – Edge classifier model

  • data_loader – Data loader

  • thresholds – List of EC thresholds to evaluate

  • n_batches – Number of batches to evaluate

  • max_workers – Number of workers for multiprocessing

Returns:

DataFrame with columns as in get_all_ec_stats

class gnn_tracking.analysis.edge_classification.ThresholdTrackInfoPlot(df: pandas.DataFrame)#

Plot track info as a function of EC threshold.

To get the plot in one go, simply call the plot method. Alternatively, use the individual methods to plot the different components separately.

Parameters:

df – DataFrame with columns as in get_all_ec_stats

plot()#

Plot all the things.

setup_axes()#
plot_line(key, **kwargs)#
plot_errorline(key, **kwargs)#
plot_100()#
plot_50()#
plot_75()#
plot_tpr_fpr()#
plot_mcc()#
plot_hlines()#
add_legend()#