Issue
I would like to plot a matrix that contains a combination of float and NaN values. This is a 3D plot where X and Y are the matrix coordinates and Z is the value within the matrix.
NaN values should be ignored. It would be great if matplot would fill in the surface between float values, but OK if it wont.
This is an adaptation of the code that I have tried thus far. It should plot the 3 data points that have been assigned manually, but instead, it produces an empty 3D plot.
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(20,15))
ax = fig.add_subplot(111, projection='3d')
X0=0
Xmax=10
Y0=0
Ymax=10
Xfill,Yfill=numpy.meshgrid(range(X0,Xmax),range(Y0,Ymax))
data_matrix=numpy.full(shape=[Xmax,Ymax],fill_value=numpy.nan)
data_matrix[5,5]=3
data_matrix[1,8]=6
data_matrix[7,2]=0.5
ax.plot_surface(Xfill,Yfill, data_matrix[X0:Xmax,Y0:Ymax],color='blue',rstride=1,cstride=1)
plt.show()
Solution
I've shown two options I know of below; one uses a scatter plot, and the other draws a surface given an arbitary set of points.
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(111, projection='3d')
X0 = 0
Xmax = 10
Y0 = 0
Ymax = 10
Xfill, Yfill = np.meshgrid(range(X0, Xmax), range(Y0, Ymax))
data_matrix = np.full(shape=[Xmax, Ymax], fill_value=np.nan)
data_matrix[5, 5] = 3
data_matrix[1, 8] = 6
data_matrix[7, 2] = 0.5
#Pull out the non-nan datapoints
x_valid = np.argwhere(~np.isnan(data_matrix))[:, 0]
y_valid = np.argwhere(~np.isnan(data_matrix))[:, 1]
data_valid = data_matrix[x_valid, y_valid]
#Scatter plot of individual points
ax.scatter(x_valid, y_valid, data_valid, c='tab:red',
s=60, label='scatter', depthshade=False)
#Also works somewhat:
# ax.scatter(Xfill, Yfill, data_matrix)
#Overlay a surface plot that doesn't require a regular grid
ax.plot_trisurf(x_valid, y_valid, data_valid,
cmap='jet', label='trisurf plot', alpha=0.7)
Optional further formatting:
# Some additional flourishes
ax.view_init(azim=20, elev=45, roll=0)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('data')
#Vertical lines from each point
from mpl_toolkits.mplot3d.art3d import Line3DCollection
lines_start = [(x, y, 0) for x, y in zip(x_valid, y_valid)]
lines_end = [(x, y, z) for x, y, z in zip(x_valid, y_valid, data_valid)]
lines = list(zip(lines_start, lines_end))
ax.add_collection(Line3DCollection(lines, linewidth=3,
color='tab:orange', label='vertical projection'))
plt.gcf().legend()
Answered By - user3128
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.