Utilizando estimaciones de statsmodel con la validación cruzada de scikit-learn, ¿es posible?

Publiqué esta pregunta en el foro de validación cruzada y luego me di cuenta de que podría encontrar la audiencia adecuada en Stackoverlfow.

Estoy buscando una forma en la que pueda usar el objeto de fit (resultado) que se obtiene de python statsmodel para ingresar al cross_val_score del método scikit-learn cross_validation? El enlace adjunto sugiere que puede ser posible pero no he tenido éxito.

Estoy teniendo el siguiente error

el estimador debe ser un estimador que implementa el método ‘fit’ statsmodels.discrete.discrete_model.BinaryResultsWrapper objeto en 0x7fa6e801c590 fue pasado

Consulte este enlace

De hecho, no puede usar cross_val_score directamente en los objetos statsmodels , debido a la diferente interfaz: en statsmodels

  • Los datos de entrenamiento se pasan directamente al constructor.
  • Un objeto separado contiene el resultado de la estimación del modelo.

Sin embargo, puede escribir un contenedor simple para hacer que los objetos de statsmodels se vean como estimadores sklear :

 import statsmodels.api as sm from sklearn.base import BaseEstimator, RegressorMixin class SMWrapper(BaseEstimator, RegressorMixin): """ A universal sklearn-style wrapper for statsmodels regressors """ def __init__(self, model_class, fit_intercept=True): self.model_class = model_class self.fit_intercept = fit_intercept def fit(self, X, y): if self.fit_intercept: X = sm.add_constant(X) self.model_ = self.model_class(y, X) self.results_ = self.model_.fit() def predict(self, X): if self.fit_intercept: X = sm.add_constant(X) return self.results_.predict(X) 

Esta clase contiene los métodos correctos de fit y predict , y se puede usar con sklear , por ejemplo, con sklear cruzada o incluida en una tubería. Como aquí:

 from sklearn.datasets import make_regression from sklearn.model_selection import cross_val_score from sklearn.linear_model import LinearRegression X, y = make_regression(random_state=1, n_samples=300, noise=100) print(cross_val_score(SMWrapper(sm.OLS), X, y, scoring='r2')) print(cross_val_score(LinearRegression(), X, y, scoring='r2')) 

Puede ver que la salida de dos modelos es idéntica, porque ambos son modelos OLS, con validación cruzada de la misma manera.

 [0.28592315 0.37367557 0.47972639] [0.28592315 0.37367557 0.47972639]