Grupo argmax / argmin sobre particiones de índices en números

Los ufuncios de ufunc tienen un método de reduceat que los ejecuta en particiones contiguas dentro de una matriz. Así que en lugar de escribir:

 import numpy as np a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9]) split_at = [4, 5] maxima = [max(subarray for subarray in np.split(a, split_at)] 

Puedo escribir:

 maxima = np.maximum.reduceat(a, np.hstack([0, split_at])) 

Ambos devolverán los valores máximos en los segmentos a[0:4] , a[4:5] , a[5:10] , siendo [8, 0, 9] .

Me gustaría una función similar para realizar argmax , señalando que solo me gustaría un solo índice máximo en cada partición: [3, 4, 5] con la anterior a y split_at (a pesar de los índices 5 y 9, ambos obtienen el valor máximo en el último grupo), como sería devuelto por

 np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)] 

Publicaré una posible solución a continuación, pero me gustaría ver una vectorizada sin crear un índice sobre grupos.

Esta solución implica construir un índice sobre grupos ( [0, 0, 0, 0, 1, 2, 2, 2, 2, 2] en el ejemplo anterior).

 group_lengths = np.diff(np.hstack([0, split_at, len(a)])) n_groups = len(group_lengths) index = np.repeat(np.arange(n_groups), group_lengths) 

Entonces podemos usar:

 maxima = np.maximum.reduceat(a, np.hstack([0, split_at])) all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a) result = np.empty(len(group_lengths), dtype='i') result[index[all_argmax[::-1]]] = all_argmax[::-1] 

Para obtener [3, 4, 5] en el result . Los [::-1] s aseguran que obtengamos el primer argmax en lugar del último argmax en cada grupo.

Esto se basa en el hecho de que el último índice en la asignación de fantasía determina el valor asignado, que @seberg dice que no debería confiar (y se puede lograr una alternativa más segura con result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]] , que implica una ordenación sobre len(maxima) ~ n_groups elementos).

Inspirado por esta pregunta, he agregado funcionalidad argmin / max al paquete numpy_indexed . Aquí es cómo se ve la prueba correspondiente. Tenga en cuenta que las claves pueden estar en cualquier orden (y de cualquier tipo admitido por npi):

 def test_argmin(): keys = [2, 0, 0, 1, 1, 2, 2, 2, 2, 2] values = [4, 5, 6, 8, 0, 9, 8, 5, 4, 9] unique, amin = group_by(keys).argmin(values) npt.assert_equal(unique, [0, 1, 2]) npt.assert_equal(amin, [1, 4, 0])