使用批處理讀取 n 個時期的資料
假設你的資料示例已經讀取到 python 的變數,並且你希望以給定大小的批量讀取 n 次:
import numpy as np
import tensorflow as tf
data = np.array([1, 2, 3, 4, 5])
n = 4
要批量合併資料,可能使用隨機改組,你可以使用 tf.train.batch
或 tf.train.batch_shuffle
,但你需要傳遞一個會產生 n 次全資料的張量:
limited_tensor = tf.train.limit_epochs(data, n)
batch = tf.train.shuffle_batch([limited_tensor], batch_size=3, enqueue_many=True, capacity=4)
limit_epochs
將 numpy 陣列轉換為引擎蓋下的張量並返回一個張量,產生 n 次並隨後丟擲 OutOfRangeError。傳遞給 shuffle_batch
的 enqueue_many=True
參數列示張量列表 [limited_tensor]
中的每個張量應該被解釋為包含許多示例。請注意,批處理佇列的容量可能小於張量中的示例數。
人們可以像往常一樣處理資料:
with tf.Session() as sess:
sess.run(tf.initialize_local_variables())
tf.train.start_queue_runners()
try:
while True:
data_batch = sess.run(batch)
# process data
except tf.errors.OutOfRangeError:
pass