使用批处理读取 n 个时期的数据
假设你的数据示例已经读取到 python 的变量,并且你希望以给定大小的批量读取 n 次:
import numpy as np
import tensorflow as tf
data = np.array([1, 2, 3, 4, 5])
n = 4
要批量合并数据,可能使用随机改组,你可以使用 tf.train.batch
或 tf.train.batch_shuffle
,但你需要传递一个会产生 n 次全数据的张量:
limited_tensor = tf.train.limit_epochs(data, n)
batch = tf.train.shuffle_batch([limited_tensor], batch_size=3, enqueue_many=True, capacity=4)
limit_epochs
将 numpy 数组转换为引擎盖下的张量并返回一个张量,产生 n 次并随后抛出 OutOfRangeError。传递给 shuffle_batch
的 enqueue_many=True
参数表示张量列表 [limited_tensor]
中的每个张量应该被解释为包含许多示例。请注意,批处理队列的容量可能小于张量中的示例数。
人们可以像往常一样处理数据:
with tf.Session() as sess:
sess.run(tf.initialize_local_variables())
tf.train.start_queue_runners()
try:
while True:
data_batch = sess.run(batch)
# process data
except tf.errors.OutOfRangeError:
pass