Issue
I am stuck on tracing a PyTorch model on this specific module with an error:
RuntimeError: 0INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":611, please report a bug to PyTorch. We don't have an op for aten::fill_ but it isn't a special case. Argument types: Tensor, bool,
Candidates:
aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> (Tensor(a!))
aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> (Tensor(a!))
Here is the reduced code example to reproduce the bug:
import torch
import torch.nn.functional as F
import torch.nn as nn
class SurroundPattern(nn.Module):
def __init__(self, crop_size=1./2):
super(SurroundPattern, self).__init__()
self.crop_size = crop_size
def forward(self, x, s):
H,W = x.shape[2:]
crop_h = (int(H / 2 - self.crop_size / 2 * H), int(H / 2 + self.crop_size / 2 * H))
crop_w = (int(W / 2 - self.crop_size / 2 * W), int(W / 2 + self.crop_size / 2 * W))
x_mask = torch.zeros(H,W,device=x.device, dtype=torch.bool)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True
inside_indices = torch.where(x_mask)
inside_part = x[:, :, inside_indices[0], inside_indices[1]]
inside_feat = inside_part.mean(2)
outside_indices = torch.where(~x_mask)
outside_part = x[:, :, outside_indices[0], outside_indices[1]]
outside_feat = outside_part.mean(2)
fused = torch.stack([inside_feat, outside_feat], dim=2).unsqueeze(3)
if s is None:
return fused
SH,SW = s.shape[2:]
crop_sh = (int(SH / 2 - self.crop_size / 2 * SH), int(SH / 2 + self.crop_size / 2 * SH))
crop_sw = (int(SW / 2 - self.crop_size / 2 * SW), int(SW / 2 + self.crop_size / 2 * SW))
s_mask = torch.zeros(SH, SW, device=s.device, dtype=torch.bool)
s_mask[crop_sh[0] : crop_sh[1], crop_sw[0] : crop_sw[1]] = True
s_inside_indices = torch.where(s_mask)
inside_sal = s[:, :, s_inside_indices[0], s_inside_indices[1]].flatten(1)
s_outside_indices = torch.where(~s_mask)
outside_sal = s[:, :, s_outside_indices[0], s_outside_indices[1]].flatten(1)
if outside_sal.shape != inside_sal.shape:
outside_sal = F.adaptive_max_pool1d(outside_sal.unsqueeze(1), output_size=784)
outside_sal = outside_sal.squeeze(1)
fused_sal = torch.stack([inside_sal, outside_sal], dim=2).unsqueeze(3)
return fused, fused_sal
x = torch.randn(2, 512, 7, 7)
s = torch.randn(2, 1, 56, 56)
patt = SurroundPattern()
traced_cell = torch.jit.trace(patt, (x, s))
print(traced_cell)
How to figure out where exactly the problem is? Is there a way to fix it with another functions? Thank you!
Solution
The problem is that you try to fill in a bool Tensor which is apparently not yet supported in jit (or a bug)
Replacing this:
x_mask= torch.zeros(H,W,device=x.device, dtype=torch.bool)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = True
with:
x_mask= torch.zeros(H,W,device=x.device)
x_mask[crop_h[0] : crop_h[1], crop_w[0] : crop_w[1]] = 1
should resolve the error. This of course is suboptimal to the target Tensor type but you should be able to perform any other operation you would be doing with torch.BoolTensor
Answered By - Proko
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.