Issue
I have found and adapted the following code snippets for generating diagnostic plots for linear regression. This is currently done using the following functions:
def residual_plot(some_values):
plot_lm_1 = plt.figure(1)
plot_lm_1 = sns.residplot()
plot_lm_1.axes[0].set_title('title')
plot_lm_1.axes[0].set_xlabel('label')
plot_lm_1.axes[0].set_ylabel('label')
plt.show()
def qq_plot(residuals):
QQ = ProbPlot(residuals)
plot_lm_2 = QQ.qqplot()
plot_lm_2.axes[0].set_title('title')
plot_lm_2.axes[0].set_xlabel('label')
plot_lm_2.axes[0].set_ylabel('label')
plt.show()
which are called with something like:
plot1 = residual_plot(value_set1)
plot2 = qq_plot(value_set1)
plot3 = residual_plot(value_set2)
plot4 = qq_plot(value_set2)
How can I create subplots
so that these 4 plots are displayed in a 2x2 grid?
I have tried using:
fig, axes = plt.subplots(2,2)
axes[0,0].plot1
axes[0,1].plot2
axes[1,0].plot3
axes[1,1].plot4
plt.show()
but receive the error:
AttributeError: 'AxesSubplot' object has no attribute 'plot1'
Should I set up the axes attributes from within the functions or where else?
Solution
You should create a single figure with four subplot axes that will serve as input axes for your custom plot functions, following
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import probplot
def residual_plot(x, y, axes = None):
if axes is None:
fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1)
else:
ax1 = axes
p = sns.residplot(x, y, ax = ax1)
ax1.set_xlabel("Data")
ax1.set_ylabel("Residual")
ax1.set_title("Residuals")
return p
def qq_plot(x, axes = None):
if axes is None:
fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1)
else:
ax1 = axes
p = probplot(x, plot = ax1)
ax1.set_xlim(-3, 3)
return p
if __name__ == "__main__":
# Generate data
x = np.arange(100)
y = 0.5 * x
y1 = y + np.random.randn(100)
y2 = y + np.random.randn(100)
# Initialize figure and axes
fig = plt.figure(figsize = (8, 8), facecolor = "white")
ax1 = fig.add_subplot(2, 2, 1)
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)
ax4 = fig.add_subplot(2, 2, 4)
# Plot data
p1 = residual_plot(y, y1, ax1)
p2 = qq_plot(y1, ax2)
p3 = residual_plot(y, y2, ax3)
p4 = qq_plot(y2, ax4)
fig.tight_layout()
fig.show()
I do not know what your ProbPlot function is, so I just took SciPy's one.
Answered By - Kefeng91
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.