Issue
Am unsure why my code does not plot cos(x) (yes, am aware pytorch has cos(x) function)
import math
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(x)
x.grad == torch.cos(x) # assert x.grad same as cos(x)
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
Solution
You need to feed the upstream gradient (equals to all ones in your case) instead of x
as input to y.backward()
.
Thus
import math
import torch
import matplotlib.pyplot as plt
x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(torch.ones_like(x))
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
plt.show()
Answered By - hkchengrex
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.