2
votes

Réinitialiser le graphique par défaut en quittant tf.Session () dans les tests unitaires

  • À la fin de chaque test unitaire, j'appelle tf.reset_default_graph () pour effacer le graphique par défaut.
  • Cependant, lorsqu'un test unitaire échoue, le graphique n'est pas effacé. Cela fait également échouer le prochain test unitaire.
  • Comment effacer un graphique en quittant le contexte tf.Session () ?

    Exemple (pytest):

    import tensorflow as tf
    
    
    def test_1():
        x = tf.get_variable('x', initializer=1)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print(4 / 0)
            print(sess.run(x))
    
    
    def test_2():
        x = tf.get_variable('x', initializer=1)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print(sess.run(x))
    

0 commentaires

3 Réponses :


2
votes

Quelque chose comme celui-ci fonctionnerait-il?

import tensorflow as tf


def test_1():
    G = tf.Graph()
    with G.as_default():
        x = tf.get_variable('x', initializer=1)
        with tf.Session() as sess:
            sess.run(tf.initializers.global_variables())
            print(sess.run(x))
            print(4 / 0)


def test_2():
    G = tf.Graph()
    with G.as_default():
        x = tf.get_variable('x', initializer=1)
        with tf.Session() as sess:
            sess.run(tf.initializers.global_variables())
            print(sess.run(x))


3 commentaires

Ensuite, une solution simple serait simplement un essai / sauf: et dans l'exception, vous réinitialisez le graphique et ensuite vous générez une erreur. Pas très joli cependant


mon mauvais, votre solution fonctionne réellement. J'ai oublié d'initialiser les variables.


Oui, j'essayais juste et je ne pouvais même pas faire fonctionner test_2. Plus grand que!



2
votes

Une solution directe consiste à utiliser la clause try ... finally (en fait, il peut être préférable de mettre la clause dans le code qui exécute les tests unitaires plutôt que dans les tests unitaires directement):

def test_1():
    with tf.Graph().as_default(), tf.Session() as sess:
        x = tf.get_variable('x', initializer=1)

        sess.run(tf.global_variables_initializer())
        print(4 / 0)
        print(sess.run(x))


def test_2():
    with tf.Graph().as_default(), tf.Session() as sess:
        x = tf.get_variable('x', initializer=1)

        sess.run(tf.global_variables_initializer())
        print(sess.run(x))

Une autre solution propre consiste à utiliser un graphique pour chaque test unitaire comme indiqué dans la réponse précédente. Voici une solution alternative basée sur cette idée avec une syntaxe légèrement simplifiée:

def test_1():
    x = tf.get_variable('x', initializer=1)
    try:
       with tf.Session() as sess:
           sess.run(tf.global_variables_initializer())
           print(4 / 0)
           print(sess.run(x))
    finally:
       tf.reset_default_graph()


def test_2():
    x = tf.get_variable('x', initializer=1)
    try:
        with tf.Session() as sess:
           sess.run(tf.global_variables_initializer())
            print(sess.run(x))
    finally:
        tf.reset_default_graph()

De la même manière que la première solution, l'instruction with peut également être mise autour du code qui exécute les tests unitaires plutôt que d'être répété dans chaque test unitaire.


2 commentaires

la 2ème méthode est un peu plus propre. Mais il doit indenter beaucoup de choses dans le test.


Comme je l'ai expliqué à la fin de la réponse, vous pouvez mettre avec tf.Graph (). As_default () dans le code qui exécute les tests et conserver test_1 et méthodes de test_2 inchangées.



6
votes

Je propose d'utiliser les outils proposés par pytest :

@pytest.fixture
def x():
    return tf.get_variable('x', initializer=1)


@pytest.fixture
def session(x):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        yield sess


@pytest.fixture(autouse=True)
def init_graph():
    with tf.Graph().as_default():
        yield


def test_1(session, x):
    print(4 / 0)
    print(session.run(x))


def test_2(session, x):
    print(session.run(x))

Le projecteur sera automatiquement invoqué avant et après chaque test (flag autouse code >), le code avant / après yield est exécuté avant / après le test. De cette façon, les tests de votre question fonctionneront sans aucune modification et vous suivrez le principe DRY, refusant d'écrire du code dupliqué dans chaque test. Autre exemple:

@pytest.fixture(autouse=True)
def init_graph():
    with tf.Graph().as_default():
        yield

créera un nouveau graphique pour chaque test avant l'exécution du test.

Les appareils dans pytest sont très puissant et peut éliminer complètement les répétitions de code lorsqu'il est utilisé correctement. Par exemple, les tests de votre question sont équivalents à:

@pytest.fixture(autouse=True)
def reset():
    yield
    tf.reset_default_graph()

Si vous voulez en savoir plus, commencez par pytest fixtures: explicite, modulaire, évolutif .


0 commentaires