Issue
So let's say I have X an input and a sequential network of net A, net B and net C. If I detach net B and I put X through A->B->C, because B is detached do I lose gradient information from A? I would assume no? I'm assuming it would just treat B like a constant to be added to the output of A rather than something differentiable.
Solution
TLDR; Preventing gradient computation on B
won't stop computing gradients for the upstream network A
.
I think there is some confusion on what you consider "detaching a model". In my opinion, there are three things to keep in mind with this kind of thing:
You can
detach
a tensor which effectively detaches it from the computational graph, i.e. if this tensor is used to compute another tensor requiring gradient, the backpropagation step will not propagate past this "detached" tensor.In your way of describing "detaching a model", you can disable gradient computation on given layers of your network by switching the
requires_grad
toFalse
on its parameters. This can done in a single line at the module level withnn.Module.requires_grad_
. So in your case doingB.requires_grad_(False)
will freeze the parameters ofB
such that they can't be updated. In other words, the gradients of the parameters ofB
won't be computed however the intermediate gradients used to propagate toA
will! Here is a minimal example:>>> A = nn.Linear(10,10) >>> B = nn.Linear(10,10) >>> C = nn.Linear(10,10) # disable gradient computation on B >>> B.requires_grad_(False) # dummy input, inference, and backpropagation >>> x = torch.rand(1,10, requires_grad=True) >>> C(B(A(x))).mean().backward()
We can now check that gradients of C and A have indeed be filled properly:
>>> A.weight.grad.sum() tensor(0.3281) >>> C.weight.grad.sum() tensor(-1.6335)
However of course,
B.weight.grad
returnsNone
.Lastly, yet another behaviour is when using the
no_grad
context manager. This effectively kills the gradient. If you do something like:>>> yA = A(x) >>> with torch.no_grad(): ... yB = B(yA) >>> yC = C(yB)
Here
yC
is already detached from the network.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.