Devuelve eficientemente el índice del primer valor que cumple la condición en la matriz

Necesito encontrar el índice del primer valor en una matriz NumPy 1d, o series numéricas de Pandas, que satisfagan una condición. La matriz es grande y el índice puede estar cerca del inicio o final de la matriz, o la condición puede no cumplirse en absoluto. No puedo decir de antemano cuál es más probable. Si la condición no se cumple, el valor de retorno debe ser -1 . He considerado algunos enfoques.

Intento 1

 # func(arr) returns a Boolean array idx = next(iter(np.where(func(arr))[0]), -1) 

Pero esto suele ser demasiado lento, ya que func(arr) aplica una función vectorizada en toda la matriz en lugar de detenerse cuando se cumple la condición. Específicamente, es costoso cuando la condición se cumple cerca del inicio de la matriz.

Intento 2

np.argmax es ligeramente más rápido, pero no identifica cuando una condición nunca se cumple:

 np.random.seed(0) arr = np.random.rand(10**7) assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999) %timeit next(iter(np.where(arr > 0.999999)[0]), -1) # 21.2 ms %timeit np.argmax(arr > 0.999999) # 17.7 ms 

np.argmax(arr > 1.0) devuelve 0 , es decir, una instancia cuando la condición no se cumple.

Intento 3

 # func(arr) returns a Boolean scalar idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1) 

Pero esto es demasiado lento cuando la condición se cumple cerca del final de la matriz. Presumiblemente esto se debe a que la expresión del generador tiene una sobrecarga costosa de un gran número de llamadas __next__ .

¿Es esto siempre un compromiso o hay una manera, para la func genérica, de extraer el primer índice de manera eficiente?

Benchmarking

Para la evaluación comparativa, suponga que func encuentra el índice cuando un valor es mayor que una constante dada:

 # Python 3.6.5, NumPy 1.14.3, Numba 0.38.0 import numpy as np np.random.seed(0) arr = np.random.rand(10**7) m = 0.9 n = 0.999999 # Start of array benchmark %timeit next(iter(np.where(arr > m)[0]), -1) # 43.5 ms %timeit next((idx for idx, val in enumerate(arr) if val > m), -1) # 2.5 µs # End of array benchmark %timeit next(iter(np.where(arr > n)[0]), -1) # 21.4 ms %timeit next((idx for idx, val in enumerate(arr) if val > n), -1) # 39.2 ms 

numba

Con numba es posible optimizar ambos escenarios. Sintácticamente, solo necesitas construir una función con un simple bucle for :

 from numba import njit @njit def get_first_index_nb(A, k): for i in range(len(A)): if A[i] > k: return i return -1 idx = get_first_index_nb(A, 0.9) 

Numba mejora el rendimiento al comstackr el código JIT (“Just In Time”) y aprovechar las optimizaciones a nivel de la CPU . Un bucle normal for sin el decorador @njit suele ser más lento que los métodos que ya ha probado en el caso en que la condición se cumple tarde.

Para una serie numérica de Pandas df['data'] , simplemente puede enviar la representación NumPy a la función comstackda JIT:

 idx = get_first_index_nb(df['data'].values, 0.9) 

Generalización

Dado que numba permite funciones como argumentos , y suponiendo que la función pasada también se puede comstackr con JIT, puede llegar a un método para calcular el índice nth donde se cumple una condición para una func arbitraria.

 @njit def get_nth_index_count(A, func, count): c = 0 for i in range(len(A)): if func(A[i]): c += 1 if c == count: return i return -1 @njit def func(val): return val > 0.9 # get index of 3rd value where func evaluates to True idx = get_nth_index_count(arr, func, 3) 

Para el tercer último valor, puede alimentar el reverso, arr[::-1] , y negar el resultado de len(arr) - 1 , el - 1 necesario para tener en cuenta la indexación de 0.

Evaluación comparativa del rendimiento

 # Python 3.6.5, NumPy 1.14.3, Numba 0.38.0 np.random.seed(0) arr = np.random.rand(10**7) m = 0.9 n = 0.999999 @njit def get_first_index_nb(A, k): for i in range(len(A)): if A[i] > k: return i return -1 def get_first_index_np(A, k): for i in range(len(A)): if A[i] > k: return i return -1 %timeit get_first_index_nb(arr, m) # 375 ns %timeit get_first_index_np(arr, m) # 2.71 µs %timeit next(iter(np.where(arr > m)[0]), -1) # 43.5 ms %timeit next((idx for idx, val in enumerate(arr) if val > m), -1) # 2.5 µs %timeit get_first_index_nb(arr, n) # 204 µs %timeit get_first_index_np(arr, n) # 44.8 ms %timeit next(iter(np.where(arr > n)[0]), -1) # 21.4 ms %timeit next((idx for idx, val in enumerate(arr) if val > n), -1) # 39.2 ms