如何使用 tf.gather nd
tf.gather_nd
是擴充套件 tf.gather
的,因為它可以讓你不僅可以訪問一個張量的 1 維的感覺,但可能所有的人。
引數:
params
:一個等級P
的張量,代表我們想要索引的張量indices
:一個等級Q
的張量,代表我們想要訪問的params
的索引
功能的輸出取決於 indices
的形狀。如果 indices
的最內層尺寸為 P
,我們正在從 params
收集單個元素。如果它小於 P
,我們正在收集切片,就像 tf.gather
一樣但沒有限制我們只能訪問第一維。
從等級 2 的張量中收集元素
要在矩陣中訪問 (1, 2)
中的元素,我們可以使用:
# data is [[0, 1, 2, 3, 4, 5],
# [6, 7, 8, 9, 10, 11],
# [12 13 14 15 16 17],
# [18 19 20 21 22 23],
# [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])
result
將如預期的那樣成為 8
。請注意這與 tf.gather
有何不同:傳遞給 tf.gather(x, [1, 2])
的相同索引將作為 data
的第 2 和第 3 行給出。
如果要同時檢索多個元素,只需傳遞一個索引對列表:
result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])
這將返回 [ 8 27 17]
從等級 2 的張量中收集行
如果在上面的示例中你想要收集行(即切片)而不是元素,請按如下方式調整 indices
引數:
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])
這將給你第 2 和第 4 行 data
,即
[[ 6 7 8 9 10 11]
[18 19 20 21 22 23]]
從第 3 級的張量中收集元素
如何訪問秩 -2 張量的概念直接轉換為更高維度的張量。因此,要訪問 rank-3 張量中的元素,indices
的最內層維度必須為 3。
# data is [[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])
result
現在看起來像這樣:[ 0 11]
從 3 級的張量中收集批量行
讓我們把秩 -3 張量想象成一批形狀為 tihuan 的矩陣 26。如果要為批處理中的每個元素收集第一行和第二行,可以使用:
# data is [[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
這將導致:
[[[0 1]
[2 3]]
[[6 7]
[8 9]]]
注意 indices
的形狀如何影響輸出張量的形狀。如果我們在 indices
引數中使用了 rank-2 張量:
result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])
輸出本來是
[[0 1]
[2 3]
[6 7]
[8 9]]