gnn_tracking.metrics.losses.metric_learning#
Module Contents#
Classes#
Loss for graph construction using metric learning. |
Functions#
|
- gnn_tracking.metrics.losses.metric_learning._hinge_loss_components(*, x: torch.Tensor, att_edges: torch.Tensor, rep_edges: torch.Tensor, r_emb_hinge: float, p_attr: float, p_rep: float, n_hits_oi: int) tuple[torch.Tensor, torch.Tensor]#
- class gnn_tracking.metrics.losses.metric_learning.GraphConstructionHingeEmbeddingLoss(*, lw_repulsive: float = 1.0, r_emb: float = 1.0, max_num_neighbors: int = 256, pt_thld: float = 0.9, max_eta: float = 4.0, p_attr: float = 1.0, p_rep: float = 1.0)#
Bases:
gnn_tracking.metrics.losses.MultiLossFct,pytorch_lightning.core.mixins.HyperparametersMixinLoss for graph construction using metric learning.
- Parameters:
lw_repulsive – Loss weight for repulsive part of potential loss
r_emb – Radius for edge construction
max_num_neighbors – Maximum number of neighbors in radius graph building. See rusty1s/pytorch_cluster
pt_thld – pt threshold for particles of interest
max_eta – maximum eta for particles of interest
p_attr – Power for the attraction term (default 1: linear loss)
p_rep – Power for the repulsion term (default 1: linear loss)
- _get_edges(*, x: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, mask: torch.Tensor, particle_id: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#
Returns edge index for graph
- forward(*, x: torch.Tensor, particle_id: torch.Tensor, batch: torch.Tensor, true_edge_index: torch.Tensor, pt: torch.Tensor, eta: torch.Tensor, reconstructable: torch.Tensor, **kwargs) gnn_tracking.metrics.losses.MultiLossFctReturn#