使用 Graph.finalize() 来捕获添加到图中的节点

使用 TensorFlow 的最常见模式包括首先构建 TensorFlow 运算符的数据流图(如 tf.constant()tf.matmul(),然后通过在循环中调用 tf.Session.run() 方法(例如训练循环)来运行步骤 )。

内存泄漏的常见来源是训练循环包含将节点添加到图形的调用,并且这些调用在每次迭代中运行,从而导致图形增长。这些可能是显而易见的(例如,调用 TensorFlow 运算符,如 tf.square()),隐式(例如调用 TensorFlow 库函数创建运算符,如 tf.train.Saver()),或微妙(例如调用 tf.Tensor 和 NumPy 数组上的重载运算符) ,隐含地调用 tf.convert_to_tensor() 并向图中添加新的 tf.constant()

tf.Graph.finalize() 方法可以帮助赶上这样的泄漏:它标志着一个图形为只读,如果有什么被添加到图中引发了异常。例如:

loss = ...
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    sess.graph.finalize()  # Graph is read-only after this statement.

    for _ in range(1000000):
        sess.run(train_op)
        loss_sq = tf.square(loss)  # Exception will be thrown here.
        sess.run(loss_sq)

在这种情况下,重载的*运算符会尝试向图中添加新节点:

loss = ...
# ...
with tf.Session() as sess:
    # ...
    sess.graph.finalize()  # Graph is read-only after this statement.
    # ...
    dbl_loss = loss * 2.0  # Exception will be thrown here.