Parcela modelo lineal en 3d con matplotlib.

Estoy tratando de crear una gráfica 3D de un ajuste de modelo lineal para un conjunto de datos. Pude hacerlo con relativa facilidad en R, pero estoy realmente luchando para hacer lo mismo en Python. Esto es lo que he hecho en R:

Trama 3d

Esto es lo que he hecho en Python:

from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt import numpy as np import pandas as pd import statsmodels.formula.api as sm csv = pd.read_csv('http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv', index_col=0) model = sm.ols(formula='Sales ~ TV + Radio', data = csv) fit = model.fit() fit.summary() fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(csv['TV'], csv['Radio'], csv['Sales'], c='r', marker='o') xx, yy = np.meshgrid(csv['TV'], csv['Radio']) # Not what I expected :( # ax.plot_surface(xx, yy, fit.fittedvalues) ax.set_xlabel('TV') ax.set_ylabel('Radio') ax.set_zlabel('Sales') plt.show() 

¿Qué estoy haciendo mal y qué debo hacer en su lugar?

Gracias.

Tenías razón al suponer que plot_surface quiere un meshgrid de coordenadas para trabajar, pero predice que quiere una estructura de datos como la que utilizaste (el “exog”).

 exog = pd.core.frame.DataFrame({'TV':xx.ravel(),'Radio':yy.ravel()}) out = fit.predict(exog=exog) ax.plot_surface(xx, yy, out.reshape(xx.shape), color='None') 

¡Lo tengo!

El problema al que me refiero en los comentarios a la respuesta de mdurant es que la superficie no se representa como un patrón cuadrado agradable como estos. Combinando el diagtwig de dispersión con el gráfico de la superficie .

Me di cuenta de que el problema era mi meshgrid , así que meshgrid ambos rangos ( y ) y usé pasos proporcionales para np.arange .

¡Esto me permitió usar el código provisto por la respuesta de mdurant y funcionó perfectamente!

Aquí está el resultado:

Diagrama de dispersión 3D con plano OLS

Y aquí está el código:

 from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt import numpy as np import pandas as pd import statsmodels.formula.api as sm from matplotlib import cm csv = pd.read_csv('http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv', index_col=0) model = sm.ols(formula='Sales ~ TV + Radio', data = csv) fit = model.fit() fit.summary() fig = plt.figure() ax = fig.add_subplot(111, projection='3d') x_surf = np.arange(0, 350, 20) # generate a mesh y_surf = np.arange(0, 60, 4) x_surf, y_surf = np.meshgrid(x_surf, y_surf) exog = pd.core.frame.DataFrame({'TV': x_surf.ravel(), 'Radio': y_surf.ravel()}) out = fit.predict(exog = exog) ax.plot_surface(x_surf, y_surf, out.reshape(x_surf.shape), rstride=1, cstride=1, color='None', alpha = 0.4) ax.scatter(csv['TV'], csv['Radio'], csv['Sales'], c='blue', marker='o', alpha=1) ax.set_xlabel('TV') ax.set_ylabel('Radio') ax.set_zlabel('Sales') plt.show()