Issue
I have an objective function where I am looking to maximize a loss w.r.t f1
and f2
, which are encoders; at the same time minimizing it w.r.t
to g
, which is a bijective convolution and X is just an image.
Here is how I assume it's supposed to be done.
obs_ch1, obs_ch23 = self.infomin.split_RGB_into_R_GB(obs=obs_RGB)
obs_ch1, obs_ch23 = self.infomin.R_GB_to_frame_stacked_R_GB(
obs_R=obs_ch1, obs_GB=obs_ch23)
# minimize wrt g
logits = self.infomin.compute_logits(anchor=obs_ch1, pos=obs_ch23)
loss = self.cross_entropy_loss(logits, labels)
self.infomin_discrim_optimizer.zero_grad()
loss.backward()
self.infomin_discrim_optimizer.step()
# maximize wrt f1, f2
logits = self.infomin.compute_logits(
anchor=obs_ch1.detach(), pos=obs_ch23.detach())
labels = torch.arange(logits.shape[0]).long().to(self.device)
loss = self.cross_entropy_loss(logits, labels)
self.infomin_encoders_optimizer.zero_grad()
(-loss).backward()
self.infomin_encoders_optimizer.step()
So I have both f1
and f2
using the same optimizer which is infomin_encoders_optimizer
. I calculate the NCE loss w.r.t to just g
then I detach that tensor and calculate NCE w.r.t f1
and f2
. I take the opposite of that value and backpropagate them separately. The reason I do it this way is because I can't do them together because they are opposing objective directions. Also have to detach what comes out of g
, because otherwise, it gives an error about multiple gradient updates to g
.
How are you supposed to minimize and maximize an objective at the same time w.r.t to different parameters?
Solution
This is actually the proper way of proceeding in order to optimize a min-max objective. You can't solve this kind of problem with a single optimization step simply because both loss functions (for f1
/f2
and for g
) are based on a result computed with I_NCE
. This means you are required to infer twice: first for computing the objective for g
, then a second time for computing the objective for f1
/f2
.
Note this is a very similar procedure if not identical to training a generative adversarial network.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.