6
votes

Comment calculer les poids déséquilibrés pour BCEWithLogitsLoss dans Pytorch

J'essaie de résoudre un problème multi-étiquettes avec 270 étiquettes et j'ai converti les étiquettes cibles en une forme encodée à chaud. J'utilise BCEWithLogitsLoss() . Puisque les données d'entraînement sont déséquilibrées, pos_weight argument pos_weight mais je suis un peu confus.

pos_weight (Tensor, facultatif) - un poids d'exemples positifs. Doit être un vecteur de longueur égale au nombre de classes.

Dois-je donner le nombre total de valeurs positives de chaque étiquette en tant que tenseur ou elles signifient autre chose par poids?


1 commentaires

Vous pouvez consulter la discussion ici: discuss.pytorch.org/t / ...


3 Réponses :


2
votes

Solution PyTorch

Eh bien, en fait, j'ai parcouru la documentation et vous pouvez simplement utiliser pos_weight .

Cet argument donne du poids à l'échantillon positif pour chaque classe, donc si vous avez 270 classes, vous devez passer torch.Tensor avec la forme (270,) définissant le poids pour chaque classe.

Voici un extrait de code légèrement modifié de la documentation :

weights = torch.zeros_like(dataset[0])
for element in dataset:
    weights += element

weights = 1 / (weights / torch.min(weights))

Solution personnalisée

En ce qui concerne la pondération, il n'y a pas de solution intégrée, mais vous pouvez en coder une vous-même très facilement:

import torch

class WeightedMultilabel(torch.nn.Module):
    def __init__(self, weights: torch.Tensor):
        self.loss = torch.nn.BCEWithLogitsLoss()
        self.weights = weights.unsqueeze()

    def forward(outputs, targets):
        return self.loss(outputs, targets) * self.weights

Tensor doit être de la même longueur que le nombre de classes dans votre classification multi-étiquettes (270), chacune donnant un poids pour votre exemple spécifique.

Calcul des poids

Vous ajoutez simplement des étiquettes de chaque échantillon de votre ensemble de données, divisez par la valeur minimale et inversez à la fin.

Sorte d'extrait:

# 270 classes, batch size = 64    
target = torch.ones([64, 270], dtype=torch.float32)  
# Logits outputted from your network, no activation
output = torch.full([64, 270], 0.9)
# Weights, each being equal to one. You can input your own here.
pos_weight = torch.ones([270])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)  # -log(sigmoid(0.9))

L'utilisation de cette classe d'approche qui se produit le moins donnera une perte normale, tandis que d'autres auront des poids inférieurs à 1 .

Cela peut cependant provoquer une certaine instabilité pendant l'entraînement, donc vous voudrez peut-être expérimenter un peu ces valeurs (peut-être log transformation log au lieu de linéaire?)

Autre approche

Vous pouvez penser au suréchantillonnage / sous-échantillonnage (bien que cette opération soit compliquée car vous ajouteriez / supprimeriez également d'autres classes, donc une heuristique avancée serait nécessaire je pense).


0 commentaires

4
votes

La documentation PyTorch pour BCEWithLogitsLoss recommande que pos_weight soit un rapport entre les nombres négatifs et positifs pour chaque classe.

Donc, si len(dataset) vaut 1000, l'élément 0 de votre encodage multihot a 100 comptes positifs, alors l'élément 0 du pos_weights_vector doit être 900/100 = 9 . Cela signifie que la perte croisée binaire se comportera comme si l'ensemble de données contenait 900 exemples positifs au lieu de 100.

Voici ma mise en œuvre:

  def calculate_pos_weights(class_counts):
    pos_weights = np.ones_like(class_counts)
    neg_counts = [len(data)-pos_count for pos_count in class_counts]
    for cdx, pos_count, neg_count in enumerate(zip(class_counts,  neg_counts)):
      pos_weights[cdx] = neg_count / (pos_count + 1e-5)

    return torch.as_tensor(pos_weights, dtype=torch.float)

class_counts est juste une somme par colonne des échantillons positifs. Je l'ai posté sur le forum PyTorch et l'un des développeurs de PyTorch lui a donné sa bénédiction.


2 commentaires

Pouvez-vous clarifier les poids dans le retour? Est-ce pos_weights ou tout autre poids. S'il s'agit de quelque chose de différent, pourriez-vous fournir plus de détails à ce sujet?


@PaladiN désolé, mes variables n'étaient pas claires. J'ai clarifié les noms des variables dans ma réponse.



0
votes

Juste pour fournir une révision rapide sur la réponse de @ crypdick, cette implémentation de la fonction a fonctionné pour moi:

def calculate_pos_weights(class_counts,data):
    pos_weights = np.ones_like(class_counts)
    neg_counts = [len(data)-pos_count for pos_count in class_counts]
    for cdx, (pos_count, neg_count) in enumerate(zip(class_counts,  neg_counts)):
        pos_weights[cdx] = neg_count / (pos_count + 1e-5)

    return torch.as_tensor(pos_weights, dtype=torch.float)

Où les data sont l'ensemble de données auquel vous essayez d'appliquer des pondérations.


0 commentaires