rápido python numpy donde la funcionalidad?

Estoy usando numpy donde funciona muchas veces dentro de varios bucles, pero se vuelve demasiado lento. ¿Hay alguna forma de realizar esta funcionalidad más rápido? Leí que debería intentar hacer bucles en línea, así como hacer variables locales para funciones antes de los bucles, pero nada parece mejorar la velocidad en gran medida (<1%). El len(UNIQ_IDS) ~ 800. emiss_data y obj_data son una serie de ndarrays con forma = (2600,5200). He usado el import profile para controlar dónde se encuentran los cuellos de botella y where se encuentran los bucles for uno grande.

 import numpy as np max = np.max where = np.where MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS)] 

Puede usar np.unique con return_index :

 def using_sort(): #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True) uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 vals=uind.argsort() count=np.bincount(uind) start=0 end=0 out=np.empty(count.shape[0]) for ind,x in np.ndenumerate(count): end+=x out[ind]=np.max(np.take(emiss_data,vals[start:end])) start+=x return out 

Usando la respuesta de @unutbu como base para shape = (2600,5200) :

 np.allclose(using_loop(),using_sort()) True %timeit using_loop() 1 loops, best of 3: 12.3 s per loop #With np.unique inside the definition %timeit using_sort() 1 loops, best of 3: 9.06 s per loop #With np.unique outside the definition %timeit using_sort() 1 loops, best of 3: 2.75 s per loop #Using @Jamie's suggestion for uind %timeit using_sort() 1 loops, best of 3: 6.74 s per loop 

Resulta que un bucle de Python puro puede ser mucho más rápido que la indexación NumPy (o llamadas a np.where) en este caso.

Considera las siguientes alternativas:

 import numpy as np import collections import itertools as IT shape = (2600,5200) # shape = (26,52) emiss_data = np.random.random(shape) obj_data = np.random.random_integers(1, 800, size=shape) UNIQ_IDS = np.unique(obj_data) def using_where(): max = np.max where = np.where MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS] return MAX_EMISS def using_index(): max = np.max MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS] return MAX_EMISS def using_max(): MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS] return MAX_EMISS def using_loop(): result = collections.defaultdict(list) for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()): result[idx].append(val) return [max(result[idx]) for idx in UNIQ_IDS] def using_sort(): uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 vals = uind.argsort() count = np.bincount(uind) start = 0 end = 0 out = np.empty(count.shape[0]) for ind, x in np.ndenumerate(count): end += x out[ind] = np.max(np.take(emiss_data, vals[start:end])) start += x return out def using_split(): uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 vals = uind.argsort() count = np.bincount(uind) return [np.take(emiss_data, item).max() for item in np.split(vals, count.cumsum())[:-1]] for func in (using_index, using_max, using_loop, using_sort, using_split): assert using_where() == func() 

Aquí están los puntos de referencia, con shape = (2600,5200) :

 In [57]: %timeit using_loop() 1 loops, best of 3: 9.15 s per loop In [90]: %timeit using_sort() 1 loops, best of 3: 9.33 s per loop In [91]: %timeit using_split() 1 loops, best of 3: 9.33 s per loop In [61]: %timeit using_index() 1 loops, best of 3: 63.2 s per loop In [62]: %timeit using_max() 1 loops, best of 3: 64.4 s per loop In [58]: %timeit using_where() 1 loops, best of 3: 112 s per loop 

Por using_loop tanto, using_loop (Python puro) resulta ser más de using_where más rápido que using_where .

No estoy completamente seguro de por qué Python puro es más rápido que NumPy aquí. Mi conjetura es que la versión de Python pura se desliza (sí, intencionalmente) a través de ambas matrices una vez. Aprovecha el hecho de que, a pesar de toda la indexación elegante, solo queremos visitar cada valor una vez . Por lo tanto, emiss_data el problema de tener que determinar exactamente a qué grupo emiss_data cada valor en emiss_data . Pero esto es solo una vaga especulación. No sabía que iba a ser más rápido hasta que hice un punto de referencia.

Creo que la forma más rápida de lograr esto es usar las operaciones groupby() en el paquete pandas . En comparación con la función using_sort() de using_sort() , Pandas es un factor de 10 más rápido:

 import numpy as np import pandas as pd shape = (2600,5200) emiss_data = np.random.random(shape) obj_data = np.random.random_integers(1, 800, size=shape) UNIQ_IDS = np.unique(obj_data) def using_sort(): #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True) uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 vals=uind.argsort() count=np.bincount(uind) start=0 end=0 out=np.empty(count.shape[0]) for ind,x in np.ndenumerate(count): end+=x out[ind]=np.max(np.take(emiss_data,vals[start:end])) start+=x return out def using_pandas(): return pd.Series(emiss_data.ravel()).groupby(obj_data.ravel()).max() print('same results:', np.allclose(using_pandas(), using_sort())) # same results: True %timeit using_sort() # 1 loops, best of 3: 3.39 s per loop %timeit using_pandas() # 1 loops, best of 3: 397 ms per loop 

No puedes simplemente hacer

 emiss_data[obj_data == i] 

? No estoy seguro de por qué estás usando where .

Asignar una tupla es mucho más rápido que asignar una lista, de acuerdo con ¿Son las tuplas más eficientes que las listas en Python? , así que quizás solo construyendo una tupla en lugar de una lista, esto mejorará la eficiencia.

Si obj_data consiste en enteros relativamente pequeños, puede usar numpy.maximum.at (desde v1.8.0):

 def using_maximumat(): n = np.max(UNIQ_IDS) + 1 temp = np.full(n, -np.inf) np.maximum.at(temp, obj_data, emiss_data) return temp[UNIQ_IDS]