Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ws rebase #2

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add mp distributed executor
  • Loading branch information
MengqingCao committed Oct 10, 2024
commit 13c2158b580401a3133de2e264aecfac4a1a70fc
16 changes: 14 additions & 2 deletions examples/offline_inference_npu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
import gc
import torch
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment

def clean_up():
destroy_model_parallel()
destroy_distributed_environment()
gc.collect()
torch.npu.empty_cache()

# Sample prompts.
prompts = [
Expand All @@ -15,8 +24,8 @@
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)

# Create an LLM.
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-7B-Instruct")
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, distributed_executor_backend="mp")
# llm = LLM(model="Qwen/Qwen2-7B-Instruct")
# llm = LLM(model="/workspace/cmq/models/LLM-Research/Meta-Llama-3-8B-Instruct")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
Expand All @@ -26,3 +35,6 @@
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

del llm
clean_up()
8 changes: 5 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,11 @@ def _get_executor_cls(
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif engine_config.device_config.device_type == "npu":
if distributed_executor_backend == "ray":
# TODO
pass
if distributed_executor_backend == "mp":
from vllm.executor.multiproc_npu_executor import MultiprocessingNPUExecutorAsync
executor_class = MultiprocessingNPUExecutorAsync
elif distributed_executor_backend == "ray":
raise NotImplementedError("ray is not implemented in Ascend NPU currently")
else:
from vllm.executor.npu_executor import NPUExecutorAsync
executor_class = NPUExecutorAsync
Expand Down
17 changes: 10 additions & 7 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,15 @@ def _get_executor_cls(cls,
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
elif engine_config.device_config.device_type == "npu":
if distributed_executor_backend == "mp":
from vllm.executor.multiproc_npu_executor import MultiprocessingNPUExecutorAsync
executor_class = MultiprocessingNPUExecutorAsync
elif distributed_executor_backend == "ray":
raise NotImplementedError("ray is not implemented in Ascend NPU currently")
else:
from vllm.executor.npu_executor import NPUExecutorAsync
executor_class = NPUExecutorAsync
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
Expand All @@ -537,13 +546,7 @@ def _get_executor_cls(cls,
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor
elif engine_config.device_config.device_type == "npu":
if distributed_executor_backend == "ray":
# TODO
pass
else:
from vllm.executor.npu_executor import NPUExecutor
executor_class = NPUExecutor

else:
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
Expand Down
1 change: 1 addition & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
VLLM_USE_TRITON_FLASH_ATTN: bool = False
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
ASCEND_RT_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
VLLM_API_KEY: Optional[str] = None
S3_ACCESS_KEY_ID: Optional[str] = None
Expand Down
264 changes: 264 additions & 0 deletions vllm/executor/multiproc_npu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import os, asyncio
import torch, torch_npu # noqa
import signal
import threading
import weakref
from functools import partial
from typing import Any, List, Optional

from vllm.executor.npu_executor import NPUExecutor
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
update_environment_variables)

logger = init_logger(__name__)


class MultiprocessingNPUExecutor(DistributedGPUExecutor, NPUExecutor):
"""Python multiprocessing-based multi-NPU executor

Do not directly inherit the class MultiprocessingGPUExecutor
because triton cannot be installed on NPU machine, rasing import error.
"""

uses_ray: bool = False

def _init_executor(self) -> None:
self._check_executor_parameters()

# Create the parallel NPU workers.
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each npu can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)


# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())

self.workers: List[ProcessWorkerWrapper] = []
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[ProcessWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[ProcessWorkerWrapper] = []

if world_size == 1:
self.worker_monitor = None
else:
result_handler = ResultHandler()
for rank in range(1, world_size):
worker = ProcessWorkerWrapper(
result_handler,
partial(
create_worker,
**self._get_create_worker_kwargs(
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)))
self.workers.append(worker)
if rank % tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)

self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()

# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well

# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)

def shutdown(signum, frame):
if executor := ref():
executor.shutdown()

if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)

self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

def _check_executor_parameters(self):
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Set ASCEND_RT_VISIBLE_DEVICES for the driver, inherited by workers
if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"ASCEND_RT_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})

npu_device_count = torch.npu.device_count()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= npu_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local Ascend npu count ({npu_device_count})")

assert world_size <= npu_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local Ascend npu count ({npu_device_count})")

def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()

def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.

Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(execute_model_req)

def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.

Args:
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""

if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")

if async_run_tensor_parallel_workers_only:
# Run only non-driver workers and just return futures.
return [
worker.execute_method(method, *args, **kwargs)
for worker in self.non_driver_workers
]

# Start all remote workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]

driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)

# Get the results of the workers.
return [driver_worker_output
] + [output.get() for output in worker_outputs]

def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
if self.worker_monitor is not None and not self.worker_monitor.is_alive(
):
raise RuntimeError("Worker processes are not running")

def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()


class MultiprocessingNPUExecutorAsync(MultiprocessingNPUExecutor,
DistributedGPUExecutorAsync):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
self.pp_locks: Optional[List[asyncio.Lock]] = None

async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
if not self.tp_driver_workers:
return await self.driver_exec_model(execute_model_req)

if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]

tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method_async,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)

# Only the last PP stage has the final results.
return results[-1]

async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
Loading