Issue
I am using torch.multiprocessing
to parallelize my for-loop as follows:
import torch
import torch.multiprocessing as mp
torch.set_num_threads(1)
mp.set_start_method("spawn", force=True)
seqs = # List of strings
def func1(x):
# Some numpy calculation
return a, b
def func2(a, b):
# Run initial computations with torch
# Run forward inference of nn.Module (wrapper of transformers model)
# Perform additional computations with torch
return c
def main_func(x):
a, b = func1(x)
c = func2(a, b)
return c
pool = mp.Pool(processes=10)
results = pool.map(main_func, seqs)
In this code, func1
performs some numpy calculations, while func2
represents the forward call of an nn.Module
, which wraps the EsmForMaskedLM
model.
Specifically, the mentioned nn.Module
can be simplified as below
import torch
from transformers import AutoTokenizer, EsmForMaskedLM, BatchEncoding
class ESM2(nn.Module):
def __init__(self):
super(ESM2, self).__init__()
self.model = EsmForMaskedLM.from_pretrained("facebook/esm2_t12_35M_UR50D")
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
self.model.to("cpu")
def forward(self, inputs: BatchEncoding) -> torch.Tensor:
results = self.model(**inputs) # inputs has been tokenized and passed to CPU
return results
When I run the code sequentially in a for-loop, everything works as expected. However, when I attempt to parallelize it using torch.multiprocessing
, it seems to occupy all CPU cores in my machine, even though I have set a limit on the number of processes. Upon debugging, I found that this issue occurs only when calling func2
(i.e., the forward
function). I suspect there might be a problem with the forward
function, but I'm uncertain. Can someone please assist me with this issue? Thank you very much for your help!
P/s: If you guys need any additional information, please let me know, I am happy to provide more
Solution
Apparently, PyTorch uses a parallelization library called OpenMP
(link); and according to this answer:
OpenMP does multi-threading within a process, and the default number of threads is typically the number that the CPU can actually run simultaneously
...
So, what happens on that quad-core CPU if you run a multiprocessing program that runs 4 Python processes, and each calls an OpenMP function runs 4 threads? You end up running 16 threads on 4 cores
So when I set os.environ["OPENMP_NUM_THREADS"] = "1"
, it performs as expected.
Other resources: HF discussion
Answered By - JonnyJack
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.