Issue
torchview does only show the input dimensions of my network
from torchview import draw_graph
net = torch.nn.Sequential(torch.nn.Conv2d((3, 256, 256)))
net = torch.nn.DataParallel(net, gpu_ids)
draw_graph(model=net, input_size=(4, 3, 256, 256))
I expected to see the full graph.
Solution
The problem is the DataParallel
wrapper, that didn't allow torchview to see through the network.
By "unpacking" it using net.module
, everything works as expected:
from torchview import draw_graph
net = torch.nn.Sequential(torch.nn.Conv2d((3, 256, 256)))
net = torch.nn.DataParallel(net, gpu_ids)
draw_graph(model=net.module, input_size=(4, 3, 256, 256)) # note unpacking
Answered By - Klops
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.