mirror of
https://github.com/huggingface/open-r1.git
synced 2026-06-24 01:54:06 +00:00
Compare commits
63 commits
main
...
faster-grp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f08b559e00 | ||
|
|
a5668ec040 | ||
|
|
36c3867811 | ||
|
|
41f41d80f8 | ||
|
|
4f94c356b6 | ||
|
|
1b074940bf | ||
|
|
0e4685f6c1 | ||
|
|
6e275af99e |
||
|
|
bf4642b4ea | ||
|
|
c6ab52a0cf | ||
|
|
93d8bc5aba | ||
|
|
e922330623 | ||
|
|
fd62cf290e | ||
|
|
1c4efd5bd3 | ||
|
|
0c3e50f332 | ||
|
|
c69164a573 | ||
|
|
a4004f658e | ||
|
|
5fdbcd5a20 | ||
|
|
4fa226ba1e | ||
|
|
96546e7721 | ||
|
|
a3d1f26715 | ||
|
|
14d75bf0d5 |
||
|
|
f4fe3550b6 | ||
|
|
b47880b1e2 | ||
|
|
10a70dfa42 | ||
|
|
7d470d02d4 | ||
|
|
dcf0af62e2 | ||
|
|
3d0e39d5d6 | ||
|
|
76f8ae7a88 | ||
|
|
389befcf3e | ||
|
|
2f4f6fe4ef | ||
|
|
48a57bbe34 | ||
|
|
54afde3b67 | ||
|
|
731caf57ed | ||
|
|
95bac8aacf | ||
|
|
dcc33dc710 | ||
|
|
8c7d764e5b | ||
|
|
ada8cecd54 | ||
|
|
9b7aa79b0e | ||
|
|
25e0b07feb | ||
|
|
db2501d531 | ||
|
|
def83fc6af |
||
|
|
00b1f61c01 | ||
|
|
67fb66af13 | ||
|
|
9de588449a | ||
|
|
0030447af5 | ||
|
|
c775de3fd0 | ||
|
|
0db1912bdc | ||
|
|
09628da5d3 | ||
|
|
55a451a813 | ||
|
|
f68c27bdf3 | ||
|
|
f50658e7c8 | ||
|
|
875628838a | ||
|
|
3e37bf1361 | ||
|
|
420d72a7da | ||
|
|
382a0c7890 | ||
|
|
5d213c48b1 | ||
|
|
12be29f08b | ||
|
|
b1394e542e | ||
|
|
fbe3b07d56 | ||
|
|
dcdcebaaac | ||
|
|
ed9554ff54 | ||
|
|
38e350d3a2 |
39 changed files with 3998 additions and 11 deletions
BIN
.litellm_cache/cache.db
Normal file
BIN
.litellm_cache/cache.db
Normal file
Binary file not shown.
|
|
@ -0,0 +1,67 @@
|
|||
# Model arguments
|
||||
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
|
||||
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}"
|
||||
dataset_name: agentica-org/DeepScaleR-Preview-Dataset
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- math_500
|
||||
- aime24
|
||||
beta: 0.001
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: DeepSeek-R1-Distill-Qwen-1.5B-GRPO
|
||||
hub_strategy: every_save
|
||||
learning_rate: 5.0e-07
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 4096
|
||||
max_steps: -1
|
||||
num_generations: 8
|
||||
num_train_epochs: 1
|
||||
output_dir: data/DeepSeek-R1-Distill-Qwen-1.5B-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 8
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- tag_count
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.2
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
# Model arguments
|
||||
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
dataset_configs:
|
||||
- all
|
||||
dataset_train_split: train
|
||||
num_processes: 8
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.8
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 32
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 16000
|
||||
max_steps: -1
|
||||
num_train_epochs: 0.1
|
||||
num_generations: 16
|
||||
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.00
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: true
|
||||
beta: 0.04
|
||||
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
use_liger_kernel: true
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
log_completions: true
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 100
|
||||
# callbacks:
|
||||
# - push_to_hub_revision
|
||||
# benchmarks:
|
||||
# - math_500_8k
|
||||
# - aime24_8k
|
||||
# - gsm8k_8k
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
# Model arguments
|
||||
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
dataset_configs:
|
||||
- all
|
||||
dataset_train_split: train
|
||||
num_processes: 8
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.8
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.01
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 16000
|
||||
max_steps: -1
|
||||
num_train_epochs: 0.1
|
||||
num_generations: 14
|
||||
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.01
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: true
|
||||
beta: 0.04
|
||||
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
use_liger_kernel: true
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
log_completions: true
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 100
|
||||
# callbacks:
|
||||
# - push_to_hub_revision
|
||||
# benchmarks:
|
||||
# - math_500_8k
|
||||
# - aime24_8k
|
||||
# - gsm8k_8k
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
# Model arguments
|
||||
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
dataset_configs:
|
||||
- all
|
||||
dataset_train_split: train
|
||||
num_processes: 8
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.8
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-RGRPO-v00.01
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 32768
|
||||
max_steps: -1
|
||||
num_train_epochs: 1.0
|
||||
num_generations: 16
|
||||
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-RGRPO-v00.01
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: true
|
||||
beta: 0.04
|
||||
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
use_liger_kernel: true
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
log_completions: true
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 100
|
||||
# callbacks:
|
||||
# - push_to_hub_revision
|
||||
# benchmarks:
|
||||
# - math_500_8k
|
||||
# - aime24_8k
|
||||
# - gsm8k_8k
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
# Model arguments
|
||||
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: AI-MO/NuminaMath-TIR
|
||||
dataset_configs:
|
||||
- all
|
||||
num_processes: 1
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
# use_vllm: true
|
||||
bf16: true
|
||||
# ref_model_url: http://127.0.0.1:8000
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: false
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: DeepSeek-R1-Distill-Qwen-1.5-GRPO
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 512
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
output_dir: data/DeepSeek-R1-Distill-Qwen-1.5-GRPO-v001
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 1
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
64
recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote.yaml
Normal file
64
recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote.yaml
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- math_500
|
||||
- aime24
|
||||
beta: 0.001
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-1.5B-Instruct-GRPO
|
||||
hub_model_revision: v00.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 4096
|
||||
max_steps: 1000
|
||||
num_generations: 16
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-1.5B-Instruct-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.2
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
65
recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote_v01.02.yaml
Normal file
65
recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote_v01.02.yaml
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- math_500
|
||||
- aime24
|
||||
beta: 0.001
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 14
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-1.5B-Instruct-RGRPO
|
||||
hub_model_revision: v01.02
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 4096
|
||||
max_steps: -1
|
||||
num_generations: 14
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-1.5B-Instruct-RGRPO_v01.02
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
remote_gen_model_url: 26.0.164.45
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.2
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
64
recipes/Qwen2.5-7B-Instruct/grpo/config_deepscaler.yaml
Normal file
64
recipes/Qwen2.5-7B-Instruct/grpo/config_deepscaler.yaml
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: agentica-org/DeepScaleR-Preview-Dataset
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- math_500
|
||||
- aime24
|
||||
beta: 0.001
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-7B-Instruct-GRPO
|
||||
hub_model_revision: v01.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 4096
|
||||
max_steps: 1000
|
||||
num_generations: 16
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-7B-Instruct-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.2
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
64
recipes/Qwen2.5-7B-Instruct/grpo/config_remote.yaml
Normal file
64
recipes/Qwen2.5-7B-Instruct/grpo/config_remote.yaml
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- math_500
|
||||
- aime24
|
||||
beta: 0.001
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 14
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-7B-Instruct-GRPO
|
||||
hub_model_revision: v00.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 4096
|
||||
max_steps: -1
|
||||
num_generations: 14
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-7B-Instruct-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.2
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
64
recipes/Qwen2.5-7B-Instruct/grpo/config_remote_v01.00.yaml
Normal file
64
recipes/Qwen2.5-7B-Instruct/grpo/config_remote_v01.00.yaml
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- math_500
|
||||
- aime24
|
||||
beta: 0.001
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 14
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
|
||||
hub_model_revision: v01.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 4096
|
||||
max_steps: -1
|
||||
num_generations: 14
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.00
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.2
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
64
recipes/Qwen2.5-7B-Instruct/grpo/config_remote_v01.01.yaml
Normal file
64
recipes/Qwen2.5-7B-Instruct/grpo/config_remote_v01.01.yaml
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- math_500
|
||||
- aime24
|
||||
beta: 0.0
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 14
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
|
||||
hub_model_revision: v01.01
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 4096
|
||||
max_steps: -1
|
||||
num_generations: 14
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.01
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.2
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
64
recipes/Qwen2.5-7B-Instruct/grpo/config_remote_v01.02.yaml
Normal file
64
recipes/Qwen2.5-7B-Instruct/grpo/config_remote_v01.02.yaml
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- math_500
|
||||
- aime24
|
||||
beta: 0.0
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 14
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
|
||||
hub_model_revision: v01.02
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 4096
|
||||
max_steps: -1
|
||||
num_generations: 14
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.02
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.2
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
60
recipes/Qwen2.5-Coder-7B-Instruct/grpo/config_remote.yaml
Normal file
60
recipes/Qwen2.5-Coder-7B-Instruct/grpo/config_remote.yaml
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Model arguments
|
||||
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
|
||||
model_revision: v06.11-step-000004005
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/verifiable-coding-problems-python_decontaminated
|
||||
|
||||
# GRPO trainer config
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- lcb
|
||||
beta: 0.001
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
|
||||
hub_model_revision: v00.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
max_grad_norm: 0.2
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 8192
|
||||
max_steps: 1000
|
||||
num_generations: 16
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- code
|
||||
reward_weights:
|
||||
- 1.0
|
||||
save_strategy: "steps"
|
||||
save_steps: 0.1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
warmup_ratio: 0.1
|
||||
59
recipes/Qwen2.5-Coder-7B-Instruct/grpo/config_v00.00.yaml
Normal file
59
recipes/Qwen2.5-Coder-7B-Instruct/grpo/config_v00.00.yaml
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# Model arguments
|
||||
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
|
||||
model_revision: v00.08-step-000001280
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/verifiable-coding-problems-python-10k_decontaminated
|
||||
dataset_configs:
|
||||
- all
|
||||
num_processes: 7
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
|
||||
hub_model_revision: v00.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 31744
|
||||
max_steps: -1
|
||||
num_train_epochs: 5
|
||||
num_generations: 7
|
||||
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO-v00.00
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 25
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- lcb
|
||||
|
||||
reward_funcs:
|
||||
# - code
|
||||
- code_format
|
||||
reward_weights:
|
||||
# - 1.0
|
||||
- 0.1
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
# Model arguments
|
||||
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
|
||||
model_revision: v00.08-step-000001280
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/verifiable-coding-problems-python-10k_decontaminated
|
||||
dataset_configs:
|
||||
- all
|
||||
num_processes: 8
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
|
||||
hub_model_revision: v00.00_remote
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 30700
|
||||
max_steps: -1
|
||||
num_train_epochs: 5
|
||||
num_generations: 16
|
||||
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO-v00.00_remote
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 1
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 25
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
benchmarks:
|
||||
- lcb
|
||||
use_liger: true
|
||||
reward_funcs:
|
||||
- code
|
||||
- code_format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
64
recipes/Qwen2.5-Coder-7B-Instruct/grpo/config_v01.00.yaml
Normal file
64
recipes/Qwen2.5-Coder-7B-Instruct/grpo/config_v01.00.yaml
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Model arguments
|
||||
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
|
||||
model_revision: v02.12-step-000003170
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/verifiable-coding-problems-python_decontaminated-tested
|
||||
dataset_configs:
|
||||
- all
|
||||
num_processes: 8
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-RGRPO
|
||||
hub_model_revision: v01.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 30000
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
num_generations: 16
|
||||
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-RGRPO-v01.00
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 1
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
beta: 0.01
|
||||
|
||||
remote_gen_model_url: 26.0.165.131
|
||||
num_iterations: 1
|
||||
|
||||
|
||||
# Saving and eval callbacks
|
||||
# save_strategy: "steps"
|
||||
# save_steps: 25
|
||||
# callbacks:
|
||||
# - push_to_hub_revision
|
||||
# benchmarks:
|
||||
# - lcb
|
||||
|
||||
reward_funcs:
|
||||
- code
|
||||
- code_format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
65
recipes/SmolLM2-1.7B-Instruct/grpo/config_fast_grpo.yaml
Normal file
65
recipes/SmolLM2-1.7B-Instruct/grpo/config_fast_grpo.yaml
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
# Model arguments
|
||||
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
dataset_configs:
|
||||
- all
|
||||
dataset_train_split: train
|
||||
num_processes: 7
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 64
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-FGRPO
|
||||
hub_model_revision: v05.00
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 7168
|
||||
max_steps: -1
|
||||
num_train_epochs: 0.5
|
||||
num_generations: 16
|
||||
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-FGRPO-v05.00
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: true
|
||||
beta: 0.04
|
||||
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
use_liger_kernel: true
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
log_completions: true
|
||||
seed: 42
|
||||
warmup_ratio: 0.02
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 10
|
||||
# callbacks:
|
||||
# - push_to_hub_revision
|
||||
# benchmarks:
|
||||
# - math_500_8k
|
||||
# - aime24_8k
|
||||
# - gsm8k_8k
|
||||
63
recipes/SmolLM2-1.7B-Instruct/grpo/config_v05.01.yaml
Normal file
63
recipes/SmolLM2-1.7B-Instruct/grpo/config_v05.01.yaml
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
# Model arguments
|
||||
model_name_or_path: open-r1/SMOLLM_I8k-GR2-deepseek
|
||||
model_revision: main-step-000000300
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
dataset_configs:
|
||||
- all
|
||||
dataset_train_split: train
|
||||
num_processes: 8
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.01
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-04
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 7168
|
||||
max_steps: -1
|
||||
num_train_epochs: 0.5
|
||||
num_generations: 16
|
||||
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.01
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 8
|
||||
push_to_hub: true
|
||||
beta: 0.0
|
||||
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
use_liger_kernel: true
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
log_completions: true
|
||||
seed: 42
|
||||
warmup_ratio: 0.02
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 10
|
||||
callbacks:
|
||||
- push_to_hub_revision
|
||||
70
recipes/SmolLM2-1.7B-Instruct/grpo/config_v05.19.yaml
Normal file
70
recipes/SmolLM2-1.7B-Instruct/grpo/config_v05.19.yaml
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# Model arguments
|
||||
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
dataset_configs:
|
||||
- all
|
||||
dataset_train_split: train
|
||||
num_processes: 8
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.19
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-04
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 7172
|
||||
max_steps: -1
|
||||
num_train_epochs: 1.0
|
||||
num_generations: 16
|
||||
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.19
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 8
|
||||
push_to_hub: true
|
||||
beta: 0.01
|
||||
remote_gen_model_url: 26.0.160.225
|
||||
num_iterations: 4
|
||||
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
use_liger_kernel: true
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
log_completions: true
|
||||
seed: 42
|
||||
warmup_ratio: 0.02
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 10
|
||||
# callbacks:
|
||||
# - push_to_hub_revision
|
||||
# benchmarks:
|
||||
# - math_500
|
||||
# - aime24
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
67
recipes/SmolLM2-135M-Instruct/grpo/config_fast_grpo.yaml
Normal file
67
recipes/SmolLM2-135M-Instruct/grpo/config_fast_grpo.yaml
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
# Model arguments
|
||||
model_name_or_path: HuggingFaceTB/SmolLM2-135M-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
|
||||
dataset_configs:
|
||||
- all
|
||||
dataset_train_split: train
|
||||
num_processes: 8
|
||||
ddp_find_unused_parameters: false
|
||||
# GRPO trainer config
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.8
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/SmolLM2-135M-Instruct-GRPO-v00.01
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 2048
|
||||
max_steps: -1
|
||||
num_train_epochs: 0.1
|
||||
num_generations: 4
|
||||
output_dir: data/open-r1/SmolLM2-135M-Instruct-GRPO-v00.01
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 8
|
||||
push_to_hub: true
|
||||
beta: 0.04
|
||||
remote_gen_model_url: 0.0.0.0
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
use_liger_kernel: true
|
||||
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_entity: huggingface
|
||||
wandb_project: open-r1
|
||||
log_completions: true
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
|
||||
# Saving and eval callbacks
|
||||
save_strategy: "steps"
|
||||
save_steps: 100
|
||||
# callbacks:
|
||||
# - push_to_hub_revision
|
||||
# benchmarks:
|
||||
# - math_500_8k
|
||||
# - aime24_8k
|
||||
# - gsm8k_8k
|
||||
47
recipes/accelerate_configs/deepspeed3_offload.json
Normal file
47
recipes/accelerate_configs/deepspeed3_offload.json
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto"
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false,
|
||||
"activation_checkpointing": {
|
||||
"partition_activations": true,
|
||||
"cpu_checkpointing": true,
|
||||
"contiguous_memory_optimization": false,
|
||||
"number_checkpoints": null,
|
||||
"synchronize_checkpoint_boundary": false,
|
||||
"profile": false
|
||||
}
|
||||
}
|
||||
23
recipes/accelerate_configs/zero3_off.yaml
Normal file
23
recipes/accelerate_configs/zero3_off.yaml
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
16
recipes/accelerate_configs/zero3_off2.yaml
Normal file
16
recipes/accelerate_configs/zero3_off2.yaml
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_config_file: recipes/accelerate_configs/deepspeed3_offload.json
|
||||
zero3_init_flag: false
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: fal
|
||||
30
recipes/accelerate_configs/zero3_offload.yaml
Normal file
30
recipes/accelerate_configs/zero3_offload.yaml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
offload_param:
|
||||
device: cpu
|
||||
pin_memory: true
|
||||
activation_checkpointing:
|
||||
partition_activations: true
|
||||
contiguous_memory_optimization: false
|
||||
cpu_checkpointing: true
|
||||
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-Math-7B
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: DigitalLearningGmbH/MATH-lighteval
|
||||
dataset_configs:
|
||||
- train
|
||||
# Num processes is less by 1 as vLLM is using 1 GPU
|
||||
num_processes: 7
|
||||
|
||||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: true
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: open-r1/Qwen-2.5-7B_Base_Math_smalllr_remote_model
|
||||
hub_strategy: every_save
|
||||
learning_rate: 3.0e-06
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 1024
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen-2.5-7B_Base_Math_smalllr_remote_model
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 1
|
||||
per_device_train_batch_size: 1
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
ref_model_url: http://26.0.163.127:30010
|
||||
250
scripts/faster_grpo.py
Normal file
250
scripts/faster_grpo.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from transformers import set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.configs import GRPOConfig
|
||||
from open_r1.rewards import (
|
||||
accuracy_reward,
|
||||
format_reward,
|
||||
get_cosine_scaled_reward,
|
||||
get_repetition_penalty_reward,
|
||||
len_reward,
|
||||
reasoning_steps_reward,
|
||||
)
|
||||
from open_r1.trainers.faster_grpo_trainer import FastGRPOTrainer, FastGRPOConfig
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class GRPOScriptArguments(ScriptArguments):
|
||||
"""
|
||||
Script arguments for the GRPO training script.
|
||||
Args:
|
||||
reward_funcs (`list[str]`):
|
||||
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
|
||||
cosine_min_value_wrong (`float`):
|
||||
Minimum reward for cosine scaling for wrong answers.
|
||||
cosine_max_value_wrong (`float`):
|
||||
Maximum reward for cosine scaling for wrong answers.
|
||||
cosine_min_value_correct (`float`):
|
||||
Minimum reward for cosine scaling for correct answers.
|
||||
cosine_max_value_correct (`float`):
|
||||
Maximum reward for cosine scaling for correct answers.
|
||||
cosine_max_len (`int`):
|
||||
Maximum length for cosine scaling.
|
||||
"""
|
||||
|
||||
reward_funcs: list[str] = field(
|
||||
default_factory=lambda: ["accuracy", "format"],
|
||||
metadata={
|
||||
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
|
||||
},
|
||||
)
|
||||
cosine_min_value_wrong: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Minimum reward for wrong answers"},
|
||||
)
|
||||
cosine_max_value_wrong: float = field(
|
||||
default=-0.5,
|
||||
metadata={"help": "Maximum reward for wrong answers"},
|
||||
)
|
||||
cosine_min_value_correct: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "Minimum reward for correct answers"},
|
||||
)
|
||||
cosine_max_value_correct: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum reward for correct answers"},
|
||||
)
|
||||
cosine_max_len: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Maximum length for scaling"},
|
||||
)
|
||||
|
||||
repetition_n_grams: int = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of n-grams for repetition penalty reward"},
|
||||
)
|
||||
repetition_max_penalty: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
|
||||
)
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
||||
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
||||
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
||||
"<think> reasoning process here </think><answer> answer here </answer>"
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
# Setup logging
|
||||
###############
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process a small summary
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Script parameters {script_args}")
|
||||
logger.info(f"Training parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
if "wandb" in training_args.report_to:
|
||||
init_wandb_training(training_args)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
# Get reward functions
|
||||
REWARD_FUNCS_REGISTRY = {
|
||||
"accuracy": accuracy_reward,
|
||||
"format": format_reward,
|
||||
"reasoning_steps": reasoning_steps_reward,
|
||||
"cosine": get_cosine_scaled_reward(
|
||||
min_value_wrong=script_args.cosine_min_value_wrong,
|
||||
max_value_wrong=script_args.cosine_max_value_wrong,
|
||||
min_value_correct=script_args.cosine_min_value_correct,
|
||||
max_value_correct=script_args.cosine_max_value_correct,
|
||||
max_len=script_args.cosine_max_len,
|
||||
),
|
||||
"repetition_penalty": get_repetition_penalty_reward(
|
||||
ngram_size=script_args.repetition_n_grams,
|
||||
max_penalty=script_args.repetition_max_penalty,
|
||||
),
|
||||
"length": len_reward,
|
||||
}
|
||||
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
|
||||
|
||||
# Format into conversation
|
||||
def make_conversation(example):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": example["problem"]},
|
||||
],
|
||||
}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
for split in dataset:
|
||||
if "messages" in dataset[split].column_names:
|
||||
dataset[split] = dataset[split].remove_columns("messages")
|
||||
|
||||
logger.info("*** Initializing model kwargs ***")
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
training_args.model_init_kwargs = model_kwargs
|
||||
|
||||
#############################
|
||||
# Initialize the Async GRPO trainer
|
||||
#############################
|
||||
trainer = FastGRPOTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
reward_funcs=reward_funcs,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
callbacks=get_callbacks(training_args, model_args),
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
logger.info("*** Train ***")
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"dataset_name": script_args.dataset_name,
|
||||
"tags": ["open-r1"],
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
#############
|
||||
# push to hub
|
||||
#############
|
||||
if training_args.push_to_hub:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((GRPOScriptArguments, FastGRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
||||
301
scripts/remote_grpo.py
Normal file
301
scripts/remote_grpo.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""GRPO trainer to train on N + 1 nodes, with 1 node allocated for generation.
|
||||
|
||||
Usage:
|
||||
|
||||
For training, run:
|
||||
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml scripts/remote_grpo.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote.yaml
|
||||
|
||||
This will automatically spin up an SGLang server on a separate node and use it for generation.
|
||||
|
||||
For development, first spin up an SGLang sever on a separate node:
|
||||
|
||||
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-1.5B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
|
||||
|
||||
Then run training by providing the IP address of the server:
|
||||
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml scripts/remote_grpo.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote.yaml \
|
||||
--remote_gen_model_url ip-26-0-160-103
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from transformers import set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.rewards import (
|
||||
accuracy_reward,
|
||||
code_reward,
|
||||
format_reward,
|
||||
get_code_format_reward,
|
||||
get_cosine_scaled_reward,
|
||||
get_repetition_penalty_reward,
|
||||
len_reward,
|
||||
reasoning_steps_reward,
|
||||
tag_count_reward,
|
||||
)
|
||||
from open_r1.utils import get_tokenizer
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import ModelConfig, ScriptArguments, TrlParser
|
||||
from open_r1.trainers.remote_grpo_trainer import RemoteGRPOTrainer, RemoteGRPOConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOScriptArguments(ScriptArguments):
|
||||
"""
|
||||
Script arguments for the GRPO training script.
|
||||
|
||||
Args:
|
||||
reward_funcs (`list[str]`):
|
||||
List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'.
|
||||
cosine_min_value_wrong (`float`):
|
||||
Minimum reward for cosine scaling for wrong answers.
|
||||
cosine_max_value_wrong (`float`):
|
||||
Maximum reward for cosine scaling for wrong answers.
|
||||
cosine_min_value_correct (`float`):
|
||||
Minimum reward for cosine scaling for correct answers.
|
||||
cosine_max_value_correct (`float`):
|
||||
Maximum reward for cosine scaling for correct answers.
|
||||
cosine_max_len (`int`):
|
||||
Maximum length for cosine scaling.
|
||||
code_language (`str`):
|
||||
Language for code format reward.
|
||||
"""
|
||||
|
||||
reward_funcs: list[str] = field(
|
||||
default_factory=lambda: ["accuracy", "format", "tag_count"],
|
||||
metadata={
|
||||
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'"
|
||||
},
|
||||
)
|
||||
cosine_min_value_wrong: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Minimum reward for wrong answers"},
|
||||
)
|
||||
cosine_max_value_wrong: float = field(
|
||||
default=-0.5,
|
||||
metadata={"help": "Maximum reward for wrong answers"},
|
||||
)
|
||||
cosine_min_value_correct: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "Minimum reward for correct answers"},
|
||||
)
|
||||
cosine_max_value_correct: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum reward for correct answers"},
|
||||
)
|
||||
cosine_max_len: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Maximum length for scaling"},
|
||||
)
|
||||
repetition_n_grams: int = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of n-grams for repetition penalty reward"},
|
||||
)
|
||||
repetition_max_penalty: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
|
||||
)
|
||||
code_language: str = field(
|
||||
default="python",
|
||||
metadata={
|
||||
"help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages",
|
||||
"choices": ["python", "javascript", "r", "java", "bash"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
# Setup logging
|
||||
###############
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process a small summary
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Script parameters {script_args}")
|
||||
logger.info(f"Training parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
if "wandb" in training_args.report_to:
|
||||
init_wandb_training(training_args)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
################
|
||||
# Load tokenizer
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
|
||||
# Get reward functions
|
||||
REWARD_FUNCS_REGISTRY = {
|
||||
"accuracy": accuracy_reward,
|
||||
"format": format_reward,
|
||||
"reasoning_steps": reasoning_steps_reward,
|
||||
"cosine": get_cosine_scaled_reward(
|
||||
min_value_wrong=script_args.cosine_min_value_wrong,
|
||||
max_value_wrong=script_args.cosine_max_value_wrong,
|
||||
min_value_correct=script_args.cosine_min_value_correct,
|
||||
max_value_correct=script_args.cosine_max_value_correct,
|
||||
max_len=script_args.cosine_max_len,
|
||||
),
|
||||
"repetition_penalty": get_repetition_penalty_reward(
|
||||
ngram_size=script_args.repetition_n_grams,
|
||||
max_penalty=script_args.repetition_max_penalty,
|
||||
),
|
||||
"length": len_reward,
|
||||
"code": code_reward,
|
||||
"code_format": get_code_format_reward(language=script_args.code_language),
|
||||
"tag_count": tag_count_reward,
|
||||
}
|
||||
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
|
||||
|
||||
# Format into conversation
|
||||
def make_conversation(example):
|
||||
prompt = []
|
||||
|
||||
if training_args.system_prompt is not None:
|
||||
prompt.append({"role": "system", "content": training_args.system_prompt})
|
||||
|
||||
prompt.append({"role": "user", "content": example["problem"]})
|
||||
return {"prompt": prompt}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
|
||||
for split in dataset:
|
||||
if "messages" in dataset[split].column_names:
|
||||
dataset[split] = dataset[split].remove_columns("messages")
|
||||
|
||||
logger.info("*** Initializing model kwargs ***")
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
training_args.model_init_kwargs = model_kwargs
|
||||
|
||||
#############################
|
||||
# Initialize the GRPO trainer
|
||||
#############################
|
||||
trainer = RemoteGRPOTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
reward_funcs=reward_funcs,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
callbacks=get_callbacks(training_args, model_args),
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
logger.info("*** Train ***")
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
|
||||
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"dataset_name": script_args.dataset_name,
|
||||
"tags": ["open-r1"],
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
# trainer.create_model_card(**kwargs) # Bug: needs fixing with TRL helper methods
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
#############
|
||||
# push to hub
|
||||
#############
|
||||
if training_args.push_to_hub:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((GRPOScriptArguments, RemoteGRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
||||
5
setup.py
5
setup.py
|
|
@ -66,8 +66,8 @@ _deps = [
|
|||
"safetensors>=0.3.3",
|
||||
"sentencepiece>=0.1.99",
|
||||
"torch==2.5.1",
|
||||
"transformers==4.49.0",
|
||||
"trl @ git+https://github.com/huggingface/trl.git@69ad852e5654a77f1695eb4c608906fe0c7e8624",
|
||||
"transformers==4.48.3", # Must pin for SGLang
|
||||
"trl @ git+https://github.com/huggingface/trl.git@e3244d2d096ff1e2e248c931d06d39e165e20623",
|
||||
"vllm==0.7.2",
|
||||
"wandb>=0.19.1",
|
||||
]
|
||||
|
|
@ -107,6 +107,7 @@ install_requires = [
|
|||
deps["math-verify"],
|
||||
deps["liger_kernel"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["peft"],
|
||||
deps["safetensors"],
|
||||
deps["sentencepiece"],
|
||||
deps["transformers"],
|
||||
|
|
|
|||
23
slurm/launch_sglang.slurm
Normal file
23
slurm/launch_sglang.slurm
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gres=gpu:1
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --output=/fsx/open-r1/logs/%x-%j.out
|
||||
#SBATCH --err=/fsx/open-r1/logs/%x-%j.err
|
||||
|
||||
# Specific configuration optimized for the Hugging Face Compute Cluster
|
||||
# Be ye warned this may not work on other clusters!
|
||||
|
||||
set -x -e
|
||||
|
||||
source ~/.bashrc
|
||||
source openr1/bin/activate
|
||||
module load cuda/12.4
|
||||
echo Starting sglang server...
|
||||
|
||||
MODEL_ID=$1
|
||||
REVISION=$2
|
||||
PORT=$3
|
||||
|
||||
NUM_GPUS=$(nvidia-smi --list-gpus | wc -l)
|
||||
python3 -m sglang.launch_server --model-path $MODEL_ID --revision $REVISION --port=$PORT --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=$NUM_GPUS
|
||||
93
slurm/train_remote.slurm
Normal file
93
slurm/train_remote.slurm
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=open-r1-sft
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --partition=hopper-prod # Adjust this for your cluster
|
||||
#SBATCH --output=./logs/%x-%j.out
|
||||
#SBATCH --err=./logs/%x-%j.err
|
||||
|
||||
"""Usage:
|
||||
|
||||
sbatch --job-name=remote-grpo --nodes=1 slurm/train_remote.slurm Qwen2.5-1.5B-Instruct grpo remote zero3
|
||||
"""
|
||||
|
||||
# Specific configuration optimized for the Hugging Face Compute Cluster
|
||||
# Be ye warned this may not work on other clusters!
|
||||
module load cuda/12.4
|
||||
|
||||
|
||||
set -x -e
|
||||
|
||||
source ~/.bashrc
|
||||
source openr1/bin/activate
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
MODEL=$1
|
||||
TASK=$2
|
||||
CONFIG_SUFFIX=$3
|
||||
ACCELERATOR=$4
|
||||
OPTIONAL_ARGS=$5
|
||||
|
||||
# Training setup
|
||||
NUM_NODES=$SLURM_NNODES
|
||||
GPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
|
||||
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
|
||||
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
|
||||
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
|
||||
|
||||
# Split the string into individual arguments
|
||||
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
|
||||
|
||||
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
|
||||
for arg in "${ARGS[@]}"; do
|
||||
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
|
||||
# Extract the value after the equals sign
|
||||
GRAD_ACC_STEPS="${arg#*=}"
|
||||
break # Exit the loop once we find the desired argument
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
|
||||
# so processes know who to talk to
|
||||
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
MASTER_PORT=6000
|
||||
|
||||
export CMD=" \
|
||||
scripts/remote_grpo.py --config $CONFIG_FILE $OPTIONAL_ARGS
|
||||
"
|
||||
|
||||
export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
|
||||
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
|
||||
--gradient_accumulation_steps $GRAD_ACC_STEPS \
|
||||
--num_machines $NUM_NODES \
|
||||
--num_processes $WORLD_SIZE \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port $MASTER_PORT \
|
||||
--machine_rank \$SLURM_PROCID \
|
||||
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
|
||||
--max_restarts 1 \
|
||||
--role \$(hostname -s): \
|
||||
--tee 3 \
|
||||
"
|
||||
|
||||
# force crashing on nccl issues like hanging broadcast
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1
|
||||
# export NCCL_DEBUG=INFO
|
||||
# export NCCL_DEBUG_SUBSYS=COLL
|
||||
# export NCCL_SOCKET_NTHREADS=1
|
||||
# export NCCL_NSOCKS_PERTHREAD=1
|
||||
# export CUDA_LAUNCH_BLOCKING=1
|
||||
|
||||
# srun error handling:
|
||||
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
|
||||
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
|
||||
SRUN_ARGS=" \
|
||||
--wait=60 \
|
||||
--kill-on-bad-exit=1 \
|
||||
"
|
||||
|
||||
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
|
|
@ -423,10 +423,12 @@ def run_async_from_sync(scripts: list[str], language: str) -> list[float]:
|
|||
|
||||
async def run_async(scripts: list[str], language: str) -> list[float]:
|
||||
# Create the sandbox by hand, currently there's no context manager for this version
|
||||
sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
|
||||
sbx = await AsyncSandbox.create(timeout=60, request_timeout=5)
|
||||
|
||||
# Create a list of tasks for running scripts concurrently
|
||||
tasks = [run_script(sbx, script, language) for script in scripts]
|
||||
MAX_TASKS_PER_PROCESS = 2 # E2B has a limit of 20 concurrent requests, assume 1 noe, 8 processes, this is 2 per process (20//8 = 2)
|
||||
semaphore = asyncio.Semaphore(MAX_TASKS_PER_PROCESS)
|
||||
tasks = [run_script(sbx, script, language, semaphore) for script in scripts]
|
||||
|
||||
# Wait for all tasks to complete and gather their results as they finish
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
|
@ -438,9 +440,10 @@ async def run_async(scripts: list[str], language: str) -> list[float]:
|
|||
return rewards
|
||||
|
||||
|
||||
async def run_script(sbx: AsyncSandbox, script: str, language: str) -> float:
|
||||
execution = await sbx.run_code(script, language=language)
|
||||
try:
|
||||
return float(execution.text)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
async def run_script(sbx: AsyncSandbox, script: str, language: str, semaphore) -> float:
|
||||
async with semaphore: # Limit concurrency
|
||||
execution = await sbx.run_code(script, language=language)
|
||||
try:
|
||||
return float(execution.text)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
|
|
|
|||
673
src/open_r1/trainers/faster_grpo_trainer.py
Normal file
673
src/open_r1/trainers/faster_grpo_trainer.py
Normal file
|
|
@ -0,0 +1,673 @@
|
|||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# great reference: https://github.com/vllm-project/vllm/issues/11400
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import reduction
|
||||
from typing import Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.integrations import get_reporting_integration_callbacks
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
|
||||
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
|
||||
from transformers.utils import is_liger_kernel_available
|
||||
|
||||
import trl
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import gather_object
|
||||
from open_r1.trainers.job_launcher import SGLangSlurmJobLauncher
|
||||
from open_r1.trainers.remote_model import RemoteModel
|
||||
from trl.data_utils import is_conversational, maybe_apply_chat_template
|
||||
from trl.trainer.utils import pad, selective_log_softmax
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def profiling_context(instance, name):
|
||||
"""
|
||||
A context manager function for profiling a block of code.
|
||||
Can also be used as a decorator.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
yield
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if "wandb" in instance.args.report_to and wandb.run is not None and instance.accelerator.is_main_process:
|
||||
wandb.log({f"profiling/Time taken: {instance.__class__.__name__}.{name}": duration})
|
||||
|
||||
|
||||
def profiling_decorator(func):
|
||||
"""
|
||||
Decorator to profile a function and log execution time using profiling_context.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with profiling_context(self, func.__name__):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
def exact_div(a, b, custom_error_message=""):
|
||||
q = a // b
|
||||
if a != q * b:
|
||||
raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}")
|
||||
return q
|
||||
|
||||
|
||||
# TODO: add the shared options with a mixin to reduce code duplication
|
||||
@dataclass
|
||||
class FastGRPOConfig(trl.GRPOConfig):
|
||||
"""
|
||||
args for callbacks, benchmarks etc
|
||||
"""
|
||||
|
||||
benchmarks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
||||
)
|
||||
callbacks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
||||
)
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main", metadata={"help": "The Hub model branch to push the model to."}
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
)
|
||||
wandb_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The project to store runs under.")},
|
||||
)
|
||||
remote_gen_model_url: str = field(
|
||||
default="26.0.165.24",
|
||||
)
|
||||
remote_gen_model_port: str = field(
|
||||
default="30010",
|
||||
)
|
||||
remote_gen_model_n_gpus: str = field(
|
||||
default=8,
|
||||
)
|
||||
|
||||
|
||||
class FastGRPOTrainer(Trainer):
|
||||
_tag_names = ["trl", "fast_grpo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str, # only accept str for now
|
||||
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
||||
args: FastGRPOConfig,
|
||||
train_dataset: Dataset,
|
||||
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||
data_collator: Optional[DataCollatorWithPadding] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.reward_funcs = reward_funcs
|
||||
# Reward weights (move this logic to post_init of config?)
|
||||
if args.reward_weights is not None:
|
||||
if len(args.reward_weights) != len(reward_funcs):
|
||||
raise ValueError(
|
||||
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
||||
f"functions ({len(reward_funcs)})"
|
||||
)
|
||||
self.reward_weights = args.reward_weights
|
||||
else:
|
||||
self.reward_weights = ([1.0] * len(reward_funcs),)
|
||||
|
||||
# start the remote model so it has time to warmup while we load the local model(s)
|
||||
if self.args.remote_gen_model_url is None:
|
||||
self.sglang_job_launcher = SGLangSlurmJobLauncher(
|
||||
model, num_gpus=self.args.remote_gen_model_n_gpus, sglang_port=self.args.remote_gen_model_port
|
||||
)
|
||||
self.sglang_job_launcher.submit_job()
|
||||
|
||||
# Trained model
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
if isinstance(model, str):
|
||||
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
||||
pass # torch_dtype is already a torch.dtype or "auto" or None
|
||||
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
||||
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
||||
)
|
||||
# Disable caching if gradient checkpointing is enabled (not supported)
|
||||
model_init_kwargs["use_cache"] = (
|
||||
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
||||
)
|
||||
model_str = model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs)
|
||||
# offload to cpu
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs) # .to("cpu")
|
||||
|
||||
self.model = model
|
||||
self.ref_model = ref_model
|
||||
if self.args.use_liger_kernel:
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.transformers import _apply_liger_kernel_to_instance
|
||||
|
||||
_apply_liger_kernel_to_instance(model=self.model)
|
||||
_apply_liger_kernel_to_instance(model=self.ref_model)
|
||||
else:
|
||||
raise ImportError(
|
||||
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
|
||||
"Please install it with `pip install liger-kernel`"
|
||||
)
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
||||
self.processing_class = processing_class
|
||||
|
||||
self.train_dataset = train_dataset
|
||||
|
||||
if data_collator is not None:
|
||||
raise ValueError("")
|
||||
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
self.data_collator = data_collator
|
||||
|
||||
local_dataloader_batch_size = exact_div(
|
||||
args.per_device_train_batch_size * args.gradient_accumulation_steps,
|
||||
args.num_generations,
|
||||
"per_device_train_batch_size * gradient_accumulation_steps must >= num_generations to remain on policy",
|
||||
)
|
||||
self.optimizer, self.lr_scheduler = optimizers
|
||||
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
||||
self.accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
||||
|
||||
self.train_dataset_len = len(self.train_dataset)
|
||||
num_total_samples = int(self.args.num_train_epochs * self.train_dataset_len)
|
||||
self.total_steps_per_device = num_total_samples // (
|
||||
local_dataloader_batch_size * self.accelerator.num_processes
|
||||
)
|
||||
self.create_optimizer_and_scheduler(num_training_steps=self.total_steps_per_device)
|
||||
#########
|
||||
### trainer specifics
|
||||
#########
|
||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||
self.control = TrainerControl()
|
||||
self.state = TrainerState(
|
||||
is_local_process_zero=self.is_local_process_zero(),
|
||||
is_world_process_zero=self.is_world_process_zero(),
|
||||
stateful_callbacks=[
|
||||
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||
],
|
||||
)
|
||||
|
||||
self.current_flos = 0
|
||||
self.hp_search_backend = None
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
# Create distant repo and output directory if needed
|
||||
self.hub_model_id = None
|
||||
if self.args.push_to_hub:
|
||||
self.init_hf_repo()
|
||||
if self.args.should_save:
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
self.backup_model = None
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
#########
|
||||
### setup dataloader
|
||||
#########
|
||||
self.dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=local_dataloader_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=self.data_collator,
|
||||
drop_last=True,
|
||||
)
|
||||
torch.manual_seed(args.seed)
|
||||
# Enable gradient checkpointing if requested
|
||||
if args.gradient_checkpointing:
|
||||
self.model = self._enable_gradient_checkpointing(self.model, self.args)
|
||||
self.model, self.optimizer, self.dataloader = self.accelerator.prepare(
|
||||
self.model, self.optimizer, self.dataloader
|
||||
)
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
# connect to a remote sglang model
|
||||
if self.args.remote_gen_model_url is None:
|
||||
self.sglang_job_launcher.wait_for_server()
|
||||
self.args.remote_gen_model_url = self.sglang_job_launcher.get_remote_ip()
|
||||
self.remote_model = RemoteModel(
|
||||
self.args.remote_gen_model_url, self.args.remote_gen_model_port, self.processing_class.eos_token_id
|
||||
)
|
||||
|
||||
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: FastGRPOConfig) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# Ensure use_cache is disabled
|
||||
model.config.use_cache = False
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
if use_reentrant:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
return model
|
||||
|
||||
def print_gpu_memory_usage(self):
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory_allocated = torch.cuda.memory_allocated()
|
||||
gpu_memory_reserved = torch.cuda.memory_reserved()
|
||||
print(f"GPU memory allocated: {gpu_memory_allocated / (1024**3):.2f} GB")
|
||||
print(f"GPU memory reserved: {gpu_memory_reserved / (1024**3):.2f} GB")
|
||||
else:
|
||||
print("CUDA is not available.")
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
@profiling_decorator
|
||||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
||||
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
|
||||
input_ids = input_ids[:, -logits_to_keep:]
|
||||
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
||||
# See https://github.com/huggingface/trl/issues/2770
|
||||
logits = logits[:, -logits_to_keep:]
|
||||
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
||||
|
||||
@torch.no_grad()
|
||||
@profiling_decorator
|
||||
def _prepare_batch(self, batch):
|
||||
"""
|
||||
This will:
|
||||
- generate k samples for each problem
|
||||
- using internal reward model(s) to get rewards
|
||||
"""
|
||||
device = self.accelerator.device
|
||||
prompts = [x["prompt"] for x in batch]
|
||||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in batch]
|
||||
prompt_inputs = self.processing_class(prompts_text)
|
||||
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
# add cuda clear cache here and a sleep
|
||||
|
||||
all_outputs = self.remote_model.generate(
|
||||
prompt_ids,
|
||||
max_new_tokens=self.args.max_completion_length,
|
||||
temperature=self.args.temperature,
|
||||
num_generations=self.args.num_generations,
|
||||
)
|
||||
|
||||
# all_outputs = self.gen_vllm.generate(prompts_text, sampling_params=self.sampling_params, use_tqdm=True)
|
||||
|
||||
completion_ids = [example["completion_ids"] for example in all_outputs]
|
||||
|
||||
# completion_ids = []
|
||||
# for outputs in all_outputs:
|
||||
# for output in outputs.outputs:
|
||||
# completion_ids.append(output.token_ids)
|
||||
|
||||
# Decode the generated completions
|
||||
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
|
||||
repeated_prompts = []
|
||||
for prompt in prompts:
|
||||
repeated_prompts.extend([prompt] * self.args.num_generations)
|
||||
|
||||
repeated_prompt_texts = []
|
||||
for prompt in prompts_text:
|
||||
repeated_prompt_texts.extend([prompt] * self.args.num_generations)
|
||||
|
||||
if is_conversational(batch[0]):
|
||||
completions = []
|
||||
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
|
||||
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
||||
else:
|
||||
completions = completions_text
|
||||
|
||||
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
|
||||
for (
|
||||
i,
|
||||
reward_func,
|
||||
) in enumerate(self.reward_funcs):
|
||||
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||
keys = [key for key in batch[0] if key not in ["prompt", "completion"]]
|
||||
reward_kwargs = defaultdict(list)
|
||||
for example in batch:
|
||||
for key in keys:
|
||||
reward_kwargs[key].extend([example[key]] * self.args.num_generations)
|
||||
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions, **reward_kwargs)
|
||||
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
|
||||
|
||||
# calculate the advantages, the prompt is all on the same device to no need to gather here
|
||||
grouped_rewards = rewards.sum(-1).view(len(prompts), self.args.num_generations)
|
||||
EPS = 1e-4
|
||||
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
|
||||
grouped_rewards.std(-1, keepdim=True) + EPS
|
||||
)
|
||||
advantages = grouped_advantages.flatten().tolist()
|
||||
|
||||
# build batch as list of dicts
|
||||
examples = []
|
||||
for i, prompt in enumerate(repeated_prompt_texts):
|
||||
example = {
|
||||
"prompt": prompt,
|
||||
"prompt_ids": prompt_ids[i // self.args.num_generations],
|
||||
"completion": completions_text[i],
|
||||
"completion_ids": completion_ids[i],
|
||||
"advantages": advantages[i],
|
||||
"rewards": rewards[i],
|
||||
}
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
@profiling_decorator
|
||||
def _sync_weights(self):
|
||||
self.accelerator.wait_for_everyone()
|
||||
if self.accelerator.is_main_process:
|
||||
start = time.time()
|
||||
with tempfile.TemporaryDirectory(dir="/fsx/edward/work/open-r1/data/") as temp_dir_path:
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model.save_pretrained(temp_dir_path)
|
||||
self.remote_model.load_weights_from_path(temp_dir_path)
|
||||
print("weight sync took: ", time.time() - start)
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
start_step = 1 # todo, set this when we resume + load model, opt state etc
|
||||
|
||||
if self.args.logging_steps is not None:
|
||||
if self.args.logging_steps < 1:
|
||||
self.state.logging_steps = math.ceil(self.state.max_steps * self.args.logging_steps)
|
||||
else:
|
||||
self.state.logging_steps = self.args.logging_steps
|
||||
|
||||
if self.args.save_steps is not None:
|
||||
if self.args.save_steps < 1:
|
||||
self.state.save_steps = math.ceil(self.state.max_steps * self.args.save_steps)
|
||||
else:
|
||||
self.state.save_steps = self.args.save_steps
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
||||
self.state.global_step = 0
|
||||
self.state.max_steps = self.total_steps_per_device
|
||||
self.state.num_train_epochs = self.args.num_train_epochs
|
||||
|
||||
def repeat_generator():
|
||||
while True:
|
||||
yield from self.dataloader
|
||||
|
||||
iter_dataloader = iter(repeat_generator())
|
||||
|
||||
self.model.train()
|
||||
|
||||
@torch.no_grad()
|
||||
def mini_batch_collator(examples):
|
||||
device = self.accelerator.device
|
||||
|
||||
prompt_ids = [torch.LongTensor(example["prompt_ids"]) for example in examples]
|
||||
completion_ids = [torch.LongTensor(example["completion_ids"]) for example in examples]
|
||||
ref_per_token_logps = [torch.Tensor(example["ref_per_token_logps"]) for example in examples]
|
||||
|
||||
for logps, completion_id in zip(ref_per_token_logps, completion_ids):
|
||||
assert len(logps) == len(completion_id), (
|
||||
f"len(logps)={len(logps)} != len(completion_id)={len(completion_id)}"
|
||||
)
|
||||
|
||||
pad_token_id = self.processing_class.pad_token_id
|
||||
|
||||
padded_prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
|
||||
padded_completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
|
||||
padd_ref_per_token_logps = pad(ref_per_token_logps, padding_value=0.0, padding_side="right")
|
||||
|
||||
if self.args.max_prompt_length is not None:
|
||||
padded_prompt_ids = padded_prompt_ids[:, -self.args.max_prompt_length :]
|
||||
|
||||
# compute the masks
|
||||
prompt_mask = (padded_prompt_ids != pad_token_id).long()
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
is_eos = padded_completion_ids == self.processing_class.eos_token_id
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
sequence_indices = torch.arange(is_eos.size(1)).expand(is_eos.size(0), -1)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
advantages = torch.Tensor([example["advantages"] for example in examples])
|
||||
|
||||
return {
|
||||
"prompt_ids": padded_prompt_ids.to(device),
|
||||
"prompt_mask": prompt_mask.to(device),
|
||||
"completion_ids": padded_completion_ids.to(device),
|
||||
"completion_mask": completion_mask.to(device),
|
||||
"advantages": advantages.to(device),
|
||||
"ref_per_token_logps": padd_ref_per_token_logps.to(device),
|
||||
}
|
||||
|
||||
device = self.accelerator.device
|
||||
for step in range(start_step, self.total_steps_per_device + 1):
|
||||
batch = next(iter_dataloader)
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
# TODO: log completions, rewards, etc
|
||||
gen_dataset = Dataset.from_list(batch)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_ref_logps(examples):
|
||||
device = self.accelerator.device
|
||||
prompt_ids = [torch.LongTensor(prompt_id) for prompt_id in examples["prompt_ids"]]
|
||||
completion_ids = [torch.LongTensor(completion_id) for completion_id in examples["completion_ids"]]
|
||||
completion_lengths = [len(c) for c in completion_ids]
|
||||
pad_token_id = self.processing_class.pad_token_id
|
||||
padded_prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
|
||||
padded_completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
|
||||
|
||||
input_ids = torch.cat([padded_prompt_ids, padded_completion_ids], dim=1)
|
||||
attention_mask = torch.cat(
|
||||
[padded_prompt_ids != pad_token_id, padded_completion_ids != pad_token_id], dim=1
|
||||
)
|
||||
logits_to_keep = torch.tensor(completion_lengths).to(device)
|
||||
logits_to_keep = padded_completion_ids.size(1)
|
||||
with torch.inference_mode():
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.ref_model, input_ids.to(device), attention_mask.to(device), logits_to_keep
|
||||
)
|
||||
ref_per_token_logps = ref_per_token_logps.to("cpu")
|
||||
examples["ref_per_token_logps"] = [
|
||||
logprobs[:length] for logprobs, length in zip(ref_per_token_logps, completion_lengths)
|
||||
]
|
||||
|
||||
return examples
|
||||
|
||||
self.ref_model = self.ref_model.to(device)
|
||||
# precompute the ref logprobs and offload the model to cpu
|
||||
gen_dataset = gen_dataset.map(
|
||||
compute_ref_logps, batched=True, batch_size=self.args.per_device_train_batch_size
|
||||
)
|
||||
self.ref_model = self.ref_model.to("cpu")
|
||||
|
||||
# we could add some optimizations here like sorting the dataset by length to improve throughput, but we will keep it simple for now
|
||||
mini_batch_dataloader = DataLoader(
|
||||
gen_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
shuffle=True, # we technically don#t need to shuffle due to grad acc, but we may move to clipped loss later
|
||||
drop_last=True,
|
||||
collate_fn=mini_batch_collator,
|
||||
)
|
||||
# optimization
|
||||
# stats for logging
|
||||
losses = []
|
||||
kls = []
|
||||
|
||||
with profiling_context(self, "train_step"):
|
||||
for mini_batch in mini_batch_dataloader:
|
||||
loss_metric, kl_metric = self._optimization_step(mini_batch)
|
||||
losses.append(loss_metric)
|
||||
kls.append(kl_metric)
|
||||
|
||||
self.lr_scheduler.step()
|
||||
self.state.global_step += 1
|
||||
self.state.epoch = step / self.total_steps_per_device # TODO, this is not correct
|
||||
|
||||
# logging stats
|
||||
metrics = {}
|
||||
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
||||
metrics["loss"] = self.accelerator.gather_for_metrics(torch.Tensor(losses).to(device)).mean().item()
|
||||
metrics["kl"] = self.accelerator.gather_for_metrics(torch.Tensor(kls).to(device)).mean().item()
|
||||
|
||||
# completions stats
|
||||
completion_lengths = [len(c) for c in gen_dataset["completion_ids"]]
|
||||
gathered_completion_lengths = self.accelerator.gather_for_metrics(
|
||||
torch.Tensor(completion_lengths).to(device)
|
||||
)
|
||||
metrics["mean_completion_lengths"] = gathered_completion_lengths.mean().item()
|
||||
metrics["max_completion_lengths"] = gathered_completion_lengths.max().item()
|
||||
metrics["min_completion_lengths"] = gathered_completion_lengths.min().item()
|
||||
|
||||
# reward stats
|
||||
rewards = gen_dataset["rewards"]
|
||||
gathered_rewards = self.accelerator.gather_for_metrics(torch.Tensor(rewards).to(device))
|
||||
reward_per_func = gathered_rewards.mean(0)
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
reward_func_name = reward_func.__name__
|
||||
metrics[f"rewards/{reward_func_name}"] = reward_per_func[i].item()
|
||||
|
||||
metrics["reward"] = reward_per_func.sum().item()
|
||||
|
||||
self.log(metrics)
|
||||
if self.args.log_completions and "wandb" in self.args.report_to:
|
||||
import pandas as pd
|
||||
|
||||
prompts = gather_object(gen_dataset["prompt"])
|
||||
completions = gather_object(gen_dataset["completion"])
|
||||
# For logging
|
||||
table = {
|
||||
"step": [str(self.state.global_step)] * len(prompts),
|
||||
"prompts": prompts,
|
||||
"completion": completions,
|
||||
"reward": gathered_rewards.sum(1).tolist(),
|
||||
}
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
if wandb.run is not None and self.accelerator.is_main_process:
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
# sync weights to remote server
|
||||
self._sync_weights()
|
||||
|
||||
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(self.model, trial=None)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(self.model, trial=None, metrics=None)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
def _optimization_step(self, mini_batch) -> tuple[float, float]:
|
||||
prompt_ids, prompt_mask = mini_batch["prompt_ids"], mini_batch["prompt_mask"]
|
||||
completion_ids, completion_mask = mini_batch["completion_ids"], mini_batch["completion_mask"]
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
ref_per_token_logps = mini_batch["ref_per_token_logps"]
|
||||
|
||||
with self.accelerator.accumulate(self.model):
|
||||
per_token_logps = self._get_per_token_logps(self.model, input_ids, attention_mask, logits_to_keep)
|
||||
per_token_kl = (
|
||||
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
||||
)
|
||||
|
||||
advantages = mini_batch["advantages"]
|
||||
# TODO: convert to clipped loss so we can multiple GRPO epochs
|
||||
per_token_loss = -torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
||||
per_token_loss = per_token_loss + self.args.beta * per_token_kl
|
||||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
del per_token_logps, per_token_kl, per_token_loss, loss
|
||||
|
||||
# force garbage collection and empty cache
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return loss.detach().item(), per_token_kl.mean().item()
|
||||
171
src/open_r1/trainers/job_launcher.py
Normal file
171
src/open_r1/trainers/job_launcher.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
import atexit
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
|
||||
# We need a special environment setup to launch vLLM from within Slurm training jobs.
|
||||
# - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105
|
||||
# - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269
|
||||
user_home_directory = os.path.expanduser("~")
|
||||
SLURM_PREFIX = [
|
||||
"env",
|
||||
"-i",
|
||||
"bash",
|
||||
"-c",
|
||||
f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch --qos=high --output=/fsx/h4/logs/%x-%j.out --err=/fsx/h4/logs/%x-%j.err ",
|
||||
]
|
||||
|
||||
|
||||
class SGLangSlurmJobLauncher:
|
||||
def __init__(
|
||||
self,
|
||||
model_id_or_path,
|
||||
model_revision="main",
|
||||
num_gpus=1,
|
||||
sglang_port=30010,
|
||||
slurm_script="slurm/launch_sglang.slurm",
|
||||
check_interval=5,
|
||||
):
|
||||
"""
|
||||
Initialize the job launcher.
|
||||
|
||||
:param slurm_script: Path to the SLURM script.
|
||||
:param check_interval: Time interval (seconds) to check job status.
|
||||
"""
|
||||
self.slurm_script = slurm_script
|
||||
self.job_id = None
|
||||
self.node_name = None
|
||||
self.check_interval = check_interval
|
||||
self.model_id_or_path = model_id_or_path
|
||||
self.model_revision = model_revision
|
||||
self.num_gpus = num_gpus
|
||||
self.sglang_port = sglang_port
|
||||
|
||||
# Register cleanup function to cancel job on exit
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def submit_job(self):
|
||||
"""Submits the SLURM job and extracts the job ID."""
|
||||
|
||||
cmd = SLURM_PREFIX.copy()
|
||||
cmd_args = [
|
||||
f"--gres=gpu:{self.num_gpus}",
|
||||
self.slurm_script,
|
||||
self.model_id_or_path,
|
||||
self.model_revision,
|
||||
str(self.sglang_port),
|
||||
]
|
||||
cmd[-1] += " " + " ".join(cmd_args)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
match = re.search(r"Submitted batch job (\d+)", result.stdout)
|
||||
if match:
|
||||
self.job_id = match.group(1)
|
||||
print(f"Job submitted with ID: {self.job_id}")
|
||||
else:
|
||||
raise RuntimeError("Failed to retrieve job ID from sbatch output.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error submitting job: {e.stderr}")
|
||||
raise
|
||||
|
||||
|
||||
def get_job_status(self):
|
||||
"""Checks the job status using squeue."""
|
||||
if not self.job_id:
|
||||
raise ValueError("Job ID is not set. Submit the job first.")
|
||||
|
||||
result = subprocess.run(
|
||||
["squeue", "--job", self.job_id, "--noheader"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
|
||||
if not result.stdout.strip():
|
||||
return None # Job is no longer in queue
|
||||
status = result.stdout.split()[4] # Typically, state is the 5th column
|
||||
return status
|
||||
|
||||
def wait_for_job_to_start(self):
|
||||
"""Waits for the job to start running and fetches its node."""
|
||||
print("Waiting for job to start...")
|
||||
while True:
|
||||
status = self.get_job_status()
|
||||
if status is None:
|
||||
raise RuntimeError("Job disappeared from queue, it may have failed.")
|
||||
if status == "R": # Running
|
||||
print("Job is running. Fetching node information...")
|
||||
self.node_name = self.get_node_name()
|
||||
return
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def get_node_name(self):
|
||||
"""Gets the node where the job is running."""
|
||||
result = subprocess.run(
|
||||
["squeue", "--job", self.job_id, "--noheader", "--format=%N"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
else:
|
||||
raise RuntimeError("Failed to retrieve node name.")
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Retrieves the IP address of the node running the job."""
|
||||
if not self.node_name:
|
||||
raise ValueError("Node name is not set. Wait for the job to start first.")
|
||||
|
||||
result = subprocess.run(
|
||||
["scontrol", "show", "node", self.node_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
|
||||
match = re.search(r"NodeAddr=(\S+)", result.stdout)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
raise RuntimeError("Failed to retrieve node IP address.")
|
||||
|
||||
def launch(self):
|
||||
"""Launches the job, waits for it to start, and retrieves the node IP."""
|
||||
self.submit_job()
|
||||
self.wait_for_job_to_start()
|
||||
ip_address = self.get_node_ip()
|
||||
print(f"Job is running on {self.node_name} with IP: {ip_address}")
|
||||
self.ip_address = ip_address
|
||||
return ip_address
|
||||
|
||||
def cleanup(self):
|
||||
"""Cancels the SLURM job if it is still running."""
|
||||
if self.job_id is not None:
|
||||
print(f"Cleaning up: Cancelling job {self.job_id}...")
|
||||
subprocess.run(["scancel", self.job_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
print("Job cancelled.")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure job cleanup when the instance is destroyed."""
|
||||
self.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from open_r1.trainers.remote_model import RemoteModel
|
||||
|
||||
launcher = SGLangSlurmJobLauncher("HuggingFaceTB/SmolLM2-135M-Instruct")
|
||||
ip_address = launcher.launch()
|
||||
launcher.ip_address
|
||||
time.sleep(15)
|
||||
remote_model = RemoteModel(f"{ip_address}", 30010)
|
||||
remote_model.wait_for_server()
|
||||
|
||||
result = remote_model.generate([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
|
||||
|
||||
assert 0
|
||||
|
||||
print(result)
|
||||
697
src/open_r1/trainers/remote_grpo_trainer.py
Normal file
697
src/open_r1/trainers/remote_grpo_trainer.py
Normal file
|
|
@ -0,0 +1,697 @@
|
|||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Iterator, Optional, Union
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset, IterableDataset, disable_progress_bars, enable_progress_bars
|
||||
from datasets.utils.logging import set_verbosity_error, set_verbosity_info
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader, Sampler
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import is_liger_kernel_available
|
||||
import deepspeed
|
||||
import trl
|
||||
from accelerate.utils import broadcast_object_list, gather_object, is_peft_model
|
||||
from open_r1.trainers.job_launcher import SGLangSlurmJobLauncher
|
||||
from open_r1.trainers.remote_model import RemoteModel
|
||||
from trl.data_utils import is_conversational, maybe_apply_chat_template
|
||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
||||
from trl.import_utils import is_rich_available
|
||||
from trl.models import create_reference_model, prepare_deepspeed
|
||||
from trl.trainer.callbacks import SyncRefModelCallback
|
||||
from trl.trainer.utils import exact_div, pad, print_prompt_completions_sample, selective_log_softmax
|
||||
|
||||
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.transformers import AutoLigerKernelForCausalLM
|
||||
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
class RepeatBatchRandomSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
data_source,
|
||||
batch_size: int = 1,
|
||||
num_processes: int = 1,
|
||||
repeat_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_processes = num_processes
|
||||
self.repeat_count = repeat_count
|
||||
self.num_samples = len(data_source)
|
||||
self.seed = seed
|
||||
self.generator = torch.Generator() # Create a local random generator
|
||||
if seed is not None:
|
||||
self.generator.manual_seed(seed)
|
||||
|
||||
def __iter__(self):
|
||||
indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||
all_process_batch_size = self.batch_size * self.num_processes
|
||||
indices = [indices[i : i + all_process_batch_size] for i in range(0, len(indices), all_process_batch_size)]
|
||||
|
||||
indices = [chunk for chunk in indices if len(chunk) == all_process_batch_size]
|
||||
|
||||
for chunk in indices:
|
||||
for _ in range(self.repeat_count):
|
||||
for index in chunk:
|
||||
yield index
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples * self.repeat_count
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteGRPOConfig(trl.GRPOConfig):
|
||||
"""
|
||||
args for callbacks, benchmarks etc
|
||||
"""
|
||||
|
||||
benchmarks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
||||
)
|
||||
callbacks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
||||
)
|
||||
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default="/fsx/h4/tmp/", metadata={"help": "The directory to save temporary checkpoints to."}
|
||||
)
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main", metadata={"help": "The Hub model branch to push the model to."}
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
)
|
||||
wandb_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The project to store runs under.")},
|
||||
)
|
||||
remote_gen_model_url: Optional[str] = field(
|
||||
default=None,
|
||||
)
|
||||
remote_gen_model_port: str = field(
|
||||
default="30010",
|
||||
)
|
||||
remote_gen_model_n_gpus: str = field(
|
||||
default=8,
|
||||
)
|
||||
use_liger: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use Liger kernel for training."},
|
||||
)
|
||||
|
||||
|
||||
class RemoteGRPOTrainer(Trainer):
|
||||
_tag_names = ["trl", "grpo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, PreTrainedModel],
|
||||
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
||||
args: Optional[RemoteGRPOConfig] = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
||||
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
||||
):
|
||||
self.args = args
|
||||
# Initialize the metrics
|
||||
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
||||
self.log_completions = args.log_completions
|
||||
|
||||
# Models
|
||||
# Trained model
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
if isinstance(model, str):
|
||||
model_id = model
|
||||
model = self._create_model_from_path(model_id, args)
|
||||
disable_dropout_in_model(model)
|
||||
else:
|
||||
model_id = model.config._name_or_path
|
||||
if args.model_init_kwargs is not None:
|
||||
raise ValueError(
|
||||
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
||||
"This argument can only be used when the `model` argument is a string."
|
||||
)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if args.gradient_checkpointing:
|
||||
model = self._enable_gradient_checkpointing(model, args)
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
||||
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
||||
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
||||
# This acts as a flag to indicate that the warning has already been issued.
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
# Reference model
|
||||
if self.args.beta == 0.0:
|
||||
# If beta is 0.0, the reference model is not needed
|
||||
self.ref_model = None
|
||||
elif is_deepspeed_zero3_enabled():
|
||||
self.ref_model = self._create_model_from_path(model_id, args)
|
||||
disable_dropout_in_model(self.ref_model)
|
||||
elif is_peft_model(model):
|
||||
raise NotImplementedError("Peft is not supported")
|
||||
else:
|
||||
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
||||
self.ref_model = create_reference_model(model)
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
reward_funcs = [reward_funcs]
|
||||
self.reward_funcs = reward_funcs
|
||||
|
||||
# Reward weights
|
||||
if args.reward_weights is not None:
|
||||
if len(args.reward_weights) != len(reward_funcs):
|
||||
raise ValueError(
|
||||
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
||||
f"functions ({len(reward_funcs)})"
|
||||
)
|
||||
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
|
||||
else:
|
||||
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
|
||||
|
||||
# Reward processing class
|
||||
if reward_processing_classes is None:
|
||||
reward_processing_classes = [None] * len(reward_funcs)
|
||||
elif not isinstance(reward_processing_classes, list):
|
||||
reward_processing_classes = [reward_processing_classes]
|
||||
else:
|
||||
if len(reward_processing_classes) != len(reward_funcs):
|
||||
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
||||
|
||||
# TODO: test RMS and also wrap them in deepspeed
|
||||
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if reward_processing_class is None:
|
||||
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
||||
if reward_processing_class.pad_token_id is None:
|
||||
reward_processing_class.pad_token = reward_processing_class.eos_token
|
||||
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
||||
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
||||
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
||||
reward_processing_classes[i] = reward_processing_class
|
||||
self.reward_processing_classes = reward_processing_classes
|
||||
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
self.batch_buffer = []
|
||||
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
ip_address = self.args.remote_gen_model_url
|
||||
|
||||
if self.args.remote_gen_model_url is None and self.accelerator.is_main_process:
|
||||
# we launch a job from here, get the ip on main process and broadcast to others
|
||||
# it would be better to move this to the start so the server warms up which the local model is being loaded
|
||||
model_revision = args.model_init_kwargs.get("revision", "main")
|
||||
self.sglang_job_launcher = SGLangSlurmJobLauncher(
|
||||
model_id,
|
||||
model_revision,
|
||||
num_gpus=self.args.remote_gen_model_n_gpus,
|
||||
sglang_port=self.args.remote_gen_model_port,
|
||||
)
|
||||
ip_address = self.sglang_job_launcher.launch()
|
||||
|
||||
# get the ip from main process and broadcast to others
|
||||
gather_ip_address = broadcast_object_list([ip_address], 0)
|
||||
self.args.remote_gen_model_url = gather_ip_address[0]
|
||||
|
||||
self.remote_model = RemoteModel(
|
||||
self.args.remote_gen_model_url, self.args.remote_gen_model_port, self.processing_class.eos_token_id
|
||||
)
|
||||
self.remote_model.wait_for_server()
|
||||
|
||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
# Add tags to the model
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
if self.ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
if args.sync_ref_model:
|
||||
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
|
||||
|
||||
def _get_train_sampler(self) -> Sampler:
|
||||
"""
|
||||
Return the train sampler.
|
||||
|
||||
Returns:
|
||||
Sampler: The train sampler.
|
||||
"""
|
||||
if self.args.dataloader_num_workers != 0:
|
||||
raise ValueError("dataloader_num_workers should not be greater than 0 for remote training")
|
||||
return RepeatBatchRandomSampler(
|
||||
data_source=self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
repeat_count=self.args.num_generations * self.args.num_iterations,
|
||||
num_processes=self.accelerator.num_processes,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _create_model_from_path(self, model_path: str, args: RemoteGRPOConfig) -> PreTrainedModel:
|
||||
"""Creates a model from a path or model identifier."""
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
# Handle torch dtype
|
||||
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
||||
pass # torch_dtype is already a torch.dtype or "auto" or None
|
||||
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
|
||||
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
||||
)
|
||||
# Disable caching if gradient checkpointing is enabled (not supported)
|
||||
if args.gradient_checkpointing:
|
||||
model_init_kwargs["use_cache"] = False
|
||||
|
||||
# Create model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
return model
|
||||
|
||||
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: RemoteGRPOConfig) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# Ensure use_cache is disabled
|
||||
model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(model):
|
||||
model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
else:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
|
||||
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
return model
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
prompts_to_log = [x["prompt"] for x in inputs]
|
||||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
||||
prompt_inputs = self.processing_class(prompts_text)
|
||||
|
||||
prompt_ids = prompt_inputs["input_ids"]
|
||||
# sync weights here?
|
||||
self._sync_weights()
|
||||
with profiling_context(self, "remote_generate"):
|
||||
all_outputs = self.remote_model.generate(
|
||||
prompt_ids,
|
||||
max_new_tokens=self.args.max_completion_length,
|
||||
temperature=self.args.temperature,
|
||||
num_generations=self.args.num_generations,
|
||||
)
|
||||
completion_ids = [example["completion_ids"] for example in all_outputs]
|
||||
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
|
||||
repeated_prompts = []
|
||||
for prompt in prompts_to_log:
|
||||
repeated_prompts.extend([prompt] * self.args.num_generations)
|
||||
|
||||
repeated_prompt_texts = []
|
||||
for prompt in prompts_text:
|
||||
repeated_prompt_texts.extend([prompt] * self.args.num_generations)
|
||||
|
||||
if is_conversational(inputs[0]):
|
||||
completions_to_log = []
|
||||
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
|
||||
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
completions_to_log.append([{"role": "assistant", "content": bootstrap + completion}])
|
||||
else:
|
||||
completions_to_log = completions_text
|
||||
|
||||
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
|
||||
with profiling_context(self, "rewards"):
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
||||
reward_kwargs = defaultdict(list)
|
||||
for example in inputs:
|
||||
for key in keys:
|
||||
reward_kwargs[key].extend([example[key]] * self.args.num_generations)
|
||||
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions_to_log, **reward_kwargs)
|
||||
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
|
||||
|
||||
# if i == 0 and self.accelerator.is_main_process: # dump generations to a text file for debugging
|
||||
# with open("python_code_completions2.jsonl", "a") as f:
|
||||
# for i,(p, c) in enumerate(zip(repeated_prompts, completions_to_log)):
|
||||
# data = {
|
||||
# "prompt": p,
|
||||
# "completion": c,
|
||||
# }
|
||||
# for k in reward_kwargs.keys():
|
||||
# data[k] = reward_kwargs[k][i]
|
||||
|
||||
# f.write(json.dumps(data) + "\n")
|
||||
|
||||
# calculate the advantages, the prompt is all on the same device to no need to gather here
|
||||
grouped_rewards = rewards.sum(-1).view(len(prompts_to_log), self.args.num_generations)
|
||||
EPS = 1e-4
|
||||
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
|
||||
grouped_rewards.std(-1, keepdim=True) + EPS
|
||||
)
|
||||
advantages = grouped_advantages.flatten().tolist()
|
||||
|
||||
examples = []
|
||||
for i, prompt in enumerate(repeated_prompt_texts):
|
||||
example = {
|
||||
"prompt": prompt,
|
||||
"prompt_ids": prompt_ids[i // self.args.num_generations],
|
||||
"completion": completions_text[i],
|
||||
"completion_ids": completion_ids[i],
|
||||
"advantages": advantages[i],
|
||||
"rewards": rewards[i],
|
||||
}
|
||||
examples.append(example)
|
||||
|
||||
# Instead of logging metrics here, collect them
|
||||
mode = "eval" if getattr(self, "control", None) and self.control.should_evaluate else "train"
|
||||
device = self.accelerator.device
|
||||
|
||||
# Collect completion length metrics
|
||||
completion_lengths = [len(example["completion_ids"]) for example in examples]
|
||||
gathered_completion_lengths = self.accelerator.gather_for_metrics(torch.Tensor(completion_lengths).to(device))
|
||||
self._metrics[mode]["mean_completion_lengths"].append(gathered_completion_lengths.mean().item())
|
||||
self._metrics[mode]["max_completion_lengths"].append(gathered_completion_lengths.max().item())
|
||||
self._metrics[mode]["min_completion_lengths"].append(gathered_completion_lengths.min().item())
|
||||
|
||||
# Collect reward metrics
|
||||
rewards = torch.stack(
|
||||
[
|
||||
example["rewards"].to(device)
|
||||
if isinstance(example["rewards"], torch.Tensor)
|
||||
else torch.tensor(example["rewards"], device=device)
|
||||
for example in examples
|
||||
]
|
||||
)
|
||||
gathered_rewards = self.accelerator.gather_for_metrics(rewards)
|
||||
reward_per_func = gathered_rewards.mean(0)
|
||||
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
reward_func_name = reward_func.__name__
|
||||
self._metrics[mode][f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
||||
|
||||
self._metrics[mode]["reward"].append(reward_per_func.sum().item())
|
||||
|
||||
if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
|
||||
prompts_to_log = gather_object([example["prompt"] for example in examples])
|
||||
completions_to_log = gather_object([example["completion"] for example in examples])
|
||||
if self.accelerator.is_main_process:
|
||||
# if is_rich_available():
|
||||
# # TODO: enable num_samples in TRL to avoid clogging logs
|
||||
# print_prompt_completions_sample(
|
||||
# prompts_to_log[:5],
|
||||
# completions_to_log[:5],
|
||||
# gathered_rewards.sum(1).tolist()[:5],
|
||||
# self.state.global_step,
|
||||
# )
|
||||
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
|
||||
import pandas as pd
|
||||
|
||||
# For logging
|
||||
table = {
|
||||
"step": [str(self.state.global_step)] * len(prompts_to_log),
|
||||
"prompts": prompts_to_log,
|
||||
"completion": completions_to_log,
|
||||
"reward": gathered_rewards.sum(1).tolist(),
|
||||
}
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
if wandb.run is not None and self.accelerator.is_main_process:
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
return examples
|
||||
|
||||
@profiling_decorator
|
||||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
if len(self.batch_buffer) > 0:
|
||||
return self.batch_buffer.pop(0)
|
||||
inputs = self._generate_and_score_completions(inputs)
|
||||
gen_dataset = Dataset.from_list(inputs)
|
||||
exact_div(
|
||||
len(gen_dataset), self.args.per_device_train_batch_size, "len(gen_dataset) is not divisible by batch size"
|
||||
)
|
||||
|
||||
def get_logprobs(example, model, output_name):
|
||||
# dict of lists to list of dicts
|
||||
examples = [dict(zip(example.keys(), values)) for values in zip(*example.values())]
|
||||
input_ids, attention_mask, completion_mask, completion_ids = self._get_padded_inputs_and_attn_mask(
|
||||
examples
|
||||
)
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
|
||||
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep).detach()
|
||||
|
||||
lengths = [len(example["completion_ids"]) for example in examples]
|
||||
# Strip the completion padding
|
||||
per_token_logps = per_token_logps.to("cpu").tolist()
|
||||
per_token_logps = [logps[:length] for logps, length in zip(per_token_logps, lengths)]
|
||||
example[output_name] = per_token_logps
|
||||
return example
|
||||
|
||||
with torch.no_grad():
|
||||
set_verbosity_error()
|
||||
disable_progress_bars()
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
gen_dataset = gen_dataset.map(
|
||||
get_logprobs,
|
||||
batched=True,
|
||||
batch_size=self.args.per_device_train_batch_size*2,
|
||||
fn_kwargs={"model": self.ref_model, "output_name": "ref_per_token_logps"},
|
||||
)
|
||||
self.model.eval()
|
||||
gen_dataset = gen_dataset.map(
|
||||
get_logprobs,
|
||||
batched=True,
|
||||
batch_size=self.args.per_device_train_batch_size*2 ,
|
||||
fn_kwargs={"model": self.model, "output_name": "old_per_token_logps"},
|
||||
)
|
||||
self.model.train()
|
||||
enable_progress_bars()
|
||||
set_verbosity_info()
|
||||
|
||||
def mini_batch_collator(mini_batch):
|
||||
return mini_batch
|
||||
|
||||
mini_batch_dataloader = DataLoader(
|
||||
gen_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
shuffle=True, # we technically don't need to shuffle due to grad acc, but we may move to clipped loss later
|
||||
drop_last=True,
|
||||
collate_fn=mini_batch_collator,
|
||||
)
|
||||
for num_iters in range(self.args.num_iterations):
|
||||
for mini_batch in mini_batch_dataloader:
|
||||
self.batch_buffer.append(mini_batch)
|
||||
|
||||
return self.batch_buffer.pop(0)
|
||||
|
||||
@profiling_decorator
|
||||
def _sync_weights(self):
|
||||
if self.remote_model.is_mock:
|
||||
return
|
||||
self.accelerator.wait_for_everyone()
|
||||
# if self.accelerator.is_main_process:
|
||||
start = time.time()
|
||||
# would be better if this was a ram disk + separate thread for writing
|
||||
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
state_dict = {}
|
||||
for name, param in unwrapped_model.named_parameters():
|
||||
if name in state_dict.keys():
|
||||
# sometimes the embed table is duplicated so no need to regather it
|
||||
continue
|
||||
with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
|
||||
state_dict[name] = param.cpu().detach().clone()
|
||||
|
||||
# if is_fsdp_managed_module(self.model):
|
||||
# state_dict = self.model.state_dict()
|
||||
# trainer.save_model(output_dir)
|
||||
else:
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
# if self.accelerator.is_main_process:
|
||||
# with tempfile.TemporaryDirectory(dir=self.args.checkpoint_dir) as temp_dir_path:
|
||||
# self.save_model(temp_dir_path)
|
||||
|
||||
# state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
with tempfile.TemporaryDirectory(dir=self.args.checkpoint_dir) as temp_dir_path:
|
||||
self._save(temp_dir_path, state_dict=state_dict)
|
||||
self.remote_model.load_weights_from_path(temp_dir_path)
|
||||
|
||||
print(f"Weight sync took: {time.time() - start:.2f}s")
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
@profiling_decorator
|
||||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
||||
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
|
||||
input_ids = input_ids[:, -logits_to_keep:]
|
||||
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
||||
# See https://github.com/huggingface/trl/issues/2770
|
||||
logits = logits[:, -logits_to_keep:]
|
||||
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
||||
|
||||
def _get_padded_inputs_and_attn_mask(self, inputs):
|
||||
device = self.accelerator.device
|
||||
prompt_ids = [torch.LongTensor(example["prompt_ids"]) for example in inputs]
|
||||
completion_ids = [torch.LongTensor(example["completion_ids"]) for example in inputs]
|
||||
|
||||
pad_token_id = self.processing_class.pad_token_id
|
||||
|
||||
prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
|
||||
completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
|
||||
# padd_ref_per_token_logps = pad(ref_per_token_logps, padding_value=0.0, padding_side="right")
|
||||
|
||||
if self.args.max_prompt_length is not None:
|
||||
prompt_ids = prompt_ids[:, -self.args.max_prompt_length :]
|
||||
|
||||
# compute the masks
|
||||
prompt_mask = (prompt_ids != pad_token_id).long()
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
is_eos = completion_ids == self.processing_class.eos_token_id
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
sequence_indices = torch.arange(is_eos.size(1)).expand(is_eos.size(0), -1)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1).to(device)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1).to(device)
|
||||
completion_mask = completion_mask.to(device)
|
||||
|
||||
return input_ids, attention_mask, completion_mask, completion_ids
|
||||
|
||||
@profiling_decorator
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
|
||||
device = self.accelerator.device
|
||||
advantages = torch.Tensor([example["advantages"] for example in inputs]).to(device)
|
||||
input_ids, attention_mask, completion_mask, completion_ids = self._get_padded_inputs_and_attn_mask(inputs)
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
old_per_token_logps = [torch.Tensor(example["old_per_token_logps"]) for example in inputs]
|
||||
|
||||
pad_token_id = self.processing_class.pad_token_id
|
||||
|
||||
# padd the ref and old logps
|
||||
pad_old_per_token_logps = pad(old_per_token_logps, padding_value=pad_token_id, padding_side="right").to(device)
|
||||
|
||||
# model.eval()
|
||||
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||
if self.ref_model is not None:
|
||||
ref_per_token_logps = [torch.Tensor(example["ref_per_token_logps"]) for example in inputs]
|
||||
pad_ref_per_token_logps = pad(ref_per_token_logps, padding_value=pad_token_id, padding_side="right").to(device)
|
||||
clamped_diff= torch.clamp(pad_ref_per_token_logps - per_token_logps,-10.0,10.0) # for numerical stability
|
||||
per_token_kl = (
|
||||
torch.exp(clamped_diff) - clamped_diff - 1
|
||||
)
|
||||
|
||||
# del inputs, input_ids, attention_mask # free up memory
|
||||
# clipped loss
|
||||
coef_1 = torch.exp(per_token_logps - pad_old_per_token_logps)
|
||||
coef_2 = torch.clamp(coef_1, 1 - self.args.epsilon, 1 + self.args.epsilon)
|
||||
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
||||
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
||||
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
||||
|
||||
if self.ref_model is not None:
|
||||
per_token_loss = per_token_loss + self.args.beta * per_token_kl
|
||||
|
||||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
||||
|
||||
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
||||
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
||||
if mode == "eval":
|
||||
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
||||
|
||||
logs = {**logs, **metrics}
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
super().log(logs, start_time)
|
||||
else: # transformers<=4.46
|
||||
super().log(logs)
|
||||
self._metrics[mode].clear()
|
||||
162
src/open_r1/trainers/remote_model.py
Normal file
162
src/open_r1/trainers/remote_model.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# great reference: https://github.com/vllm-project/vllm/issues/11400
|
||||
|
||||
|
||||
import time
|
||||
import random
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class RemoteModel:
|
||||
"""
|
||||
launch with:
|
||||
export LD_LIBRARY_PATH=$(python -c "import site; print(site.getsitepackages()[0] + '/nvidia/nvjitlink/lib')"):$LD_LIBRARY_PATH
|
||||
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port=30010 --skip-tokenizer-init --mem-fraction-static 0.4
|
||||
python3 -m sglang.launch_server --model-path HuggingFaceTB/SmolLM2-135M-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.4 --host=0.0.0.0
|
||||
|
||||
# on a separate node
|
||||
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
|
||||
|
||||
python3 -m sglang.launch_server --model-path HuggingFaceTB/SmolLM2-1.7B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.6 --host=0.0.0.0 --dp-size=8
|
||||
|
||||
|
||||
python3 -m sglang.launch_server --model-path open-r1/Qwen2.5-Coder-7B-Instruct-SFT --revision v00.08-step-000001280 --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
|
||||
|
||||
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-1.5B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
|
||||
"""
|
||||
|
||||
def __init__(self, remote_model_url, remote_model_port, stop_token_id=None):
|
||||
self.remote_model_url = remote_model_url
|
||||
self.remote_model_port = remote_model_port
|
||||
self.stop_token_id = stop_token_id
|
||||
|
||||
if self.remote_model_url == "mock":
|
||||
print("Using mock remote model")
|
||||
|
||||
@property
|
||||
def is_mock(self):
|
||||
return self.remote_model_url == "mock"
|
||||
|
||||
def is_healthy(self, timeout=5):
|
||||
if self.remote_model_url == "mock":
|
||||
return True
|
||||
"""Checks if the remote model server is up and running."""
|
||||
try:
|
||||
url = f"http://{self.remote_model_url}:{self.remote_model_port}/health"
|
||||
response = requests.get(url, timeout=timeout)
|
||||
return response.status_code == 200
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
def wait_for_server(self, max_retries=120, delay=5):
|
||||
"""Waits for the server to become available before proceeding."""
|
||||
for attempt in range(max_retries):
|
||||
if self.is_healthy():
|
||||
print("Remote model server is healthy!")
|
||||
return True
|
||||
print(f"Waiting for server to start... (Attempt {attempt + 1}/{max_retries})")
|
||||
time.sleep(delay)
|
||||
raise RuntimeError("Remote model server did not start in time.")
|
||||
|
||||
def generate(
|
||||
self, input_ids: list[list[int]], max_new_tokens=256, temperature=0.8, num_generations=2
|
||||
) -> tuple[list[list[int]], list[list[int]]]:
|
||||
# Prepare the request body
|
||||
if self.remote_model_url == "mock":
|
||||
examples = []
|
||||
for prompt_ids in input_ids:
|
||||
for j in range(num_generations):
|
||||
example = {
|
||||
"prompt_ids": prompt_ids,
|
||||
"completion_ids": random.choices(range(10 ,1000), k=max_new_tokens),
|
||||
# "prompt_log_probs": None, # TODO, not used for now
|
||||
# "completion_log_probs": None,
|
||||
}
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
||||
request_body = {
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop_token_ids": [self.stop_token_id],
|
||||
"n": num_generations,
|
||||
},
|
||||
"stream": False,
|
||||
# "return_logprob": True, # disabled as we occasiosally see https://github.com/sgl-project/sglang/issues/4097
|
||||
# "logprob_start_len": 0,
|
||||
}
|
||||
|
||||
# Send the POST request to the server
|
||||
# add a few retries?
|
||||
response = requests.post(
|
||||
f"http://{self.remote_model_url}:{self.remote_model_port}/generate", json=request_body
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
examples = []
|
||||
|
||||
for i, result in enumerate(response_json):
|
||||
prompt_index = i // num_generations
|
||||
prompt_ids = input_ids[prompt_index]
|
||||
completion_ids = result["output_ids"]
|
||||
# prompt_log_probs = [prob[0] for prob in result["meta_info"]["input_token_logprobs"]]
|
||||
# completion_log_probs = [prob[0] for prob in result["meta_info"]["output_token_logprobs"]]
|
||||
|
||||
example = {
|
||||
"prompt_ids": prompt_ids,
|
||||
"completion_ids": completion_ids,
|
||||
# "prompt_log_probs": prompt_log_probs,
|
||||
# "completion_log_probs": completion_log_probs,
|
||||
}
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
def load_weights_from_path(self, path: str):
|
||||
if self.remote_model_url == "mock":
|
||||
return
|
||||
url = f"http://{self.remote_model_url}:{self.remote_model_port}/update_weights_from_disk"
|
||||
data = {"model_path": path}
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
print(response.text)
|
||||
assert response.json()["success"] is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from datasets import load_dataset
|
||||
|
||||
url = "0.0.0.0"
|
||||
port = 30010
|
||||
MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
||||
|
||||
remote_model = RemoteModel(url, port, tokenizer.eos_token_id)
|
||||
dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train")
|
||||
dataloader = DataLoader(dataset, batch_size=4)
|
||||
|
||||
for i, batch in zip(range(2), dataloader):
|
||||
problems = batch["problem"]
|
||||
ids = tokenizer(problems)
|
||||
new_ids, logprobs = remote_model.generate(ids["input_ids"])
|
||||
print(new_ids)
|
||||
print(logprobs)
|
||||
36
src/open_r1/trainers/sampler.py
Normal file
36
src/open_r1/trainers/sampler.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Iterator
|
||||
|
||||
from torch.utils.data import RandomSampler
|
||||
|
||||
|
||||
class RepeatBatchRandomSampler(RandomSampler):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
num_generations: int = 1,
|
||||
batch_size: int = 3,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_generations = num_generations
|
||||
self.batch_size = batch_size
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return super().__len__() * self.num_generations
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
batch_indices = []
|
||||
for idx in super().__iter__():
|
||||
batch_indices.append(idx)
|
||||
if len(batch_indices) == self.batch_size:
|
||||
batch_indices = batch_indices * self.num_generations
|
||||
yield from batch_indices
|
||||
batch_indices = []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sampler = RepeatBatchRandomSampler(num_generations=2, data_source=range(12), replacement=False)
|
||||
# print(list(sampler))
|
||||
|
||||
for sample in sampler:
|
||||
print(sample)
|
||||
148
src/open_r1/trainers/special_dataloader.py
Normal file
148
src/open_r1/trainers/special_dataloader.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from open_r1.configs import GRPOConfig
|
||||
from open_r1.trainers.remote_model import RemoteModel
|
||||
from trl.data_utils import is_conversational, maybe_apply_chat_template
|
||||
|
||||
|
||||
class RemoteGRPODataloader(DataLoader):
|
||||
def __init__(
|
||||
self, *args, config: GRPOConfig, remote_model=None, processing_class=None, reward_funcs=None, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
self.remote_model = remote_model
|
||||
self.processing_class = processing_class
|
||||
self.reward_funcs = reward_funcs
|
||||
self.reward_weights = [1.0] * len(reward_funcs) # TODO: make this configurable
|
||||
|
||||
def __len__(self):
|
||||
return super().__len__() * self.config.num_generations
|
||||
|
||||
def __iter__(self):
|
||||
for batch in super().__iter__():
|
||||
batch = self._prepare_batch(batch)
|
||||
gen_dataset = Dataset.from_list(batch)
|
||||
mini_batch_dataloader = DataLoader(
|
||||
gen_dataset,
|
||||
batch_size=self.config.per_device_train_batch_size,
|
||||
shuffle=True, # we technically don#t need to shuffle due to grad acc, but we may move to clipped loss later
|
||||
drop_last=True,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
for mini_batch in mini_batch_dataloader:
|
||||
yield mini_batch
|
||||
|
||||
def _prepare_batch(self, batch):
|
||||
prompts = [x["prompt"] for x in batch]
|
||||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in batch]
|
||||
prompt_inputs = self.processing_class(prompts_text)
|
||||
|
||||
prompt_ids = prompt_inputs["input_ids"]
|
||||
# add cuda clear cache here and a sleep
|
||||
|
||||
all_outputs = self.remote_model.generate(
|
||||
prompt_ids,
|
||||
max_new_tokens=self.config.max_completion_length,
|
||||
temperature=self.config.temperature,
|
||||
num_generations=self.config.num_generations,
|
||||
)
|
||||
|
||||
completion_ids = [example["completion_ids"] for example in all_outputs]
|
||||
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
|
||||
repeated_prompts = []
|
||||
for prompt in prompts:
|
||||
repeated_prompts.extend([prompt] * self.config.num_generations)
|
||||
|
||||
repeated_prompt_texts = []
|
||||
for prompt in prompts_text:
|
||||
repeated_prompt_texts.extend([prompt] * self.config.num_generations)
|
||||
|
||||
if is_conversational(batch[0]):
|
||||
completions = []
|
||||
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
|
||||
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
||||
else:
|
||||
completions = completions_text
|
||||
|
||||
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||
keys = [key for key in batch[0] if key not in ["prompt", "completion"]]
|
||||
reward_kwargs = defaultdict(list)
|
||||
for example in batch:
|
||||
for key in keys:
|
||||
reward_kwargs[key].extend([example[key]] * self.config.num_generations)
|
||||
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions, **reward_kwargs)
|
||||
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
|
||||
|
||||
grouped_rewards = rewards.sum(-1).view(len(prompts), self.config.num_generations)
|
||||
EPS = 1e-4
|
||||
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
|
||||
grouped_rewards.std(-1, keepdim=True) + EPS
|
||||
)
|
||||
advantages = grouped_advantages.flatten().tolist()
|
||||
|
||||
# build batch as list of dicts
|
||||
examples = []
|
||||
for i, prompt in enumerate(repeated_prompt_texts):
|
||||
example = {
|
||||
"prompt": prompt,
|
||||
"prompt_ids": prompt_ids[i // self.config.num_generations],
|
||||
"completion": completions_text[i],
|
||||
"completion_ids": completion_ids[i],
|
||||
"advantages": advantages[i],
|
||||
"rewards": rewards[i],
|
||||
}
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = load_dataset("open-r1/OpenR1-Math-cn_k12-86k", split="train").select(range(32))
|
||||
|
||||
def make_conversation(example):
|
||||
prompt = []
|
||||
|
||||
prompt.append({"role": "user", "content": example["problem"]})
|
||||
return {"prompt": prompt}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
|
||||
def collate_fn(batch):
|
||||
return batch
|
||||
|
||||
dataset = dataset.remove_columns("messages")
|
||||
|
||||
def reward_func(prompts, completions, **kwargs):
|
||||
return [0.5] * len(prompts)
|
||||
|
||||
reward_funcs = [reward_func, reward_func]
|
||||
|
||||
MODEL = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
||||
processing_class = AutoTokenizer.from_pretrained(MODEL)
|
||||
remote_model = RemoteModel("0.0.0.0", 30010, processing_class.eos_token_id)
|
||||
config = GRPOConfig()
|
||||
data_loader = RemoteGRPODataloader(
|
||||
dataset,
|
||||
remote_model=remote_model,
|
||||
processing_class=processing_class,
|
||||
reward_funcs=reward_funcs,
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn,
|
||||
config=config,
|
||||
)
|
||||
print(len(data_loader))
|
||||
|
||||
for i, batch in enumerate(data_loader):
|
||||
print(i, len(batch))
|
||||
print(batch)
|
||||
|
|
@ -88,7 +88,7 @@ def run_lighteval_job(
|
|||
f"{model_args.trust_remote_code}",
|
||||
]
|
||||
if training_args.system_prompt is not None:
|
||||
cmd_args.append(f"--system_prompt={training_args.system_prompt}")
|
||||
cmd_args.append(f"'{training_args.system_prompt}'")
|
||||
cmd[-1] += " " + " ".join(cmd_args)
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue