0
votes

Un callback pour vérifier la saturation de val_acc

Habituellement, nous pouvons définir un rappel pour qu'un modèle arrête l'époque si la précision atteint un certain niveau.

Je travaille sur l'ajustement des paramètres. Le val_acc est hautement instable comme le montre l'image acc_graphs .

def LSTM_model(X_train, y_train, X_test, y_test, num_classes, batch_size=68, units=128, learning_rate=0.005, epochs=20,
               dropout=0.2, recurrent_dropout=0.2):
    class myCallback(tf.keras.callbacks.Callback):

        def on_epoch_end(self, epoch, logs={}):
            if (logs.get('acc') > 0.90):
                print("\nReached 90% accuracy so cancelling training!")
                self.model.stop_training = True

    callbacks = myCallback()

Comme le montrent les graphiques, le val_acc (orange) fluctue dans une plage et ne monte plus vraiment.

Existe-t-il un moyen d'arrêter automatiquement la formation une fois que la tendance générale du val_acc cesse d'augmenter?


0 commentaires

3 Réponses :


3
votes

Vous pouvez y parvenir avec un callback comme celui-ci

from collections import deque
import numpy as np

import tensorflow as tf 
from tensorflow import keras 
import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Dense

x = np.linspace(0,10,1000)
np.random.shuffle(x)
y = np.sin(x) + x

x_train,x_val,y_train,y_val = train_test_split(x,y,test_size=0.3)

input_x = Input(shape=(1,))
y = Dense(10,activation='relu')(input_x)
y = Dense(10,activation='relu')(y)
y = Dense(1,activation='relu')(y)
model = Model(inputs=input_x,outputs=y)

adamopt = tf.keras.optimizers.Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-8)

class terminate_on_plateau(keras.callbacks.Callback):
    
    def __init__(self):
        self.patience = 10
        self.val_loss = deque([],self.patience)
        self.std_threshold = 1e-2
        
    def on_epoch_end(self,epoch,logs=None):
        val_loss,val_mae = model.evaluate(x_val,y_val)
        self.val_loss.append(val_loss)
        if len(self.val_loss) >= self.patience:
            std = np.std(self.val_loss)
            if std < self.std_threshold:
                print('\n\n EarlyStopping on std invoked! \n\n')
                # clear the deque
                self.val_loss = deque([],self.patience)
                model.stop_training = True
    
model.compile(loss='mse',optimizer=adamopt,metrics=['mae'])
history = model.fit(x_train,y_train,
                    batch_size=8,
                    epochs=100,
                    validation_data=(x_val, y_val),
                    verbose=1,
                    callbacks=[terminate_on_plateau()])

Comme vous pouvez le voir, dans terminate_on_plateau , val_loss des époques est stocké dans un deque de longueur maximale self.patience . Une fois que la longueur du deque atteint self.patience , l'écart-type de val_loss sera calculé pour chaque nouvelle époque , et le processus de formation sera terminé (le deque de val_loss sera également effacé), si le std calculé est plus petit que a seuil.

Vous trouverez ci-dessous un script simple qui vous montre comment utiliser ce

class terminate_on_plateau(keras.callbacks.Callback):
    
    def __init__(self):
        self.patience = 10
        self.val_loss = deque([],self.patience)
        self.std_threshold = 1e-2
        
    def on_epoch_end(self,epoch,logs=None):
        val_loss,val_mae = model.evaluate(x_val,y_val)
        self.val_loss.append(val_loss)
        if len(self.val_loss) >= self.patience:
            std = np.std(self.val_loss)
            if std < self.std_threshold:
                print('\n\n EarlyStopping on std invoked! \n\n')
                # clear the deque
                self.val_loss = deque([],self.patience)
                model.stop_training = True


4 commentaires

Puis-je savoir pourquoi avez-vous choisi de calculer val_loss au lieu de val_acc dans ce cas puisque les graphiques ont tracé la précision? Plus, veuillez me corriger si je me trompe. Puis-je comprendre la «patience» comme la taille de la fenêtre utilisée pour évaluer le std dans les époques? par exemple, std est calculé toutes les 10 époques si la patience est de 10.


@Leo Bien sûr, vous pouvez choisir val_acc , ce faisant, il vous suffit de changer ce model.evaluate(x_val,y_val) en quelque chose qui calcule val_acc . Je choisis val_loss juste pour la simplicité. La patience est le nombre de val_loss (ou val_acc ) qui est utilisé pour calculer l'écart type, je vous suggère de lire un peu sur deque , puisque je l'utilise pour stocker l'historique val_loss dans mon code, si la patience est de 10, std le fera être calculé pour chaque époque une fois que le deque atteint sa longueur maximale, soit 10.


Oui, cela a du sens. Bien que je rencontrais un problème, la deuxième fois que j'entraînais le modèle, il activait immédiatement les rappels, je suppose que c'était parce que les données du rappel de la dernière formation n'avaient pas été supprimées. Pourriez-vous modifier vos codes pour résoudre ce problème?


@Leo Bien sûr, la solution est simple, il suffit d'ajouter cette ligne self.val_loss = deque([],self.patience) juste avant la fin du processus de formation. Consultez mon message mis à jour.



0
votes

Le code ci-dessous est pour un rappel personnalisé qui arrêtera l'entraînement lorsque la quantité surveillée ne s'améliorera pas après un certain nombre d'époques. Réglez le paramètre acc_or_loss sur 'loss' afin de surveiller la perte de validation. Réglez-le sur «acc» pour surveiller la précision de la validation. Je recommande de NE PAS surveiller l'exactitude de la validation, car elle peut basculer énormément, en particulier aux premières époques. J'ai mis des déclarations imprimées pour que vous puissiez voir ce qui se passe pendant la formation. Vous pouvez bien sûr les supprimer plus tard. Si vous surveillez la perte de validation, le rappel interrompt la formation si, pendant un certain nombre d'époques, la perte de validation a dépassé la perte la plus faible constatée aux époques précédentes. Si vous surveillez la précision de la validation, le rappel interrompt la formation si, pendant un certain nombre d'époques, la précision de la validation est restée inférieure à la précision de validation la plus élevée enregistrée aux époques précédentes.

class halt(keras.callbacks.Callback):
    def __init__(self, patience, acc_or_loss):
        self.acc_or_loss=acc_or_loss
        super(halt, self).__init__()
        self.patience=patience # specifies how many epochs without improvement before learning rate is adjusted
        self.lowest_loss=np.inf 
        self.highest_acc=0
        self.count=0
        print ('initializing values ', 'count= ', self.count, '  lowest_loss= ', self.lowest_loss, 'highest acc= ', self.highest_acc)
    def on_epoch_end(self, epoch, logs=None):
        v_loss=logs.get('val_loss')  # get the validation loss for this epoch
        v_acc=logs.get('val_accuracy')
        if self.acc_or_loss=='loss':
            print (' for epoch ', epoch +1, '  v_loss= ', v_loss, ' lowest_loss= ', self.lowest_loss,  'count= ', self.count)
            if v_loss< self.lowest_loss:
                self.lowest_loss=v_loss
                self.count=0
            else:
                self.count=self.count +1
                if self.count>=self.patience:
                    print('There have been ', self.patience, ' epochs with no reduction of validation loss below the lowest loss')
                    print ('Terminating training')
                    self.model.stop_training = True
        else:
            print (' for epoch ', epoch +1, '  v_acc= ', v_acc, ' highest accuracy= ', self.highest_acc,  'count= ', self.count)
            if v_acc>self.highest_acc:
                self.count=0
                self.highest_acc=v_acc
            else:
                self.count=self.count +1
                if self.count>=self.patience:
                    print('There have been ', self.patience, ' epochs with noincrease in validation accuracy')
                    print ('Terminating training')
                    self.model.stop_training = True

patience= 2 #  specify the patience value
acc_or_loss='loss' # specify to monitor validation loss or validation accuracy
callbacks=[halt(patience=patience, acc_or_loss=acc_or_loss)]
# in model.fit include callbacks=callbacks


0 commentaires

0
votes

Ou vous pouvez simplement utiliser l'API Keras dans tensorflow: tf.keras.callbacks.EarlyStopping

Compte tenu de votre question initiale, je ne sais pas pourquoi vous auriez besoin de callbacks personnalisés

Voici un exemple d'application:

history = model.fit([trainX,trainX,trainX],
                    np.array(trainLabels),
                    validation_data = ([testX, testX, testX], np.array(testLabels)),
                    epochs=EPOCH,
                    batch_size=BATCH_SIZE,
                    steps_per_epoch = None,
                    callbacks=[tf.keras.callbacks.EarlyStopping(
                        monitor="val_acc",
                        patience=5,
                        mode="min",
                        restore_best_weights = True)])


0 commentaires