为什么要使用 tf.py func
tf.py_func()
运算符使你可以在 TensorFlow 图的中间运行任意 Python 代码。包装自定义 NumPy 运算符特别方便,因为没有等效的 TensorFlow 运算符(尚未存在)。添加 tf.py_func()
是在图形中使用 sess.run()
调用的替代方法。
另一种方法是将图形分为两部分:
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
使用 tf.py_func
,这更容易:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)