Issue
I have this csv File with my datapoints:
index,0,1,2,3,4,5,6,7,8,9
0,0.1979067325592041,0.18397437781095505,0.17194999009370804,0.1615455374121666,0.15237035602331161,0.1441698521375656,0.13674741983413696,0.12997239083051682,0.1237289234995842,0.11788441985845566
1,0.11262453719973564,0.10770953074097633,0.10310722887516022,0.098851528018713,0.0949881486594677,0.09146557748317719,0.0882694236934185,0.08538785949349403,0.08278434351086617,0.08050092682242393
2,0.1496530845761299,0.1432030200958252,0.1372092440724373,0.13438178598880768,0.12992391735315323,0.13074920699000359,0.1275181733071804,0.1273222416639328,0.12758294492959976,0.12424736469984055
I would like to use Matplotlib and plot these three datarows as three graphs in one line chart figure. The column names should be on the x-Axis and the values of each element as the y-Axis. I have written this program to do it for me, but the output of it is not my desired result.
import matplotlib.pyplot as plt
import pandas as pd
if __name__ == "__main__":
test_losses = pd.read_csv(filepath_or_buffer="data.csv", index_col="index")
plt.plot(test_losses)
plt.title("Test Loss")
plt.ylabel("MSE Loss")
plt.xlabel("Epoche")
plt.show()
How can I express that I want the rows as independent graphs in one plot, the column names as the x-Values and the data-elements as the y-Values?
Solution
One example output would be below.
data = [
['index',0,1,2,3,4,5,6,7,8,9],
[0,0.1979067325592041,0.18397437781095505,0.17194999009370804,0.1615455374121666,0.15237035602331161,0.1441698521375656,0.13674741983413696,0.12997239083051682,0.1237289234995842,0.11788441985845566],
[1,0.11262453719973564,0.10770953074097633,0.10310722887516022,0.098851528018713,0.0949881486594677,0.09146557748317719,0.0882694236934185,0.08538785949349403,0.08278434351086617,0.08050092682242393],
[2,0.1496530845761299,0.1432030200958252,0.1372092440724373,0.13438178598880768,0.12992391735315323,0.13074920699000359,0.1275181733071804,0.1273222416639328,0.12758294492959976,0.12424736469984055]
]
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 1, figsize=(6, 4))
axes.plot(data[1], marker=".", label="Graph1", color='tab:blue')
axes.plot(data[2], marker=".", label="Graph2", color='tab:red')
axes.plot(data[3], marker=".", label="Graph3", color='tab:green')
axes.legend(['Graph 1', 'Graph 2', 'Graph 3'])
# axes.set_yscale('log')
axes.set_title("Test Loss")
axes.set_xlabel("Epoch")
axes.set_ylabel("MSE Loss")
Answered By - Jiho Choi
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.