'Reshaping data for sklearn's PolynomialFeatures completely scrambled up the original data
So I understand that in order to use PolynomialFeatures, the original x data needs to be reshaped to make it work. The below plot shows the original data's scatterplot with a simple linear regression line fit to it:
I wanted to do a polynomial fit to the data, but I ended up with this:
Notice that the scatterplot has been scrambled up after reshaping the data, even though it came from the same dataset. What is extremely puzzling, is that the same code however, works perfectly fine with a different dataset.
Can anyone out there figure what exactly is the issue right here? Below is the minimum working example of my codes. Note that one set of data has been commented out. To switch between the two sets, just comment/uncomment the other:
import matplotlib.pyplot as plt
import statsmodels.api as sm
import numpy as np
import copy
from sklearn.preprocessing import PolynomialFeatures
BRforscatter = [0, 1, 0, 2, 1, 3, 5, 3, 7, 6, 6, 6, 7, 5, 8, 6, 7, 7, 8, 7, 6, 5, 2, 4, 7, 7, 0, 13, 3, 3, 3, 3, 2, 4, 4, 6, 5, 3, 5, 4, 6, 3, 2, 2, 4, 6, 7, 7, 6, 6, 7, 6, 8, 7, 0, 4, 7, 6, 8, 8, 10, 9, 8, 4, 7, 8, 6, 8, 8, 7, 8, 5, 9, 7, 9, 6, 4, 9, 10, 10, 7, 7, 7, 8, 7, 8, 9, 11, 5, 7, 9, 8, 8, 7, 7, 6, 6, 8, 9, 7, 7, 10, 8, 11, 9, 9, 11, 0]
RTforscatter = [0.517, 0.433, 0.417, 0.4, 0.433, 0.633, 0.717, 0.733, 0.667, 0.7, 0.45, 0.467, 0.5, 0.567, 0.533, 0.45, 0.433, 0.416, 0.433, 0.45, 0.5, 0.483, 0.483, 0.467, 0.451, 0.4, 0.5, 0.517, 0.467, 0.45, 0.567, 0.417, 0.533, 0.5, 0.5, 0.5, 0.483, 0.5, 0.483, 0.483, 0.417, 0.483, 0.467, 0.533, 0.45, 0.516, 0.483, 0.5, 0.5, 0.483, 0.5, 0.533, 0.533, 0.417, 0.467, 0.7, 0.483, 0.667, 0.467, 0.433, 0.5, 0.433, 0.501, 0.55, 0.483, 0.533, 0.483, 0.567, 0.483, 0.5, 0.45, 0.467, 0.483, 0.748, 0.433, 0.5, 0.514, 0.533, 0.45, 0.483, 0.5, 0.55, 0.8, 0.501, 0.45, 0.483, 0.533, 0.467, 0.5, 0.533, 0.533, 0.45, 0.533, 0.467, 0.5, 0.533, 0.45, 0.533, 0.55, 0.467, 0.5, 0.65, 0.5, 0.517, 0.483, 0.45, 0.55, 0.533]
#BRforscatter = [1.559, 2.951, 4.565, 4.763, 5.716, 7.735, 9.014, 9.63, 11.053, 12.744, 14.092, 18.792, 21.419, 23.726, 24.914, 25.588, 26.731, 27.0, 29.213, 30.952, 36.083, 39.809, 41.17, 46.825, 55.648, 58.439]
#RTforscatter = [204.213, 216.411, 194.764, 180.967, 220.363, 216.88, 205.671, 206.404, 191.447, 172.676, 212.343, 192.509, 184.937, 190.604, 186.404, 197.046, 201.1, 173.751, 189.407, 173.324, 179.159, 196.855, 178.001, 207.638, 200.469, 188.234]
def simple_regplot(
x, y, p=False, n_std=2, n_pts=100, ax=None, scatter_kws=None, line_kws=None, ci_kws=None
):
""" Draw a regression line with error interval. """
ax = plt.gca() if ax is None else ax
# calculate best-fit line and interval
x_fit = sm.add_constant(x)
fit_results = sm.OLS(y, x_fit).fit()
eval_x = sm.add_constant(np.linspace(np.min(x), np.max(x), n_pts))
pred = fit_results.get_prediction(eval_x)
# draw the fit line and error interval
ci_kws = {} if ci_kws is None else ci_kws
ax.fill_between(
eval_x[:, 1],
pred.predicted_mean - n_std * pred.se_mean,
pred.predicted_mean + n_std * pred.se_mean,
alpha=0.5,
**ci_kws,
)
line_kws = {} if line_kws is None else line_kws
h = ax.plot(eval_x[:, 1], pred.predicted_mean, **line_kws)
# draw the scatterplot
scatter_kws = {} if scatter_kws is None else scatter_kws
ax.scatter(x, y, c=h[0].get_color(), **scatter_kws, alpha=0.2)
if p == True:
print(fit_results.summary())
return fit_results
simple_regplot(BRforscatter,RTforscatter)
plt.show()
## ---------------------
allBlinksFull = copy.deepcopy(BRforscatter)
allRTFull = copy.deepcopy(RTforscatter)
allBlinksFull = np.array(allBlinksFull)
allBlinksFull = allBlinksFull[:,np.newaxis]
inds = allBlinksFull.ravel().argsort()
allBlinksFull = allBlinksFull.ravel()[inds].reshape(-1,1)
allRTFull = np.array(allRTFull)
allRTFull = allRTFull[:,np.newaxis]
inds2 = allBlinksFull.ravel().argsort()
allRTFull = allRTFull.ravel()[inds2].reshape(-1,1)
polynomial_features= PolynomialFeatures(degree=2)
xp = polynomial_features.fit_transform(allBlinksFull)
model = sm.OLS(allRTFull, xp).fit()
ypred = model.predict(xp)
plt.scatter(allBlinksFull,allRTFull)
plt.plot(allBlinksFull,ypred)
plt.show()
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|




