4
votes

Exception avec rappel dans Keras - Tensorflow 2.0 - Python

Le code suivant exécute un modèle Sequential Keras, assez simple, sur les données MNIST qui sont empaquetées avec Keras.

En exécutant le morceau de code suivant, j'obtiens une exception.

Le code est facilement reproductible.

Epoch 1/10
59296/60000 [============================>.] - ETA: 0s - loss: 0.2005 - accuracy: 0.9400

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-26-f5e673b24d24> in <module>()
     23               metrics=['accuracy'])
     24 
---> 25 model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

C:\Program Files (x86)\Microsoft Visual Studio\Shared\Anaconda3_64\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    871           validation_steps=validation_steps,
    872           validation_freq=validation_freq,
--> 873           steps_name='steps_per_epoch')
    874 
    875   def evaluate(self,

C:\Program Files (x86)\Microsoft Visual Studio\Shared\Anaconda3_64\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
    406     if mode == ModeKeys.TRAIN:
    407       # Epochs only apply to `fit`.
--> 408       callbacks.on_epoch_end(epoch, epoch_logs)
    409     progbar.on_epoch_end(epoch, epoch_logs)
    410 

C:\Program Files (x86)\Microsoft Visual Studio\Shared\Anaconda3_64\lib\site-packages\tensorflow\python\keras\callbacks.py in on_epoch_end(self, epoch, logs)
    288     logs = logs or {}
    289     for callback in self.callbacks:
--> 290       callback.on_epoch_end(epoch, logs)
    291 
    292   def on_train_batch_begin(self, batch, logs=None):

<ipython-input-26-f5e673b24d24> in on_epoch_end(self, epoch, logs)
      3 class myCallback(tf.keras.callbacks.Callback):
      4       def on_epoch_end(self, epoch, logs={}):
----> 5         if(logs.get('acc')>0.99):
      6           print("\nReached 99% accuracy so cancelling training!")
      7           self.model.stop_training = True

TypeError: '>' not supported between instances of 'NoneType' and 'float'

L'exception est:

import tensorflow as tf

class myCallback(tf.keras.callbacks.Callback):
      def on_epoch_end(self, epoch, logs={}):
        if(logs.get('acc')>0.99):
          print("\nReached 99% accuracy so cancelling training!")
          self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

callbacks = myCallback()

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])


3 commentaires

Avez-vous essayé if (logs.get ('precision')> 0.99): ?


Lors de la création d'un rappel, si nous avons besoin d'un seuil de précision pour l'entraînement, les versions précédentes de TF ont logs.get ('acc') mais dans cette version, nous devons utiliser logs.get ('precision') pour que cela fonctionne. Il n'y avait aucune documentation concernant ce changement


L'erreur est due au fait que logs.get ('acc') doit correspondre à la valeur de la métrique dans model.compile (optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['precision'] )


9 Réponses :


0
votes

Je pense que cela peut venir de la façon dont vous appelez votre fonction:

Si votre fonction est

model.fit(x_train, y_train, epochs=10, callbacks=[myCallback()])

Elle devrait être appelée comme ça:

XXX


0 commentaires

-2
votes

Le problème est logs.get ('acc')> 0.99 . De votre côté, logs.get ('acc') est Aucun pour une raison quelconque.

Il suffit d'exécuter:

None> 0.99 et vous obtiendrez la même erreur. Vous avez probablement migré votre code de Python 2 où cela fonctionnerait réellement :).

Vous pouvez simplement modifier cela avec

if(logs.get('acc') is None): # in this case you cannot compare...

Ou vous pouvez utiliser essayez : ... sauf : blocs.

BTW, le même code fonctionne très bien de mon côté.


0 commentaires

7
votes

Dans la fonction model.compile, vous avez défini metrics = ['precision']. Vous devez utiliser «exactitude» dans logs.get, c'est-à-dire logs.get («exactitude»).


0 commentaires

1
votes

C'est juste qu'avec la mise à jour de tensorflow vers la version 2.x, la balise du dictionnaire 'acc' a été changée en 'precision' donc remplacer la ligne 5 comme suit devrait faire l'affaire!

if (logs.get ('precision')> 0.99):


0 commentaires

1
votes

Changez simplement logs.get ('precision') -> logs.get ('acc'). Cela devrait fonctionner correctement!


0 commentaires

-2
votes

Pour une raison quelconque, j'ai fait ['acc'] dans la classe de rappel, avec ['precision'] dans les métriques et cela a fonctionné.

entrez la description de l'image ici


0 commentaires

1
votes

J'ai eu le même problème. Je l'ai changé en "acc" et cela a fonctionné comme un charme. J'ai apporté les modifications suivantes.

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get("acc") >= 0.99):
            print("Reached 99% accuracy so cancelling training!")
            self.model.stop_training = True

Et dans le rappel,

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])


0 commentaires

0
votes

Dans le notebook Jupyter, j'ai dû utiliser "acc", mais dans google Colab "precision" à la place. Je suppose que cela dépend de la version de tensorflow installée.


0 commentaires

0
votes

Vous utilisez probablement tensorflow 1., vous pouvez donc essayer: if (logs.get ('acc')> 0.998) et metrics = ['acc']


0 commentaires