AVERTISSEMENT: tensorflow: les modes sample_weight ont été contraints de… à ['…']


47

Formation d'un classificateur d'images en utilisant .fit_generator()ou .fit()et en passant un dictionnaire à class_weight=comme argument.

Je n'ai jamais eu d'erreurs dans TF1.x mais en 2.1 j'obtiens la sortie suivante lors du démarrage de la formation:

WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']

Que signifie contraindre quelque chose de ...à ['...']?

La source de cet avertissement sur tensorflowle dépôt est ici , les commentaires placés sont:

Essayez de contraindre sample_weight_modes à la structure cible. Cela dépend implicitement du fait que le modèle aplatit les sorties pour sa représentation interne.


7
C'est drôle de voir une question aussi récente comme le seul résultat de recherche pour mes propres avertissements.
jmkjaer

1
@jorijnsmit pouvez-vous fournir le code pour reproduire le problème / avertissement?
thushv89

2
Passer à TF2 avec %tensorflow_version 2.xsuffit pour faire apparaître cet avertissement: colab.research.google.com/gist/jorijnsmit/…
jorijnsmit

1
@jorijnsmit, Non, je reçois le même avertissement mais j'ai en fait installé TF2.1 comme pip install tensorflow(dans l'environnement pyenv / virtualenv)
lurix66

1
Oui en effet @ lurix66, le code qui génère cette erreur est introduit dans 2.1.0rc0.
jorijnsmit

Réponses:


11

Cela ressemble à un faux message. Je reçois le même message d'avertissement après la mise à niveau vers TensorFlow 2.1, mais je n'utilise aucun poids de classe ou exemple de poids. J'utilise un générateur qui renvoie un tuple comme celui-ci:

return inputs, targets

Et maintenant, je viens de le changer comme suit pour faire disparaître l'avertissement:

return inputs, targets, [None]

Je ne sais pas si cela est pertinent, mais mon modèle utilise 3 entrées, donc ma inputsvariable est en fait une liste de 3 tableaux numpy. targetsest juste un tableau numpy unique.

En tout cas, ce n'est qu'un avertissement. La formation fonctionne bien dans les deux cas.

Modifier pour TensorFlow 2.2:

Ce bug semble avoir été corrigé dans TensorFlow 2.2, ce qui est génial. Cependant, le correctif ci-dessus échouera dans TF 2.2, car il essaiera d'obtenir la forme des poids d'échantillonnage, qui échouera évidemment avec AttributeError: 'NoneType' object has no attribute 'shape'. Annulez donc le correctif ci-dessus lors de la mise à niveau vers 2.2.


Cela fonctionne aussi pour moi.
Robert Lugg

14

Je crois que c'est un bug avec tensorflow qui se produira lorsque vous appelez model.compile()avec le paramètre par défaut sample_weight_mode=None, puis appelez model.fit()avec sample_weightou spécifié class_weight.

Depuis les dépôts tensorflow:

  • fit() appelle finalement _process_training_inputs()
  • _process_training_inputs() ensembles sample_weight_modes = [None] basés sur model.sample_weight_mode = Nonepuis crée un DataAdapteravecsample_weight_modes = [None]
  • les DataAdapterappels broadcast_sample_weight_modes()avec sample_weight_modes = [None]lors de l' initialisation
  • broadcast_sample_weight_modes() semble attendre sample_weight_modes = None mais reçoit[None]
  • il affirme qu'il [None]s'agit d'une structure différente de sample_weight/ class_weight, l'écrase Noneen s'adaptant à la structure de sample_weight/ class_weightet émet un avertissement

Attention à part cela n'a aucun effet fit()comme sample_weight_modesdans le DataAdapterest remis à None.

Notez que la documentation tensorflow indique qu'il sample_weightdoit s'agir d'un tableau numpy. Si vous appelez fit()avec à la sample_weight.tolist()place, vous n'obtiendrez pas d'avertissement mais vous serez sample_weightécrasé en silence Nonequand il _process_numpy_inputs()est appelé en prétraitement et reçoit une entrée de longueur supérieure à un.


1
Une explication très approfondie, merci. La seule chose que je ne comprends pas, c'est que l'avertissement décrit d' ...être contraint [...], alors que dans votre cas, il [None]est contraint de None...
jorijnsmit

4

J'ai pris votre Gist et installé Tensorflow 2.0, au lieu de TFA et cela a fonctionné sans un tel avertissement.

Voici l' essentiel du code complet. Le code d'installation de Tensorflow est indiqué ci-dessous:

!pip install tensorflow==2.0

Une capture d'écran de l'exécution réussie est présentée ci-dessous:

entrez la description de l'image ici

Mise à jour: ce bug est corrigé dansTensorflow Version 2.2.


5
Merci pour votre réponse. Vous avez raison, le message d'avertissement n'est pas introduit avant la version 2.1.0rc0. Cependant, je crains que ma question demeure: "Que signifie contraindre quelque chose de ...à ['...']?"
jorijnsmit

3
J'ai remarqué que certaines choses probablement involontaires se produisent lorsque sample_weight_mode=Noneet target_structuresont de type dict, sample_weight_modesest alors [None]et l'exception broadcast_sample_weight_modesest interceptée en raison de la dict. Cela peut-il être considéré comme un bug?
Franz Knülle

2
Nan. La question continue de recueillir des vues et des votes positifs, mais aucune réponse.
jorijnsmit

1
@gkennos: Si vous pensez qu'il s'agit d'un bogue, pouvez-vous déposer un bogue dans le référentiel Github Tensorflow.
Prise en charge de Tensorflow le

1
Il est sans aucun doute un bug, mais il est maintenant fixé à tensorflow 2.2
JLH

2

au lieu de fournir un dictionnaire

weights = {'0': 42.0, '1': 1.0}

j'ai essayé une liste

weights = [42.0, 1.0]

et l'avertissement a disparu.


Merci mec! J'essayais (sans succès) avec des dictionnaires. En utilisant la liste, l'erreur est corrigée!
Victor Mondejar-Guerra

Bien que cela supprime l'erreur, pour moi, cela rompt la pondération pour chaque classe, ce qui produit de moins bons résultats. Je vérifierais la cohérence avant de passer à une liste.
CanofDrink
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.