Issue
Problem:
When I open pytorch model (read as load state_dict from disk) in subprocess it pops up cmd window for couple ms which causes other programs to loose focus - annoying when working on something else etc.
I have traced the cause to 2 lines, both causing it in some circumstances and managed to reproduce it for one (the second one is when doing model.to(device)
)
main.py
model_path = 'testing\\agent\\model_test.pth'
# create model
from testing.agent.torch_model import Net
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # I have cuda available
m = Net()
m.to(device)
# save it
torch.save(m.state_dict(), model_path)
# open it in subprocess
from testing.agent.AgentOpenSim_Process import Open_Agent_Sim
p = Open_Agent_Sim(p=model_path, msgLogger=None)
p.start()
torch_model.py
(source pydocs: https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html)
import torch.nn as nn
import torch.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
AgentOpenSim_Process.py
from multiprocessing import Queue, Process
import os, time, torch
from testing.agent.torch_model import Net
class Open_Agent_Sim(Process):
def __init__(self, p:str, **kwargs):
super(Process, self).__init__(daemon=True)
self.path = p
self._msgLogger = kwargs['msgLogger'] if kwargs['msgLogger'] is not None else Queue()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device_cpu = torch.device("cpu")
def __print(self, msg: str, verbosity):
# personal tool for debugging multiprocessing (it sends messages to Queue and main process
# reads them inside thread and prints into console... and yes i know about multiprocessing
# logger, its used aswell)
self._msgLogger.put(('Open_Agent_Sim: '+msg, verbosity))
def run(self):
self.__pid = os.getpid()
try:
self.__print('opening model',0)
self.init_agent()
self.__print('opening model - done',0)
except Exception as e:
# solved by custom exception wrapper
pass
else:
self.__print('has ended',0)
return
def init_agent(self):
# init instance
self.__print('0a', 0)
m = Net()
self.__print('0b', 0)
time.sleep(2)
# load state dict
self.__print('1a', 0)
l = torch.load(self.path, map_location=self.device_cpu)
self.__print('1b', 0)
time.sleep(2)
self.__print('2a', 0)
m.load_state_dict(l)
# set to device
self.__print('2b', 0)
time.sleep(2)
try:
self.__print('3a', 0)
m.to(self.device) # ----> This line pops up cmd
self.__print('3b', 0)
except RuntimeError as e:
self.__print(str(e), 0)
When visually debugging those cmd pops up, its always in step 1 (m.load_state_dict(torch.load(self.path, map_location=self.device))
)
I have tried something like disabling console output which didnt work.
import contextlib
with contextlib.redirect_stdout(None):
...
The if __name__=='__main__':
makes no difference and also this is all part of heavy multiprocessing in some lower subprocess
Update
I traced problem to switching device - if I use torch.load(self.path, map_location=self.device_cpu)
and later .to(self.device_gpu)
it pops cmd on line with .to(...)
but if I use torch.load(self.path, map_location=self.device_gpu)
it pops on that line. Another thing to note is, that it does not matter on which device model is saved.
I am open to any workaround.
Solution
Updating pytorch version by install command from their website solved the issue
Answered By - Tomas Trdla
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.