儲存模型
在張量流中儲存模型非常簡單。
假設你有一個輸入 x
的線性模型,並想要預測輸出 y
。這裡的損失是均方誤差(MSE)。批量大小為 16。
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
這裡有 Saver 物件,它可以有多個引數(參見 doc )。
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
最後,我們在 tf.Session()
中訓練模型,進行 1000
迭代。我們只在每個 100
迭代中儲存模型。
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
sess.run(tf.initialize_all_variables())
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, "./model", global_step=step)
執行此程式碼後,你應該看到目錄中的最後 5 個檢查點:
model-500
和model-500.meta
model-600
和model-600.meta
model-700
和model-700.meta
model-800
和model-800.meta
model-900
和model-900.meta
請注意,在這個例子中,雖然 saver
實際上儲存了變數的當前值作為檢查點和圖形的結構(*.meta
),但是沒有特別注意如何檢索例如佔位符 x
和 y
一旦模型是恢復。例如,如果在此訓練指令碼以外的任何地方進行恢復,則從恢復的圖形中檢索 x
和 y
可能很麻煩(特別是在更復雜的模型中)。為了避免這種情況,請始終為變數/佔位符/操作命名,或者考慮使用 tf.collections
,如其中一個備註所示。