tf.reset_default_graph ()
pour effacer le graphique par défaut. Comment effacer un graphique en quittant le contexte tf.Session ()
?
Exemple (pytest):
import tensorflow as tf def test_1(): x = tf.get_variable('x', initializer=1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(4 / 0) print(sess.run(x)) def test_2(): x = tf.get_variable('x', initializer=1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(x))
3 Réponses :
Quelque chose comme celui-ci fonctionnerait-il?
import tensorflow as tf def test_1(): G = tf.Graph() with G.as_default(): x = tf.get_variable('x', initializer=1) with tf.Session() as sess: sess.run(tf.initializers.global_variables()) print(sess.run(x)) print(4 / 0) def test_2(): G = tf.Graph() with G.as_default(): x = tf.get_variable('x', initializer=1) with tf.Session() as sess: sess.run(tf.initializers.global_variables()) print(sess.run(x))
Ensuite, une solution simple serait simplement un essai / sauf: et dans l'exception, vous réinitialisez le graphique et ensuite vous générez une erreur. Pas très joli cependant
mon mauvais, votre solution fonctionne réellement. J'ai oublié d'initialiser les variables.
Oui, j'essayais juste et je ne pouvais même pas faire fonctionner test_2. Plus grand que!
Une solution directe consiste à utiliser la clause try
... finally
(en fait, il peut être préférable de mettre la clause dans le code qui exécute les tests unitaires plutôt que dans les tests unitaires directement):
def test_1(): with tf.Graph().as_default(), tf.Session() as sess: x = tf.get_variable('x', initializer=1) sess.run(tf.global_variables_initializer()) print(4 / 0) print(sess.run(x)) def test_2(): with tf.Graph().as_default(), tf.Session() as sess: x = tf.get_variable('x', initializer=1) sess.run(tf.global_variables_initializer()) print(sess.run(x))
Une autre solution propre consiste à utiliser un graphique pour chaque test unitaire comme indiqué dans la réponse précédente. Voici une solution alternative basée sur cette idée avec une syntaxe légèrement simplifiée:
def test_1(): x = tf.get_variable('x', initializer=1) try: with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(4 / 0) print(sess.run(x)) finally: tf.reset_default_graph() def test_2(): x = tf.get_variable('x', initializer=1) try: with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(x)) finally: tf.reset_default_graph()
De la même manière que la première solution, l'instruction with
peut également être mise autour du code qui exécute les tests unitaires plutôt que d'être répété dans chaque test unitaire.
la 2ème méthode est un peu plus propre. Mais il doit indenter beaucoup de choses dans le test.
Comme je l'ai expliqué à la fin de la réponse, vous pouvez mettre avec tf.Graph (). As_default ()
dans le code qui exécute les tests et conserver test_1
et méthodes de test_2
inchangées.
Je propose d'utiliser les outils proposés par pytest
:
@pytest.fixture def x(): return tf.get_variable('x', initializer=1) @pytest.fixture def session(x): with tf.Session() as sess: sess.run(tf.global_variables_initializer()) yield sess @pytest.fixture(autouse=True) def init_graph(): with tf.Graph().as_default(): yield def test_1(session, x): print(4 / 0) print(session.run(x)) def test_2(session, x): print(session.run(x))
Le projecteur sera automatiquement invoqué avant et après chaque test (flag autouse code >), le code avant / après
yield
est exécuté avant / après le test. De cette façon, les tests de votre question fonctionneront sans aucune modification et vous suivrez le principe DRY, refusant d'écrire du code dupliqué dans chaque test. Autre exemple:
@pytest.fixture(autouse=True) def init_graph(): with tf.Graph().as_default(): yield
créera un nouveau graphique pour chaque test avant l'exécution du test.
Les appareils dans pytest
sont très puissant et peut éliminer complètement les répétitions de code lorsqu'il est utilisé correctement. Par exemple, les tests de votre question sont équivalents à:
@pytest.fixture(autouse=True) def reset(): yield tf.reset_default_graph()
Si vous voulez en savoir plus, commencez par pytest fixtures: explicite, modulaire, évolutif .