如何使用 tf.gather nd
tf.gather_nd 是扩展 tf.gather 的,因为它可以让你不仅可以访问一个张量的 1 维的感觉,但可能所有的人。
参数:
- params:一个等级- P的张量,代表我们想要索引的张量
- indices:一个等级- Q的张量,代表我们想要访问的- params的索引
功能的输出取决于 indices 的形状。如果 indices 的最内层尺寸为 P,我们正在从 params 收集单个元素。如果它小于 P,我们正在收集切片,就像 tf.gather 一样但没有限制我们只能访问第一维。
从等级 2 的张量中收集元素
要在矩阵中访问 (1, 2) 中的元素,我们可以使用:
# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])
result 将如预期的那样成为 8。请注意这与 tf.gather 有何不同:传递给 tf.gather(x, [1, 2]) 的相同索引将作为 data 的第 2 和第 3 行给出。
如果要同时检索多个元素,只需传递一个索引对列表:
result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])
这将返回 [ 8 27 17]
从等级 2 的张量中收集行
如果在上面的示例中你想要收集行(即切片)而不是元素,请按如下方式调整 indices 参数:
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])
这将给你第 2 和第 4 行 data,即
[[ 6  7  8  9 10 11]
 [18 19 20 21 22 23]]
从第 3 级的张量中收集元素
如何访问秩 -2 张量的概念直接转换为更高维度的张量。因此,要访问 rank-3 张量中的元素,indices 的最内层维度必须为 3。
# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])
result 现在看起来像这样:[ 0 11]
从 3 级的张量中收集批量行
让我们把秩 -3 张量想象成一批形状为 tihuan 的矩阵 26。如果要为批处理中的每个元素收集第一行和第二行,可以使用:
# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
这将导致:
[[[0 1]
  [2 3]]
 [[6 7]
  [8 9]]]
注意 indices 的形状如何影响输出张量的形状。如果我们在 indices 参数中使用了 rank-2 张量:
result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])
输出本来是
[[0 1]
 [2 3]
 [6 7]
 [8 9]]