Issue
While looking at PyTorch tutorials, they use classes they define like functions.
e.g.
#Making an instance of the class NeuralNetwork
model = NeuralNetwork()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
I wasn't aware that this could be done in python, and am therefore confused about what is going on, and what is getting called. I know that the forward()
function (a function necessary to any class that inherits from nn.Module) is the movement of data through the neural network, but i'm not sure when that gets called and how model(x)
is working. Also, while trying to find an answer, I came across this link, where a contributor said "A pytorch model is a function. You provide it with appropriately defined input, and it returns an output. If you just want to visually inspect the output given a specific input image, simply call it:". This left me more confused, because the model is a class, right?
Solution
There's a few layers to this. First, the pure python layer:
Python objects have various dunder methods (double underscore) such as __init__
, __repr__
, and others. One such method is __call__
. If you call a python object like a function, it invokes the __call__
method of the object class.
class MyClass
__call__(self, input):
return print(input)
my_class = MyClass()
# doing this
my_class(5)
# is the same as
my_class.__call__(5)
Now the pytorch layer:
Pytorch nn.Module
classes use the __call__
method to invoke the module. The full implementation of what happens is in the nn.Module._call_impl
function (link).
nn.Module._call_impl
makes use of the forward
method you implement in your custom module. The full _call_impl
logic involves things that happen both before and after forward
. Pytorch is designed in a way to abstract these things from the user - you only need to implement your isolated forward
logic and the rest should be handled for you.
Now the ML theory layer:
When someone says A pytorch model is a function
, what they mean is ML models are mathematical functions. They are not referring to the programming level class/function distinction.
Now a deeper pytorch layer:
Most pytorch classes are actually wrappers around a functional implementation at the C/Cuda level. For example, look at the forward
method of nn.Conv2d
. The forward function of the class object actually calls nn.functional.conv2d
, which in turn calls the lower level implementation of the convolution. The functional versions of modules pass both the inputs and weights from the module itself.
Answered By - Karl
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.