Encuentra filas coincidentes en una matriz numpy de 2 dimensiones

Me gustaría obtener el índice de una matriz Numpy bidimensional que coincide con una fila. Por ejemplo, mi matriz es esta:

vals = np.array([[0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1], [0, 2], [1, 2], [2, 2], [0, 3], [1, 3], [2, 3], [0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1], [0, 2], [1, 2], [2, 2], [0, 3], [1, 3], [2, 3]]) 

Me gustaría obtener el índice que coincide con la fila [0, 1] que es el índice 3 y 15. Cuando hago algo como numpy.where(vals == [0 ,1]) obtengo …

 (array([ 0, 3, 3, 4, 5, 6, 9, 12, 15, 15, 16, 17, 18, 21]), array([0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0])) 

Quiero matriz de índice ([3, 15]).

Necesita la función np.where para obtener los índices:

 >>> np.where((vals == (0, 1)).all(axis=1)) (array([ 3, 15]),) 

O, como indica la documentación:

Si solo se da una condición, return condition.nonzero()

Puede llamar directamente a .nonzero() en la matriz devuelta por .all :

 >>> (vals == (0, 1)).all(axis=1).nonzero() (array([ 3, 15]),) 

Para desmontar eso:

 >>> vals == (0, 1) array([[ True, False], [False, False], ... [ True, False], [False, False], [False, False]], dtype=bool) 

y llamar al método .all en esa matriz (con axis=1 ) le da True donde ambos son verdaderos:

 >>> (vals == (0, 1)).all(axis=1) array([False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False], dtype=bool) 

y para obtener cuales índices son True :

 >>> np.where((vals == (0, 1)).all(axis=1)) (array([ 3, 15]),) 

o

 >>> (vals == (0, 1)).all(axis=1).nonzero() (array([ 3, 15]),) 

Encuentro mi solución un poco más legible, pero como señala Unutbu, lo siguiente puede ser más rápido y devuelve el mismo valor que (vals == (0, 1)).all(axis=1) :

 >>> (vals[:, 0] == 0) & (vals[:, 1] == 1) 
 In [5]: np.where((vals[:,0] == 0) & (vals[:,1]==1))[0] Out[5]: array([ 3, 15]) 

No estoy seguro de por qué, pero esto es significativamente más rápido que
np.where((vals == (0, 1)).all(axis=1)) :

 In [34]: vals2 = np.tile(vals, (1000,1)) In [35]: %timeit np.where((vals2 == (0, 1)).all(axis=1))[0] 1000 loops, best of 3: 808 µs per loop In [36]: %timeit np.where((vals2[:,0] == 0) & (vals2[:,1]==1))[0] 10000 loops, best of 3: 152 µs per loop 

Usando el paquete numpy_indexed , simplemente puede escribir:

 import numpy_indexed as npi print(np.flatnonzero(npi.contains([[0, 1]], vals)))