如何使用 tf.gather nd
tf.gather_nd
是扩展 tf.gather
的,因为它可以让你不仅可以访问一个张量的 1 维的感觉,但可能所有的人。
参数:
params
:一个等级P
的张量,代表我们想要索引的张量indices
:一个等级Q
的张量,代表我们想要访问的params
的索引
功能的输出取决于 indices
的形状。如果 indices
的最内层尺寸为 P
,我们正在从 params
收集单个元素。如果它小于 P
,我们正在收集切片,就像 tf.gather
一样但没有限制我们只能访问第一维。
从等级 2 的张量中收集元素
要在矩阵中访问 (1, 2)
中的元素,我们可以使用:
# data is [[0, 1, 2, 3, 4, 5],
# [6, 7, 8, 9, 10, 11],
# [12 13 14 15 16 17],
# [18 19 20 21 22 23],
# [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])
result
将如预期的那样成为 8
。请注意这与 tf.gather
有何不同:传递给 tf.gather(x, [1, 2])
的相同索引将作为 data
的第 2 和第 3 行给出。
如果要同时检索多个元素,只需传递一个索引对列表:
result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])
这将返回 [ 8 27 17]
从等级 2 的张量中收集行
如果在上面的示例中你想要收集行(即切片)而不是元素,请按如下方式调整 indices
参数:
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])
这将给你第 2 和第 4 行 data
,即
[[ 6 7 8 9 10 11]
[18 19 20 21 22 23]]
从第 3 级的张量中收集元素
如何访问秩 -2 张量的概念直接转换为更高维度的张量。因此,要访问 rank-3 张量中的元素,indices
的最内层维度必须为 3。
# data is [[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])
result
现在看起来像这样:[ 0 11]
从 3 级的张量中收集批量行
让我们把秩 -3 张量想象成一批形状为 tihuan 的矩阵 26。如果要为批处理中的每个元素收集第一行和第二行,可以使用:
# data is [[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
这将导致:
[[[0 1]
[2 3]]
[[6 7]
[8 9]]]
注意 indices
的形状如何影响输出张量的形状。如果我们在 indices
参数中使用了 rank-2 张量:
result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])
输出本来是
[[0 1]
[2 3]
[6 7]
[8 9]]