J'ai un ensemble de données avec 3 classes avec les éléments suivants:
- Classe 1: 900 éléments
- Classe 2: 15 000 éléments
- Classe 3: 800 éléments
Je dois prédire les classes 1 et 3, qui signalent des écarts importants par rapport à la norme. La classe 2 est le cas «normal» par défaut dont je me fiche.
Quel type de fonction de perte devrais-je utiliser ici? Je pensais utiliser CrossEntropyLoss, mais comme il y a un déséquilibre de classe, cela devrait être pondéré, je suppose? Comment cela fonctionne-t-il dans la pratique? Comme ça (en utilisant PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
Ou faut-il inverser le poids? soit 1 / poids?
Est-ce la bonne approche pour commencer ou y a-t-il d'autres / meilleures méthodes que je pourrais utiliser?
Merci