Compare commits

...

2 commits

Author SHA1 Message Date
Quentin Gallouédec
e128cd5edc it takes 2 hours 2025-03-24 23:43:28 +00:00
Quentin Gallouédec
4c77b66ba4 remove unsued params 2025-03-24 23:40:17 +00:00
6 changed files with 18 additions and 18 deletions

View file

@ -169,15 +169,21 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
--config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
```
and in another shell session:
```shell
CUDA_VISIBLE_DEVICES=7 trl vllm-serve --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
```
> [!WARNING]
> The chat template used in the distilled DeepSeek models omits the contents of the reasoning block within the `<think>` and `</think>` tags. It also prefills the assistant response with `<think>` which interferes with the format reward function. To handle that, it is important to override the chat template as done in e.g. [recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml](./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml).
We provide a minimal reproducible experiment using GRPO for mathematical reasoning, referencing the approach from [SimpleRL-Reason](https://hkust-nlp.notion.site/simplerl-reason) which uses a 7B model trained on 8K examples. Running this on 8 H100 80G GPU takes about 3 hours:
We provide a minimal reproducible experiment using GRPO for mathematical reasoning, referencing the approach from [SimpleRL-Reason](https://hkust-nlp.notion.site/simplerl-reason) which uses a 7B model trained on 8K examples. Running this on 8 H100 80G GPU takes about 2 hours:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
--num_processes=4 src/open_r1/grpo.py \
--config recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
```

View file

@ -13,8 +13,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true

View file

@ -11,8 +11,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true

View file

@ -12,8 +12,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
beta: 0.01
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true

View file

@ -12,8 +12,6 @@ system_prompt: "You are a helpful AI Assistant, designed to provided well-reason
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
@ -24,7 +22,7 @@ gradient_checkpointing_kwargs:
hub_model_id: Qwen-2.5-7B-Simple-RL
hub_strategy: every_save
learning_rate: 3.0e-06
log_completions: true
log_completions: false
log_level: info
logging_first_step: true
logging_steps: 5
@ -37,8 +35,8 @@ num_generations: 7
num_train_epochs: 1
output_dir: data/Qwen-2.5-7B-Simple-RL
overwrite_output_dir: true
per_device_eval_batch_size: 16
per_device_train_batch_size: 16
per_device_eval_batch_size: 28
per_device_train_batch_size: 28
push_to_hub: true
report_to:
- wandb

View file

@ -24,6 +24,7 @@ from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from accelerate import PartialState
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
@ -189,11 +190,12 @@ def main(script_args, training_args, model_args):
prompt.append({"role": "user", "content": example["problem"]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
with PartialState().main_process_first():
dataset = dataset.map(make_conversation, desc="Formatting conversation")
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (