gnn_tracking.analysis.edge_classification#
Classes#
Plot track info as a function of EC threshold. |
Functions#
|
Evaluate edge classification/graph construction performance for a single batch. |
|
Collect edge classification statistics for a model and a data loader, basically |
Module Contents#
- gnn_tracking.analysis.edge_classification.get_all_ec_stats(threshold: float, w: torch.Tensor, data: torch_geometric.data.Data, *, pt_thld=0.9, max_eta=4) dict[str, float] #
Evaluate edge classification/graph construction performance for a single batch. Similar to get_all_graph_construction_stats, but also includes edge classification information.
- Parameters:
threshold – Edge classification threshold
w – Edge classification output
data – Data
pt_thld – pt threshold for particle IDs to consider. For edge classification stats (TPR, etc.), two versions are calculated: The stats ending with _thld are calculated for all edges with pt > pt_thld
- Returns:
Dictionary of metrics
- gnn_tracking.analysis.edge_classification.collect_all_ec_stats(model: torch.nn.Module, data_loader: torch_geometric.data.DataLoader, thresholds: Sequence[float], n_batches: int | None = None, max_workers=6, pt_thld=0.9) pandas.DataFrame #
Collect edge classification statistics for a model and a data loader, basically mapping get_all_ec_stats over the data loader with multiprocessing.
- Parameters:
model – Edge classifier model
data_loader – Data loader
thresholds – List of EC thresholds to evaluate
n_batches – Number of batches to evaluate
max_workers – Number of workers for multiprocessing
- Returns:
DataFrame with columns as in get_all_ec_stats
- class gnn_tracking.analysis.edge_classification.ThresholdTrackInfoPlot(df: pandas.DataFrame)#
Plot track info as a function of EC threshold.
To get the plot in one go, simply call the plot method. Alternatively, use the individual methods to plot the different components separately.
- Parameters:
df – DataFrame with columns as in get_all_ec_stats
- df#
- ax: matplotlib.pyplot.Axes | None = None#
- plot()#
Plot all the things.
- setup_axes()#
- plot_line(key, **kwargs)#
- plot_errorline(key, **kwargs)#
- plot_100()#
- plot_50()#
- plot_75()#
- plot_tpr_fpr()#
- plot_mcc()#
- plot_hlines()#
- add_legend()#