在 TensorFlow 中保存和恢复模型
在上面的恢复模型部分,如果我理解正确,你构建模型然后恢复变量。我相信只要在使用 tf.add_to_collection()
保存时添加相关的张量/占位符,就不需要重建模型。例如:
tf.add_to_collection('cost_op', cost_op)
然后,你可以恢复已保存的图表并使用获取 cost_op
的权限
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('model.meta')`
new_saver.restore(sess, 'model')
cost_op = tf.get_collection('cost_op')[0]
即使你没有运行 tf.add_to_collection()
,你也可以检索你的张量,但这个过程有点麻烦,你可能需要做一些挖掘才能找到合适的名字。例如:
在构建张量流图的脚本中,我们定义了一组张量 lab_squeeze
:
...
with tf.variable_scope("inputs"):
y=tf.convert_to_tensor([[0,1],[1,0]])
split_labels=tf.split(1,0,x,name='lab_split')
split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
saver=tf.train.Saver(sess,split_labels)
saver.save("./checkpoint.chk")
我们稍后可以回忆一下如下:
with tf.Session() as sess:
g=tf.get_default_graph()
new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')`
new_saver.restore(sess, './checkpoint.chk')
split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']
split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0')
split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")
有很多方法可以找到张量的名称 - 你可以在张量板上的图表中找到它,或者你可以使用以下内容搜索它:
sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph