Skip to content

Commit

Permalink
Fix passing vLLM server URL (huggingface#21)
Browse files Browse the repository at this point in the history
* Use head node ip as vLLM server url

* Pass correct server url

* Add num_generations argument

* Fix style

* Remove `select`

---------

Co-authored-by: plaguss <[email protected]>
  • Loading branch information
gabrielmbmb and plaguss authored Jan 25, 2025
1 parent f844eac commit a90b996
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 23 deletions.
56 changes: 44 additions & 12 deletions slurm/generate.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#SBATCH --job-name=deepseek-r1-generation
#SBATCH --partition=hopper-prod
#SBATCH --qos=normal
#SBATCH --nodes=2
#SBATCH --nodes=4
#SBATCH --exclusive
#SBATCH --gpus-per-node=8
#SBATCH --output=./logs/%x-%j.out
Expand Down Expand Up @@ -44,6 +44,10 @@ while [[ $# -gt 0 ]]; do
MAX_NEW_TOKENS="$2"
shift 2
;;
--num-generations)
NUM_GENERATIONS="$2"
shift 2
;;
--hf-output-dataset)
HF_OUTPUT_DATASET="$2"
shift 2
Expand All @@ -64,15 +68,32 @@ if [ -z "$MODEL" ] || [ -z "$HF_DATASET" ]; then
exit 1
fi

# Set default values for optional parameters
HF_DATASET_SPLIT=${HF_DATASET_SPLIT:-"train"}
PROMPT_COLUMN=${PROMPT_COLUMN:-"prompt"}
TEMPERATURE=${TEMPERATURE:-0.7}
TOP_P=${TOP_P:-0.9}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-8192}
NUM_GENERATIONS=${NUM_GENERATIONS:-1}
PRIVATE=${PRIVATE:-"false"}

# Print all input arguments
echo "Input arguments:"
echo "MODEL: $MODEL"
echo "HF_DATASET: $HF_DATASET"
echo "HF_DATASET_CONFIG: $HF_DATASET_CONFIG"
echo "HF_DATASET_SPLIT: $HF_DATASET_SPLIT"
echo "PROMPT_COLUMN: $PROMPT_COLUMN"
echo "TEMPERATURE: $TEMPERATURE"
echo "TOP_P: $TOP_P"
echo "MAX_NEW_TOKENS: $MAX_NEW_TOKENS"
echo "NUM_GENERATIONS: $NUM_GENERATIONS"
echo "HF_OUTPUT_DATASET: $HF_OUTPUT_DATASET"
echo "PRIVATE: $PRIVATE"
echo "-------------------"

set -ex

module load cuda/12.1

export LD_LIBRARY_PATH=.venv/lib/python3.11/site-packages/nvidia/nvjitlink/lib

echo "SLURM_JOB_ID: $SLURM_JOB_ID"
Expand Down Expand Up @@ -127,19 +148,19 @@ RAY_ADDRESS="http://$head_node_ip:8265" ray job submit \
--no-wait \
-- vllm serve $MODEL \
--tensor-parallel-size 8 \
--pipeline-parallel-size 2 \
--max-model-len 32768 \
--pipeline-parallel-size 4 \
--max-model-len 16384 \
--enable-chunked-prefill \
--trust-remote-code \
--distributed-executor-backend ray

# wait for vllm to load the model
echo "Waiting for vLLM (http://localhost:8000) server to be up..."
echo "Waiting for vLLM (http://$head_node_ip:8000) server to be up..."

# wait for vllm to load and serve the model
while true; do
if curl -s -o /dev/null -w "%{http_code}" http://localhost:8000 >/dev/null 2>&1; then
echo "Received response from http://localhost:8000"
if curl -s -o /dev/null -w "%{http_code}" http://$head_node_ip:8000 >/dev/null 2>&1; then
echo "Received response from http://$head_node_ip:8000"
break
else
echo "Still waiting... (Press Ctrl+C to cancel)"
Expand All @@ -148,21 +169,32 @@ while true; do
done

echo "Checking available models..."
curl http://localhost:8000/v1/models
curl http://$head_node_ip:8000/v1/models

echo "Executing sanity check..."
curl http://$head_node_ip:8000/v1/completions \
-H "Content-Type: application/json" \
-d "{
\"model\": \"$MODEL\",
\"prompt\": \"<|begin▁of▁sentence|><|User|>hi, how are you?<|Assistant|>\",
\"max_tokens\": 2048,
\"temperature\": 0.6
}"

# Finally submit the job to the cluster
echo "Submitting job to ray cluster..."
RAY_ADDRESS="http://$head_node_ip:8265" ray job submit \
--working-dir pipeline \
--working-dir src/open_r1 \
-- python -u generate.py \
--model "$MODEL" \
--hf-dataset "$HF_DATASET" \
${HF_DATASET_CONFIG:+--hf-dataset-config "$HF_DATASET_CONFIG"} \
--hf-dataset-split "$HF_DATASET_SPLIT" \
--prompt-column "$PROMPT_COLUMN" \
--temperature "$TEMPERATURE" \
--top-p "$TOP_P" \
${TEMPERATURE:+--temperature "$TEMPERATURE"} \
${TOP_P:+--top-p "$TOP_P"} \
--max-new-tokens "$MAX_NEW_TOKENS" \
--num-generations "$NUM_GENERATIONS" \
${HF_OUTPUT_DATASET:+--hf-output-dataset "$HF_OUTPUT_DATASET"} \
${PRIVATE:+--private} \
--vllm-server-url "http://$head_node_ip:8000/v1"
36 changes: 25 additions & 11 deletions src/open_r1/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,32 @@ def build_distilabel_pipeline(
model: str,
base_url: str = "http://localhost:8000/v1",
prompt_column: Optional[str] = None,
temperature: float = 0.7,
top_p: float = 0.9,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_new_tokens: int = 8192,
num_generations: int = 1,
) -> Pipeline:
generation_kwargs = {"max_new_tokens": max_new_tokens}

if temperature is not None:
generation_kwargs["temperature"] = temperature

if top_p is not None:
generation_kwargs["top_p"] = top_p

with Pipeline().ray() as pipeline:
TextGeneration(
llm=OpenAILLM(
base_url=base_url,
api_key="something",
model=model,
generation_kwargs={
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
},
# thinking can take some time...
timeout=10 * 60,
generation_kwargs=generation_kwargs,
),
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
input_batch_size=10,
num_generations=num_generations,
)

return pipeline
Expand Down Expand Up @@ -85,13 +94,11 @@ def build_distilabel_pipeline(
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Temperature for generation",
)
parser.add_argument(
"--top-p",
type=float,
default=0.9,
help="Top-p value for generation",
)
parser.add_argument(
Expand All @@ -100,6 +107,12 @@ def build_distilabel_pipeline(
default=8192,
help="Maximum number of new tokens to generate",
)
parser.add_argument(
"--num-generations",
type=int,
default=1,
help="Number of generations per problem",
)
parser.add_argument(
"--hf-output-dataset",
type=str,
Expand All @@ -120,7 +133,7 @@ def build_distilabel_pipeline(
print()

print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split).select(range(50))
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
print("Dataset loaded!")

pipeline = build_distilabel_pipeline(
Expand All @@ -130,10 +143,11 @@ def build_distilabel_pipeline(
temperature=args.temperature,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
num_generations=args.num_generations,
)

print("Running generation pipeline...")
distiset = pipeline.run(dataset=dataset, dataset_batch_size=5000)
distiset = pipeline.run(dataset=dataset, use_cache=False)
print("Generation pipeline finished!")

if args.hf_output_dataset:
Expand Down

0 comments on commit a90b996

Please sign in to comment.