numpy.take_along_axis#

numpy.take_along_axis(arr, indices, axis=-1)[源代码]#

通过匹配一维索引和数据切片,从输入数组中获取值.

这会迭代在索引和数据数组中沿指定轴定向的匹配的 1d 切片,并使用前者来查找后者中的值.这些切片可以具有不同的长度.

返回沿轴的索引的函数(如 argsortargpartition )会为此函数生成合适的索引.

参数:
arrndarray (Ni…, M, Nk…)

源数组

indicesndarray (Ni…, J, Nk…)

沿 arr 的每个 1d 切片获取的索引.这必须与 arr 的维度匹配,但维度 Ni 和 Nj 只需要与 arr 进行广播.

axisint 或 None,可选

沿其获取 1d 切片的轴.如果 axis 为 None,则输入数组将被视为首先展平为 1d,以便与 sortargsort 保持一致.

在 2.3 版本发生变更: 默认值现在为 -1 .

返回:
out: ndarray (Ni…, J, Nk…)

索引结果.

参见

take

沿轴获取,对每个 1d 切片使用相同的索引

put_along_axis

通过匹配一维索引和数据切片,将值放入目标数组.

注释

这等效于(但比)以下使用 ndindexs_ 的方法更快,该方法将 iikk 都设置为索引元组:

Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:]
J = indices.shape[axis]  # Need not equal M
out = np.empty(Ni + (J,) + Nk)

for ii in ndindex(Ni):
    for kk in ndindex(Nk):
        a_1d       = a      [ii + s_[:,] + kk]
        indices_1d = indices[ii + s_[:,] + kk]
        out_1d     = out    [ii + s_[:,] + kk]
        for j in range(J):
            out_1d[j] = a_1d[indices_1d[j]]

等效地,消除内部循环,最后两行将是:

out_1d[:] = a_1d[indices_1d]

示例

>>> import numpy as np

对于此示例数组

>>> a = np.array([[10, 30, 20], [60, 40, 50]])

我们可以直接使用 sort 进行排序,或者使用 argsort 和此函数进行排序

>>> np.sort(a, axis=1)
array([[10, 20, 30],
       [40, 50, 60]])
>>> ai = np.argsort(a, axis=1)
>>> ai
array([[0, 2, 1],
       [1, 2, 0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 20, 30],
       [40, 50, 60]])

如果使用 keepdims 保持微不足道的维度,则同样适用于 max 和 min:

>>> np.max(a, axis=1, keepdims=True)
array([[30],
       [60]])
>>> ai = np.argmax(a, axis=1, keepdims=True)
>>> ai
array([[1],
       [0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[30],
       [60]])

如果我们要同时获得最大值和最小值,我们可以先堆叠索引

>>> ai_min = np.argmin(a, axis=1, keepdims=True)
>>> ai_max = np.argmax(a, axis=1, keepdims=True)
>>> ai = np.concatenate([ai_min, ai_max], axis=1)
>>> ai
array([[0, 1],
       [1, 0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 30],
       [40, 60]])