使用張量的類似 Numpy 的索引
這個例子基於這篇文章: TensorFlow - 類似 numpy 的張量索引 。
在 Numpy 中,你可以使用陣列索引到陣列中。例如,為了在二維陣列中選擇 (1, 2)
和 (3, 2)
中的元素,你可以這樣做:
# data is [[0, 1, 2, 3, 4, 5],
# [6, 7, 8, 9, 10, 11],
# [12 13 14 15 16 17],
# [18 19 20 21 22 23],
# [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
a = [1, 3]
b = [2, 2]
selected = data[a, b]
print(selected)
這將列印:
[ 8 20]
要在 Tensorflow 中獲得相同的行為,你可以使用 tf.gather_nd
,它是 tf.gather
的擴充套件。上面的例子可以這樣寫:
x = tf.constant(data)
idx1 = tf.constant(a)
idx2 = tf.constant(b)
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
with tf.Session() as sess:
print(sess.run(result))
這將列印:
[ 8 20]
tf.stack
相當於 np.asarray
,在這種情況下,沿著最後一個維度(在本例中為第一維)堆疊兩個索引向量,以產生:
[[1 2]
[3 2]]