https://omoindrot.github.io/triplet-loss
- The goal of the triplet loss is to make sure that:
-
- Two examples with the same label have their embeddings close together in the embedding space
-
- Two examples with different labels have their embeddings far away.
-
- To formalise this requirement, the loss will be defined overĀ tripletsĀ of embeddings:
- anĀ anchor
- aĀ positiveĀ of the same class as the anchor
- aĀ negativeĀ of a different class
- For some distance on the embedding spaceĀ d, the loss of a tripletĀ (a,p,n) is:
- where margin is a value we choose that says: āthe negative should be farther away than the positive by this marginā
- This loss pushes:
- d(anchor, pos) to 0
- d(anchor, neg)Ā to be greater thanĀ d(anchor, pos) + margin
- hard negatives are examples that are closer to the anchor example than the positive example:
- online mining: we have computed a batch ofĀ BĀ embeddings from a batch ofĀ B inputs. Now we want to generate triplets from theseĀ BĀ embeddings.
- batch all: select all the valid triplets, and average the loss on the hard and semi-hard triplets.
- a crucial point here is to not take into account the easy triplets (those with lossĀ 00), as averaging on them would make the overall loss very small
- batch hard: for each anchor, select the hardest positive (biggest distanceĀ d(a,p)) and the hardest negative among the batch
- works the best according to https://arxiv.org/abs/1703.07737
- implementation:
- to get the hardest positive:
-
- generate the pairwise distance matrix between all examples in the batch
-
- take the maximum distance over each row
-
- to get the hardest negative (tricky, since we need to get the minimum distance for each row, and the diagonal is 0)
- solution, just add infinity to the diagonal before getting the smallest value in the matrix
- now calculate: triplet_loss = max(hardest_positive_dist - hardest_negative_dist + margin, 0.0)
- to get the hardest positive:
- batch all: select all the valid triplets, and average the loss on the hard and semi-hard triplets.