J'ai cet avertissement de dépréciation lors de l'utilisation de Model.fit_generator
dans tensorflow:
WARNING:tensorflow: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version. Instructions for updating: Please use Model.fit, which supports generators.
Comment puis-je utiliser Model.fit
au lieu de Model.fit_generator
?
3 Réponses :
Model.fit_generator
est obsolète à partir de tensorflow 2.1.0 qui se trouve actuellement dans rc1 . Vous pouvez trouver la documentation pour tf-2.1.0-rc1 ici: https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/Model#fit
Comme vous pouvez le voir, le premier argument de Model.fit
peut prendre un générateur, alors passez-le simplement à votre générateur.
Comme mentionné dans la documentation de tensorflow:
x: données d'entrée.
- Cela peut être: Un tableau Numpy (ou de type tableau), ou une liste de tableaux (au cas où le modèle aurait plusieurs entrées).
- Un tenseur TensorFlow, ou une liste de tenseurs (dans le cas où le modèle a plusieurs entrées).
- Un dict mappant les noms d'entrée au tableau / tenseurs correspondants, si le modèle a nommé des entrées.
- Un jeu de données tf.data. Devrait retourner un tuple de (entrées, cibles) ou (entrées, cibles, poids_échantillon)
- Un générateur ou keras.utils.Sequence retournant (entrées, cibles) ou (entrées, cibles, poids d'échantillons). Une description plus détaillée du comportement de décompression pour les types d'itérateur (jeu de données, générateur, séquence) est donnée ci-dessous.
vous pouvez simplement passer le générateur à Model.fit comme similaire à Model.fit_generator
data_gen_train = ImageDataGenerator(rescale=1/255.)
data_gen_valid = ImageDataGenerator(rescale=1/255.)
train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode="binary")
valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode="binary")
model.fit(train_generator, epochs=2, validation_data=valid_generator)
.
Votre réponse est incorrecte ou, devrais-je dire, partiellement correcte. valid_generator n'est pas pris en charge dans model.fit
@sidk oui ça l'est.
@sidk oui, il semble que fit
param validation_data
ne prend pas en charge l'ensemble de données, le générateur ou keras.utils.Sequence et validation_split
ne sont pas pris en charge lorsque le paramètre x
est un ensemble de données, un générateur ou keras.utils.Sequence per fit
info at tensorflow.org / api_docs / python / tf / keras / Modèle
La documentation dit que dans le cas d'un générateur, x = doit être mis en tant que générateur et y = ne doit pas être spécifié, ce qui est cohérent avec la signature fit_generator où seule une instance de ImageDataGenerator
est transmise.