'Create function in Python to Create a 3D Scatterplot from Training data
I'm new to Matplotlib and creating functions, but I'm trying to create a function called plot3Ddata, which accepts a Pandas DataFrame (composed of 3 spatial coordinates) and uses scatter3D() to plot only the training data (Columns x and y in my dataset below). I'm able to produce a 3D plot, but it's not a scatterplot, and it isn't limited to only my training data. I appreciate any assistance as I'm not sure how to proceed.
First 5 rows of the data set which I read in as a DataFrame:
x y z
0 6.550561 0.918746 5.056359
1 11.314821 0.675399 2.048655
2 0.001797 0.250325 4.429342
3 4.749025 -0.644546 0.565993
4 2.305234 0.024039 6.768186
## Here's my function called plot3Ddata which accepts a dataframe:
def plot3Ddata(df:pd.DataFrame):
# Transform Pandas data into a format that's compatible with
# Matplotlib's surface and wireframe plotting.
index = df.index
columns = df.columns
x, y = np.meshgrid(np.arange(len(columns)), np.arange(len(index)))
xticks = dict(ticks=np.arange(len(columns)), labels=columns)
yticks = dict(ticks=np.arange(len(index)), labels=index)
return x, y, xticks, yticks
Here I combined my feature columns (x and y) and reshaped them to a numpy.ndarray and also specified my label (z):
# Features
X = data[['x', 'y']].values.reshape(-1,2) # values.reshape converts dataframe (two brackets) to numpy.ndarray
# Label
z = data.iloc[:,- 1].values.ravel()
Then I split out the data using train_test_split:
X_train, X_test, z_train, z_test = train_test_split(X, z, test_size=0.20, random_state=42)
Then I plot the data:
train_df = pd.DataFrame(X_train, z_train)
# plot3Ddata(train_df)
### Transform to Matplotlib friendly format.
x, y, z, xticks, yticks = plot3Ddata(train_df)
### Set up axes and put data on the surface.
# axes = plt.figure().gca(projection='3d')
# axes.plot_surface(x, y, z)
### Set up axes and put data on the surface.
axes = plt.figure().gca(projection='3d')
axes.plot_surface(x, y, z)
axes.set_zlim3d(bottom=0)
plt.xticks(**xticks)
plt.yticks(**yticks)
plt.show()
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|

