保存模型
在张量流中保存模型非常简单。
假设你有一个输入 x
的线性模型,并想要预测输出 y
。这里的损失是均方误差(MSE)。批量大小为 16。
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
这里有 Saver 对象,它可以有多个参数(参见 doc )。
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
最后,我们在 tf.Session()
中训练模型,进行 1000
迭代。我们只在每个 100
迭代中保存模型。
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
sess.run(tf.initialize_all_variables())
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, "./model", global_step=step)
运行此代码后,你应该看到目录中的最后 5 个检查点:
model-500
和model-500.meta
model-600
和model-600.meta
model-700
和model-700.meta
model-800
和model-800.meta
model-900
和model-900.meta
请注意,在这个例子中,虽然 saver
实际上保存了变量的当前值作为检查点和图形的结构(*.meta
),但是没有特别注意如何检索例如占位符 x
和 y
一旦模型是恢复。例如,如果在此训练脚本以外的任何地方进行恢复,则从恢复的图形中检索 x
和 y
可能很麻烦(特别是在更复杂的模型中)。为了避免这种情况,请始终为变量/占位符/操作命名,或者考虑使用 tf.collections
,如其中一个备注所示。