Est-il possible de modifier la métrique utilisée par le rappel Early Stop à Keras?


13

Lors de l'utilisation du rappel Early Stop dans Keras, la formation s'arrête lorsqu'une mesure (généralement une perte de validation) n'augmente pas. Existe-t-il un moyen d'utiliser une autre métrique (comme la précision, le rappel, la mesure f) au lieu d'une perte de validation? Tous les exemples que j'ai vus jusqu'à présent sont similaires à celui-ci: callbacks.EarlyStopping (monitor = 'val_loss', patience = 5, verbose = 0, mode = 'auto')

Réponses:


11

Vous pouvez utiliser n'importe quelle fonction métrique que vous avez spécifiée lors de la compilation du modèle.

Supposons que vous ayez la fonction métrique suivante:

def my_metric(y_true, y_pred):
     return some_metric_computation(y_true, y_pred)

La seule exigence de cette fonction est qu'elle accepte le vrai y et le y prévu.

Lorsque vous compilez le modèle, vous spécifiez cette métrique, de la même manière que vous spécifiez la génération de métriques comme la «précision»:

model.compile(metrics=['accuracy', my_metric], ...)

Notez que nous utilisons le nom de la fonction my_metric sans '' (contrairement à la construction de 'precision').

Ensuite, si vous définissez votre EarlyStopping, utilisez simplement le nom de la fonction (cette fois avec ''):

EarlyStopping(monitor='my_metric', mode='min')

Assurez-vous de spécifier le mode (min si inférieur est meilleur, max si supérieur est meilleur).

Vous pouvez l'utiliser comme n'importe quelle métrique intégrée. Cela fonctionne probablement aussi avec d'autres rappels comme ModelCheckpoint (mais je n'ai pas testé cela). En interne, Keras ajoute simplement la nouvelle métrique à la liste des métriques disponibles pour ce modèle en utilisant le nom de la fonction.

Si vous spécifiez des données à valider dans votre model.fit (...), vous pouvez également les utiliser pour EarlyStopping en utilisant 'val_my_metric'.


3

Bien sûr, créez le vôtre!

class EarlyStopByF1(keras.callbacks.Callback):
    def __init__(self, value = 0, verbose = 0):
        super(keras.callbacks.Callback, self).__init__()
        self.value = value
        self.verbose = verbose


    def on_epoch_end(self, epoch, logs={}):
         predict = np.asarray(self.model.predict(self.validation_data[0]))
         target = self.validation_data[1]
         score = f1_score(target, prediction)
         if score > self.value:
            if self.verbose >0:
                print("Epoch %05d: early stopping Threshold" % epoch)
            self.model.stop_training = True


callbacks = [EarlyStopByF1(value = .90, verbose =1)]
model.fit(X, y, batch_size = 32, nb_epoch=nb_epoch, verbose = 1, 
validation_data(X_val,y_val), callbacks=callbacks)

Je n'ai pas testé cela, mais cela devrait être la saveur générale de la façon dont vous procédez. Si cela ne fonctionne pas, faites le moi savoir et je réessayerai le week-end. Je suppose également que votre propre score f1 est déjà implémenté. Si ce n'est pas seulement importer pour sklearn.


+1 Fonctionne toujours au 2/11/2020 en utilisant les derniers Keras et Python 3.7
Austin
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.