0
votes

Existe-t-il un moyen de trouver la taille du lot pour un tf.data.Dataset

Je comprends que vous pouvez attribuer une taille de lot à un ensemble de données et renvoyer un nouvel objet d'ensemble de données. Existe-t-il une API pour interroger la taille du lot en fonction d'un objet de jeu de données?

J'essaie de trouver les appels à l'adresse:

https://www.tensorflow.org/api_docs/python/tf / data / Dataset


0 commentaires

3 Réponses :


1
votes

Je ne sais pas si vous pouvez simplement l'obtenir en tant qu'attribut, mais vous pouvez simplement parcourir l'ensemble de données une fois et imprimer la forme:

batch size:  3

Si vous savez que votre ensemble de données a des cibles / étiquettes aussi, vous devez itérer comme suit:

# iterating once
for one_batch_x, one_batch_y in f:
    print('batch size:', one_batch_x.shape[0])
    break

Dans les deux cas, il imprimera:

# create a simple tf.data.Dataset with batchsize 3
import tensorflow as tf 
f = tf.data.Dataset.range(10).batch(3) # Dataset with batch_size 3

# iterating once
for one_batch in f:
    print('batch size:', one_batch.shape[0])
    break


1 commentaires

D'accord. J'espérais qu'il y avait une API cachée.




0
votes

Dans Tensorflow 1. * accédez à batch_size via dataset._dataset._batch_size:

import tensorflow as tf
import numpy as np
print(tf.__version__) # 2.0.1

dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)

batch_size = dataset._batch_size.numpy()

print(batch_size) # 10

Dans Tensorflow 2, vous pouvez accéder via dataset._batch_size :

import tensorflow as tf
import numpy as np
print(tf.__version__) # 1.14.0

dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)

with tf.compat.v1.Session() as sess:
    batch_size = sess.run(dataset._dataset._batch_size)
    print(batch_size) # 10


0 commentaires