Issue
I am failing to run torch.jit.trace despite my best effort, encountering RuntimeError: Input, output and indices must be on the current device
I have a (fairly complex) model which I have already put on GPU, along with a set of inputs, also on GPU. I can verify that all input tensors and model parameters & buffers are on the same device:
(Pdb) {p.device for p in self.parameters()}
{device(type='cuda', index=0)}
(Pdb) {p.device for p in self.buffers()}
{device(type='cuda', index=0)}
(Pdb) in_ = (<several tensors here>)
(Pdb) {p.device for p in in_}
{device(type='cuda', index=0)}
(Pdb) torch.cuda.current_device()
0
I can certify the model runs and the output is on the correct device:
(Pdb) self(*in_).device
device(type='cuda', index=0)
Despite all this, tracing fails:
(Pdb) generator_script = torch.jit.trace(self, example_inputs=in_)
*** RuntimeError: Input, output and indices must be on the current device
- I understand about inputs and outputs, but what are these "indices" that must also be on the same device?
- What other elements that I am not accounting for could be causing trace to fail?
Solution
If you're not yet mapping the device during the loading process, doing so could be the solution.[1] That is, mapping the device should happen during jit.load
, not as a simple call of .to(device)
after jit.load
has already finished. See this page for more info.
As an example of what to do:
model = jit.load("your_traced_model.pt", map_location=torch.device("cuda"))
This is different from how it works for typical/non-JIT models, where you can simply do:
model = some_model_creation_function()
_ = model.to(torch.device("cuda"))
1 = this does not currently work for the MPS device.
Answered By - carbocation
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.