Issue
I made a module, that needs an extra loss term, e.g.
class MyModule:
def forward(self, x):
out = f(x)
extra_loss = loss_f(self.parameters(), x)
return out, extra_loss
I can't figure out how to make this module embeddable, for example, into a Sequential
model: any regular module like Linear
put after this one will fail because extra_loss
causes the input to Linear
to be a tuple, which Linear
does not support.
So what I am looking for is extracting that extra loss after running the model forward
my_module = MyModule()
model = Sequential(
my_module,
Linear(my_module_outputs, 1)
)
output = model(x)
my_module_loss = ????
loss = mse(label, output) + my_module_loss
Does module composability support this scenario?
Solution
IMHO, hooks here is overreaction. Provided extra_loss
is additive, we can use global variable like this:
class MyModule:
extra_loss =0
def forward(self, x):
out = f(x)
MyModule.extra_loss += loss_f(self.parameters(), x)
return out
output = model(x)
loss = mse(label, output) + MyModule.extra_loss
MyModule.extra_loss =0
Answered By - Alexey Birukov
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.