Issue
I have learnt that forward hook function has the form as hook_fn(m,x,y)
. m refers to model, x refers to input and y refers to output. I want to write a forward hook function for nn.Transformer
.
However there are to input for transformer layer which is src and tgt. For example, >>> out = transformer_model(src, tgt)
. So how can I differ these inputs?
Solution
Your hook will call your callback function with tuples for x
and y
. As described in the documentation page of torch.nn.Module.register_forward_hook
(it does quite explain the type of x
and y
though).
The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. [...].
model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
Define your callback:
def hook(module, x, y):
print(f'is tuple={isinstance(x, tuple)} - length={len(x)}')
src, tgt = x
print(f'src: {src.shape}')
print(f'tgt: {tgt.shape}')
Hook to your nn.Module
:
>>> model.register_forward_hook(hook)
Do an inference:
>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.