¿Cómo utilizar la distancia de mahalanobis en sklearn DistanceMetrics?

Tal vez esto sea elemental, pero no puedo encontrar un buen ejemplo del uso de la distancia de sklearn en sklearn .

Ni siquiera puedo obtener la métrica de esta manera:

 from sklearn.neighbors import DistanceMetric DistanceMetric.get_metric('mahalanobis') 

Esto genera un error: TypeError: 0-dimensional array given. Array must be at least two-dimensional TypeError: 0-dimensional array given. Array must be at least two-dimensional .

Pero, parece que ni siquiera puedo hacer que tome una matriz:

 DistanceMetric.get_metric('mahalanobis', [[0.5],[0.7]]) 

tiros

 TypeError: get_metric() takes exactly 1 positional argument (2 given) 

Revisé los documentos aquí y aquí . Pero, no veo qué tipos de argumentos está esperando.
¿Hay un ejemplo del uso de la distancia de Mahalanobis que puedo ver?

MahalanobisDistance está esperando un parámetro V que es la matriz de covarianza y, opcionalmente, otro parámetro VI que es el inverso de la matriz de covarianza. Además, ambos parámetros se nombran y no son posicionales.

También revise la cadena de documentación para la clase MahalanobisDistance en el archivo scikit-learn/sklearn/neighbors/dist_metrics.pyx en el repository de sklearn .

Ejemplo:

 In [18]: import numpy as np In [19]: from sklearn.datasets import make_classification In [20]: from sklearn.neighbors import DistanceMetric In [21]: X, y = make_classification() In [22]: DistanceMetric.get_metric('mahalanobis', V=np.cov(X)) Out[22]:  

Editar:

Por algunas razones (¿error?), No puede pasar el objeto de distancia al constructor NearestNeighbor , pero necesita usar el nombre de la métrica de distancia. Además, establecer algorithm='auto' (que por defecto es 'ball_tree' ) no parece funcionar; así que dada la X del código anterior puedes hacer:

 In [23]: nn = NearestNeighbors(algorithm='brute', metric='mahalanobis', metric_params={'V': np.cov(X)}) # returns the 5 nearest neighbors of that sample In [24]: nn.fit(X).kneighbors(X[0, :]) Out[24]: (array([[ 0., 3.21120892, 3.81840748, 4.18195987, 4.21977517]]), array([[ 0, 36, 46, 5, 17]]))