gnn_tracking.analysis.edge_classification#

Classes#

BinaryClassificationStats

Calculator for binary classification metrics.

ThresholdTrackInfoPlot

Plot track info as a function of EC threshold.

Functions#

get_orphan_counts(→ OrphanCount)

Count unmber of orphan nodes in a graph. See OrphanCount for details.

get_track_graph_info_from_data(→ pandas.DataFrame)

Get DataFrame of track graph information for every particle ID in the data.

summarize_track_graph_info(→ dict[str, float])

Summarize track graph information returned by

add_key_suffix(→ dict[str, _P])

Return a copy of the dictionary with the suffix added to all keys.

get_edge_mask_from_node_mask(→ torch.Tensor)

Get a mask for edges that are between two nodes that are both in the node

get_good_node_mask(→ torch.Tensor)

Get a mask for nodes that are included in metrics and more.

get_all_ec_stats(→ dict[str, float])

Evaluate edge classification/graph construction performance for a single batch.

collect_all_ec_stats(→ pandas.DataFrame)

Collect edge classification statistics for a model and a data loader, basically

Module Contents#

gnn_tracking.analysis.edge_classification.get_orphan_counts(data: torch_geometric.data.Data, *, pt_thld=0.9, max_eta: float = 4.0) OrphanCount#

Count unmber of orphan nodes in a graph. See OrphanCount for details.

gnn_tracking.analysis.edge_classification.get_track_graph_info_from_data(data: torch_geometric.data.Data, *, w: torch.Tensor | None = None, pt_thld=0.9, threshold: float | None = None, max_eta: float = 4.0) pandas.DataFrame#

Get DataFrame of track graph information for every particle ID in the data. This function basically applies get_track_graph_info to every particle ID.

Parameters:
  • data – Data

  • w – Edge weights. If None, no cut on edge weights

  • pt_thld – pt threshold for particle IDs to consider

  • threshold – Edge classification cutoff (if w is given)

Returns:

DataFrame with columns as in TrackGraphInfo

gnn_tracking.analysis.edge_classification.summarize_track_graph_info(tgi: pandas.DataFrame) dict[str, float]#

Summarize track graph information returned by get_track_graph_info_from_data.

class gnn_tracking.analysis.edge_classification.BinaryClassificationStats(output: torch.Tensor, y: torch.Tensor, thld: torch.Tensor | float)#

Calculator for binary classification metrics. All properties are cached, so they are only calculated once.

Parameters:
  • output – Output weights

  • y – True labels

  • thld – Threshold to consider something true

Returns:

accuracy, TPR, TNR

_true() torch.Tensor#
n_true() int#
_false() torch.Tensor#
n_false() int#
_predicted_false() torch.Tensor#
n_predicted_false() int#
_predicted_true() torch.Tensor#
n_predicted_true() int#
TP() float#
TN() float#
FP() float#
FN() float#
acc() float#
TPR() float#
TNR() float#
FPR() float#
FNR() float#
balanced_acc() float#
F1() float#
MCC() float#
get_all() dict[str, float]#
gnn_tracking.analysis.edge_classification.add_key_suffix(dct: dict[str, _P], suffix: str = '') dict[str, _P]#

Return a copy of the dictionary with the suffix added to all keys.

gnn_tracking.analysis.edge_classification.get_edge_mask_from_node_mask(node_mask: torch.Tensor, edge_index: torch.Tensor) torch.Tensor#

Get a mask for edges that are between two nodes that are both in the node mask.

gnn_tracking.analysis.edge_classification.get_good_node_mask(data: torch_geometric.data.Data, *, pt_thld: float = 0.9, max_eta: float = 4.0) torch.Tensor#

Get a mask for nodes that are included in metrics and more. This includes lower limit on pt, not noise, reconstructable, cut on eta.

gnn_tracking.analysis.edge_classification.get_all_ec_stats(threshold: float, w: torch.Tensor, data: torch_geometric.data.Data, *, pt_thld=0.9, max_eta=4) dict[str, float]#

Evaluate edge classification/graph construction performance for a single batch. Similar to get_all_graph_construction_stats, but also includes edge classification information.

Parameters:
  • threshold – Edge classification threshold

  • w – Edge classification output

  • data – Data

  • pt_thld – pt threshold for particle IDs to consider. For edge classification stats (TPR, etc.), two versions are calculated: The stats ending with _thld are calculated for all edges with pt > pt_thld

Returns:

Dictionary of metrics

gnn_tracking.analysis.edge_classification.collect_all_ec_stats(model: torch.nn.Module, data_loader: torch_geometric.data.DataLoader, thresholds: Sequence[float], n_batches: int | None = None, max_workers=6, pt_thld=0.9) pandas.DataFrame#

Collect edge classification statistics for a model and a data loader, basically mapping get_all_ec_stats over the data loader with multiprocessing.

Parameters:
  • model – Edge classifier model

  • data_loader – Data loader

  • thresholds – List of EC thresholds to evaluate

  • n_batches – Number of batches to evaluate

  • max_workers – Number of workers for multiprocessing

Returns:

DataFrame with columns as in get_all_ec_stats

class gnn_tracking.analysis.edge_classification.ThresholdTrackInfoPlot(df: pandas.DataFrame)#

Plot track info as a function of EC threshold.

To get the plot in one go, simply call the plot method. Alternatively, use the individual methods to plot the different components separately.

Parameters:

df – DataFrame with columns as in get_all_ec_stats

plot()#

Plot all the things.

setup_axes()#
plot_line(key, **kwargs)#
plot_errorline(key, **kwargs)#
plot_100()#
plot_50()#
plot_75()#
plot_tpr_fpr()#
plot_mcc()#
plot_hlines()#
add_legend()#