在 TensorFlow 中儲存和恢復模型

在上面的恢復模型部分,如果我理解正確,你構建模型然後恢復變數。我相信只要在使用 tf.add_to_collection() 儲存時新增相關的張量/佔位符,就不需要重建模型。例如:

tf.add_to_collection('cost_op', cost_op)

然後,你可以恢復已儲存的圖表並使用獲取 cost_op 的許可權

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

即使你沒有執行 tf.add_to_collection(),你也可以檢索你的張量,但這個過程有點麻煩,你可能需要做一些挖掘才能找到合適的名字。例如:

在構建張量流圖的指令碼中,我們定義了一組張量 lab_squeeze

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

我們稍後可以回憶一下如下:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

有很多方法可以找到張量的名稱 - 你可以在張量板上的圖表中找到它,或者你可以使用以下內容搜尋它:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph