使用 Graph.finalize() 来捕获添加到图中的节点
使用 TensorFlow 的最常见模式包括首先构建 TensorFlow 运算符的数据流图(如 tf.constant()
和 tf.matmul()
,然后通过在循环中调用 tf.Session.run()
方法(例如训练循环)来运行步骤 )。
内存泄漏的常见来源是训练循环包含将节点添加到图形的调用,并且这些调用在每次迭代中运行,从而导致图形增长。这些可能是显而易见的(例如,调用 TensorFlow 运算符,如 tf.square()
),隐式(例如调用 TensorFlow 库函数创建运算符,如 tf.train.Saver()
),或微妙(例如调用 tf.Tensor
和 NumPy 数组上的重载运算符) ,隐含地调用 tf.convert_to_tensor()
并向图中添加新的 tf.constant()
。
该 tf.Graph.finalize()
方法可以帮助赶上这样的泄漏:它标志着一个图形为只读,如果有什么被添加到图中引发了异常。例如:
loss = ...
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
sess.graph.finalize() # Graph is read-only after this statement.
for _ in range(1000000):
sess.run(train_op)
loss_sq = tf.square(loss) # Exception will be thrown here.
sess.run(loss_sq)
在这种情况下,重载的*
运算符会尝试向图中添加新节点:
loss = ...
# ...
with tf.Session() as sess:
# ...
sess.graph.finalize() # Graph is read-only after this statement.
# ...
dbl_loss = loss * 2.0 # Exception will be thrown here.