保存模型
在张量流中保存模型非常简单。
假设你有一个输入 x 的线性模型,并想要预测输出 y。这里的损失是均方误差(MSE)。批量大小为 16。
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
这里有 Saver 对象,它可以有多个参数(参见 doc )。
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
最后,我们在 tf.Session() 中训练模型,进行 1000 迭代。我们只在每个 100 迭代中保存模型。
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
sess.run(tf.initialize_all_variables())
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, "./model", global_step=step)
运行此代码后,你应该看到目录中的最后 5 个检查点:
model-500和model-500.metamodel-600和model-600.metamodel-700和model-700.metamodel-800和model-800.metamodel-900和model-900.meta
请注意,在这个例子中,虽然 saver 实际上保存了变量的当前值作为检查点和图形的结构(*.meta),但是没有特别注意如何检索例如占位符 x 和 y 一旦模型是恢复。例如,如果在此训练脚本以外的任何地方进行恢复,则从恢复的图形中检索 x 和 y 可能很麻烦(特别是在更复杂的模型中)。为了避免这种情况,请始终为变量/占位符/操作命名,或者考虑使用 tf.collections,如其中一个备注所示。