Skip to content

Commit

Permalink
update multi gpu infer.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 19, 2023
1 parent 9d82044 commit 21ac3ee
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ Embedding shape: (768,)
模型自动下载到本机路径:`~/.cache/huggingface/transformers`
- `w2v-light-tencent-chinese`是通过gensim加载的Word2Vec模型,使用腾讯词向量`Tencent_AILab_ChineseEmbedding.tar.gz`计算各字词的词向量,句子向量通过单词词
向量取平均值得到,模型自动下载到本机路径:`~/.text2vec/datasets/light_Tencent_AILab_ChineseEmbedding.bin`
- `text2vec`支持多卡推理(计算文本向量): [examples/computing_embeddings_multi_gpu_demo.py](https://github.com/shibing624/text2vec/blob/master/examples/computing_embeddings_multi_gpu_demo.py)

#### Usage (HuggingFace Transformers)
Without [text2vec](https://github.com/shibing624/text2vec), you can use the model like this:
Expand Down
37 changes: 37 additions & 0 deletions examples/computing_embeddings_multi_gpu_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
This example starts multiple processes (1 per GPU), which encode
sentences in parallel. This gives a near linear speed-up
when encoding large text collections.
This basic example loads a pre-trained model from the web and uses it to
generate sentence embeddings for a given list of sentences.
"""

import sys

sys.path.append('..')
from text2vec import SentenceModel


def main():
# Create a large list of sentences
sentences = ["This is sentence {}".format(i) for i in range(10000)]
model = SentenceModel("shibing624/text2vec-base-chinese")
print(f"Sentences size: {len(sentences)}, model: {model}")

# Start the multi processes pool on all available CUDA devices
pool = model.start_multi_process_pool()

# Compute the embeddings using the multi processes pool
emb = model.encode_multi_process(sentences, pool)
print(f"Embeddings computed. Shape: {emb.shape}")

# Optional: Stop the process in the pool
model.stop_multi_process_pool(pool)


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions tests/test_multi_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
code copy from: SentenceTransformers.tests.test_multi_process.py
"""

import sys
import unittest

sys.path.append('..')
from text2vec import SentenceModel
import numpy as np


class ComputeMultiProcessTest(unittest.TestCase):
def setUp(self):
self.model = SentenceModel()

def test_multi_gpu_encode(self):
# Start the multi processes pool on all available CUDA devices
pool = self.model.start_multi_process_pool(['cpu', 'cpu'])

sentences = ["This is sentence {}".format(i) for i in range(1000)]

# Compute the embeddings using the multi processes pool
emb = self.model.encode_multi_process(sentences, pool, chunk_size=50)
assert emb.shape == (len(sentences), 768)

emb_normal = self.model.encode(sentences)

diff = np.max(np.abs(emb - emb_normal))
print("Max multi proc diff", diff)
assert diff < 0.001
125 changes: 122 additions & 3 deletions text2vec/sentence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
@description: Base sentence model function, add encode function.
Parts of this file is adapted from the sentence-transformers: https://github.com/UKPLab/sentence-transformers
"""
import math
import os
import queue
from enum import Enum
from typing import List, Union, Optional
from typing import List, Union, Optional, Dict

import numpy as np
import torch
import torch.multiprocessing as mp
from loguru import logger
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm
from tqdm.autonotebook import trange
from transformers import AutoTokenizer, AutoModel
Expand Down Expand Up @@ -141,7 +145,7 @@ def get_sentence_embeddings(self, input_ids, attention_mask, token_type_ids=None
def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 64,
batch_size: int = 32,
show_progress_bar: bool = False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
Expand Down Expand Up @@ -288,3 +292,118 @@ def save_model(self, output_dir, model, results=None):
with open(output_eval_file, "w") as writer:
for key in sorted(results.keys()):
writer.write("{} = {}\n".format(key, str(results[key])))

def start_multi_process_pool(self, target_devices: List[str] = None):
"""
Starts multi processes to process the encoding with several, independent processes.
This method is recommended if you want to encode on multiple GPUs. It is advised
to start only one process per GPU. This method works together with encode_multi_process
:param target_devices: PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used
:return: Returns a dict with the target processes, an input queue and output queue.
"""
if target_devices is None:
if torch.cuda.is_available():
target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
else:
logger.info("CUDA is not available. Start 4 CPU worker")
target_devices = ['cpu'] * 4

logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))

ctx = mp.get_context('spawn')
input_queue = ctx.Queue()
output_queue = ctx.Queue()
processes = []

for cuda_id in target_devices:
p = ctx.Process(
target=SentenceModel._encode_multi_process_worker,
args=(cuda_id, self, input_queue, output_queue),
daemon=True
)
p.start()
processes.append(p)

return {'input': input_queue, 'output': output_queue, 'processes': processes}

@staticmethod
def stop_multi_process_pool(pool):
"""
Stops all processes started with start_multi_process_pool
"""
for p in pool['processes']:
p.terminate()

for p in pool['processes']:
p.join()
p.close()

pool['input'].close()
pool['output'].close()

def encode_multi_process(
self,
sentences: List[str],
pool: Dict[str, object],
batch_size: int = 32,
normalize_embeddings: bool = False,
chunk_size: int = None
):
"""
This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
and sent to individual processes, which encode these on the different GPUs. This method is only suitable
for encoding large sets of sentences
:param sentences: List of sentences
:param pool: A pool of workers started with start_multi_process_pool
:param batch_size: Encode sentences with batch size
:param normalize_embeddings: bool, Whether to normalize embeddings before returning them
:param chunk_size: Sentences are chunked and sent to the individual processes. If none, it is a sensible size.
:return: Numpy matrix with all embeddings
"""

if chunk_size is None:
chunk_size = min(math.ceil(len(sentences) / len(pool["processes"]) / 10), 5000)

logger.debug(f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}")

input_queue = pool['input']
last_chunk_id = 0
chunk = []

for sentence in sentences:
chunk.append(sentence)
if len(chunk) >= chunk_size:
input_queue.put([last_chunk_id, batch_size, chunk, normalize_embeddings])
last_chunk_id += 1
chunk = []

if len(chunk) > 0:
input_queue.put([last_chunk_id, batch_size, chunk, normalize_embeddings])
last_chunk_id += 1

output_queue = pool['output']
results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0])
embeddings = np.concatenate([result[1] for result in results_list])
return embeddings

@staticmethod
def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue):
"""
Internal working process to encode sentences in multi processes setup
"""
while True:
try:
id, batch_size, sentences, normalize_embeddings = input_queue.get()
embeddings = model.encode(
sentences,
device=target_device,
show_progress_bar=False,
convert_to_numpy=True,
batch_size=batch_size,
normalize_embeddings=normalize_embeddings,
)
results_queue.put([id, embeddings])
except queue.Empty:
break

0 comments on commit 21ac3ee

Please sign in to comment.