Issue
I have the following Python code:
import json
import sys
import matplotlib.pyplot as plt
import numpy as np
if len(sys.argv) != 2:
print("Usage: python script_name.py <path_to_input_json>")
sys.exit(1)
filename = sys.argv[1]
with open(filename) as f:
data = json.load(f)
# Set up figure and axes
fig, ax = plt.subplots(figsize=(12, 10))
# Set up data for grid
periods = ['early_morning', 'morning', 'noon', 'afternoon', 'evening', 'night', 'late_night']
days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
grid_data = np.full((len(periods), len(days)), '', dtype=object)
# Set up axes
ax.set_xticks(np.arange(len(days)))
ax.set_yticks(np.arange(len(periods)))
ax.set_xticklabels(days)
ax.set_yticklabels(periods)
ax.set_title('Weekly Schedule')
ax.set_xlabel('Day of the week')
ax.set_ylabel('Time Periods')
# Rotate x-axis labels to prevent overlap
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Adjust spacing to prevent label overlap
fig.tight_layout()
# Populate grid
for i, day in enumerate(days):
for j, period in enumerate(periods):
cell_text = []
if day in data and period in data[day]:
if data[day][period]['accuracy_level'] >= 0.7:
for item in data[day][period]['study_material']:
if item['accuracy_level'] >= 0.7:
cell_text.append(item['course'])
cell_text.append(item['subject'])
if data[day][period]['accuracy_level'] >= 0.7:
cell_text.append(data[day][period]['platform_group'])
if cell_text:
ax.text(j, i, cell_text, ha='center', va='center')
# Save figure
plt.savefig('weekly_schedule.png')
It outputs the following result:
Please advise on how a few questions:
- how to align the text into the proper location on the grid, based on the
periods
anddays
values I have from my input data? - is it possible to draw the grid lines?
- is it possible to space 'early_morning' & 'late_night' values on the Y axis and 'Monday' & 'Sunday' on X axis respectively?
Solution
It would be easier to help you if you included a sample of data. That being said. Your code is almost correct. I noticed that you inverted x/y: ax.text(j, i,...)
should be ax.text(i, j,...
. Also there is no need to check twice if data[day][period]['accuracy_level'] >= 0.7
. Displaying cell_text
(a list) isn't readable, you could join its element with a line break for instance: '\n'.join(cell_text)
.
Showing the grid is easy: ax.grid(visible=True)
.
Finally if you want to add columns left/right and rows on top/bottom, you can specify it in ax.set_xticks
and ax.set_xticklabels
(and for y resp.). Note that you'll have to change the above mentionned ax.text(i, j,...
to ax.text(i+1, j+1,...
.
I created a little sample to show you the result:
import json
import sys
import matplotlib.pyplot as plt
import numpy as np
data = {
'Monday': {
'noon': {
'accuracy_level': 0.8,
'study_material': [{
'accuracy_level': 0.8,
'course': 'course1',
'subject': 'subject1'
}],
'platform_group': 'PG1'
},
'evening': {
'accuracy_level': 0.8,
'study_material': [{
'accuracy_level': 0.8,
'course': 'course2',
'subject': 'subject2'
}],
'platform_group': 'PG2'
}
},
'Wednesday': {
'early_morning': {
'accuracy_level': 0.8,
'study_material': [{
'accuracy_level': 0.8,
'course': 'course3',
'subject': 'subject3'
},
{
'accuracy_level': 0.8,
'course': 'course4',
'subject': 'subject4'
}],
'platform_group': 'PG3'
}
}
}
# Set up figure and axes
fig, ax = plt.subplots(figsize=(12, 10))
# Set up data for grid
periods = ['early_morning', 'morning', 'noon', 'afternoon', 'evening', 'night', 'late_night']
days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
grid_data = np.full((len(periods), len(days)), '', dtype=object)
# Set up axes
ax.set_xticks(np.arange(len(days)+2))
ax.set_yticks(np.arange(len(periods)+2))
ax.set_xticklabels([''] + days + [''])
ax.set_yticklabels([''] + periods + [''])
ax.grid(visible=True)
ax.set_title('Weekly Schedule')
ax.set_xlabel('Day of the week')
ax.set_ylabel('Time Periods')
# Rotate x-axis labels to prevent overlap
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Populate grid
for i, day in enumerate(days):
for j, period in enumerate(periods):
cell_text = []
if day in data and period in data[day]:
if data[day][period]['accuracy_level'] >= 0.7:
for item in data[day][period]['study_material']:
if item['accuracy_level'] >= 0.7:
cell_text.append(item['course'])
cell_text.append(item['subject'])
if data[day][period]['accuracy_level'] >= 0.7:
cell_text.append(data[day][period]['platform_group'])
if cell_text:
ax.text(i+1, j+1, '\n'.join(cell_text), ha='center', va='center')
plt.show()
Output:
Answered By - Tranbi
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.