mirror of
https://github.com/huggingface/open-r1.git
synced 2026-06-24 01:54:06 +00:00
Compare commits
2 commits
main
...
vllm-serve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e128cd5edc | ||
|
|
4c77b66ba4 |
6 changed files with 18 additions and 18 deletions
12
README.md
12
README.md
|
|
@ -169,15 +169,21 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
|||
--config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
|
||||
```
|
||||
|
||||
and in another shell session:
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=7 trl vllm-serve --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> The chat template used in the distilled DeepSeek models omits the contents of the reasoning block within the `<think>` and `</think>` tags. It also prefills the assistant response with `<think>` which interferes with the format reward function. To handle that, it is important to override the chat template as done in e.g. [recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml](./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml).
|
||||
|
||||
|
||||
We provide a minimal reproducible experiment using GRPO for mathematical reasoning, referencing the approach from [SimpleRL-Reason](https://hkust-nlp.notion.site/simplerl-reason) which uses a 7B model trained on 8K examples. Running this on 8 H100 80G GPU takes about 3 hours:
|
||||
We provide a minimal reproducible experiment using GRPO for mathematical reasoning, referencing the approach from [SimpleRL-Reason](https://hkust-nlp.notion.site/simplerl-reason) which uses a 7B model trained on 8K examples. Running this on 8 H100 80G GPU takes about 2 hours:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
|
||||
--num_processes=7 src/open_r1/grpo.py \
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
|
||||
--num_processes=4 src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
|
|||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
|
|||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
|
|
|
|||
|
|
@ -12,8 +12,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
|
|||
beta: 0.01
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
|
|
|
|||
|
|
@ -12,8 +12,6 @@ system_prompt: "You are a helpful AI Assistant, designed to provided well-reason
|
|||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: true
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
|
|
@ -24,7 +22,7 @@ gradient_checkpointing_kwargs:
|
|||
hub_model_id: Qwen-2.5-7B-Simple-RL
|
||||
hub_strategy: every_save
|
||||
learning_rate: 3.0e-06
|
||||
log_completions: true
|
||||
log_completions: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 5
|
||||
|
|
@ -37,8 +35,8 @@ num_generations: 7
|
|||
num_train_epochs: 1
|
||||
output_dir: data/Qwen-2.5-7B-Simple-RL
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 16
|
||||
per_device_train_batch_size: 16
|
||||
per_device_eval_batch_size: 28
|
||||
per_device_train_batch_size: 28
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from datasets import load_dataset
|
|||
from transformers import set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from accelerate import PartialState
|
||||
from open_r1.configs import GRPOConfig
|
||||
from open_r1.rewards import (
|
||||
accuracy_reward,
|
||||
|
|
@ -189,11 +190,12 @@ def main(script_args, training_args, model_args):
|
|||
prompt.append({"role": "user", "content": example["problem"]})
|
||||
return {"prompt": prompt}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
with PartialState().main_process_first():
|
||||
dataset = dataset.map(make_conversation, desc="Formatting conversation")
|
||||
|
||||
for split in dataset:
|
||||
if "messages" in dataset[split].column_names:
|
||||
dataset[split] = dataset[split].remove_columns("messages")
|
||||
for split in dataset:
|
||||
if "messages" in dataset[split].column_names:
|
||||
dataset[split] = dataset[split].remove_columns("messages")
|
||||
|
||||
logger.info("*** Initializing model kwargs ***")
|
||||
torch_dtype = (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue