儲存模型

在張量流中儲存模型非常簡單。

假設你有一個輸入 x 的線性模型,並想要預測輸出 y。這裡的損失是均方誤差(MSE)。批量大小為 16。

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

這裡有 Saver 物件,它可以有多個引數(參見 doc )。

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

最後,我們在 tf.Session() 中訓練模型,進行 1000 迭代。我們只在每個 100 迭代中儲存模型。

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

執行此程式碼後,你應該看到目錄中的最後 5 個檢查點:

  • model-500model-500.meta
  • model-600model-600.meta
  • model-700model-700.meta
  • model-800model-800.meta
  • model-900model-900.meta

請注意,在這個例子中,雖然 saver 實際上儲存了變數的當前值作為檢查點和圖形的結構(*.meta),但是沒有特別注意如何檢索例如佔位符 xy 一旦模型是恢復。例如,如果在此訓練指令碼以外的任何地方進行恢復,則從恢復的圖形中檢索 xy 可能很麻煩(特別是在更復雜的模型中)。為了避免這種情況,請始終為變數/佔位符/操作命名,或者考慮使用 tf.collections,如其中一個備註所示。