Compare commits

...

63 commits

Author SHA1 Message Date
edbeeching
f08b559e00 slurm 2025-03-20 08:05:07 +00:00
edbeeching
a5668ec040 add default model revision 2025-03-20 08:02:11 +00:00
edbeeching
36c3867811 pushing last changes while I windown this PR 2025-03-20 08:01:34 +00:00
edbeeching
41f41d80f8 fixes for gradient checkpointing, profiling context 2025-03-18 08:02:11 +00:00
edbeeching
4f94c356b6 adds option for a mock remote model 2025-03-18 07:58:03 +00:00
edbeeching
1b074940bf add remote grpo exp configs 2025-03-18 07:57:10 +00:00
edbeeching
0e4685f6c1 adds option to not use ref model 2025-03-16 19:45:13 +00:00
lewtun
6e275af99e
Fix buffer (#508) 2025-03-16 13:54:12 +01:00
edbeeching
bf4642b4ea style 2025-03-14 08:31:42 +00:00
edbeeching
c6ab52a0cf cleaning up grpo PR 2025-03-14 08:22:11 +00:00
edbeeching
93d8bc5aba Merge branch 'main' into faster-grpo-trainer 2025-03-13 21:59:09 +00:00
edbeeching
e922330623 temp dumping of samples to json 2025-03-13 10:02:25 +00:00
edbeeching
fd62cf290e grpo configs 2025-03-11 07:48:06 +00:00
edbeeching
1c4efd5bd3 add clipping, fix bug with sampler 2025-03-10 23:00:06 +00:00
edbeeching
0c3e50f332 disable map progress bar 2025-03-10 14:31:28 +00:00
edbeeching
c69164a573 move ref logprob 2025-03-10 13:58:05 +00:00
edbeeching
a4004f658e fix lewis bugs 2025-03-10 10:03:37 +00:00
Lewis Tunstall
5fdbcd5a20 Restore Liger 2025-03-09 12:36:55 +00:00
Lewis Tunstall
4fa226ba1e Make checkpoint dir customisatble 2025-03-09 11:53:29 +00:00
Lewis Tunstall
96546e7721 Fix 2025-03-08 19:22:58 +00:00
Lewis Tunstall
a3d1f26715 Refactor preparation 2025-03-08 19:16:43 +00:00
lewtun
14d75bf0d5
Merge branch 'main' into faster-grpo-trainer 2025-03-08 15:41:31 +01:00
Lewis Tunstall
f4fe3550b6 Fix reward devices 2025-03-08 13:46:51 +00:00
Lewis Tunstall
b47880b1e2 Add local logging 2025-03-08 12:56:58 +00:00
Lewis Tunstall
10a70dfa42 Refactor model load 2025-03-08 12:49:20 +00:00
Lewis Tunstall
7d470d02d4 Align logging 2025-03-07 19:59:56 +00:00
Lewis Tunstall
dcf0af62e2 Fix evals 2025-03-07 18:59:24 +00:00
Lewis Tunstall
3d0e39d5d6 Add Slurm 2025-03-07 16:37:32 +00:00
Lewis Tunstall
76f8ae7a88 Tidy up 2025-03-07 16:05:17 +00:00
Lewis Tunstall
389befcf3e Add recipes 2025-03-07 15:39:25 +00:00
Lewis Tunstall
2f4f6fe4ef Revert 2025-03-07 15:02:14 +00:00
Lewis Tunstall
48a57bbe34 Revert 2025-03-07 14:56:20 +00:00
Lewis Tunstall
54afde3b67 Tune 2025-03-07 14:55:54 +00:00
Lewis Tunstall
731caf57ed Update slurm 2025-03-07 11:08:06 +00:00
Lewis Tunstall
95bac8aacf Add deepscaler recipe 2025-03-07 10:47:19 +00:00
Lewis Tunstall
dcc33dc710 Reduce batch size 2025-03-06 13:44:41 +00:00
Lewis Tunstall
8c7d764e5b Clean 2025-03-06 13:36:39 +00:00
Lewis Tunstall
ada8cecd54 Add Qwen config 2025-03-06 09:41:11 +00:00
Lewis Tunstall
9b7aa79b0e Add remote script 2025-03-04 15:52:16 +00:00
Lewis Tunstall
25e0b07feb Tune params 2025-03-04 15:41:23 +00:00
Lewis Tunstall
db2501d531 Fix checkpoint 2025-03-04 10:34:25 +00:00
lewtun
def83fc6af
Merge branch 'main' into faster-grpo-trainer 2025-03-03 17:34:00 +01:00
edbeeching
00b1f61c01 fix job launcher 2025-02-28 21:34:47 +00:00
edbeeching
67fb66af13 adds smollm grpo config for replication 2025-02-28 21:09:40 +00:00
edbeeching
9de588449a Merge branch 'main' into faster-grpo-trainer 2025-02-27 14:45:05 +00:00
edbeeching
0030447af5 save configs wip 2025-02-26 09:07:57 +00:00
edbeeching
c775de3fd0 save wip 2025-02-26 09:06:38 +00:00
edbeeching
0db1912bdc adds logging to remote grpo 2025-02-22 22:01:16 +00:00
edbeeching
09628da5d3 Merge branch 'main' into faster-grpo-trainer 2025-02-22 21:11:36 +00:00
edbeeching
55a451a813 sampler 2025-02-22 12:31:11 +00:00
edbeeching
f68c27bdf3 save wip 2025-02-21 15:06:58 +00:00
edbeeching
f50658e7c8 save WIP 2025-02-21 15:06:26 +00:00
edbeeching
875628838a adds ref model offload 2025-02-21 09:59:39 +00:00
edbeeching
3e37bf1361 changes the completion attn mask 2025-02-21 09:29:44 +00:00
edbeeching
420d72a7da adds liger kernel support 2025-02-21 08:16:55 +00:00
edbeeching
382a0c7890 style 2025-02-20 21:55:35 +00:00
edbeeching
5d213c48b1 adds remote model, job launch, switch to sglang server with weight sync 2025-02-20 21:44:08 +00:00
edbeeching
12be29f08b save WIP 2025-02-20 11:04:26 +00:00
edbeeching
b1394e542e hacky fixes for memory issues 2025-02-20 07:34:24 +00:00
edbeeching
fbe3b07d56 patches to get vllm working in ddp setting 2025-02-19 21:28:10 +00:00
edbeeching
dcdcebaaac adds licence, world size patch for vllm 2025-02-19 10:52:45 +00:00
edbeeching
ed9554ff54 adds weight sync 2025-02-19 10:18:23 +00:00
edbeeching
38e350d3a2 make it run 2025-02-19 09:48:00 +00:00
39 changed files with 3998 additions and 11 deletions

BIN
.litellm_cache/cache.db Normal file

Binary file not shown.

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}"
dataset_name: agentica-org/DeepScaleR-Preview-Dataset
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: DeepSeek-R1-Distill-Qwen-1.5B-GRPO
hub_strategy: every_save
learning_rate: 5.0e-07
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 8
num_train_epochs: 1
output_dir: data/DeepSeek-R1-Distill-Qwen-1.5B-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 8
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- tag_count
- format
reward_weights:
- 1.0
- 1.0
- 1.0
save_strategy: "steps"
save_steps: 0.2
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.8
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.00
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 16000
max_steps: -1
num_train_epochs: 0.1
num_generations: 16
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.00
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
beta: 0.04
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 100
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.8
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 2
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.01
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 16000
max_steps: -1
num_train_epochs: 0.1
num_generations: 14
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.01
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
beta: 0.04
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 100
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.8
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-RGRPO-v00.01
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 32768
max_steps: -1
num_train_epochs: 1.0
num_generations: 16
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-RGRPO-v00.01
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
beta: 0.04
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 100
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,43 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: AI-MO/NuminaMath-TIR
dataset_configs:
- all
num_processes: 1
ddp_find_unused_parameters: false
# GRPO trainer config
# use_vllm: true
bf16: true
# ref_model_url: http://127.0.0.1:8000
do_eval: false
eval_strategy: "no"
eval_steps: 100
gradient_accumulation_steps: 2
gradient_checkpointing: false
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: DeepSeek-R1-Distill-Qwen-1.5-GRPO
hub_strategy: every_save
learning_rate: 1.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 512
max_steps: -1
num_train_epochs: 1
output_dir: data/DeepSeek-R1-Distill-Qwen-1.5-GRPO-v001
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-1.5B-Instruct-GRPO
hub_model_revision: v00.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: 1000
num_generations: 16
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Instruct-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,65 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-1.5B-Instruct-RGRPO
hub_model_revision: v01.02
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Instruct-RGRPO_v01.02
overwrite_output_dir: true
per_device_train_batch_size: 4
remote_gen_model_url: 26.0.164.45
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: agentica-org/DeepScaleR-Preview-Dataset
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-GRPO
hub_model_revision: v01.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: 1000
num_generations: 16
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-GRPO
hub_model_revision: v00.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
hub_model_revision: v01.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.00
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.0
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
hub_model_revision: v01.01
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.01
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.0
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
hub_model_revision: v01.02
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.02
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,60 @@
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
model_revision: v06.11-step-000004005
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python_decontaminated
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- lcb
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
hub_model_revision: v00.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 8192
max_steps: 1000
num_generations: 16
num_train_epochs: 1
output_dir: data/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- code
reward_weights:
- 1.0
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,59 @@
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
model_revision: v00.08-step-000001280
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python-10k_decontaminated
dataset_configs:
- all
num_processes: 7
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
hub_model_revision: v00.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 31744
max_steps: -1
num_train_epochs: 5
num_generations: 7
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO-v00.00
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
report_to:
- wandb
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 25
callbacks:
- push_to_hub_revision
benchmarks:
- lcb
reward_funcs:
# - code
- code_format
reward_weights:
# - 1.0
- 0.1

View file

@ -0,0 +1,59 @@
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
model_revision: v00.08-step-000001280
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python-10k_decontaminated
dataset_configs:
- all
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
hub_model_revision: v00.00_remote
hub_strategy: every_save
learning_rate: 1.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 30700
max_steps: -1
num_train_epochs: 5
num_generations: 16
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO-v00.00_remote
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 25
callbacks:
- push_to_hub_revision
benchmarks:
- lcb
use_liger: true
reward_funcs:
- code
- code_format
reward_weights:
- 1.0
- 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
model_revision: v02.12-step-000003170
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python_decontaminated-tested
dataset_configs:
- all
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-RGRPO
hub_model_revision: v01.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 30000
max_steps: -1
num_train_epochs: 1
num_generations: 16
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-RGRPO-v01.00
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
seed: 42
warmup_ratio: 0.1
beta: 0.01
remote_gen_model_url: 26.0.165.131
num_iterations: 1
# Saving and eval callbacks
# save_strategy: "steps"
# save_steps: 25
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - lcb
reward_funcs:
- code
- code_format
reward_weights:
- 1.0
- 0.1

View file

@ -0,0 +1,65 @@
# Model arguments
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 7
ddp_find_unused_parameters: false
# GRPO trainer config
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 64
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-FGRPO
hub_model_revision: v05.00
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 7168
max_steps: -1
num_train_epochs: 0.5
num_generations: 16
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-FGRPO-v05.00
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
beta: 0.04
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.02
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 10
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,63 @@
# Model arguments
model_name_or_path: open-r1/SMOLLM_I8k-GR2-deepseek
model_revision: main-step-000000300
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.01
hub_strategy: every_save
learning_rate: 1.0e-04
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 7168
max_steps: -1
num_train_epochs: 0.5
num_generations: 16
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.01
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 8
push_to_hub: true
beta: 0.0
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.02
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 10
callbacks:
- push_to_hub_revision

View file

@ -0,0 +1,70 @@
# Model arguments
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.19
hub_strategy: every_save
learning_rate: 1.0e-04
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 7172
max_steps: -1
num_train_epochs: 1.0
num_generations: 16
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.19
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 8
push_to_hub: true
beta: 0.01
remote_gen_model_url: 26.0.160.225
num_iterations: 4
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.02
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 10
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500
# - aime24
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: HuggingFaceTB/SmolLM2-135M-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.8
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/SmolLM2-135M-Instruct-GRPO-v00.01
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 2048
max_steps: -1
num_train_epochs: 0.1
num_generations: 4
output_dir: data/open-r1/SmolLM2-135M-Instruct-GRPO-v00.01
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 8
push_to_hub: true
beta: 0.04
remote_gen_model_url: 0.0.0.0
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 100
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,47 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false,
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": false,
"number_checkpoints": null,
"synchronize_checkpoint_boundary": false,
"profile": false
}
}

View file

@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View file

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_config_file: recipes/accelerate_configs/deepspeed3_offload.json
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: fal

View file

@ -0,0 +1,30 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
offload_param:
device: cpu
pin_memory: true
activation_checkpointing:
partition_activations: true
contiguous_memory_optimization: false
cpu_checkpointing: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View file

@ -0,0 +1,46 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-Math-7B
model_revision: main
torch_dtype: bfloat16
# Data training arguments
dataset_name: DigitalLearningGmbH/MATH-lighteval
dataset_configs:
- train
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen-2.5-7B_Base_Math_smalllr_remote_model
hub_strategy: every_save
learning_rate: 3.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_train_epochs: 1
output_dir: data/Qwen-2.5-7B_Base_Math_smalllr_remote_model
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
ref_model_url: http://26.0.163.127:30010

250
scripts/faster_grpo.py Normal file
View file

@ -0,0 +1,250 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
from dataclasses import dataclass, field
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
)
from open_r1.trainers.faster_grpo_trainer import FastGRPOTrainer, FastGRPOConfig
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
logger = logging.getLogger(__name__)
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
#############################
# Initialize the Async GRPO trainer
#############################
trainer = FastGRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
callbacks=get_callbacks(training_args, model_args),
)
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
#############
# push to hub
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, FastGRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

301
scripts/remote_grpo.py Normal file
View file

@ -0,0 +1,301 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GRPO trainer to train on N + 1 nodes, with 1 node allocated for generation.
Usage:
For training, run:
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml scripts/remote_grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote.yaml
This will automatically spin up an SGLang server on a separate node and use it for generation.
For development, first spin up an SGLang sever on a separate node:
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-1.5B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
Then run training by providing the IP address of the server:
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml scripts/remote_grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote.yaml \
--remote_gen_model_url ip-26-0-160-103
"""
import logging
import os
import sys
from dataclasses import dataclass, field
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.rewards import (
accuracy_reward,
code_reward,
format_reward,
get_code_format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
tag_count_reward,
)
from open_r1.utils import get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import ModelConfig, ScriptArguments, TrlParser
from open_r1.trainers.remote_grpo_trainer import RemoteGRPOTrainer, RemoteGRPOConfig
logger = logging.getLogger(__name__)
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
code_language (`str`):
Language for code format reward.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format", "tag_count"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
code_language: str = field(
default="python",
metadata={
"help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages",
"choices": ["python", "javascript", "r", "java", "bash"],
},
)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, training_args)
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
"code": code_reward,
"code_format": get_code_format_reward(language=script_args.code_language),
"tag_count": tag_count_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
# Format into conversation
def make_conversation(example):
prompt = []
if training_args.system_prompt is not None:
prompt.append({"role": "system", "content": training_args.system_prompt})
prompt.append({"role": "user", "content": example["problem"]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
#############################
# Initialize the GRPO trainer
#############################
trainer = RemoteGRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
callbacks=get_callbacks(training_args, model_args),
processing_class=tokenizer,
)
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
# trainer.create_model_card(**kwargs) # Bug: needs fixing with TRL helper methods
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
#############
# push to hub
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, RemoteGRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

View file

@ -66,8 +66,8 @@ _deps = [
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
"torch==2.5.1",
"transformers==4.49.0",
"trl @ git+https://github.com/huggingface/trl.git@69ad852e5654a77f1695eb4c608906fe0c7e8624",
"transformers==4.48.3", # Must pin for SGLang
"trl @ git+https://github.com/huggingface/trl.git@e3244d2d096ff1e2e248c931d06d39e165e20623",
"vllm==0.7.2",
"wandb>=0.19.1",
]
@ -107,6 +107,7 @@ install_requires = [
deps["math-verify"],
deps["liger_kernel"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["peft"],
deps["safetensors"],
deps["sentencepiece"],
deps["transformers"],

23
slurm/launch_sglang.slurm Normal file
View file

@ -0,0 +1,23 @@
#!/bin/bash
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --partition=hopper-prod
#SBATCH --output=/fsx/open-r1/logs/%x-%j.out
#SBATCH --err=/fsx/open-r1/logs/%x-%j.err
# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
set -x -e
source ~/.bashrc
source openr1/bin/activate
module load cuda/12.4
echo Starting sglang server...
MODEL_ID=$1
REVISION=$2
PORT=$3
NUM_GPUS=$(nvidia-smi --list-gpus | wc -l)
python3 -m sglang.launch_server --model-path $MODEL_ID --revision $REVISION --port=$PORT --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=$NUM_GPUS

93
slurm/train_remote.slurm Normal file
View file

@ -0,0 +1,93 @@
#!/bin/bash
#SBATCH --job-name=open-r1-sft
#SBATCH --ntasks-per-node=1
#SBATCH --exclusive
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod # Adjust this for your cluster
#SBATCH --output=./logs/%x-%j.out
#SBATCH --err=./logs/%x-%j.err
"""Usage:
sbatch --job-name=remote-grpo --nodes=1 slurm/train_remote.slurm Qwen2.5-1.5B-Instruct grpo remote zero3
"""
# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
module load cuda/12.4
set -x -e
source ~/.bashrc
source openr1/bin/activate
echo "START TIME: $(date)"
MODEL=$1
TASK=$2
CONFIG_SUFFIX=$3
ACCELERATOR=$4
OPTIONAL_ARGS=$5
# Training setup
NUM_NODES=$SLURM_NNODES
GPUS_PER_NODE=8
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
# Split the string into individual arguments
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
for arg in "${ARGS[@]}"; do
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
# Extract the value after the equals sign
GRAD_ACC_STEPS="${arg#*=}"
break # Exit the loop once we find the desired argument
fi
done
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
export CMD=" \
scripts/remote_grpo.py --config $CONFIG_FILE $OPTIONAL_ARGS
"
export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
--gradient_accumulation_steps $GRAD_ACC_STEPS \
--num_machines $NUM_NODES \
--num_processes $WORLD_SIZE \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--machine_rank \$SLURM_PROCID \
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
--max_restarts 1 \
--role \$(hostname -s): \
--tee 3 \
"
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
echo "END TIME: $(date)"

View file

@ -423,10 +423,12 @@ def run_async_from_sync(scripts: list[str], language: str) -> list[float]:
async def run_async(scripts: list[str], language: str) -> list[float]:
# Create the sandbox by hand, currently there's no context manager for this version
sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
sbx = await AsyncSandbox.create(timeout=60, request_timeout=5)
# Create a list of tasks for running scripts concurrently
tasks = [run_script(sbx, script, language) for script in scripts]
MAX_TASKS_PER_PROCESS = 2 # E2B has a limit of 20 concurrent requests, assume 1 noe, 8 processes, this is 2 per process (20//8 = 2)
semaphore = asyncio.Semaphore(MAX_TASKS_PER_PROCESS)
tasks = [run_script(sbx, script, language, semaphore) for script in scripts]
# Wait for all tasks to complete and gather their results as they finish
results = await asyncio.gather(*tasks)
@ -438,9 +440,10 @@ async def run_async(scripts: list[str], language: str) -> list[float]:
return rewards
async def run_script(sbx: AsyncSandbox, script: str, language: str) -> float:
execution = await sbx.run_code(script, language=language)
try:
return float(execution.text)
except (TypeError, ValueError):
return 0.0
async def run_script(sbx: AsyncSandbox, script: str, language: str, semaphore) -> float:
async with semaphore: # Limit concurrency
execution = await sbx.run_code(script, language=language)
try:
return float(execution.text)
except (TypeError, ValueError):
return 0.0

View file

@ -0,0 +1,673 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# great reference: https://github.com/vllm-project/vllm/issues/11400
import contextlib
import functools
import gc
import math
import os
import tempfile
import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing import reduction
from typing import Callable, Optional, Union
from unittest.mock import patch
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorWithPadding,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_liger_kernel_available
import trl
from accelerate import Accelerator
from accelerate.utils import gather_object
from open_r1.trainers.job_launcher import SGLangSlurmJobLauncher
from open_r1.trainers.remote_model import RemoteModel
from trl.data_utils import is_conversational, maybe_apply_chat_template
from trl.trainer.utils import pad, selective_log_softmax
from vllm import LLM, SamplingParams
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
if is_wandb_available():
import wandb
@contextlib.contextmanager
def profiling_context(instance, name):
"""
A context manager function for profiling a block of code.
Can also be used as a decorator.
"""
start_time = time.perf_counter()
yield
end_time = time.perf_counter()
duration = end_time - start_time
if "wandb" in instance.args.report_to and wandb.run is not None and instance.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {instance.__class__.__name__}.{name}": duration})
def profiling_decorator(func):
"""
Decorator to profile a function and log execution time using profiling_context.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with profiling_context(self, func.__name__):
return func(self, *args, **kwargs)
return wrapper
from accelerate import Accelerator
if is_wandb_available():
import wandb
def exact_div(a, b, custom_error_message=""):
q = a // b
if a != q * b:
raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}")
return q
# TODO: add the shared options with a mixin to reduce code duplication
@dataclass
class FastGRPOConfig(trl.GRPOConfig):
"""
args for callbacks, benchmarks etc
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
system_prompt: Optional[str] = field(
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
)
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
remote_gen_model_url: str = field(
default="26.0.165.24",
)
remote_gen_model_port: str = field(
default="30010",
)
remote_gen_model_n_gpus: str = field(
default=8,
)
class FastGRPOTrainer(Trainer):
_tag_names = ["trl", "fast_grpo"]
def __init__(
self,
model: str, # only accept str for now
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: FastGRPOConfig,
train_dataset: Dataset,
processing_class: Optional[PreTrainedTokenizerBase] = None,
data_collator: Optional[DataCollatorWithPadding] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
) -> None:
self.args = args
self.reward_funcs = reward_funcs
# Reward weights (move this logic to post_init of config?)
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
self.reward_weights = args.reward_weights
else:
self.reward_weights = ([1.0] * len(reward_funcs),)
# start the remote model so it has time to warmup while we load the local model(s)
if self.args.remote_gen_model_url is None:
self.sglang_job_launcher = SGLangSlurmJobLauncher(
model, num_gpus=self.args.remote_gen_model_n_gpus, sglang_port=self.args.remote_gen_model_port
)
self.sglang_job_launcher.submit_job()
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
)
model_str = model
model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs)
# offload to cpu
ref_model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs) # .to("cpu")
self.model = model
self.ref_model = ref_model
if self.args.use_liger_kernel:
if is_liger_kernel_available():
from liger_kernel.transformers import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=self.model)
_apply_liger_kernel_to_instance(model=self.ref_model)
else:
raise ImportError(
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
"Please install it with `pip install liger-kernel`"
)
# Processing class
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
self.processing_class = processing_class
self.train_dataset = train_dataset
if data_collator is not None:
raise ValueError("")
def data_collator(features): # No data collation is needed in GRPO
return features
self.data_collator = data_collator
local_dataloader_batch_size = exact_div(
args.per_device_train_batch_size * args.gradient_accumulation_steps,
args.num_generations,
"per_device_train_batch_size * gradient_accumulation_steps must >= num_generations to remain on policy",
)
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
self.accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
self.train_dataset_len = len(self.train_dataset)
num_total_samples = int(self.args.num_train_epochs * self.train_dataset_len)
self.total_steps_per_device = num_total_samples // (
local_dataloader_batch_size * self.accelerator.num_processes
)
self.create_optimizer_and_scheduler(num_training_steps=self.total_steps_per_device)
#########
### trainer specifics
#########
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = TrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# Create distant repo and output directory if needed
self.hub_model_id = None
if self.args.push_to_hub:
self.init_hf_repo()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
self.backup_model = None
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
#########
### setup dataloader
#########
self.dataloader = DataLoader(
self.train_dataset,
batch_size=local_dataloader_batch_size,
shuffle=True,
collate_fn=self.data_collator,
drop_last=True,
)
torch.manual_seed(args.seed)
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
self.model = self._enable_gradient_checkpointing(self.model, self.args)
self.model, self.optimizer, self.dataloader = self.accelerator.prepare(
self.model, self.optimizer, self.dataloader
)
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
# connect to a remote sglang model
if self.args.remote_gen_model_url is None:
self.sglang_job_launcher.wait_for_server()
self.args.remote_gen_model_url = self.sglang_job_launcher.get_remote_ip()
self.remote_model = RemoteModel(
self.args.remote_gen_model_url, self.args.remote_gen_model_port, self.processing_class.eos_token_id
)
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: FastGRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False
model.gradient_checkpointing_enable()
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
model.enable_input_require_grads()
return model
def print_gpu_memory_usage(self):
if torch.cuda.is_available():
gpu_memory_allocated = torch.cuda.memory_allocated()
gpu_memory_reserved = torch.cuda.memory_reserved()
print(f"GPU memory allocated: {gpu_memory_allocated / (1024**3):.2f} GB")
print(f"GPU memory reserved: {gpu_memory_reserved / (1024**3):.2f} GB")
else:
print("CUDA is not available.")
# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
@torch.no_grad()
@profiling_decorator
def _prepare_batch(self, batch):
"""
This will:
- generate k samples for each problem
- using internal reward model(s) to get rewards
"""
device = self.accelerator.device
prompts = [x["prompt"] for x in batch]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in batch]
prompt_inputs = self.processing_class(prompts_text)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
# add cuda clear cache here and a sleep
all_outputs = self.remote_model.generate(
prompt_ids,
max_new_tokens=self.args.max_completion_length,
temperature=self.args.temperature,
num_generations=self.args.num_generations,
)
# all_outputs = self.gen_vllm.generate(prompts_text, sampling_params=self.sampling_params, use_tqdm=True)
completion_ids = [example["completion_ids"] for example in all_outputs]
# completion_ids = []
# for outputs in all_outputs:
# for output in outputs.outputs:
# completion_ids.append(output.token_ids)
# Decode the generated completions
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
repeated_prompts = []
for prompt in prompts:
repeated_prompts.extend([prompt] * self.args.num_generations)
repeated_prompt_texts = []
for prompt in prompts_text:
repeated_prompt_texts.extend([prompt] * self.args.num_generations)
if is_conversational(batch[0]):
completions = []
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
for (
i,
reward_func,
) in enumerate(self.reward_funcs):
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in batch[0] if key not in ["prompt", "completion"]]
reward_kwargs = defaultdict(list)
for example in batch:
for key in keys:
reward_kwargs[key].extend([example[key]] * self.args.num_generations)
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions, **reward_kwargs)
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
# calculate the advantages, the prompt is all on the same device to no need to gather here
grouped_rewards = rewards.sum(-1).view(len(prompts), self.args.num_generations)
EPS = 1e-4
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
grouped_rewards.std(-1, keepdim=True) + EPS
)
advantages = grouped_advantages.flatten().tolist()
# build batch as list of dicts
examples = []
for i, prompt in enumerate(repeated_prompt_texts):
example = {
"prompt": prompt,
"prompt_ids": prompt_ids[i // self.args.num_generations],
"completion": completions_text[i],
"completion_ids": completion_ids[i],
"advantages": advantages[i],
"rewards": rewards[i],
}
examples.append(example)
return examples
@profiling_decorator
def _sync_weights(self):
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
start = time.time()
with tempfile.TemporaryDirectory(dir="/fsx/edward/work/open-r1/data/") as temp_dir_path:
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.save_pretrained(temp_dir_path)
self.remote_model.load_weights_from_path(temp_dir_path)
print("weight sync took: ", time.time() - start)
self.accelerator.wait_for_everyone()
def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
):
start_step = 1 # todo, set this when we resume + load model, opt state etc
if self.args.logging_steps is not None:
if self.args.logging_steps < 1:
self.state.logging_steps = math.ceil(self.state.max_steps * self.args.logging_steps)
else:
self.state.logging_steps = self.args.logging_steps
if self.args.save_steps is not None:
if self.args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * self.args.save_steps)
else:
self.state.save_steps = self.args.save_steps
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
self.state.global_step = 0
self.state.max_steps = self.total_steps_per_device
self.state.num_train_epochs = self.args.num_train_epochs
def repeat_generator():
while True:
yield from self.dataloader
iter_dataloader = iter(repeat_generator())
self.model.train()
@torch.no_grad()
def mini_batch_collator(examples):
device = self.accelerator.device
prompt_ids = [torch.LongTensor(example["prompt_ids"]) for example in examples]
completion_ids = [torch.LongTensor(example["completion_ids"]) for example in examples]
ref_per_token_logps = [torch.Tensor(example["ref_per_token_logps"]) for example in examples]
for logps, completion_id in zip(ref_per_token_logps, completion_ids):
assert len(logps) == len(completion_id), (
f"len(logps)={len(logps)} != len(completion_id)={len(completion_id)}"
)
pad_token_id = self.processing_class.pad_token_id
padded_prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
padded_completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
padd_ref_per_token_logps = pad(ref_per_token_logps, padding_value=0.0, padding_side="right")
if self.args.max_prompt_length is not None:
padded_prompt_ids = padded_prompt_ids[:, -self.args.max_prompt_length :]
# compute the masks
prompt_mask = (padded_prompt_ids != pad_token_id).long()
# Mask everything after the first EOS token
is_eos = padded_completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1)).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
advantages = torch.Tensor([example["advantages"] for example in examples])
return {
"prompt_ids": padded_prompt_ids.to(device),
"prompt_mask": prompt_mask.to(device),
"completion_ids": padded_completion_ids.to(device),
"completion_mask": completion_mask.to(device),
"advantages": advantages.to(device),
"ref_per_token_logps": padd_ref_per_token_logps.to(device),
}
device = self.accelerator.device
for step in range(start_step, self.total_steps_per_device + 1):
batch = next(iter_dataloader)
batch = self._prepare_batch(batch)
# TODO: log completions, rewards, etc
gen_dataset = Dataset.from_list(batch)
@torch.no_grad()
def compute_ref_logps(examples):
device = self.accelerator.device
prompt_ids = [torch.LongTensor(prompt_id) for prompt_id in examples["prompt_ids"]]
completion_ids = [torch.LongTensor(completion_id) for completion_id in examples["completion_ids"]]
completion_lengths = [len(c) for c in completion_ids]
pad_token_id = self.processing_class.pad_token_id
padded_prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
padded_completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
input_ids = torch.cat([padded_prompt_ids, padded_completion_ids], dim=1)
attention_mask = torch.cat(
[padded_prompt_ids != pad_token_id, padded_completion_ids != pad_token_id], dim=1
)
logits_to_keep = torch.tensor(completion_lengths).to(device)
logits_to_keep = padded_completion_ids.size(1)
with torch.inference_mode():
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, input_ids.to(device), attention_mask.to(device), logits_to_keep
)
ref_per_token_logps = ref_per_token_logps.to("cpu")
examples["ref_per_token_logps"] = [
logprobs[:length] for logprobs, length in zip(ref_per_token_logps, completion_lengths)
]
return examples
self.ref_model = self.ref_model.to(device)
# precompute the ref logprobs and offload the model to cpu
gen_dataset = gen_dataset.map(
compute_ref_logps, batched=True, batch_size=self.args.per_device_train_batch_size
)
self.ref_model = self.ref_model.to("cpu")
# we could add some optimizations here like sorting the dataset by length to improve throughput, but we will keep it simple for now
mini_batch_dataloader = DataLoader(
gen_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True, # we technically don#t need to shuffle due to grad acc, but we may move to clipped loss later
drop_last=True,
collate_fn=mini_batch_collator,
)
# optimization
# stats for logging
losses = []
kls = []
with profiling_context(self, "train_step"):
for mini_batch in mini_batch_dataloader:
loss_metric, kl_metric = self._optimization_step(mini_batch)
losses.append(loss_metric)
kls.append(kl_metric)
self.lr_scheduler.step()
self.state.global_step += 1
self.state.epoch = step / self.total_steps_per_device # TODO, this is not correct
# logging stats
metrics = {}
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["loss"] = self.accelerator.gather_for_metrics(torch.Tensor(losses).to(device)).mean().item()
metrics["kl"] = self.accelerator.gather_for_metrics(torch.Tensor(kls).to(device)).mean().item()
# completions stats
completion_lengths = [len(c) for c in gen_dataset["completion_ids"]]
gathered_completion_lengths = self.accelerator.gather_for_metrics(
torch.Tensor(completion_lengths).to(device)
)
metrics["mean_completion_lengths"] = gathered_completion_lengths.mean().item()
metrics["max_completion_lengths"] = gathered_completion_lengths.max().item()
metrics["min_completion_lengths"] = gathered_completion_lengths.min().item()
# reward stats
rewards = gen_dataset["rewards"]
gathered_rewards = self.accelerator.gather_for_metrics(torch.Tensor(rewards).to(device))
reward_per_func = gathered_rewards.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
reward_func_name = reward_func.__name__
metrics[f"rewards/{reward_func_name}"] = reward_per_func[i].item()
metrics["reward"] = reward_per_func.sum().item()
self.log(metrics)
if self.args.log_completions and "wandb" in self.args.report_to:
import pandas as pd
prompts = gather_object(gen_dataset["prompt"])
completions = gather_object(gen_dataset["completion"])
# For logging
table = {
"step": [str(self.state.global_step)] * len(prompts),
"prompts": prompts,
"completion": completions,
"reward": gathered_rewards.sum(1).tolist(),
}
df = pd.DataFrame(table)
if wandb.run is not None and self.accelerator.is_main_process:
wandb.log({"completions": wandb.Table(dataframe=df)})
# sync weights to remote server
self._sync_weights()
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(self.model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(self.model, trial=None, metrics=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def _optimization_step(self, mini_batch) -> tuple[float, float]:
prompt_ids, prompt_mask = mini_batch["prompt_ids"], mini_batch["prompt_mask"]
completion_ids, completion_mask = mini_batch["completion_ids"], mini_batch["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
ref_per_token_logps = mini_batch["ref_per_token_logps"]
with self.accelerator.accumulate(self.model):
per_token_logps = self._get_per_token_logps(self.model, input_ids, attention_mask, logits_to_keep)
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)
advantages = mini_batch["advantages"]
# TODO: convert to clipped loss so we can multiple GRPO epochs
per_token_loss = -torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = per_token_loss + self.args.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
self.accelerator.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
del per_token_logps, per_token_kl, per_token_loss, loss
# force garbage collection and empty cache
gc.collect()
torch.cuda.empty_cache()
return loss.detach().item(), per_token_kl.mean().item()

View file

@ -0,0 +1,171 @@
import atexit
import os
import re
import subprocess
import time
# We need a special environment setup to launch vLLM from within Slurm training jobs.
# - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105
# - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269
user_home_directory = os.path.expanduser("~")
SLURM_PREFIX = [
"env",
"-i",
"bash",
"-c",
f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch --qos=high --output=/fsx/h4/logs/%x-%j.out --err=/fsx/h4/logs/%x-%j.err ",
]
class SGLangSlurmJobLauncher:
def __init__(
self,
model_id_or_path,
model_revision="main",
num_gpus=1,
sglang_port=30010,
slurm_script="slurm/launch_sglang.slurm",
check_interval=5,
):
"""
Initialize the job launcher.
:param slurm_script: Path to the SLURM script.
:param check_interval: Time interval (seconds) to check job status.
"""
self.slurm_script = slurm_script
self.job_id = None
self.node_name = None
self.check_interval = check_interval
self.model_id_or_path = model_id_or_path
self.model_revision = model_revision
self.num_gpus = num_gpus
self.sglang_port = sglang_port
# Register cleanup function to cancel job on exit
atexit.register(self.cleanup)
def submit_job(self):
"""Submits the SLURM job and extracts the job ID."""
cmd = SLURM_PREFIX.copy()
cmd_args = [
f"--gres=gpu:{self.num_gpus}",
self.slurm_script,
self.model_id_or_path,
self.model_revision,
str(self.sglang_port),
]
cmd[-1] += " " + " ".join(cmd_args)
try:
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
match = re.search(r"Submitted batch job (\d+)", result.stdout)
if match:
self.job_id = match.group(1)
print(f"Job submitted with ID: {self.job_id}")
else:
raise RuntimeError("Failed to retrieve job ID from sbatch output.")
except subprocess.CalledProcessError as e:
print(f"Error submitting job: {e.stderr}")
raise
def get_job_status(self):
"""Checks the job status using squeue."""
if not self.job_id:
raise ValueError("Job ID is not set. Submit the job first.")
result = subprocess.run(
["squeue", "--job", self.job_id, "--noheader"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if not result.stdout.strip():
return None # Job is no longer in queue
status = result.stdout.split()[4] # Typically, state is the 5th column
return status
def wait_for_job_to_start(self):
"""Waits for the job to start running and fetches its node."""
print("Waiting for job to start...")
while True:
status = self.get_job_status()
if status is None:
raise RuntimeError("Job disappeared from queue, it may have failed.")
if status == "R": # Running
print("Job is running. Fetching node information...")
self.node_name = self.get_node_name()
return
time.sleep(self.check_interval)
def get_node_name(self):
"""Gets the node where the job is running."""
result = subprocess.run(
["squeue", "--job", self.job_id, "--noheader", "--format=%N"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.stdout.strip():
return result.stdout.strip()
else:
raise RuntimeError("Failed to retrieve node name.")
def get_node_ip(self):
"""Retrieves the IP address of the node running the job."""
if not self.node_name:
raise ValueError("Node name is not set. Wait for the job to start first.")
result = subprocess.run(
["scontrol", "show", "node", self.node_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
match = re.search(r"NodeAddr=(\S+)", result.stdout)
if match:
return match.group(1)
else:
raise RuntimeError("Failed to retrieve node IP address.")
def launch(self):
"""Launches the job, waits for it to start, and retrieves the node IP."""
self.submit_job()
self.wait_for_job_to_start()
ip_address = self.get_node_ip()
print(f"Job is running on {self.node_name} with IP: {ip_address}")
self.ip_address = ip_address
return ip_address
def cleanup(self):
"""Cancels the SLURM job if it is still running."""
if self.job_id is not None:
print(f"Cleaning up: Cancelling job {self.job_id}...")
subprocess.run(["scancel", self.job_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print("Job cancelled.")
def __del__(self):
"""Ensure job cleanup when the instance is destroyed."""
self.cleanup()
if __name__ == "__main__":
from open_r1.trainers.remote_model import RemoteModel
launcher = SGLangSlurmJobLauncher("HuggingFaceTB/SmolLM2-135M-Instruct")
ip_address = launcher.launch()
launcher.ip_address
time.sleep(15)
remote_model = RemoteModel(f"{ip_address}", 30010)
remote_model.wait_for_server()
result = remote_model.generate([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
assert 0
print(result)

View file

@ -0,0 +1,697 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Iterator, Optional, Union
from trl.trainer.utils import disable_dropout_in_model
import torch
import transformers
from datasets import Dataset, IterableDataset, disable_progress_bars, enable_progress_bars
from datasets.utils.logging import set_verbosity_error, set_verbosity_info
from packaging import version
from torch.utils.data import DataLoader, Sampler
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_liger_kernel_available
import deepspeed
import trl
from accelerate.utils import broadcast_object_list, gather_object, is_peft_model
from open_r1.trainers.job_launcher import SGLangSlurmJobLauncher
from open_r1.trainers.remote_model import RemoteModel
from trl.data_utils import is_conversational, maybe_apply_chat_template
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_rich_available
from trl.models import create_reference_model, prepare_deepspeed
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.utils import exact_div, pad, print_prompt_completions_sample, selective_log_softmax
if is_liger_kernel_available():
from liger_kernel.transformers import AutoLigerKernelForCausalLM
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
if is_wandb_available():
import wandb
class RepeatBatchRandomSampler(Sampler):
def __init__(
self,
data_source,
batch_size: int = 1,
num_processes: int = 1,
repeat_count: int = 1,
seed: Optional[int] = None,
):
self.data_source = data_source
self.batch_size = batch_size
self.num_processes = num_processes
self.repeat_count = repeat_count
self.num_samples = len(data_source)
self.seed = seed
self.generator = torch.Generator() # Create a local random generator
if seed is not None:
self.generator.manual_seed(seed)
def __iter__(self):
indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
all_process_batch_size = self.batch_size * self.num_processes
indices = [indices[i : i + all_process_batch_size] for i in range(0, len(indices), all_process_batch_size)]
indices = [chunk for chunk in indices if len(chunk) == all_process_batch_size]
for chunk in indices:
for _ in range(self.repeat_count):
for index in chunk:
yield index
def __len__(self) -> int:
return self.num_samples * self.repeat_count
@dataclass
class RemoteGRPOConfig(trl.GRPOConfig):
"""
args for callbacks, benchmarks etc
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
checkpoint_dir: Optional[str] = field(
default="/fsx/h4/tmp/", metadata={"help": "The directory to save temporary checkpoints to."}
)
system_prompt: Optional[str] = field(
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
)
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
remote_gen_model_url: Optional[str] = field(
default=None,
)
remote_gen_model_port: str = field(
default="30010",
)
remote_gen_model_n_gpus: str = field(
default=8,
)
use_liger: bool = field(
default=True,
metadata={"help": "Whether to use Liger kernel for training."},
)
class RemoteGRPOTrainer(Trainer):
_tag_names = ["trl", "grpo"]
def __init__(
self,
model: Union[str, PreTrainedModel],
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: Optional[RemoteGRPOConfig] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
):
self.args = args
# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self.log_completions = args.log_completions
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
model = self._create_model_from_path(model_id, args)
disable_dropout_in_model(model)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
)
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
# This acts as a flag to indicate that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Reference model
if self.args.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_deepspeed_zero3_enabled():
self.ref_model = self._create_model_from_path(model_id, args)
disable_dropout_in_model(self.ref_model)
elif is_peft_model(model):
raise NotImplementedError("Peft is not supported")
else:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
self.reward_funcs = reward_funcs
# Reward weights
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError("The number of reward processing classes must match the number of reward functions.")
# TODO: test RMS and also wrap them in deepspeed
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
def data_collator(features): # No data collation is needed in GRPO
return features
self.batch_buffer = []
super().__init__(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
data_collator=data_collator,
)
ip_address = self.args.remote_gen_model_url
if self.args.remote_gen_model_url is None and self.accelerator.is_main_process:
# we launch a job from here, get the ip on main process and broadcast to others
# it would be better to move this to the start so the server warms up which the local model is being loaded
model_revision = args.model_init_kwargs.get("revision", "main")
self.sglang_job_launcher = SGLangSlurmJobLauncher(
model_id,
model_revision,
num_gpus=self.args.remote_gen_model_n_gpus,
sglang_port=self.args.remote_gen_model_port,
)
ip_address = self.sglang_job_launcher.launch()
# get the ip from main process and broadcast to others
gather_ip_address = broadcast_object_list([ip_address], 0)
self.args.remote_gen_model_url = gather_ip_address[0]
self.remote_model = RemoteModel(
self.args.remote_gen_model_url, self.args.remote_gen_model_port, self.processing_class.eos_token_id
)
self.remote_model.wait_for_server()
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Add tags to the model
self.model.add_model_tags(self._tag_names)
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if args.sync_ref_model:
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
def _get_train_sampler(self) -> Sampler:
"""
Return the train sampler.
Returns:
Sampler: The train sampler.
"""
if self.args.dataloader_num_workers != 0:
raise ValueError("dataloader_num_workers should not be greater than 0 for remote training")
return RepeatBatchRandomSampler(
data_source=self.train_dataset,
batch_size=self._train_batch_size,
repeat_count=self.args.num_generations * self.args.num_iterations,
num_processes=self.accelerator.num_processes,
seed=self.args.seed,
)
def _create_model_from_path(self, model_path: str, args: RemoteGRPOConfig) -> PreTrainedModel:
"""Creates a model from a path or model identifier."""
model_init_kwargs = args.model_init_kwargs or {}
# Handle torch dtype
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
if args.gradient_checkpointing:
model_init_kwargs["use_cache"] = False
# Create model
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return model
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: RemoteGRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(model):
model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
else:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
model.enable_input_require_grads()
return model
def _generate_and_score_completions(
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
prompts_to_log = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(prompts_text)
prompt_ids = prompt_inputs["input_ids"]
# sync weights here?
self._sync_weights()
with profiling_context(self, "remote_generate"):
all_outputs = self.remote_model.generate(
prompt_ids,
max_new_tokens=self.args.max_completion_length,
temperature=self.args.temperature,
num_generations=self.args.num_generations,
)
completion_ids = [example["completion_ids"] for example in all_outputs]
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
repeated_prompts = []
for prompt in prompts_to_log:
repeated_prompts.extend([prompt] * self.args.num_generations)
repeated_prompt_texts = []
for prompt in prompts_text:
repeated_prompt_texts.extend([prompt] * self.args.num_generations)
if is_conversational(inputs[0]):
completions_to_log = []
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions_to_log.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions_to_log = completions_text
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
with profiling_context(self, "rewards"):
for i, reward_func in enumerate(self.reward_funcs):
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = defaultdict(list)
for example in inputs:
for key in keys:
reward_kwargs[key].extend([example[key]] * self.args.num_generations)
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions_to_log, **reward_kwargs)
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
# if i == 0 and self.accelerator.is_main_process: # dump generations to a text file for debugging
# with open("python_code_completions2.jsonl", "a") as f:
# for i,(p, c) in enumerate(zip(repeated_prompts, completions_to_log)):
# data = {
# "prompt": p,
# "completion": c,
# }
# for k in reward_kwargs.keys():
# data[k] = reward_kwargs[k][i]
# f.write(json.dumps(data) + "\n")
# calculate the advantages, the prompt is all on the same device to no need to gather here
grouped_rewards = rewards.sum(-1).view(len(prompts_to_log), self.args.num_generations)
EPS = 1e-4
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
grouped_rewards.std(-1, keepdim=True) + EPS
)
advantages = grouped_advantages.flatten().tolist()
examples = []
for i, prompt in enumerate(repeated_prompt_texts):
example = {
"prompt": prompt,
"prompt_ids": prompt_ids[i // self.args.num_generations],
"completion": completions_text[i],
"completion_ids": completion_ids[i],
"advantages": advantages[i],
"rewards": rewards[i],
}
examples.append(example)
# Instead of logging metrics here, collect them
mode = "eval" if getattr(self, "control", None) and self.control.should_evaluate else "train"
device = self.accelerator.device
# Collect completion length metrics
completion_lengths = [len(example["completion_ids"]) for example in examples]
gathered_completion_lengths = self.accelerator.gather_for_metrics(torch.Tensor(completion_lengths).to(device))
self._metrics[mode]["mean_completion_lengths"].append(gathered_completion_lengths.mean().item())
self._metrics[mode]["max_completion_lengths"].append(gathered_completion_lengths.max().item())
self._metrics[mode]["min_completion_lengths"].append(gathered_completion_lengths.min().item())
# Collect reward metrics
rewards = torch.stack(
[
example["rewards"].to(device)
if isinstance(example["rewards"], torch.Tensor)
else torch.tensor(example["rewards"], device=device)
for example in examples
]
)
gathered_rewards = self.accelerator.gather_for_metrics(rewards)
reward_per_func = gathered_rewards.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
reward_func_name = reward_func.__name__
self._metrics[mode][f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics[mode]["reward"].append(reward_per_func.sum().item())
if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
prompts_to_log = gather_object([example["prompt"] for example in examples])
completions_to_log = gather_object([example["completion"] for example in examples])
if self.accelerator.is_main_process:
# if is_rich_available():
# # TODO: enable num_samples in TRL to avoid clogging logs
# print_prompt_completions_sample(
# prompts_to_log[:5],
# completions_to_log[:5],
# gathered_rewards.sum(1).tolist()[:5],
# self.state.global_step,
# )
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
import pandas as pd
# For logging
table = {
"step": [str(self.state.global_step)] * len(prompts_to_log),
"prompts": prompts_to_log,
"completion": completions_to_log,
"reward": gathered_rewards.sum(1).tolist(),
}
df = pd.DataFrame(table)
if wandb.run is not None and self.accelerator.is_main_process:
wandb.log({"completions": wandb.Table(dataframe=df)})
return examples
@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
if len(self.batch_buffer) > 0:
return self.batch_buffer.pop(0)
inputs = self._generate_and_score_completions(inputs)
gen_dataset = Dataset.from_list(inputs)
exact_div(
len(gen_dataset), self.args.per_device_train_batch_size, "len(gen_dataset) is not divisible by batch size"
)
def get_logprobs(example, model, output_name):
# dict of lists to list of dicts
examples = [dict(zip(example.keys(), values)) for values in zip(*example.values())]
input_ids, attention_mask, completion_mask, completion_ids = self._get_padded_inputs_and_attn_mask(
examples
)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep).detach()
lengths = [len(example["completion_ids"]) for example in examples]
# Strip the completion padding
per_token_logps = per_token_logps.to("cpu").tolist()
per_token_logps = [logps[:length] for logps, length in zip(per_token_logps, lengths)]
example[output_name] = per_token_logps
return example
with torch.no_grad():
set_verbosity_error()
disable_progress_bars()
if self.ref_model is not None:
self.ref_model.eval()
gen_dataset = gen_dataset.map(
get_logprobs,
batched=True,
batch_size=self.args.per_device_train_batch_size*2,
fn_kwargs={"model": self.ref_model, "output_name": "ref_per_token_logps"},
)
self.model.eval()
gen_dataset = gen_dataset.map(
get_logprobs,
batched=True,
batch_size=self.args.per_device_train_batch_size*2 ,
fn_kwargs={"model": self.model, "output_name": "old_per_token_logps"},
)
self.model.train()
enable_progress_bars()
set_verbosity_info()
def mini_batch_collator(mini_batch):
return mini_batch
mini_batch_dataloader = DataLoader(
gen_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True, # we technically don't need to shuffle due to grad acc, but we may move to clipped loss later
drop_last=True,
collate_fn=mini_batch_collator,
)
for num_iters in range(self.args.num_iterations):
for mini_batch in mini_batch_dataloader:
self.batch_buffer.append(mini_batch)
return self.batch_buffer.pop(0)
@profiling_decorator
def _sync_weights(self):
if self.remote_model.is_mock:
return
self.accelerator.wait_for_everyone()
# if self.accelerator.is_main_process:
start = time.time()
# would be better if this was a ram disk + separate thread for writing
unwrapped_model = self.accelerator.unwrap_model(self.model)
if is_deepspeed_zero3_enabled():
state_dict = {}
for name, param in unwrapped_model.named_parameters():
if name in state_dict.keys():
# sometimes the embed table is duplicated so no need to regather it
continue
with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
state_dict[name] = param.cpu().detach().clone()
# if is_fsdp_managed_module(self.model):
# state_dict = self.model.state_dict()
# trainer.save_model(output_dir)
else:
state_dict = self.accelerator.get_state_dict(self.model)
# if self.accelerator.is_main_process:
# with tempfile.TemporaryDirectory(dir=self.args.checkpoint_dir) as temp_dir_path:
# self.save_model(temp_dir_path)
# state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
with tempfile.TemporaryDirectory(dir=self.args.checkpoint_dir) as temp_dir_path:
self._save(temp_dir_path, state_dict=state_dict)
self.remote_model.load_weights_from_path(temp_dir_path)
print(f"Weight sync took: {time.time() - start:.2f}s")
self.accelerator.wait_for_everyone()
# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
def _get_padded_inputs_and_attn_mask(self, inputs):
device = self.accelerator.device
prompt_ids = [torch.LongTensor(example["prompt_ids"]) for example in inputs]
completion_ids = [torch.LongTensor(example["completion_ids"]) for example in inputs]
pad_token_id = self.processing_class.pad_token_id
prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
# padd_ref_per_token_logps = pad(ref_per_token_logps, padding_value=0.0, padding_side="right")
if self.args.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.args.max_prompt_length :]
# compute the masks
prompt_mask = (prompt_ids != pad_token_id).long()
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1)).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
input_ids = torch.cat([prompt_ids, completion_ids], dim=1).to(device)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1).to(device)
completion_mask = completion_mask.to(device)
return input_ids, attention_mask, completion_mask, completion_ids
@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
device = self.accelerator.device
advantages = torch.Tensor([example["advantages"] for example in inputs]).to(device)
input_ids, attention_mask, completion_mask, completion_ids = self._get_padded_inputs_and_attn_mask(inputs)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
old_per_token_logps = [torch.Tensor(example["old_per_token_logps"]) for example in inputs]
pad_token_id = self.processing_class.pad_token_id
# padd the ref and old logps
pad_old_per_token_logps = pad(old_per_token_logps, padding_value=pad_token_id, padding_side="right").to(device)
# model.eval()
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
if self.ref_model is not None:
ref_per_token_logps = [torch.Tensor(example["ref_per_token_logps"]) for example in inputs]
pad_ref_per_token_logps = pad(ref_per_token_logps, padding_value=pad_token_id, padding_side="right").to(device)
clamped_diff= torch.clamp(pad_ref_per_token_logps - per_token_logps,-10.0,10.0) # for numerical stability
per_token_kl = (
torch.exp(clamped_diff) - clamped_diff - 1
)
# del inputs, input_ids, attention_mask # free up memory
# clipped loss
coef_1 = torch.exp(per_token_logps - pad_old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.args.epsilon, 1 + self.args.epsilon)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.ref_model is not None:
per_token_loss = per_token_loss + self.args.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
torch.cuda.empty_cache()
return loss
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
mode = "eval" if self.control.should_evaluate else "train"
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if mode == "eval":
metrics = {f"eval_{key}": val for key, val in metrics.items()}
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics[mode].clear()

View file

@ -0,0 +1,162 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# great reference: https://github.com/vllm-project/vllm/issues/11400
import time
import random
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import requests
class RemoteModel:
"""
launch with:
export LD_LIBRARY_PATH=$(python -c "import site; print(site.getsitepackages()[0] + '/nvidia/nvjitlink/lib')"):$LD_LIBRARY_PATH
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port=30010 --skip-tokenizer-init --mem-fraction-static 0.4
python3 -m sglang.launch_server --model-path HuggingFaceTB/SmolLM2-135M-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.4 --host=0.0.0.0
# on a separate node
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
python3 -m sglang.launch_server --model-path HuggingFaceTB/SmolLM2-1.7B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.6 --host=0.0.0.0 --dp-size=8
python3 -m sglang.launch_server --model-path open-r1/Qwen2.5-Coder-7B-Instruct-SFT --revision v00.08-step-000001280 --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-1.5B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
"""
def __init__(self, remote_model_url, remote_model_port, stop_token_id=None):
self.remote_model_url = remote_model_url
self.remote_model_port = remote_model_port
self.stop_token_id = stop_token_id
if self.remote_model_url == "mock":
print("Using mock remote model")
@property
def is_mock(self):
return self.remote_model_url == "mock"
def is_healthy(self, timeout=5):
if self.remote_model_url == "mock":
return True
"""Checks if the remote model server is up and running."""
try:
url = f"http://{self.remote_model_url}:{self.remote_model_port}/health"
response = requests.get(url, timeout=timeout)
return response.status_code == 200
except requests.RequestException:
return False
def wait_for_server(self, max_retries=120, delay=5):
"""Waits for the server to become available before proceeding."""
for attempt in range(max_retries):
if self.is_healthy():
print("Remote model server is healthy!")
return True
print(f"Waiting for server to start... (Attempt {attempt + 1}/{max_retries})")
time.sleep(delay)
raise RuntimeError("Remote model server did not start in time.")
def generate(
self, input_ids: list[list[int]], max_new_tokens=256, temperature=0.8, num_generations=2
) -> tuple[list[list[int]], list[list[int]]]:
# Prepare the request body
if self.remote_model_url == "mock":
examples = []
for prompt_ids in input_ids:
for j in range(num_generations):
example = {
"prompt_ids": prompt_ids,
"completion_ids": random.choices(range(10 ,1000), k=max_new_tokens),
# "prompt_log_probs": None, # TODO, not used for now
# "completion_log_probs": None,
}
examples.append(example)
return examples
request_body = {
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop_token_ids": [self.stop_token_id],
"n": num_generations,
},
"stream": False,
# "return_logprob": True, # disabled as we occasiosally see https://github.com/sgl-project/sglang/issues/4097
# "logprob_start_len": 0,
}
# Send the POST request to the server
# add a few retries?
response = requests.post(
f"http://{self.remote_model_url}:{self.remote_model_port}/generate", json=request_body
)
response_json = response.json()
examples = []
for i, result in enumerate(response_json):
prompt_index = i // num_generations
prompt_ids = input_ids[prompt_index]
completion_ids = result["output_ids"]
# prompt_log_probs = [prob[0] for prob in result["meta_info"]["input_token_logprobs"]]
# completion_log_probs = [prob[0] for prob in result["meta_info"]["output_token_logprobs"]]
example = {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
# "prompt_log_probs": prompt_log_probs,
# "completion_log_probs": completion_log_probs,
}
examples.append(example)
return examples
def load_weights_from_path(self, path: str):
if self.remote_model_url == "mock":
return
url = f"http://{self.remote_model_url}:{self.remote_model_port}/update_weights_from_disk"
data = {"model_path": path}
response = requests.post(url, json=data)
print(response.text)
assert response.json()["success"] is True
if __name__ == "__main__":
from datasets import load_dataset
url = "0.0.0.0"
port = 30010
MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
remote_model = RemoteModel(url, port, tokenizer.eos_token_id)
dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train")
dataloader = DataLoader(dataset, batch_size=4)
for i, batch in zip(range(2), dataloader):
problems = batch["problem"]
ids = tokenizer(problems)
new_ids, logprobs = remote_model.generate(ids["input_ids"])
print(new_ids)
print(logprobs)

View file

@ -0,0 +1,36 @@
from typing import Iterator
from torch.utils.data import RandomSampler
class RepeatBatchRandomSampler(RandomSampler):
def __init__(
self,
*args,
num_generations: int = 1,
batch_size: int = 3,
**kwargs,
) -> None:
self.num_generations = num_generations
self.batch_size = batch_size
super().__init__(*args, **kwargs)
def __len__(self) -> int:
return super().__len__() * self.num_generations
def __iter__(self) -> Iterator[int]:
batch_indices = []
for idx in super().__iter__():
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
batch_indices = batch_indices * self.num_generations
yield from batch_indices
batch_indices = []
if __name__ == "__main__":
sampler = RepeatBatchRandomSampler(num_generations=2, data_source=range(12), replacement=False)
# print(list(sampler))
for sample in sampler:
print(sample)

View file

@ -0,0 +1,148 @@
from collections import defaultdict
import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from open_r1.configs import GRPOConfig
from open_r1.trainers.remote_model import RemoteModel
from trl.data_utils import is_conversational, maybe_apply_chat_template
class RemoteGRPODataloader(DataLoader):
def __init__(
self, *args, config: GRPOConfig, remote_model=None, processing_class=None, reward_funcs=None, **kwargs
):
super().__init__(*args, **kwargs)
self.config = config
self.remote_model = remote_model
self.processing_class = processing_class
self.reward_funcs = reward_funcs
self.reward_weights = [1.0] * len(reward_funcs) # TODO: make this configurable
def __len__(self):
return super().__len__() * self.config.num_generations
def __iter__(self):
for batch in super().__iter__():
batch = self._prepare_batch(batch)
gen_dataset = Dataset.from_list(batch)
mini_batch_dataloader = DataLoader(
gen_dataset,
batch_size=self.config.per_device_train_batch_size,
shuffle=True, # we technically don#t need to shuffle due to grad acc, but we may move to clipped loss later
drop_last=True,
collate_fn=self.collate_fn,
)
for mini_batch in mini_batch_dataloader:
yield mini_batch
def _prepare_batch(self, batch):
prompts = [x["prompt"] for x in batch]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in batch]
prompt_inputs = self.processing_class(prompts_text)
prompt_ids = prompt_inputs["input_ids"]
# add cuda clear cache here and a sleep
all_outputs = self.remote_model.generate(
prompt_ids,
max_new_tokens=self.config.max_completion_length,
temperature=self.config.temperature,
num_generations=self.config.num_generations,
)
completion_ids = [example["completion_ids"] for example in all_outputs]
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
repeated_prompts = []
for prompt in prompts:
repeated_prompts.extend([prompt] * self.config.num_generations)
repeated_prompt_texts = []
for prompt in prompts_text:
repeated_prompt_texts.extend([prompt] * self.config.num_generations)
if is_conversational(batch[0]):
completions = []
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
for i, reward_func in enumerate(self.reward_funcs):
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in batch[0] if key not in ["prompt", "completion"]]
reward_kwargs = defaultdict(list)
for example in batch:
for key in keys:
reward_kwargs[key].extend([example[key]] * self.config.num_generations)
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions, **reward_kwargs)
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
grouped_rewards = rewards.sum(-1).view(len(prompts), self.config.num_generations)
EPS = 1e-4
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
grouped_rewards.std(-1, keepdim=True) + EPS
)
advantages = grouped_advantages.flatten().tolist()
# build batch as list of dicts
examples = []
for i, prompt in enumerate(repeated_prompt_texts):
example = {
"prompt": prompt,
"prompt_ids": prompt_ids[i // self.config.num_generations],
"completion": completions_text[i],
"completion_ids": completion_ids[i],
"advantages": advantages[i],
"rewards": rewards[i],
}
examples.append(example)
return examples
if __name__ == "__main__":
dataset = load_dataset("open-r1/OpenR1-Math-cn_k12-86k", split="train").select(range(32))
def make_conversation(example):
prompt = []
prompt.append({"role": "user", "content": example["problem"]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
def collate_fn(batch):
return batch
dataset = dataset.remove_columns("messages")
def reward_func(prompts, completions, **kwargs):
return [0.5] * len(prompts)
reward_funcs = [reward_func, reward_func]
MODEL = "HuggingFaceTB/SmolLM2-135M-Instruct"
processing_class = AutoTokenizer.from_pretrained(MODEL)
remote_model = RemoteModel("0.0.0.0", 30010, processing_class.eos_token_id)
config = GRPOConfig()
data_loader = RemoteGRPODataloader(
dataset,
remote_model=remote_model,
processing_class=processing_class,
reward_funcs=reward_funcs,
batch_size=2,
num_workers=0,
collate_fn=collate_fn,
config=config,
)
print(len(data_loader))
for i, batch in enumerate(data_loader):
print(i, len(batch))
print(batch)

View file

@ -88,7 +88,7 @@ def run_lighteval_job(
f"{model_args.trust_remote_code}",
]
if training_args.system_prompt is not None:
cmd_args.append(f"--system_prompt={training_args.system_prompt}")
cmd_args.append(f"'{training_args.system_prompt}'")
cmd[-1] += " " + " ".join(cmd_args)
subprocess.run(cmd, check=True)