10
votes

Comment utiliser Model.fit qui prend en charge les générateurs (après la dépréciation de fit_generator)

J'ai cet avertissement de dépréciation lors de l'utilisation de Model.fit_generator dans tensorflow:

WARNING:tensorflow: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.

Comment puis-je utiliser Model.fit au lieu de Model.fit_generator ?


0 commentaires

3 Réponses :


13
votes

Model.fit_generator est obsolète à partir de tensorflow 2.1.0 qui se trouve actuellement dans rc1 . Vous pouvez trouver la documentation pour tf-2.1.0-rc1 ici: https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/Model#fit

Comme vous pouvez le voir, le premier argument de Model.fit peut prendre un générateur, alors passez-le simplement à votre générateur.


0 commentaires

7
votes

Comme mentionné dans la documentation de tensorflow:

x: données d'entrée.

  1. Cela peut être: Un tableau Numpy (ou de type tableau), ou une liste de tableaux (au cas où le modèle aurait plusieurs entrées).
    1. Un tenseur TensorFlow, ou une liste de tenseurs (dans le cas où le modèle a plusieurs entrées).
    2. Un dict mappant les noms d'entrée au tableau / tenseurs correspondants, si le modèle a nommé des entrées.
    3. Un jeu de données tf.data. Devrait retourner un tuple de (entrées, cibles) ou (entrées, cibles, poids_échantillon)
    4. Un générateur ou keras.utils.Sequence retournant (entrées, cibles) ou (entrées, cibles, poids d'échantillons). Une description plus détaillée du comportement de décompression pour les types d'itérateur (jeu de données, générateur, séquence) est donnée ci-dessous.

vous pouvez simplement passer le générateur à Model.fit comme similaire à Model.fit_generator

data_gen_train = ImageDataGenerator(rescale=1/255.)

data_gen_valid = ImageDataGenerator(rescale=1/255.)

train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode="binary")

valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode="binary")

model.fit(train_generator, epochs=2, validation_data=valid_generator) .


3 commentaires

Votre réponse est incorrecte ou, devrais-je dire, partiellement correcte. valid_generator n'est pas pris en charge dans model.fit


@sidk oui ça l'est.


@sidk oui, il semble que fit param validation_data ne prend pas en charge l'ensemble de données, le générateur ou keras.utils.Sequence et validation_split ne sont pas pris en charge lorsque le paramètre x est un ensemble de données, un générateur ou keras.utils.Sequence per fit info at tensorflow.org / api_docs / python / tf / keras / Modèle



0
votes

La documentation dit que dans le cas d'un générateur, x = doit être mis en tant que générateur et y = ne doit pas être spécifié, ce qui est cohérent avec la signature fit_generator où seule une instance de ImageDataGenerator est transmise.


0 commentaires