Issue
Here is some sample REINFORCE code found in the PyTorch distributions docs:
probs = policy_network(state)
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
I don't understand why this loss is differentiable. In particular, how does m.log_prob(action)
maintain the computational path of the network output probs
? How are m.log_prob(action)
and probs
are 'connected'?
Edit: I looked at the implementation of log_prob
, and it doesn't even seem to reference self.probs
anywhere; only self.logits
.
Solution
As @lejlot noted in the comments, if a Categorical
object is constructed with probs
rather than logits
, then logits
is later defined in terms of probs
. Hence, when logits
is used in log_prob
, the gradients from probs
are propagated. I missed this connection between logits
and probs
because it doesn't occur in __init__
, but instead, logits
is a lazy_property
.
Answered By - Archie Gertsman
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.