Alternativas más rápidas a numpy.argmax / argmin que es lento

Estoy usando mucho argmin y argmax en Python.

Desafortunadamente, la función es muy lenta.

He hecho algunas búsquedas alrededor, y lo mejor que puedo encontrar es aquí:

http://lemire.me/blog/archives/2008/12/17/fast-argmax-in-python/

def fastest_argmax(array): array = list( array ) return array.index(max(array)) 

Desafortunadamente, esta solución es solo la mitad de rápida que np.max, y creo que debería poder encontrar algo tan rápido como np.max.

 x = np.random.randn(10) %timeit np.argmax( x ) 10000 loops, best of 3: 21.8 us per loop %timeit fastest_argmax( x ) 10000 loops, best of 3: 20.8 us per loop 

Como nota, estoy aplicando esto a un Pandas DataFrame Groupby

P.EJ

 %timeit grp2[ 'ODDS' ].agg( [ fastest_argmax ] ) 100 loops, best of 3: 8.8 ms per loop %timeit grp2[ 'ODDS' ].agg( [ np.argmax ] ) 100 loops, best of 3: 11.6 ms per loop 

Donde los datos se ven así:

 grp2[ 'ODDS' ].head() Out[60]: EVENT_ID SELECTION_ID 104601100 4367029 682508 3.05 682509 3.15 682510 3.25 682511 3.35 5319660 682512 2.04 682513 2.08 682514 2.10 682515 2.12 682516 2.14 5510310 682520 4.10 682521 4.40 682522 4.50 682523 4.80 682524 5.30 5559264 682526 5.00 682527 5.30 682528 5.40 682529 5.50 682530 5.60 5585869 682533 1.96 682534 1.97 682535 1.98 682536 2.02 682537 2.04 6064546 682540 3.00 682541 2.74 682542 2.76 682543 2.96 682544 3.05 104601200 4916112 682548 2.64 682549 2.68 682550 2.70 682551 2.72 682552 2.74 5315859 682557 2.90 682558 2.92 682559 3.05 682560 3.10 682561 3.15 5356995 682564 2.42 682565 2.44 682566 2.48 682567 2.50 682568 2.52 5465225 682573 1.85 682574 1.89 682575 1.91 682576 1.93 682577 1.94 5773661 682588 5.00 682589 4.40 682590 4.90 682591 5.10 6013187 682592 5.00 682593 4.20 682594 4.30 682595 4.40 682596 4.60 104606300 2489827 683438 4.00 683439 3.90 683440 3.95 683441 4.30 683442 4.40 3602724 683446 2.16 683447 2.32 Name: ODDS, Length: 65, dtype: float64 

Resulta que np.argmax es np.argmax rápido, pero solo con los arreglos nativos de números. Con datos extranjeros, casi todo el tiempo se gasta en la conversión:

 In [194]: print platform.architecture() ('64bit', 'WindowsPE') In [5]: x = np.random.rand(10000) In [57]: l=list(x) In [123]: timeit numpy.argmax(x) 100000 loops, best of 3: 6.55 us per loop In [122]: timeit numpy.argmax(l) 1000 loops, best of 3: 729 us per loop In [134]: timeit numpy.array(l) 1000 loops, best of 3: 716 us per loop 

Llamé a tu función “ineficiente” porque primero convierte todo a la lista, luego lo repite 2 veces (efectivamente, 3 iteraciones + construcción de lista).

Iba a sugerir algo como esto que solo se repite una vez:

 def imax(seq): it=iter(seq) im=0 try: m=it.next() except StopIteration: raise ValueError("the sequence is empty") for i,e in enumerate(it,start=1): if e>m: m=e im=i return im 

Pero, su versión resulta ser más rápida porque itera muchas veces pero lo hace en C, en lugar de Python, el código. C es mucho más rápido, incluso teniendo en cuenta que también se gasta mucho tiempo en la conversión:

 In [158]: timeit imax(x) 1000 loops, best of 3: 883 us per loop In [159]: timeit fastest_argmax(x) 1000 loops, best of 3: 575 us per loop In [174]: timeit list(x) 1000 loops, best of 3: 316 us per loop In [175]: timeit max(l) 1000 loops, best of 3: 256 us per loop In [181]: timeit l.index(0.99991619010758348) #the greatest number in my case, at index 92 100000 loops, best of 3: 2.69 us per loop 

Por lo tanto, el conocimiento clave para acelerar esto aún más es saber qué formato tienen los datos en su secuencia de forma nativa (por ejemplo, si puede omitir el paso de conversión o usar / escribir otra funcionalidad nativa de ese formato).

Por cierto, es probable que obtengas un poco de aceleración al usar el aggregate(max_fn) lugar de agg([max_fn]) .

¿Puedes publicar algún código? Aquí está el resultado en mi pc:

 x = np.random.rand(10000) %timeit np.max(x) %timeit np.argmax(x) 

salida:

 100000 loops, best of 3: 7.43 µs per loop 100000 loops, best of 3: 11.5 µs per loop 

Para aquellos que vinieron por un breve fragmento de código sin números que devuelve el índice del primer valor mínimo:

 def argmin(a): return min(range(len(a)), key=lambda x: a[x]) a = [6, 5, 4, 1, 1, 3, 2] argmin(a) # returns 3