Je comprends que vous pouvez attribuer une taille de lot à un ensemble de données et renvoyer un nouvel objet d'ensemble de données. Existe-t-il une API pour interroger la taille du lot en fonction d'un objet de jeu de données?
J'essaie de trouver les appels à l'adresse:
https://www.tensorflow.org/api_docs/python/tf / data / Dataset
3 Réponses :
Je ne sais pas si vous pouvez simplement l'obtenir en tant qu'attribut, mais vous pouvez simplement parcourir l'ensemble de données une fois et imprimer la forme:
batch size: 3
Si vous savez que votre ensemble de données a des cibles / étiquettes aussi, vous devez itérer comme suit:
# iterating once
for one_batch_x, one_batch_y in f:
print('batch size:', one_batch_x.shape[0])
break
Dans les deux cas, il imprimera:
# create a simple tf.data.Dataset with batchsize 3
import tensorflow as tf
f = tf.data.Dataset.range(10).batch(3) # Dataset with batch_size 3
# iterating once
for one_batch in f:
print('batch size:', one_batch.shape[0])
break
D'accord. J'espérais qu'il y avait une API cachée.
lorsque vous appelez la méthode .batch (32) , elle renvoie un objet tensorflow.python.data.ops.dataset_ops.BatchDataset . Comme documenté dans Ce type d'objet a un attribut privé appelé ._batch_size qui contient un tenseur de batch_size.
Dans tensorflow 2.X, il vous suffit d'appeler la méthode .numpy () de ce tenseur pour le convertir en type numpy.int64 .
Dans tensorflow 1.X, vous devez caler la méthode .eval () .
Dans Tensorflow 1. * accédez à batch_size via dataset._dataset._batch_size:
import tensorflow as tf import numpy as np print(tf.__version__) # 2.0.1 dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10) batch_size = dataset._batch_size.numpy() print(batch_size) # 10
Dans Tensorflow 2, vous pouvez accéder via dataset._batch_size :
import tensorflow as tf
import numpy as np
print(tf.__version__) # 1.14.0
dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)
with tf.compat.v1.Session() as sess:
batch_size = sess.run(dataset._dataset._batch_size)
print(batch_size) # 10