gnn_tracking.metrics.losses.metric_learning#

Classes#

GraphConstructionHingeEmbeddingLoss

Loss for graph construction using metric learning.

OldGraphConstructionHingeEmbeddingLoss

Loss for graph construction using metric learning.

Functions#

_hinge_loss_components(→ tuple[torch.Tensor, torch.Tensor])

_old_hinge_loss_components(→ tuple[torch.Tensor, ...)

Module Contents#

gnn_tracking.metrics.losses.metric_learning._hinge_loss_components(*, x: torch.Tensor, att_edges: torch.Tensor, rep_edges: torch.Tensor, r_emb_hinge: float, p_attr: float, p_rep: float, n_hits_oi: int, normalization: str) tuple[torch.Tensor, torch.Tensor]#
class gnn_tracking.metrics.losses.metric_learning.GraphConstructionHingeEmbeddingLoss(*, lw_repulsive: float = 1.0, r_emb: float = 1.0, max_num_neighbors: int = 256, pt_thld: float = 0.9, max_eta: float = 4.0, p_attr: float = 1.0, p_rep: float = 1.0, rep_normalization: str = 'n_hits_oi', rep_oi_only: bool = True)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Loss for graph construction using metric learning.

Parameters:
  • lw_repulsive – Loss weight for repulsive part of potential loss

  • r_emb – Radius for edge construction

  • max_num_neighbors – Maximum number of neighbors in radius graph building. See rusty1s/pytorch_cluster

  • pt_thld – pt threshold for particles of interest

  • max_eta – maximum eta for particles of interest

  • p_attr – Power for the attraction term (default 1: linear loss)

  • p_rep – Power for the repulsion term (default 1: linear loss)

  • normalization – Normalization for the repulsive term. Can be either “n_rep_edges” (normalizes by the number of repulsive edges < r_emb) or “n_hits_oi” (normalizes by the number of hits of interest) or “n_att_edges” (normalizes by the number of attractive edges of interest)

  • rep_oi_only – Only consider repulsion between hits if at least one of the hits is of interest

_get_edges(*, x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, mask: torch.Tensor, particle_id: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#

Returns edge index for graph

forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, eta: torch.Tensor, reconstructable: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn#
gnn_tracking.metrics.losses.metric_learning._old_hinge_loss_components(*, x: torch.Tensor, edge_index: torch.Tensor, particle_id: torch.Tensor, pt: torch.Tensor, r_emb_hinge: float, pt_thld: float, p_attr: float, p_rep: float) tuple[torch.Tensor, torch.Tensor]#
class gnn_tracking.metrics.losses.metric_learning.OldGraphConstructionHingeEmbeddingLoss(*, r_emb=1, max_num_neighbors: int = 256, attr_pt_thld: float = 0.9, p_attr: float = 1, p_rep: float = 1, lw_repulsive: float = 1.0)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Loss for graph construction using metric learning.

Parameters:
  • r_emb – Radius for edge construction

  • max_num_neighbors – Maximum number of neighbors in radius graph building. See rusty1s/pytorch_cluster

  • p_attr – Power for the attraction term (default 1: linear loss)

  • p_rep – Power for the repulsion term (default 1: linear loss)

_build_graph(x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor) torch.Tensor#
forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) dict[str, torch.Tensor]#