gnn_tracking.metrics.losses.metric_learning#

Classes#

MultiLossFct

Base class for loss functions that return multiple losses.

MultiLossFctReturn

Return type for loss functions that return multiple losses.

GraphConstructionHingeEmbeddingLoss

Loss for graph construction using metric learning.

OldGraphConstructionHingeEmbeddingLoss

Loss for graph construction using metric learning.

Functions#

get_good_node_mask_tensors(→ torch.Tensor)

See get_good_node_mask

_hinge_loss_components(→ tuple[torch.Tensor, torch.Tensor])

_old_hinge_loss_components(→ tuple[torch.Tensor, ...)

Module Contents#

class gnn_tracking.metrics.losses.metric_learning.MultiLossFct#

Bases: torch.nn.Module

Base class for loss functions that return multiple losses.

forward(*args: Any, **kwargs: Any) MultiLossFctReturn#
class gnn_tracking.metrics.losses.metric_learning.MultiLossFctReturn#

Return type for loss functions that return multiple losses.

loss_dct: dict[str, torch.Tensor]#
weight_dct: dict[str, torch.Tensor] | dict[str, float]#
extra_metrics: dict[str, Any]#
__post_init__() None#
property loss: torch.Tensor#
property weighted_losses: dict[str, torch.Tensor]#
gnn_tracking.metrics.losses.metric_learning.get_good_node_mask_tensors(*, pt, particle_id, reconstructable, eta, pt_thld: float = 0.9, max_eta: float = 4.0) torch.Tensor#

See get_good_node_mask

gnn_tracking.metrics.losses.metric_learning._hinge_loss_components(*, x: torch.Tensor, att_edges: torch.Tensor, rep_edges: torch.Tensor, r_emb_hinge: float, p_attr: float, p_rep: float, n_hits_oi: int, normalization: str) tuple[torch.Tensor, torch.Tensor]#
class gnn_tracking.metrics.losses.metric_learning.GraphConstructionHingeEmbeddingLoss(*, lw_repulsive: float = 1.0, r_emb: float = 1.0, max_num_neighbors: int = 256, pt_thld: float = 0.9, max_eta: float = 4.0, p_attr: float = 1.0, p_rep: float = 1.0, rep_normalization: str = 'n_hits_oi', rep_oi_only: bool = True)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Loss for graph construction using metric learning.

Parameters:
  • lw_repulsive – Loss weight for repulsive part of potential loss

  • r_emb – Radius for edge construction

  • max_num_neighbors – Maximum number of neighbors in radius graph building. See rusty1s/pytorch_cluster

  • pt_thld – pt threshold for particles of interest

  • max_eta – maximum eta for particles of interest

  • p_attr – Power for the attraction term (default 1: linear loss)

  • p_rep – Power for the repulsion term (default 1: linear loss)

  • normalization – Normalization for the repulsive term. Can be either “n_rep_edges” (normalizes by the number of repulsive edges < r_emb) or “n_hits_oi” (normalizes by the number of hits of interest) or “n_att_edges” (normalizes by the number of attractive edges of interest)

  • rep_oi_only – Only consider repulsion between hits if at least one of the hits is of interest

_get_edges(*, x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, mask: torch.Tensor, particle_id: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#

Returns edge index for graph

forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, eta: torch.Tensor, reconstructable: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn#
gnn_tracking.metrics.losses.metric_learning._old_hinge_loss_components(*, x: torch.Tensor, edge_index: torch.Tensor, particle_id: torch.Tensor, pt: torch.Tensor, r_emb_hinge: float, pt_thld: float, p_attr: float, p_rep: float) tuple[torch.Tensor, torch.Tensor]#
class gnn_tracking.metrics.losses.metric_learning.OldGraphConstructionHingeEmbeddingLoss(*, r_emb=1, max_num_neighbors: int = 256, attr_pt_thld: float = 0.9, p_attr: float = 1, p_rep: float = 1, lw_repulsive: float = 1.0)#

Bases: gnn_tracking.metrics.losses.MultiLossFct, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Loss for graph construction using metric learning.

Parameters:
  • r_emb – Radius for edge construction

  • max_num_neighbors – Maximum number of neighbors in radius graph building. See rusty1s/pytorch_cluster

  • p_attr – Power for the attraction term (default 1: linear loss)

  • p_rep – Power for the repulsion term (default 1: linear loss)

_build_graph(x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor) torch.Tensor#
forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, **kwargs) dict[str, torch.Tensor]#