Compare commits

..

63 commits

Author SHA1 Message Date
edbeeching
f08b559e00 slurm 2025-03-20 08:05:07 +00:00
edbeeching
a5668ec040 add default model revision 2025-03-20 08:02:11 +00:00
edbeeching
36c3867811 pushing last changes while I windown this PR 2025-03-20 08:01:34 +00:00
edbeeching
41f41d80f8 fixes for gradient checkpointing, profiling context 2025-03-18 08:02:11 +00:00
edbeeching
4f94c356b6 adds option for a mock remote model 2025-03-18 07:58:03 +00:00
edbeeching
1b074940bf add remote grpo exp configs 2025-03-18 07:57:10 +00:00
edbeeching
0e4685f6c1 adds option to not use ref model 2025-03-16 19:45:13 +00:00
lewtun
6e275af99e
Fix buffer (#508) 2025-03-16 13:54:12 +01:00
edbeeching
bf4642b4ea style 2025-03-14 08:31:42 +00:00
edbeeching
c6ab52a0cf cleaning up grpo PR 2025-03-14 08:22:11 +00:00
edbeeching
93d8bc5aba Merge branch 'main' into faster-grpo-trainer 2025-03-13 21:59:09 +00:00
edbeeching
e922330623 temp dumping of samples to json 2025-03-13 10:02:25 +00:00
edbeeching
fd62cf290e grpo configs 2025-03-11 07:48:06 +00:00
edbeeching
1c4efd5bd3 add clipping, fix bug with sampler 2025-03-10 23:00:06 +00:00
edbeeching
0c3e50f332 disable map progress bar 2025-03-10 14:31:28 +00:00
edbeeching
c69164a573 move ref logprob 2025-03-10 13:58:05 +00:00
edbeeching
a4004f658e fix lewis bugs 2025-03-10 10:03:37 +00:00
Lewis Tunstall
5fdbcd5a20 Restore Liger 2025-03-09 12:36:55 +00:00
Lewis Tunstall
4fa226ba1e Make checkpoint dir customisatble 2025-03-09 11:53:29 +00:00
Lewis Tunstall
96546e7721 Fix 2025-03-08 19:22:58 +00:00
Lewis Tunstall
a3d1f26715 Refactor preparation 2025-03-08 19:16:43 +00:00
lewtun
14d75bf0d5
Merge branch 'main' into faster-grpo-trainer 2025-03-08 15:41:31 +01:00
Lewis Tunstall
f4fe3550b6 Fix reward devices 2025-03-08 13:46:51 +00:00
Lewis Tunstall
b47880b1e2 Add local logging 2025-03-08 12:56:58 +00:00
Lewis Tunstall
10a70dfa42 Refactor model load 2025-03-08 12:49:20 +00:00
Lewis Tunstall
7d470d02d4 Align logging 2025-03-07 19:59:56 +00:00
Lewis Tunstall
dcf0af62e2 Fix evals 2025-03-07 18:59:24 +00:00
Lewis Tunstall
3d0e39d5d6 Add Slurm 2025-03-07 16:37:32 +00:00
Lewis Tunstall
76f8ae7a88 Tidy up 2025-03-07 16:05:17 +00:00
Lewis Tunstall
389befcf3e Add recipes 2025-03-07 15:39:25 +00:00
Lewis Tunstall
2f4f6fe4ef Revert 2025-03-07 15:02:14 +00:00
Lewis Tunstall
48a57bbe34 Revert 2025-03-07 14:56:20 +00:00
Lewis Tunstall
54afde3b67 Tune 2025-03-07 14:55:54 +00:00
Lewis Tunstall
731caf57ed Update slurm 2025-03-07 11:08:06 +00:00
Lewis Tunstall
95bac8aacf Add deepscaler recipe 2025-03-07 10:47:19 +00:00
Lewis Tunstall
dcc33dc710 Reduce batch size 2025-03-06 13:44:41 +00:00
Lewis Tunstall
8c7d764e5b Clean 2025-03-06 13:36:39 +00:00
Lewis Tunstall
ada8cecd54 Add Qwen config 2025-03-06 09:41:11 +00:00
Lewis Tunstall
9b7aa79b0e Add remote script 2025-03-04 15:52:16 +00:00
Lewis Tunstall
25e0b07feb Tune params 2025-03-04 15:41:23 +00:00
Lewis Tunstall
db2501d531 Fix checkpoint 2025-03-04 10:34:25 +00:00
lewtun
def83fc6af
Merge branch 'main' into faster-grpo-trainer 2025-03-03 17:34:00 +01:00
edbeeching
00b1f61c01 fix job launcher 2025-02-28 21:34:47 +00:00
edbeeching
67fb66af13 adds smollm grpo config for replication 2025-02-28 21:09:40 +00:00
edbeeching
9de588449a Merge branch 'main' into faster-grpo-trainer 2025-02-27 14:45:05 +00:00
edbeeching
0030447af5 save configs wip 2025-02-26 09:07:57 +00:00
edbeeching
c775de3fd0 save wip 2025-02-26 09:06:38 +00:00
edbeeching
0db1912bdc adds logging to remote grpo 2025-02-22 22:01:16 +00:00
edbeeching
09628da5d3 Merge branch 'main' into faster-grpo-trainer 2025-02-22 21:11:36 +00:00
edbeeching
55a451a813 sampler 2025-02-22 12:31:11 +00:00
edbeeching
f68c27bdf3 save wip 2025-02-21 15:06:58 +00:00
edbeeching
f50658e7c8 save WIP 2025-02-21 15:06:26 +00:00
edbeeching
875628838a adds ref model offload 2025-02-21 09:59:39 +00:00
edbeeching
3e37bf1361 changes the completion attn mask 2025-02-21 09:29:44 +00:00
edbeeching
420d72a7da adds liger kernel support 2025-02-21 08:16:55 +00:00
edbeeching
382a0c7890 style 2025-02-20 21:55:35 +00:00
edbeeching
5d213c48b1 adds remote model, job launch, switch to sglang server with weight sync 2025-02-20 21:44:08 +00:00
edbeeching
12be29f08b save WIP 2025-02-20 11:04:26 +00:00
edbeeching
b1394e542e hacky fixes for memory issues 2025-02-20 07:34:24 +00:00
edbeeching
fbe3b07d56 patches to get vllm working in ddp setting 2025-02-19 21:28:10 +00:00
edbeeching
dcdcebaaac adds licence, world size patch for vllm 2025-02-19 10:52:45 +00:00
edbeeching
ed9554ff54 adds weight sync 2025-02-19 10:18:23 +00:00
edbeeching
38e350d3a2 make it run 2025-02-19 09:48:00 +00:00
103 changed files with 4932 additions and 5078 deletions

View file

@ -16,9 +16,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
uses: actions/checkout@v4
- name: Setup Python environment
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
uses: actions/setup-python@v5
with:
python-version: 3.10.10
- name: Install dependencies

3
.gitignore vendored
View file

@ -177,5 +177,4 @@ logs/
eval_results/
results/
.vscode/
.python-version
.vscode/

BIN
.litellm_cache/cache.db Normal file

Binary file not shown.

View file

@ -8,11 +8,10 @@ check_dirs := src tests
# dev dependencies
install:
uv venv openr1 --python 3.11
. openr1/bin/activate && uv pip install --upgrade pip && \
uv pip install vllm==0.8.5.post1 && \
uv pip install setuptools && \
uv pip install flash-attn --no-build-isolation && \
uv venv openr1 --python 3.11 && . openr1/bin/activate && uv pip install --upgrade pip
uv pip install vllm==0.7.2
uv pip install setuptools
uv pip install flash-attn --no-build-isolation
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]"
style:
@ -47,7 +46,8 @@ evaluate:
--use-chat-template \
--output-dir data/evals/$(MODEL); \
else \
lighteval vllm $$MODEL_ARGS "lighteval|$(TASK)|0|0" \
lighteval vllm $$MODEL_ARGS "custom|$(TASK)|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir data/evals/$(MODEL); \
fi

457
README.md
View file

@ -21,9 +21,10 @@
The goal of this repo is to build the missing pieces of the R1 pipeline such that everybody can reproduce and build on top of it. The project is simple by design and mostly consists of:
- `src/open_r1`: contains the scripts to train models as well as generate synthetic data:
- `src/open_r1`: contains the scripts to train and evaluate models as well as generate synthetic data:
- `grpo.py`: trains a model with GRPO on a given dataset.
- `sft.py`: performs a simple SFT of a model on a dataset.
- `evaluate.py`: evaluates a model on the R1 benchmarks.
- `generate.py`: generates synthetic data from a model using [Distilabel](https://github.com/argilla-io/distilabel).
- `Makefile`: contains easy-to-run commands for each step in the R1 pipeline leveraging the scripts above.
@ -41,7 +42,6 @@ We will use the DeepSeek-R1 [tech report](https://github.com/deepseek-ai/DeepSee
## News 🗞️
* **🧑‍🍳 [2025/05/26] (Step 1 completed!)** We release [**Mixture-of-Thoughts**](https://huggingface.co/datasets/open-r1/Mixture-of-Thoughts)--a curated reasoning dataset of 350k verified traces distilled from R1. The dataset spans tasks in mathematics, coding, and science, and is designed to teach language models to reason step-by-step. We also provide a recipe to train [OpenR1-Distill-7B](https://huggingface.co/open-r1/OpenR1-Distill-7B), which replicates the reasoning capabilities of [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) and marks the completion of step 1 in the Open R1 project.
* **⚡️ [2025/03/11] [(update #3)](https://huggingface.co/blog/open-r1/update-3):** We release the [**CodeForces-CoTs**](https://huggingface.co/datasets/open-r1/codeforces-cots) dataset of 10k competitive programming problems and 100k solutions distilled from R1. We also release IOI24: a new benchmark of _very_ hard problems from international olympiads. A 7B Qwen model trained on CodeForces-CoTs can outperform Claude 3.7 Sonnet on IOI24, while a 32B model can outperform R1 itself.
* **∞ [2025/02/10] [(update #2)](https://huggingface.co/blog/open-r1/update-2):** We release the [**OpenR1-Math-220k**](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k) dataset of 220k traces distilled from R1 on a new version of NuminaMath. Models trained on this dataset match the performance of DeepSeek's distilled ones.
* **🔥 [2025/02/02] [(update #1)](https://huggingface.co/blog/open-r1/update-1):** We implement the first parts of the [training](https://github.com/huggingface/open-r1?tab=readme-ov-file#training-models), [inference](https://github.com/huggingface/open-r1?tab=readme-ov-file#data-generation), and [evaluation](https://github.com/huggingface/open-r1?tab=readme-ov-file#reproducing-deepseeks-evaluation-results) pipelines. Let's go!
@ -69,11 +69,11 @@ uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --u
Next, install vLLM and FlashAttention:
```shell
uv pip install vllm==0.8.5.post1
uv pip install vllm==0.7.2
uv pip install setuptools && uv pip install flash-attn --no-build-isolation
```
This will also install PyTorch `v2.6.0` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
```shell
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]"
@ -100,30 +100,25 @@ sudo apt-get install git-lfs
## Training models
> [!NOTE]
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
We support training models with either DDP or DeepSpeed (ZeRO-2 and ZeRO-3). For example, to perform SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/Mixture-of-Thoughts](https://huggingface.co/datasets/open-r1/Mixture-of-Thoughts), run:
We support training models with either DDP or DeepSpeed (ZeRO-2 and ZeRO-3). For example, to run SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), run:
```shell
# Train via command line
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \
--dataset_name open-r1/Mixture-of-Thoughts \
--dataset_config all \
--eos_token '<|im_end|>' \
--learning_rate 4.0e-5 \
--num_train_epochs 5 \
--max_seq_length 32768 \
--per_device_train_batch_size 2 \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name open-r1/OpenR1-Math-220k \
--learning_rate 1.0e-5 \
--num_train_epochs 1 \
--packing \
--max_seq_length 16384 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
--bf16 \
--use_liger_kernel \
--output_dir data/OpenR1-Distill-7B
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
# Train via YAML config
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/OpenR1-Distill-7B/sft/config_distill.yaml
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
```
Currently, the following tasks are supported:
@ -137,160 +132,76 @@ Currently, the following tasks are supported:
By default, these scripts will push each model to your Hugging Face Hub username, i.e. `{username}/{model_name}-{task}`. You can override the parameters in each YAML config by appending them to the command as follows:
```shell
# Change the base model to a smaller variant
# Change batch size, number of epochs etc
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/OpenR1-Distill-7B/sft/config_distill.yaml \
--model_name_or_path Qwen/Qwen3-0.6B-Base \
--hub_model_id OpenR1-Distill-0.6B \
--output_dir data/OpenR1-Distill-0.6B
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--per_device_train_batch_size=1 --num_train_epochs=5
```
If you also wish to override the Weights and Biases default settings, you can do so as follows:
```shell
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/OpenR1-Distill-7B/sft/config_distill.yaml
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
```
**🚨 WARNING 🚨**
> [!NOTE]
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
Most base models like `meta-llama/Llama-3.2-1B` do not have a chat template, so we set ChatML as the default during training. However, for Qwen base models like `Qwen/Qwen2.5-1.5B`, a chat template is pre-defined in the tokenizer, so the EOS token must be set accordingly, e.g.
### SFT
```diff
# Align EOS token with chat template for Qwen base models
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B \
+ --eos_token '<|im_end|>'
--dataset_name open-r1/Mixture-of-Thoughts \
--dataset_config all \
--learning_rate 4.0e-5 \
--num_train_epochs 1 \
--max_seq_length 32768 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
--bf16 \
--use_liger_kernel \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
```
If you wish to use a custom chat template (e.g. Llama or Gemma), then the chat template and associated EOS token must be provided:
```diff
# Align EOS token with custom chat template
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path meta-llama/Llama-3.2-1B \
+ --chat_template "$(cat llama_chat_template.jinja)" \
+ --eos_token '<|eot_id|>' \
--dataset_name open-r1/Mixture-of-Thoughts \
--dataset_config all \
--learning_rate 4.0e-5 \
--num_train_epochs 1 \
--max_seq_length 32768 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
--bf16 \
--use_liger_kernel \
--output_dir data/Llama-3.2-1B-Open-R1-Distill
```
### SFT distillation
We provide a recipe to reproduce the reasoning capabilities of [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B), starting from the same base model. To do so, run:
To run SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), run:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
src/open_r1/sft.py \
--config recipes/OpenR1-Distill-7B/sft/config_distill.yaml
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
```
The result will be a model like [open-r1/OpenR1-Distill-7B](https://huggingface.co/open-r1/OpenR1-Distill-7B), with the following downstream performance:
| Model | AIME 2024 | MATH-500 | GPQA Diamond | LiveCodeBench v5 |
|-----------------------------|-----------|----------|--------------|------------------|
| OpenR1-Distill-7B | 52.7 | 89.0 | 52.8 | 39.4 |
| DeepSeek-R1-Distill-Qwen-7B | 51.3 | 93.5 | 52.4 | 37.4 |
You can adjust the YAML config to train on a different base model or dataset.
### GRPO
We use TRL's [vLLM backend](https://huggingface.co/docs/trl/speeding_up_training?vllm+examples=GRPO#vllm-for-fast-generation-in-online-methods) to scale training to large models across multiple nodes. For single-node training of smol models across 8 GPUs, use `vllm_mode="colocate"` to run vLLM in the same process as the training script:
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, set `--num_processes` to override the default value in the `accelerate` configs:
```shell
ACCELERATE_LOG_LEVEL=info \
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
src/open_r1/grpo.py --config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml \
--vllm_mode colocate
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
```
> [!WARNING]
> The chat template used in the distilled DeepSeek models omits the contents of the reasoning block within the `<think>` and `</think>` tags. It also prefills the assistant response with `<think>` which interferes with the format reward function. To handle that, it is important to override the chat template as done in e.g. [recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml](./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml).
For multi-node training on N+1 nodes, with 1 node running the vLLM server and N nodes running training, we provide an example Slurm script. For example, to run the above example on 1+1 nodes with data parallelism, run:
We provide a minimal reproducible experiment using GRPO for mathematical reasoning, referencing the approach from [SimpleRL-Reason](https://hkust-nlp.notion.site/simplerl-reason) which uses a 7B model trained on 8K examples. Running this on 8 H100 80G GPU takes about 3 hours:
```shell
sbatch --nodes=2 slurm/train.slurm --model Qwen2.5-1.5B-Instruct --task grpo --config demo --accelerator zero2 --dp 8 --tp 1
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
```
See the [Launching jobs on a Slurm cluster](#launching-jobs-on-a-slurm-cluster) section for more details.
#### GRPO dataset filtering
We provide support to filter datasets by generating and computing pass rate on veriable tasks, see this [README](scripts/pass_rate_filtering/README.md)
Our final [model](https://huggingface.co/Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on MATH-500, demonstrating a 17%+ improvement over the base model.
#### 👨‍💻 Training with a code interpreter
We provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we support multiple sandbox providers:
1. [E2B](https://e2b.dev) - Fast, cloud-based sandboxes with focus on Python execution
2. [Morph](https://cloud.morph.so/web/) - Cloud-based sandboxes with broader language support - Python/JS/C++/Rust
To use the code reward function, first install the necessary dependencies:
We provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we use [E2B](https://e2b.dev) sandboxes, which are fast and cheap to run. To use this reward function, first install the necessary dependencies:
```shell
uv pip install -e '.[code]'
```
##### E2B Provider
To use E2B sandboxes, create a `.env` file and add your E2B API token:
Then create a `.env` file and place an API token from E2B within it:
```
E2B_API_KEY="e2b_xxx"
```
##### Morph Provider
To use Morph, first install the morphcloud package:
```shell
pip install morphcloud
```
Then add your Morph API token to the `.env` file:
```
MORPH_API_KEY="YOUR_MORPH_API_KEY"
```
To specify which provider to use, add the `provider_type` parameter in your configuration:
```yaml
# For E2B
provider_type: e2b
# For Morph
provider_type: morph
```
##### Dataset Requirements
Make sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):
Then make sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):
```python
{
"language": "python", # Morph supports more languages including C++, Java, etc.
"language": "python",
"test_cases": [
{
"input": "4\n4\n0001\n1000\n0011\n0111\n3\n010\n101\n0\n2\n00000\n00001\n4\n01\n001\n0001\n00001\n",
@ -301,94 +212,43 @@ Make sure your dataset contains a `verification_info` column with the following
}
```
For example, to train a smol model on Python problems, start the vLLM server:
```shell
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-1.5B-Instruct
```
Then run training with:
```shell
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info \
accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes=7 \
src/open_r1/grpo.py --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
```
##### Using Router Services
It is possible to be rate limited when too many scripts are executed on sandbox services. For both providers, we offer router scripts that can be launched on a CPU node:
For E2B:
```shell
sbatch slurm/e2b_router.slurm
```
For Morph:
```shell
sbatch slurm/morph_router.slurm
```
Then add the router URL in your training YAML config:
```yaml
# For E2B
e2b_router_url: 1.2.3.4:8000
# For Morph
morph_router_url: 1.2.3.4:8000
```
The port should match the one used when launching the router.
All training jobs can share the same router IP which will ensure parallel executions are properly managed.
#### Competitive Programming problems: IOI & CodeForces
We provide `ioi_code_reward` and `cf_code_reward` reward functions for executing problems from [IOI](https://hf.co/datasets/open-r1/ioi) and [CodeForces](https://huggingface.co/datasets/open-r1/codeforces), respectively. You can use either [piston](https://github.com/engineer-man/piston) or Morph (currently IOI only) as your execution provider.
##### Piston
To use Piston:
1. Get piston workers running, see [slurm/piston/README.md](./slurm/piston/README.md)
2. Set your environment variable `PISTON_ENDPOINTS` to `slurm` or to a list of piston worker endpoints
For IOI:
3. In your configuration, use `ioi_provider: "piston"`
For CodeForces:
3. Download the generated (hard) test cases:
```
# change PATH_TO_SAVE_TESTCASES. Increase --max-workers according to your machine's capacity
huggingface-cli download open-r1/codeforces --repo-type=dataset --include='generated_tests/*.parquet' --max-workers=8 --local-dir PATH_TO_SAVE_TESTCASES
```
4. Save the path in .env:
```
CF_TESTS_FOLDER=PATH_TO_SAVE_TESTCASES
```
##### Morph
Morph is a cloud-based solution that provides sandboxed environments for running code. To use it:
1. Install the Morph client: `pip install morphcloud`
2. Add your Morph API key to the `.env` file: `MORPH_API_KEY="your_key_here"`
3. In your configuration, use `ioi_provider: "morph"`
##### Example recipes
For IOI:
See the [example recipe](./recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml) for how to use the IOI reward function:
For example, to train a smol model on Python problems, run:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
```
For CodeForces:
#### Data decontamination
Following [s1: Simple test-time scaling](https://arxiv.org/abs/2501.19393) the data can be decontaminated using the script at: [scripts/decontaminate.py](./scripts/decontaminate.py), which decontaminates a dataset using 8-grams and deduplicate the data. Sample run:
```shell
sbatch --job-name=cf-grpo --nodes=2 slurm/train.slurm --model Qwen2.5-Coder-7B-Instruct --task grpo --config codeforces --accelerator zero3 --dp 8 --tp 1
python scripts/decontaminate.py \
--dataset "open-r1/verifiable-coding-problems-python" \
--problem_column problem \
--cleanup
```
It will decontaminate against the benchmark datasets, and remove the contaminated samples afterwards. If no argument `--new_dataset_name` is provided, the same dataset will be reused, adding a `_decontaminated`. It runs against the prompt, which for this dataset is the column `problem`, but a different one can be provided.
Arguments for the script:
```shell
usage: decontaminate.py [-h] --dataset DATASET [--split SPLIT] [--ngram_size NGRAM_SIZE] [--problem_column PROBLEM_COLUMN] [--cleanup] [--new_dataset_name NEW_DATASET_NAME]
options:
-h, --help show this help message and exit
--dataset DATASET Name of the dataset to check for contamination.
--split SPLIT Split to check for contamination, defaults to `train`.
--ngram_size NGRAM_SIZE
Size of n-grams to build, defaults to 8.
--problem_column PROBLEM_COLUMN
Name of the column containing the problem (prompt).
--cleanup Whether to remove the contaminated rows before pushing the dataset.
--new_dataset_name NEW_DATASET_NAME
New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.
```
### Launching jobs on a Slurm cluster
@ -396,76 +256,48 @@ sbatch --job-name=cf-grpo --nodes=2 slurm/train.slurm --model Qwen2.5-Coder-7B-I
If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
```shell
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm --model {model_name} --task {task} --config {config_suffix} --accelerator {accelerator}
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm {model_name} {task} {config_suffix} {accelerator}
```
Here `{model_name}` and `{task}` are defined as above, while `{config_suffix}` refers to the specific config and `{accelerator}` refers to the choice of 🤗 Accelerate config in `recipes/accelerate_configs`. If you wish to override the default config parameters, you can provide them by appending a space-separated string like `'--arg1=value1 --arg2=value2'`. Here's a concrete example to run SFT on 1 node of 8 GPUs:
```shell
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm --model OpenR1-Distill-7B --task sft --config distill --accelerator zero3
# Launch on Slurm and override default hyperparameters
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm Qwen2.5-1.5B-Instruct sft demo zero3 '--per_device_train_batch_size=1 --num_train_epochs=5'
```
You can scale the number of nodes by increasing the `--nodes` flag.
For GRPO, we use 1 node for the vLLM server and N nodes for training. For example, to run GRPO on 1+1 nodes with mixed data and tensor parallelism, run:
```shell
sbatch --job-name=open_r1 --nodes=2 slurm/train.slurm --model Qwen2.5-1.5B-Instruct --task grpo --config demo --accelerator zero2 --dp 4 --tp 2
```
> [!NOTE]
> The configuration in `slurm/train.slurm` is optimised for the Hugging Face Compute Cluster and may require tweaking to be adapted to your own compute nodes.
### Customising the dataset mixture
To combine multiple datasets as a single training mixture, you can specify the `dataset_mixture` parameter in the YAML config file. Here's a template for how to do this:
```yaml
dataset_mixture:
datasets: # List of datasets to include in the mixture
- id: dataset_1 # Hub dataset ID
config: config_name_1 # Name of the dataset config
split: split_1 # Split to use from the dataset
columns: # Columns to keep
- column_1
- column_2
weight: 0.25 # Fraction of dataset to use
- id: dataset_2
config: config_name_2
split: split_2
columns:
- column_1
- column_2
weight: 0.5
seed: 42 # Seed for shuffling the combined dataset
test_split_size: 0.1 # Fraction of mixture to use for a test split
```
## Evaluating models
We use `lighteval` to evaluate models. For models which fit on a single GPU, run:
We use `lighteval` to evaluate models, with custom tasks defined in `src/open_r1/evaluate.py`. For models which fit on a single GPU, run:
```shell
export VLLM_WORKER_MULTIPROC_METHOD=spawn # Required for vLLM
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
# AIME 2024
TASK=aime24
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir $OUTPUT_DIR
# MATH-500
TASK=math_500
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir $OUTPUT_DIR
# GPQA Diamond
TASK=gpqa:diamond
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir $OUTPUT_DIR
@ -475,18 +307,22 @@ lighteval vllm $MODEL_ARGS "extended|lcb:codegeneration|0|0" \
--output-dir $OUTPUT_DIR
```
> [!IMPORTANT]
> You must set `max_model_length=32768` in the `vllm` command to align with the `max_new_tokens` we define per eval. Without this, `lighteval` will throw an error.
To increase throughput across multiple GPUs, use _data parallel_ as follows:
```shell
NUM_GPUS=8
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
TASK=aime24
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir $OUTPUT_DIR
--output-dir $OUTPUT_DIR
```
For large models which require sharding across GPUs, use _tensor parallel_ and run:
@ -494,14 +330,15 @@ For large models which require sharding across GPUs, use _tensor parallel_ and r
```shell
NUM_GPUS=8
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
TASK=aime24
OUTPUT_DIR=data/evals/$MODEL
export VLLM_WORKER_MULTIPROC_METHOD=spawn
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir $OUTPUT_DIR
--output-dir $OUTPUT_DIR
```
You can also launch an evaluation with `make evaluate`, specifying the model, task, and optionally the parallelism technique and number of GPUs.
@ -526,40 +363,32 @@ make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLE
## Reproducing Deepseek's evaluation results
The DeepSeek-R1 paper uses sampling with 4-64 responses per query to estimate `pass@1` accuracy, but does not specify the specific number of responses per benchmark. In the tables below, we estimate `pass@1` accuracy with the following number of responses per query:
| Benchmark | Number of responses per query |
|:-------------:|:-----------------------------:|
| AIME 2024 | 64 |
| MATH-500 | 4 |
| GPQA Diamond | 8 |
| LiveCodeBench | 16 |
Note that for benchmarks like AIME24, it is important to sample many responses as there are only 30 problems and this can introduce high variance across repeated runs. The choice of how many responses to sample per prompt likely explains the small differences between our evaluation results and those reported by DeepSeek.
> [!NOTE]
> The DeepSeek-R1 paper uses sampling with 64 responses per query to estimate `pass@1`. Below, we report the results from sampling 1 response per query, which likely explains the small 1-3σ discrepancies between our results and theirs.
### AIME 2024
We are able to reproduce Deepseek's reported results on the AIME 2024 benchmark within ~1-3 standard deviations:
| Model | AIME 2024 (🤗 LightEval) | AIME 2024 (DeepSeek Reported) |
|:------------------------------|:------------------------:|:-----------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 30.7 | 28.9 |
| DeepSeek-R1-Distill-Qwen-7B | 50.8 | 55.5 |
| DeepSeek-R1-Distill-Qwen-14B | 65.9 | 69.7 |
| DeepSeek-R1-Distill-Qwen-32B | 69.7 | 72.6 |
| DeepSeek-R1-Distill-Llama-8B | 43.9 | 41.7 |
| DeepSeek-R1-Distill-Llama-70B | 63.0 | 70.0 |
|:------------------------------|:-----------------------:|:----------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 26.7 | 28.9 |
| DeepSeek-R1-Distill-Qwen-7B | 56.6 | 55.5 |
| DeepSeek-R1-Distill-Qwen-14B | 60.0 | 69.7 |
| DeepSeek-R1-Distill-Qwen-32B | 73.2 | 72.6 |
| DeepSeek-R1-Distill-Llama-8B | 43.3 | 50.4 |
| DeepSeek-R1-Distill-Llama-70B | 73.3 | 70.0 |
To reproduce these results use the following command:
```shell
NUM_GPUS=1 # Set to 8 for 32B and 70B models
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "lighteval|aime24|0|0" \
lighteval vllm $MODEL_ARGS "custom|aime24|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir $OUTPUT_DIR
```
@ -576,23 +405,23 @@ We are able to reproduce Deepseek's reported results on the MATH-500 benchmark w
| Model | MATH-500 (🤗 LightEval) | MATH-500 (DeepSeek Reported) |
|:------------------------------|:-----------------------:|:----------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 83.1 | 83.9 |
| DeepSeek-R1-Distill-Qwen-7B | 94.5 | 92.8 |
| DeepSeek-R1-Distill-Qwen-14B | 94.1 | 93.9 |
| DeepSeek-R1-Distill-Qwen-32B | 95.6 | 94.3 |
| DeepSeek-R1-Distill-Qwen-1.5B | 84.6 | 83.9 |
| DeepSeek-R1-Distill-Qwen-7B | 93.0 | 92.8 |
| DeepSeek-R1-Distill-Qwen-14B | 95.0 | 93.9 |
| DeepSeek-R1-Distill-Qwen-32B | 96.6 | 94.3 |
| DeepSeek-R1-Distill-Llama-8B | 88.6 | 89.1 |
| DeepSeek-R1-Distill-Llama-70B | 95.1 | 94.5 |
| DeepSeek-R1-Distill-Llama-70B | 96.4 | 94.5 |
To reproduce these results use the following command:
```shell
export VLLM_WORKER_MULTIPROC_METHOD=spawn
NUM_GPUS=1 # Set to 8 for 32B and 70B models
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "lighteval|math_500|0|0" \
lighteval vllm $MODEL_ARGS "custom|math_500|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir $OUTPUT_DIR
```
@ -609,23 +438,23 @@ We are able to reproduce Deepseek's reported results on the GPQA Diamond benchma
| Model | GPQA Diamond (🤗 LightEval) | GPQA Diamond (DeepSeek Reported) |
|:------------------------------|:---------------------------:|:--------------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 35.8 | 33.8 |
| DeepSeek-R1-Distill-Qwen-1.5B | 34.3 | 33.8 |
| DeepSeek-R1-Distill-Qwen-7B | 50.5 | 49.1 |
| DeepSeek-R1-Distill-Qwen-14B | 61.5 | 59.1 |
| DeepSeek-R1-Distill-Qwen-32B | 63.1 | 62.1 |
| DeepSeek-R1-Distill-Llama-8B | 46.7 | 49.0 |
| DeepSeek-R1-Distill-Llama-70B | 67.4 | 65.2 |
| DeepSeek-R1-Distill-Qwen-14B | 59.6 | 59.1 |
| DeepSeek-R1-Distill-Qwen-32B | 63.6 | 62.1 |
| DeepSeek-R1-Distill-Llama-8B | 52.0 | 49.0 |
| DeepSeek-R1-Distill-Llama-70B | 67.2 | 65.2 |
To reproduce these results use the following command:
```shell
export VLLM_WORKER_MULTIPROC_METHOD=spawn
NUM_GPUS=1 # Set to 8 for 32B and 70B models
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "lighteval|gpqa:diamond|0|0" \
lighteval vllm $MODEL_ARGS "custom|gpqa:diamond|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--output-dir $OUTPUT_DIR
```
@ -638,21 +467,21 @@ python scripts/run_benchmarks.py --model-id {model_id} --benchmarks gpqa
We are able to reproduce Deepseek's reported results on the LiveCodeBench code generation benchmark within ~1-3 standard deviations:
| Model | LiveCodeBench (🤗 LightEval) | LiveCodeBench (DeepSeek Reported) |
|:------------------------------|:----------------------------:|:---------------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 16.1 | 16.9 |
| DeepSeek-R1-Distill-Qwen-7B | 37.4 | 37.6 |
| DeepSeek-R1-Distill-Qwen-14B | 51.3 | 53.1 |
| DeepSeek-R1-Distill-Qwen-32B | 56.0 | 57.2 |
| DeepSeek-R1-Distill-Llama-8B | 37.4 | 39.6 |
| DeepSeek-R1-Distill-Llama-70B | 55.9 | 57.5 |
| Model | LiveCodeBench (🤗 LightEval) | GPQA Diamond (DeepSeek Reported) |
|:------------------------------|:----------------------------:|:--------------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 16.3 | 16.9 |
| DeepSeek-R1-Distill-Qwen-7B | 36.6 | 37.6 |
| DeepSeek-R1-Distill-Qwen-14B | 51.5 | 53.1 |
| DeepSeek-R1-Distill-Qwen-32B | 56.6 | 57.2 |
| DeepSeek-R1-Distill-Llama-8B | 37.0 | 39.6 |
| DeepSeek-R1-Distill-Llama-70B | 54.5 | 57.5 |
To reproduce these results use the following command:
```shell
NUM_GPUS=1 # Set to 8 for 32B and 70B models, or data_parallel_size=8 with the smaller models for speed
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "extended|lcb:codegeneration|0|0" \
@ -751,38 +580,6 @@ sbatch slurm/generate.slurm \
> [!NOTE]
> While the job is running, you can setup an SSH tunnel through the cluster login node to access the Ray dashboard from your computer running `ssh -L 8265:ray_ip_head_node:8265 <login_node>`, then browsing `http://localhost:8265`
### Data decontamination
Following [s1: Simple test-time scaling](https://huggingface.co/papers/2501.19393) the data can be decontaminated using the script at: [scripts/decontaminate.py](./scripts/decontaminate.py), which decontaminates a dataset using 8-grams and deduplicate the data. Sample run:
```shell
python scripts/decontaminate.py \
--dataset "open-r1/verifiable-coding-problems-python" \
--problem_column problem \
--cleanup
```
It will decontaminate against the benchmark datasets, and remove the contaminated samples afterwards. If no argument `--new_dataset_name` is provided, the same dataset will be reused, adding a `_decontaminated`. It runs against the prompt, which for this dataset is the column `problem`, but a different one can be provided.
Arguments for the script:
```shell
usage: decontaminate.py [-h] --dataset DATASET [--split SPLIT] [--ngram_size NGRAM_SIZE] [--problem_column PROBLEM_COLUMN] [--cleanup] [--new_dataset_name NEW_DATASET_NAME]
options:
-h, --help show this help message and exit
--dataset DATASET Name of the dataset to check for contamination.
--split SPLIT Split to check for contamination, defaults to `train`.
--ngram_size NGRAM_SIZE
Size of n-grams to build, defaults to 8.
--problem_column PROBLEM_COLUMN
Name of the column containing the problem (prompt).
--cleanup Whether to remove the contaminated rows before pushing the dataset.
--new_dataset_name NEW_DATASET_NAME
New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.
```
## Contributing
Contributions are welcome. Please refer to https://github.com/huggingface/open-r1/issues/23.
@ -799,7 +596,7 @@ If you find this project is useful in your own work, please consider citing as f
@misc{openr1,
title = {Open R1: A fully open reproduction of DeepSeek-R1},
url = {https://github.com/huggingface/open-r1},
author = {{Hugging Face}},
author = {Hugging Face},
month = {January},
year = {2025}
}

View file

@ -7,22 +7,61 @@ attn_implementation: flash_attention_2
# Data training arguments
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}"
dataset_name: open-r1/OpenR1-Math-220k
dataset_prompt_column: problem
dataset_name: agentica-org/DeepScaleR-Preview-Dataset
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# Generation arguments
max_completion_length: 2048
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: DeepSeek-R1-Distill-Qwen-1.5B-GRPO
hub_strategy: every_save
learning_rate: 5.0e-07
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 8
temperature: 0.7
top_p: 0.95
# Reward func arguments
num_train_epochs: 1
output_dir: data/DeepSeek-R1-Distill-Qwen-1.5B-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 8
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- tag_count
- format
reward_weights:
- 1.0
# Filtering arguments. Samples with a pass rate outside the interval `pass_rate_min < x < pass_rate_max` will be filtered.
pass_rate_min: 0.2
pass_rate_max: 0.8
- 1.0
- 1.0
save_strategy: "steps"
save_steps: 0.2
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -8,12 +8,13 @@ attn_implementation: flash_attention_2
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}"
dataset_name: open-r1/OpenR1-Math-220k
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
@ -54,5 +55,4 @@ save_strategy: "epoch"
save_total_limit: 1
seed: 42
temperature: 0.7
use_liger_kernel: true
warmup_ratio: 0.1

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.8
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.00
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 16000
max_steps: -1
num_train_epochs: 0.1
num_generations: 16
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.00
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
beta: 0.04
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 100
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.8
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 2
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.01
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 16000
max_steps: -1
num_train_epochs: 0.1
num_generations: 14
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-v00.01
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
beta: 0.04
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 100
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.8
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/DeepSeek-R1-Distill-Qwen-1.5B-RGRPO-v00.01
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 32768
max_steps: -1
num_train_epochs: 1.0
num_generations: 16
output_dir: data/open-r1/DeepSeek-R1-Distill-Qwen-1.5B-RGRPO-v00.01
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
beta: 0.04
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 100
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,43 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: AI-MO/NuminaMath-TIR
dataset_configs:
- all
num_processes: 1
ddp_find_unused_parameters: false
# GRPO trainer config
# use_vllm: true
bf16: true
# ref_model_url: http://127.0.0.1:8000
do_eval: false
eval_strategy: "no"
eval_steps: 100
gradient_accumulation_steps: 2
gradient_checkpointing: false
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: DeepSeek-R1-Distill-Qwen-1.5-GRPO
hub_strategy: every_save
learning_rate: 1.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 512
max_steps: -1
num_train_epochs: 1
output_dir: data/DeepSeek-R1-Distill-Qwen-1.5-GRPO-v001
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1

View file

@ -0,0 +1,42 @@
# To start the training, run the following command:
# sbatch -N 4 --job-name=mistral_sft slurm/train.slurm Mistral-Small-24B-Instruct-2501 sft numina zero3
model_name_or_path: mistralai/Mistral-Small-24B-Instruct-2501
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
# dataset_name: yentinglin/s1K-1.1-trl-format
dataset_name: yentinglin/OpenR1-Math-220k-trl-format
preprocessing_num_workers: 8
# SFT trainer config
bf16: true
do_eval: true
eval_strategy: no
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Mistral-Small-24B-Instruct-2501-Open-R1-Distill
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
packing: true
max_length: 32768
max_steps: -1
num_train_epochs: 5
output_dir: data/Mistral-Small-24B-Instruct-2501-Open-R1-Distill
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: epoch
seed: 42
warmup_ratio: 0.1

View file

@ -45,5 +45,5 @@ save_only_model: true # needed to bypass FSDP errors with saving paged optimizer
save_strategy: epoch
save_total_limit: 1
seed: 42
use_liger_kernel: false # fails on multi-node
use_liger: false # fails on multi-node
warmup_ratio: 0.03

View file

@ -42,5 +42,5 @@ report_to:
save_strategy: epoch
save_total_limit: 1
seed: 42
use_liger_kernel: true
use_liger: true
warmup_ratio: 0.03

View file

@ -1,48 +0,0 @@
# Config for 1 node of 8 x H100s (80GB)
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Math-7B-RoPE-300k
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
chat_template: "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Open-R1, a language model trained by Hugging Face to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Open-R1, a language model trained by Hugging Face to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
dataset_name: open-r1/Mixture-of-Thoughts
dataset_config: all
dataset_num_proc: 12
eos_token: <|im_end|>
# SFT trainer config
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Distill-7B
hub_strategy: every_save
learning_rate: 4.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
packing: false
max_grad_norm: 0.2
max_length: 32768
max_steps: -1
num_train_epochs: 5
output_dir: data/OpenR1-Distill-7B
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 2
push_to_hub: true
report_to:
- wandb
save_strategy: epoch
save_total_limit: 1
seed: 42
use_liger_kernel: true
warmup_ratio: 0.03

View file

@ -0,0 +1,48 @@
# Model arguments
# You need to download the model and manually change the rope to 300k and max_position_embeddings to 32768
# the config file should match https://huggingface.co/open-r1/OpenR1-Qwen-7B/blob/main/config.json
model_name_or_path: Qwen/Qwen2.5-Math-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
#SFT hyperparam
max_length: 32768
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
# SFT trainer config
max_steps: -1
num_train_epochs: 3
bf16: true
do_eval: false
use_liger_kernel: true
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Qwen-7B-SFT
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
output_dir: data/OpenR1-Qwen-7B-SFT
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -6,12 +6,13 @@ attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true

View file

@ -6,13 +6,14 @@ attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python
dataset_prompt_column: problem_statement
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
beta: 0.01
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-1.5B-Instruct-GRPO
hub_model_revision: v00.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: 1000
num_generations: 16
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Instruct-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -5,57 +5,61 @@ torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/ioi
dataset_prompt_column: problem
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
beta: 0.01
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
use_vllm: true
do_eval: false
gradient_accumulation_steps: 4
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-Code-GRPO
hub_model_id: open-r1/Qwen2.5-1.5B-Instruct-RGRPO
hub_model_revision: v01.02
hub_strategy: every_save
learning_rate: 5.0e-06
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 2048
max_steps: 500
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-Code-GRPO
output_dir: data/Qwen2.5-1.5B-Instruct-RGRPO_v01.02
overwrite_output_dir: true
per_device_train_batch_size: 16
per_device_train_batch_size: 4
remote_gen_model_url: 26.0.164.45
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 50
save_total_limit: 1
seed: 42
temperature: 1.0
warmup_ratio: 0.03
# ioi specific config
code_language: cpp
reward_funcs:
- ioi_code
- code_format
- accuracy
- format
reward_weights:
- 1.0
- 0.1
- 0.1
# for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating
# otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions
code_eval_test_batch_size: 3
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,44 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
# SFT trainer config
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-Distill
hub_strategy: every_save
learning_rate: 5.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
packing: true
max_length: 16384
max_steps: -1
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-Distill
overwrite_output_dir: true
per_device_eval_batch_size: 16
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
use_liger: true
warmup_ratio: 0.05

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: agentica-org/DeepScaleR-Preview-Dataset
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-GRPO
hub_model_revision: v01.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: 1000
num_generations: 16
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-GRPO
hub_model_revision: v00.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
hub_model_revision: v01.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.00
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.0
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
hub_model_revision: v01.01
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.01
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- math_500
- aime24
beta: 0.0
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 14
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-7B-Instruct-RGRPO
hub_model_revision: v01.02
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-RGRPO_v01.02
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -1,80 +0,0 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-Coder-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/codeforces
dataset_prompt_column: prompt
dataset_config: verifiable-prompts
dataset_test_split: test
dataset_train_split: train
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- lcb_v4
beta: 0.0
loss_type: dr_grpo
scale_rewards: false
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
gradient_accumulation_steps: 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-Codeforces-GRPO
hub_model_revision: v01.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 2000
max_completion_length: 8192
max_steps: -1
num_generations: 16
# aiming for 1k optimization steps
# total_samples_per_batch = num_gpus * grad_accumulation_steps * per_device_batch_size = 8 * 32 * 4 = 1024
# unique_prompts_per_batch = total_samples_per_batch / num_generations = 1024 / 16 = 64
# #dataset ~= 16k (8k * 2, for python and cpp)
# global_steps_per_epoch = #dataset / unique_prompts_per_batch = 16k / 64 ~= 250
# epochs_for_1k_steps = 1000/250 = 4 epochs
num_train_epochs: 4
output_dir: data/Qwen2.5-Coder-7B-Instruct-Codeforces-GRPO_v01.00
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- cf_code
- code_format
reward_weights:
- 1.0
- 0.1
save_strategy: "steps"
save_steps: 0.05
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1
mask_truncated_completions: true
# for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating
# otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions
code_eval_test_batch_size: -1
code_eval_scoring_mode: weighted_sum

View file

@ -0,0 +1,60 @@
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
model_revision: v06.11-step-000004005
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python_decontaminated
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- lcb
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
hub_model_revision: v00.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 8192
max_steps: 1000
num_generations: 16
num_train_epochs: 1
output_dir: data/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- code
reward_weights:
- 1.0
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1

View file

@ -0,0 +1,59 @@
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
model_revision: v00.08-step-000001280
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python-10k_decontaminated
dataset_configs:
- all
num_processes: 7
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
hub_model_revision: v00.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 31744
max_steps: -1
num_train_epochs: 5
num_generations: 7
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO-v00.00
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
report_to:
- wandb
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 25
callbacks:
- push_to_hub_revision
benchmarks:
- lcb
reward_funcs:
# - code
- code_format
reward_weights:
# - 1.0
- 0.1

View file

@ -0,0 +1,59 @@
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
model_revision: v00.08-step-000001280
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python-10k_decontaminated
dataset_configs:
- all
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO
hub_model_revision: v00.00_remote
hub_strategy: every_save
learning_rate: 1.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 30700
max_steps: -1
num_train_epochs: 5
num_generations: 16
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-GRPO-v00.00_remote
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 25
callbacks:
- push_to_hub_revision
benchmarks:
- lcb
use_liger: true
reward_funcs:
- code
- code_format
reward_weights:
- 1.0
- 0.1

View file

@ -0,0 +1,64 @@
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Coder-7B-Instruct-SFT
model_revision: v02.12-step-000003170
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python_decontaminated-tested
dataset_configs:
- all
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-SFT-RGRPO
hub_model_revision: v01.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 30000
max_steps: -1
num_train_epochs: 1
num_generations: 16
output_dir: data/open-r1/Qwen2.5-Coder-7B-Instruct-SFT-RGRPO-v01.00
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
seed: 42
warmup_ratio: 0.1
beta: 0.01
remote_gen_model_url: 26.0.165.131
num_iterations: 1
# Saving and eval callbacks
# save_strategy: "steps"
# save_steps: 25
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - lcb
reward_funcs:
- code
- code_format
reward_weights:
- 1.0
- 0.1

View file

@ -0,0 +1,53 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-Math-7B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: DigitalLearningGmbH/MATH-lighteval
dataset_config: default
system_prompt: "You are a helpful AI Assistant, designed to provided well-reasoned and detailed responses. You FIRST think about the reasoning process as an internal monologue and then provide the user with the answer. The reasoning process MUST BE enclosed within <think> and </think> tags."
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen-2.5-7B-Simple-RL
hub_strategy: every_save
learning_rate: 3.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_generations: 7
num_train_epochs: 1
output_dir: data/Qwen-2.5-7B-Simple-RL
overwrite_output_dir: true
per_device_eval_batch_size: 16
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1

View file

@ -1,23 +1,15 @@
# Post-training recipes
## OpenR1 Distill 7B
To train the OpenR1 Distill 7B model, run:
```
sbatch --nodes=1 slurm/train.slurm --model OpenR1-Distill-7B --task sft --config distill --accelerator zero3
```
## OlympicCoder
To train the OlympicCoder models, run:
```
# 7B
sbatch --nodes=1 slurm/train.slurm --model OlympicCoder-7B --task sft --config v00.00 --accelerator zero3
sbatch --nodes=1 slurm/train.slurm OlympicCoder-7B sft v00.00 zero3
# 32B
sbatch --nodes=16 slurm/train.slurm --model OlympicCoder-32B --task sft --config v00.00 --accelerator fsdp
sbatch --nodes=16 slurm/train.slurm OlympicCoder-32B sft v00.00 fsdp
```
Note that we found it necessary to switch to FSDP1 and paged AdamW 8-bit for the 32B model in order to fit the largest possible context size.

View file

@ -0,0 +1,65 @@
# Model arguments
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 7
ddp_find_unused_parameters: false
# GRPO trainer config
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 64
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-FGRPO
hub_model_revision: v05.00
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 7168
max_steps: -1
num_train_epochs: 0.5
num_generations: 16
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-FGRPO-v05.00
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
push_to_hub: true
beta: 0.04
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.02
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 10
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,63 @@
# Model arguments
model_name_or_path: open-r1/SMOLLM_I8k-GR2-deepseek
model_revision: main-step-000000300
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.01
hub_strategy: every_save
learning_rate: 1.0e-04
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 7168
max_steps: -1
num_train_epochs: 0.5
num_generations: 16
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.01
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 8
push_to_hub: true
beta: 0.0
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.02
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 10
callbacks:
- push_to_hub_revision

View file

@ -0,0 +1,70 @@
# Model arguments
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.19
hub_strategy: every_save
learning_rate: 1.0e-04
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 7172
max_steps: -1
num_train_epochs: 1.0
num_generations: 16
output_dir: data/open-r1/SmolLM2-1.7B-Instruct-GRPO-v05.19
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 8
push_to_hub: true
beta: 0.01
remote_gen_model_url: 26.0.160.225
num_iterations: 4
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.02
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 10
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500
# - aime24
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"

View file

@ -0,0 +1,46 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
#SFT hyperparam
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
# SFT trainer config
max_steps: -1
num_train_epochs: 3
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Qwen-7B-SFT
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
output_dir: data/OpenR1-Qwen-7B-SFT
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -0,0 +1,46 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
#SFT hyperparam
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
# SFT trainer config
max_steps: -1
num_train_epochs: 3
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Qwen-7B-SFT
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
output_dir: data/OpenR1-Qwen-7B-SFT
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -0,0 +1,67 @@
# Model arguments
model_name_or_path: HuggingFaceTB/SmolLM2-135M-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
dataset_configs:
- all
dataset_train_split: train
num_processes: 8
ddp_find_unused_parameters: false
# GRPO trainer config
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.8
bf16: true
do_eval: false
eval_strategy: "no"
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/SmolLM2-135M-Instruct-GRPO-v00.01
hub_strategy: every_save
learning_rate: 1.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 1024
max_completion_length: 2048
max_steps: -1
num_train_epochs: 0.1
num_generations: 4
output_dir: data/open-r1/SmolLM2-135M-Instruct-GRPO-v00.01
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 8
push_to_hub: true
beta: 0.04
remote_gen_model_url: 0.0.0.0
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.1
use_liger_kernel: true
report_to:
- wandb
wandb_entity: huggingface
wandb_project: open-r1
log_completions: true
seed: 42
warmup_ratio: 0.1
# Saving and eval callbacks
save_strategy: "steps"
save_steps: 100
# callbacks:
# - push_to_hub_revision
# benchmarks:
# - math_500_8k
# - aime24_8k
# - gsm8k_8k

View file

@ -0,0 +1,47 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false,
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": false,
"number_checkpoints": null,
"synchronize_checkpoint_boundary": false,
"profile": false
}
}

View file

@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View file

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_config_file: recipes/accelerate_configs/deepspeed3_offload.json
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: fal

View file

@ -0,0 +1,30 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
offload_param:
device: cpu
pin_memory: true
activation_checkpointing:
partition_activations: true
contiguous_memory_optimization: false
cpu_checkpointing: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View file

@ -1,28 +0,0 @@
# Model arguments
model_name_or_path: open-r1/R1-Distill-Qwen-Math-7B
model_revision: v03.00-step-000008190
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
dataset_name: open-r1/DAPO-Math-17k-Processed
dataset_config: all
dataset_split: train
# Generation arguments
max_completion_length: 32000
num_generations: 8
temperature: 1.0
# Reward func arguments
reward_funcs:
- accuracy
reward_weights:
- 1.0
# Filtering arguments. Samples with mean reward outside of low / high will be filtered
pass_rate_min: 0.1
pass_rate_max: 0.6
output_dataset_name: open-r1/DAPO-Math-17k-Processed-R1-Distill-Qwen-Math-7B-v03.00-step-000008190-filter

View file

@ -1,26 +0,0 @@
# Model arguments
model_name_or_path: open-r1/R1-Distill-Qwen-Math-7B-Merges
model_revision: v00.00-step-000003660_v01.00-step-000002600_weights-0.50-0.50
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
dataset_name: open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled
dataset_prompt_column: problem
# Generation arguments
max_completion_length: 16000
num_generations: 8
temperature: 0.7
# Reward func arguments
reward_funcs:
- binary_code
reward_weights:
- 1.0
e2b_router_url: ip-10-53-85-92:8000
# Filtering arguments. Samples with mean reward outside of low / high will be filtered
pass_rate_min: 0.1
pass_rate_max: 0.6

View file

@ -0,0 +1,46 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-Math-7B
model_revision: main
torch_dtype: bfloat16
# Data training arguments
dataset_name: DigitalLearningGmbH/MATH-lighteval
dataset_configs:
- train
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen-2.5-7B_Base_Math_smalllr_remote_model
hub_strategy: every_save
learning_rate: 3.0e-06
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_train_epochs: 1
output_dir: data/Qwen-2.5-7B_Base_Math_smalllr_remote_model
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
ref_model_url: http://26.0.163.127:30010

View file

@ -1,85 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Benchmark script for the code_reward function with E2B.
This script measures the performance of the code_reward function with varying numbers
of samples and parallelization levels.
Each sample is a CodeForces problem with a gold standard solution that is executed against a set of public test cases.
"""
from datasets import load_dataset
import time
from tqdm.auto import tqdm
from dotenv import load_dotenv
load_dotenv()
from open_r1.rewards import code_reward
def benchmark_code_reward(example):
start_time = time.time()
test_completions = [[{"content": example["gold_standard_solution"]}]]
reward_kwargs = {"verification_info": [example["verification_info"]]}
rewards = code_reward(test_completions, **reward_kwargs)
end_time = time.time()
example["test_reward"] = rewards[0]
example["reward_time"] = end_time - start_time
return example
if __name__ == "__main__":
parallel_dict = {
16:[1,4,16],
64:[4,16, 64],
256:[16, 64, 96], # cap at 96 as PRO account is limited to 100
}
# Store results for table formatting
results = []
for num_samples in tqdm([16, 64,256], desc="Benchmarking samples"):
for num_parallel in parallel_dict[num_samples]:
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated")
code_dataset = code_dataset["train"].shuffle(seed=42).select(range(num_samples))
test_completions = [[{"content": example["gold_standard_solution"]}] for example in code_dataset]
reward_kwargs = {"verification_info": [example["verification_info"] for example in code_dataset]}
start_time = time.time()
rewards = code_reward(test_completions, num_parallel=num_parallel, **reward_kwargs)
execution_time = time.time() - start_time
# Calculate some statistics about rewards
mean_reward = sum(rewards) / len(rewards)
min_reward = min(rewards)
max_reward = max(rewards)
# Store results
results.append({
"num_samples": num_samples,
"num_parallel": num_parallel,
"execution_time": execution_time,
"mean_reward": mean_reward,
"min_reward": min_reward,
"max_reward": max_reward
})
print("\n## Benchmark Results\n")
print("| Sample Size | Parallelization | Execution Time (s) | Mean Reward | Min Reward | Max Reward |")
print("|:-----------:|:---------------:|------------------:|:-----------:|:-----------:|:-----------:|")
for result in results:
print(f"| {result['num_samples']:^11} | {result['num_parallel']:^15} | {result['execution_time']:17.2f} | {result['mean_reward']:^11.4f} | {result['min_reward']:^11.4f} | {result['max_reward']:^11.4f} |")

View file

@ -15,7 +15,7 @@
# limitations under the License.
"""
This script is used to decontaminate a dataset by checking for n-gram overlap with other datasets.
It uses the same approach presented in https://huggingface.co/papers/2501.19393,
It uses the same approach presented in https://arxiv.org/abs/2501.19393,
as found in: https://github.com/simplescaling/s1/blob/main/data/decontaminate_util.py
Usage:

View file

@ -1,161 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
from fastapi import FastAPI
from pydantic import BaseModel, ConfigDict
from typing import Optional
from fastapi import FastAPI, Request
import argparse
import asyncio
from fastapi import FastAPI
import uvicorn
from e2b_code_interpreter.models import Execution
from dotenv import load_dotenv
from e2b_code_interpreter import AsyncSandbox
load_dotenv()
class BatchRequest(BaseModel):
"""
BatchRequest is a data model representing a batch processing request.
Attributes:
scripts (list[str]): A list of script names or paths to be executed.
languages (list[str]): The programming languages for each script in the list.
timeout (int): The maximum allowed execution time for each script in seconds.
request_timeout (int): The maximum allowed time for the entire batch request in seconds.
"""
scripts: list[str]
languages: list[str]
timeout: int
request_timeout: int
class ScriptResult(BaseModel):
"""
ScriptResult is a Pydantic model that represents the result of a script execution.
Attributes:
execution (Optional[Execution]): An optional instance of the `Execution` class
that contains details about the script's execution, such as status, output,
or any other relevant metadata.
exception_str (Optional[str]): An optional string that captures the exception
message or details if an error occurred during the script's execution.
model_config (ConfigDict): A configuration dictionary that allows arbitrary
types to be used within the Pydantic model. This is necessary to support
custom types like `Execution` within the model.
"""
execution: Optional[Execution]
exception_str: Optional[str]
# required to allow arbitrary types in pydantic models such as Execution
model_config = ConfigDict(arbitrary_types_allowed=True)
def create_app(args):
"""
Creates and configures a FastAPI application instance.
Args:
args: An object containing configuration parameters for the application.
- num_sandboxes (int): The maximum number of concurrent sandboxes allowed.
Returns:
FastAPI: A configured FastAPI application instance.
The application includes the following endpoints:
1. GET /health:
- Returns the health status of the application.
- Response: {"status": "ok"}
2. POST /execute_batch:
- Executes a batch of scripts in an isolated sandbox environment.
- Request Body: BatchRequest object containing:
- languages (list[str]): The programming languages of the scripts (python or javascript).
- timeout (int): The maximum execution time for each script.
- request_timeout (int): The timeout for the request itself.
- scripts (List[str]): A list of scripts to execute.
- Response: A list of ScriptResult objects for each script, containing:
- execution: The result of the script execution.
- exception_str: Any exception encountered during execution.
Notes:
- A semaphore is used to limit the number of concurrent sandboxes.
- Each script execution is wrapped in a timeout to prevent hanging.
- Sandboxes are cleaned up after execution, even in case of errors.
"""
app = FastAPI()
# Instantiate semaphore and attach it to app state
app.state.sandbox_semaphore = asyncio.Semaphore(args.max_num_sandboxes)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/execute_batch")
async def execute_batch(batch: BatchRequest, request: Request):
semaphore = request.app.state.sandbox_semaphore
languages = batch.languages
timeout = batch.timeout
request_timeout = batch.request_timeout
asyncio_timeout = batch.timeout + 1
async def run_script(script: str, language: str) -> ScriptResult:
async with semaphore:
try:
sandbox = await AsyncSandbox.create(
timeout=timeout,
request_timeout=request_timeout,
)
execution = await asyncio.wait_for(
sandbox.run_code(script, language=language),
timeout=asyncio_timeout,
)
return ScriptResult(execution=execution, exception_str=None)
except Exception as e:
return ScriptResult(execution=None, exception_str=str(e))
finally:
try:
await sandbox.kill()
except Exception:
pass
tasks = [run_script(script, lang) for script, lang in zip(batch.scripts, batch.languages)]
return await asyncio.gather(*tasks)
return app
def parse_args():
"""
Parse command-line arguments for the e2b_router script.
Arguments:
--host (str): The hostname or IP address to bind the server to. Defaults to "0.0.0.0" (binds to all interfaces).
--port (int): The port number on which the server will listen. Defaults to 8000.
--max_num_sandboxes (int): The maximum number of sandboxes that can be created or managed simultaneously. Defaults to 20.
Returns:
argparse.Namespace: Parsed command-line arguments as an object.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--max_num_sandboxes", type=int, default=20)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port)

250
scripts/faster_grpo.py Normal file
View file

@ -0,0 +1,250 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
from dataclasses import dataclass, field
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
)
from open_r1.trainers.faster_grpo_trainer import FastGRPOTrainer, FastGRPOConfig
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
logger = logging.getLogger(__name__)
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
#############################
# Initialize the Async GRPO trainer
#############################
trainer = FastGRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
callbacks=get_callbacks(training_args, model_args),
)
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
#############
# push to hub
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, FastGRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

View file

@ -1,28 +0,0 @@
import argparse
from transformers import AutoConfig
from math import gcd
def get_tensor_parallel_size(model_name: str, revision: str = None, default_tp: int = 8) -> int:
try:
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
num_heads = getattr(config, 'num_attention_heads', None)
if num_heads is not None and num_heads % default_tp != 0:
tp = gcd(num_heads, default_tp)
return max(tp, 1)
else:
return default_tp
except Exception as e:
print(f"Warning: Failed to fetch config for {model_name}@{revision}: {e}")
return default_tp
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True, help="Hugging Face model name or path")
parser.add_argument("--revision", type=str, default=None, help="Model revision if applicable")
parser.add_argument("--default_tp", type=int, default=8, help="Default TP size (usually GPUs per node)")
args = parser.parse_args()
tp = get_tensor_parallel_size(args.model_name, args.revision, args.default_tp)
print(tp)

View file

@ -1,173 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
from fastapi import FastAPI
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from fastapi import FastAPI, Request
import uvicorn
from dotenv import load_dotenv
import os
load_dotenv()
class BatchRequest(BaseModel):
"""
BatchRequest is a data model representing a batch processing request.
Attributes:
scripts (list[str]): A list of script names or paths to be executed.
languages (List[str]): The programming languages for each script in the list.
timeout (int): The maximum allowed execution time for each script in seconds.
request_timeout (int): The maximum allowed time for the entire batch request in seconds.
"""
scripts: List[str]
languages: List[str]
timeout: int
request_timeout: int
class ScriptResult(BaseModel):
"""
ScriptResult is a Pydantic model that represents the result of a script execution.
Attributes:
text (Optional[str]): The output text from the script execution.
exception_str (Optional[str]): An optional string that captures the exception
message or details if an error occurred during the script's execution.
model_config (ConfigDict): A configuration dictionary that allows arbitrary
types to be used within the Pydantic model.
"""
text: Optional[str]
exception_str: Optional[str]
model_config = ConfigDict(arbitrary_types_allowed=True)
def create_app(args):
"""
Creates and configures a FastAPI application instance for the MorphCloud router.
Args:
args: An object containing configuration parameters for the application.
- max_num_sandboxes (int): The maximum number of concurrent sandboxes allowed.
- api_key (str): The MorphCloud API key to use.
Returns:
FastAPI: A configured FastAPI application instance.
"""
app = FastAPI()
from morphcloud.api import MorphCloudClient
from morphcloud.sandbox import Sandbox
app.state.client = MorphCloudClient(api_key=args.api_key)
app.state.Sandbox = Sandbox
app.state.sandbox_semaphore = asyncio.Semaphore(args.max_num_sandboxes)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/execute_batch")
async def execute_batch(batch: BatchRequest, request: Request):
semaphore = request.app.state.sandbox_semaphore
client = request.app.state.client
Sandbox = request.app.state.Sandbox
languages = batch.languages
timeout = batch.timeout
request_timeout = batch.request_timeout
asyncio_timeout = batch.timeout + 1
async def run_script(script: str, language: str) -> ScriptResult:
sandbox = None
sandbox_id = "unknown"
async with semaphore:
try:
sandbox = await asyncio.to_thread(
Sandbox.new,
client=client,
ttl_seconds=timeout
)
sandbox_id = getattr(sandbox, 'id', None) or getattr(sandbox._instance, 'id', 'unknown')
execution = await asyncio.wait_for(
asyncio.to_thread(
sandbox.run_code,
script,
language=language,
timeout=timeout * 1000
),
timeout=asyncio_timeout,
)
if hasattr(execution, 'text') and execution.text:
return ScriptResult(text=execution.text, exception_str=None)
elif hasattr(execution, 'stdout') and execution.stdout:
return ScriptResult(text=execution.stdout, exception_str=None)
else:
return ScriptResult(text="", exception_str="No output from execution")
except Exception as e:
return ScriptResult(text=None, exception_str=str(e))
finally:
if sandbox:
try:
await asyncio.to_thread(sandbox.close)
await asyncio.to_thread(sandbox.shutdown)
except Exception:
pass
tasks = [run_script(script, lang) for script, lang in zip(batch.scripts, batch.languages)]
return await asyncio.gather(*tasks)
return app
def parse_args():
"""
Parse command-line arguments for the morph_router script.
Arguments:
--host (str): The hostname or IP address to bind the server to. Defaults to "0.0.0.0".
--port (int): The port number on which the server will listen. Defaults to 8001.
--max_num_sandboxes (int): The maximum number of sandboxes that can be created simultaneously. Defaults to 20.
--api_key (str): The MorphCloud API key. If not provided, it will be read from the MORPH_API_KEY environment variable.
Returns:
argparse.Namespace: Parsed command-line arguments as an object.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--max_num_sandboxes", type=int, default=20)
parser.add_argument("--api_key", default=os.getenv("MORPH_API_KEY"))
args = parser.parse_args()
if not args.api_key:
raise ValueError("MorphCloud API key not provided. Please set MORPH_API_KEY environment variable or use --api_key.")
return args
if __name__ == "__main__":
args = parse_args()
app = create_app(args)
print(f"Starting MorphCloud Router on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)

View file

@ -1,36 +0,0 @@
# Pass rate filtering
We provide support to filter datasets by generating and computing pass rate on veriable tasks
See `scripts/pass_rate_filtering/compute_pass_rate.py` and `scripts/pass_rate_filtering/launch_filtering.sh` (hardcoded for DAPO at the moment)
By default the script chunks the dataset, merge can be run using the following snippet (example for DAPO) :
from datasets import load_dataset, concatenate_datasets
name = "open-r1/DAPO-Math-17k-Processed-R1-Distill-Qwen-Math-7B-Merges-v00.02-v01.02-0.3-0.7-filter"
```python
gen_datasets = []
filt_datasets = []
for start in range(0,17400,200):
end = start + 200
if start == 17200:
end = 17398
gen_config_name = f"gen-{start}-{end}"
gen_dataset = load_dataset(name, gen_config_name, revision="gen", split="train")
gen_datasets.append(gen_dataset)
filt_config_name = f"filt-0.1-0.6-{start}-{end}"
filt_dataset = load_dataset(name, filt_config_name, revision="pass_rate", split="train")
filt_datasets.append(filt_dataset)
gen_dataset = concatenate_datasets(gen_datasets)
gen_dataset.push_to_hub(name, config_name="gen", split="train")
print(gen_dataset)
filt_dataset = concatenate_datasets(filt_datasets)
filt_dataset.push_to_hub(name, config_name="default", split="train")
print(filt_dataset)
```

View file

@ -1,205 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# example usage python scripts/filter_dataset.py --config recipes/dataset_filtering/config_demo.yaml
import logging
from dataclasses import dataclass
from git import Optional
import torch
import sys
import datasets
import transformers
from datasets import load_dataset
from transformers import set_seed
from open_r1.configs import GRPOConfig, GRPOScriptArguments
from open_r1.rewards import get_reward_funcs
from open_r1.utils import get_tokenizer
from trl import ModelConfig, TrlParser
from trl.data_utils import apply_chat_template
from vllm import LLM, SamplingParams
logger = logging.getLogger(__name__)
@dataclass
class PassRateScriptArguments(GRPOScriptArguments):
# we can be lazy and just use the same script args as GRPO
output_dataset_name: Optional[str] = None
pass_rate_min: float = 0.1
pass_rate_max: float = 0.9
dataset_start_index: Optional[int] = None
dataset_end_index: Optional[int] = None
dataset_split: str = "train"
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_split)
if script_args.dataset_start_index is not None and script_args.dataset_end_index is not None:
dataset = dataset.select(range(script_args.dataset_start_index, script_args.dataset_end_index))
# Get reward functions from the registry
reward_funcs = get_reward_funcs(script_args)
# Format into conversation
def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column):
example["prompt_backup"] = example[prompt_column]
prompt = []
if training_args.system_prompt is not None:
prompt.append({"role": "system", "content": training_args.system_prompt})
if prompt_column not in example:
raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")
prompt.append({"role": "user", "content": example[prompt_column]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
tokenizer = get_tokenizer(model_args, training_args)
if "messages" in dataset.column_names:
dataset = dataset.remove_columns("messages")
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
llm = LLM(
model=model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
sampling_params=SamplingParams(
temperature=training_args.temperature,
top_p=training_args.top_p,
top_k=training_args.top_k,
n=training_args.num_generations,
max_tokens=training_args.max_completion_length,
)
def batch_score(examples):
prompts = examples["prompt"]
outputs = llm.generate(
prompts,
sampling_params=sampling_params,
use_tqdm=False,
)
repeated_prompts = []
reward_completions = []
grouped_completions = []
for output in outputs:
prompt = output.prompt
group = []
for completion in output.outputs:
text = completion.text
group.append(text)
message = [{"role": "assistant", "content": text}]
repeated_prompts.append(prompt)
reward_completions.append(message)
grouped_completions.append(group)
def repeat_each_element_k_times(list_to_repeat: list, k: int) -> list:
return [element for item in list_to_repeat for element in [item] * k]
rewards_per_func = torch.zeros(len(repeated_prompts), len(reward_funcs))
for i, reward_func in enumerate(reward_funcs):
keys = [key for key in examples.data.keys() if key not in ["prompt", "completion"]]
reward_kwargs = {key: repeat_each_element_k_times(examples[key], training_args.num_generations) for key in keys}
output_reward_func = reward_func(prompts=repeated_prompts, completions=reward_completions, **reward_kwargs)
# Convert None values to NaN
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32)
reshaped_rewards = rewards_per_func.view(-1, training_args.num_generations)
examples["pass_rate_generations"] = grouped_completions
examples["pass_rate_rewards"] = reshaped_rewards.tolist()
return examples
dataset = dataset.map(batch_score, batched=True, batch_size=64)
# we need to restore the prompt for the final dataset
def restore_prompt(example):
example["prompt"] = example["prompt_backup"]
return example
dataset = dataset.map(restore_prompt)
dataset = dataset.remove_columns("prompt_backup")
if script_args.output_dataset_name is not None:
output_dataset_name = script_args.output_dataset_name
else:
model_name = model_args.model_name_or_path
if "/" in model_name:
model_name = model_name.split("/")[-1]
model_revision = model_args.model_revision
output_dataset_name = f"{script_args.dataset_name}-{model_name}-{model_revision}-gen"
config_name="default"
filtered_config_name = f"filt-{script_args.pass_rate_min}-{script_args.pass_rate_max}"
if script_args.dataset_start_index is not None and script_args.dataset_end_index is not None:
config_name = f"gen-{script_args.dataset_start_index}-{script_args.dataset_end_index}"
filtered_config_name = f"{filtered_config_name}-{script_args.dataset_start_index}-{script_args.dataset_end_index}"
dataset.push_to_hub(output_dataset_name, config_name=config_name, revision="gen")
def filter_func(example):
rewards = example["pass_rate_rewards"]
# get the mean of the rewards that are not None
mean_reward = torch.nanmean(torch.tensor(rewards, dtype=torch.float32))
return script_args.pass_rate_min < mean_reward < script_args.pass_rate_max
logger.info(f"Filtering dataset with low reward threshold {script_args.pass_rate_min} and high reward threshold {script_args.pass_rate_max}")
logger.info(f"Dataset size before filtering: {dataset}")
dataset = dataset.filter(filter_func)
logger.info(f"Dataset size after filtering: {dataset}")
dataset.push_to_hub(output_dataset_name, config_name=filtered_config_name, revision="pass_rate")
if __name__ == "__main__":
parser = TrlParser((PassRateScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

View file

@ -1,15 +0,0 @@
# a bash foor loop from 0 to 17,400 in chunks of 200
for i in {0..17000..200}
do
START=$i
END=$((i + 200))
echo "Processing chunk from $START to $END"
# Submit the job to SLURM
sbatch slurm/compute_pass_rate.slurm recipes/dataset_filtering/filter_dapo.yaml $START $END
done
sbatch slurm/compute_pass_rate.slurm recipes/dataset_filtering/filter_dapo.yaml 17200 17398

301
scripts/remote_grpo.py Normal file
View file

@ -0,0 +1,301 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GRPO trainer to train on N + 1 nodes, with 1 node allocated for generation.
Usage:
For training, run:
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml scripts/remote_grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote.yaml
This will automatically spin up an SGLang server on a separate node and use it for generation.
For development, first spin up an SGLang sever on a separate node:
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-1.5B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
Then run training by providing the IP address of the server:
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml scripts/remote_grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_remote.yaml \
--remote_gen_model_url ip-26-0-160-103
"""
import logging
import os
import sys
from dataclasses import dataclass, field
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.rewards import (
accuracy_reward,
code_reward,
format_reward,
get_code_format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
tag_count_reward,
)
from open_r1.utils import get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import ModelConfig, ScriptArguments, TrlParser
from open_r1.trainers.remote_grpo_trainer import RemoteGRPOTrainer, RemoteGRPOConfig
logger = logging.getLogger(__name__)
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
code_language (`str`):
Language for code format reward.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format", "tag_count"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
code_language: str = field(
default="python",
metadata={
"help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages",
"choices": ["python", "javascript", "r", "java", "bash"],
},
)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, training_args)
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
"code": code_reward,
"code_format": get_code_format_reward(language=script_args.code_language),
"tag_count": tag_count_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
# Format into conversation
def make_conversation(example):
prompt = []
if training_args.system_prompt is not None:
prompt.append({"role": "system", "content": training_args.system_prompt})
prompt.append({"role": "user", "content": example["problem"]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
#############################
# Initialize the GRPO trainer
#############################
trainer = RemoteGRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
callbacks=get_callbacks(training_args, model_args),
processing_class=tokenizer,
)
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
# trainer.create_model_card(**kwargs) # Bug: needs fixing with TRL helper methods
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
#############
# push to hub
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, RemoteGRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

View file

@ -44,21 +44,19 @@ _deps = [
"accelerate==1.4.0",
"bitsandbytes>=0.43.0",
"datasets>=3.2.0",
"deepspeed==0.16.8",
"deepspeed==0.15.4",
"distilabel[vllm,ray,openai]>=1.5.2",
"e2b-code-interpreter>=1.0.5",
"einops>=0.8.0",
"flake8>=6.0.0",
"hf_transfer>=0.1.4",
"huggingface-hub[cli,hf_xet]>=0.30.2,<1.0",
"huggingface-hub[cli]>=0.19.2,<1.0",
"isort>=5.12.0",
"jieba", # Needed for Chinese language support
"langdetect", # Needed for LightEval's extended tasks
"latex2sympy2_extended>=1.0.6",
"liger-kernel>=0.5.10",
"lighteval @ git+https://github.com/huggingface/lighteval.git@d3da6b9bbf38104c8b5e1acc86f83541f9a502d1", # Critical bug fix for tokenizer revisions: https://github.com/huggingface/lighteval/pull/721
"liger_kernel==0.5.3",
"lighteval @ git+https://github.com/huggingface/lighteval.git@ed084813e0bd12d82a06d9f913291fdbee774905",
"math-verify==0.5.2", # Used for math verification in grpo
"morphcloud==0.1.67",
"packaging>=23.0",
"parameterized>=0.9.0",
"peft>=0.14.0",
@ -67,13 +65,11 @@ _deps = [
"ruff>=0.9.0",
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
"torch==2.6.0",
"transformers==4.52.3",
"trl[vllm]==0.18.0",
"torch==2.5.1",
"transformers==4.48.3", # Must pin for SGLang
"trl @ git+https://github.com/huggingface/trl.git@e3244d2d096ff1e2e248c931d06d39e165e20623",
"vllm==0.7.2",
"wandb>=0.19.1",
"async-lru>=2.0.5",
"aiofiles>=24.1.0",
"pandas>=2.2.3",
]
# this is a lookup table with items like:
@ -90,12 +86,12 @@ def deps_list(*pkgs):
extras = {}
extras["tests"] = deps_list("pytest", "parameterized", "math-verify", "jieba")
extras["tests"] = deps_list("pytest", "parameterized", "math-verify")
extras["torch"] = deps_list("torch")
extras["quality"] = deps_list("ruff", "isort", "flake8")
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv", "morphcloud", "jieba", "pandas", "aiofiles")
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
extras["eval"] = deps_list("lighteval", "math-verify")
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["code"]
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
# core dependencies shared across the whole project - keep this to a bare minimum :)
install_requires = [
@ -109,14 +105,14 @@ install_requires = [
deps["langdetect"],
deps["latex2sympy2_extended"],
deps["math-verify"],
deps["liger-kernel"],
deps["liger_kernel"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["peft"],
deps["safetensors"],
deps["sentencepiece"],
deps["transformers"],
deps["trl"],
deps["wandb"],
deps["async-lru"],
]
setup(

View file

@ -5,7 +5,7 @@
conda create -n sglang124 python=3.11
conda activate sglang124
pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124
pip install torch=2.5.1 --index-url https://download.pytorch.org/whl/cu124
pip install sgl-kernel --force-reinstall --no-deps
pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/

View file

@ -1,20 +0,0 @@
#!/bin/bash
#SBATCH --job-name=open-r1-compute-pass-rate
#SBATCH --partition=hopper-prod
#SBATCH --qos=normal
#SBATCH --nodes=1
#SBATCH --gpus-per-node=1
#SBATCH --output=./logs/%x-%j.out
#SBATCH --error=./logs/%x-%j.err
#SBATCH --time=01-00:00:00
#SBATCH --requeue
# example usage: sbatch slurm/dataset_filter.slurm recipes/dataset_filtering/filter_dapo.yaml 0 500
set -x -e
source ~/.bashrc
source openr1/bin/activate
python scripts/pass_rate_filtering/compute_pass_rate.py --config $1 --dataset_start_index $2 --dataset_end_index $3

View file

@ -1,17 +0,0 @@
#!/bin/bash
#SBATCH --partition=hopper-cpu
#SBATCH --mem=16g
#SBATCH --cpus-per-task=16
#SBATCH --output=/fsx/open-r1/logs/e2b_router/%x-%j.out
#SBATCH --error=/fsx/open-r1/logs/e2b_router/%x-%j.err
#SBATCH --requeue
#SBATCH --time=7-00:00:00
echo "Starting job"
set -x -e
source ~/.bashrc
source openr1/bin/activate
srun python scripts/e2b_router.py

View file

@ -3,22 +3,13 @@
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod
#SBATCH --output=./logs/%x-%j.out
#SBATCH --error=./logs/%x-%j.err
#SBATCH --err=./logs/%x-%j.err
#SBATCH --requeue
#SBATCH --time=1-00:00:00
# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
module load cuda/12.4
# Refresh Weka on h4 cache
echo "Refreshing Weka filesystem..."
find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
# Needed for vLLM
export VLLM_WORKER_MULTIPROC_METHOD=spawn
set -x -e
source ~/.bashrc
@ -34,11 +25,14 @@ MODEL_REVISION=$4
# $7 is reserved for system_prompt, see line 51
NUM_GPUS=$(nvidia-smi -L | wc -l)
# Use TP to shard model across GPUs
# Set Whether to use tensor parallelism or data parallelism
if [ "$TENSOR_PARALLEL" = "True" ]; then
MODEL_ARGS="model_name=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
# use TP to shard model across NUM_GPUS
export VLLM_WORKER_MULTIPROC_METHOD=spawn
# FIXME: lighteval now requires us to manually pass the generation params
MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
else
MODEL_ARGS="model_name=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
fi
LM_EVAL_REPO_ID="open-r1/open-r1-eval-leaderboard"
@ -47,14 +41,27 @@ DETAILS_REPO_ID="open-r1/details-$MODEL_NAME"
OUTPUT_DIR="eval_results/$MODEL_ID/$MODEL_REVISION/$TASK_NAME"
# We need this flag since we run this script from training jobs that use DeepSpeed and the env vars get progated which causes errors during evaluation
ACCELERATE_USE_DEEPSPEED=false
# Enable fast downloads
HF_HUB_ENABLE_HF_TRANSFER=1
echo "Running lighteval script ..."
echo "Eval results will be saved to $OUTPUT_DIR"
lighteval vllm "$MODEL_ARGS" $TASKS \
# Check if "custom" is a substring of TASKS
if [[ $TASKS == *"custom"* ]]; then
echo "Custom task detected. Running custom task evaluation script ..."
lighteval vllm "$MODEL_ARGS" $TASKS \
--custom-tasks "src/open_r1/evaluate.py" \
--use-chat-template \
--output-dir $OUTPUT_DIR \
--save-details \
${7:+--system-prompt "$(echo "$7" | base64 --decode)"}
${7:+--system-prompt "$7"}
else
lighteval vllm "$MODEL_ARGS" $TASKS \
--use-chat-template \
--output-dir $OUTPUT_DIR \
--save-details \
${7:+--system-prompt "$7"}
fi
OUTPUT_FILEPATHS=$(find $OUTPUT_DIR/results/ -type f \( -name "*.json" \))
for filepath in $OUTPUT_FILEPATHS; do

View file

@ -6,7 +6,7 @@
#SBATCH --exclusive
#SBATCH --gpus-per-node=8
#SBATCH --output=./logs/%x-%j.out
#SBATCH --error=./logs/%x-%j.err
#SBATCH --err=./logs/%x-%j.err
#SBATCH --time=04-00:00:00
# Parse command line arguments

23
slurm/launch_sglang.slurm Normal file
View file

@ -0,0 +1,23 @@
#!/bin/bash
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --partition=hopper-prod
#SBATCH --output=/fsx/open-r1/logs/%x-%j.out
#SBATCH --err=/fsx/open-r1/logs/%x-%j.err
# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
set -x -e
source ~/.bashrc
source openr1/bin/activate
module load cuda/12.4
echo Starting sglang server...
MODEL_ID=$1
REVISION=$2
PORT=$3
NUM_GPUS=$(nvidia-smi --list-gpus | wc -l)
python3 -m sglang.launch_server --model-path $MODEL_ID --revision $REVISION --port=$PORT --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=$NUM_GPUS

View file

@ -1,18 +0,0 @@
#!/bin/bash
#SBATCH --partition=hopper-cpu
#SBATCH --mem=16g
#SBATCH --cpus-per-task=16
#SBATCH --output=/fsx/open-r1/logs/morph_router/%x-%j.out
#SBATCH --err=/fsx/open-r1/logs/morph_router/%x-%j.err
#SBATCH --requeue
#SBATCH --time=7-00:00:00
echo "Starting job"
set -x -e
source ~/.bashrc
source openr1/bin/activate
srun python scripts/morph_router.py --port 8001 --max_num_sandboxes 20

View file

@ -1,81 +0,0 @@
# Piston workers (slurm)
We have built a [piston](https://github.com/engineer-man/piston) package to run IOI problems.
To launch a fleet of piston workers on a slurm cluster, you can adapt the paths in `launch_piston_workers.sh` and `launch_single_piston.sh` and run:
```bash
slurm/piston/launch_piston_workers.sh (number of workers to launch)
```
This command will launch a slurm job for each worker, which will be called `piston-worker-<port>`, where `<port>` is the port where the worker will be listening.
## First time setup
You will need to install the [IOI package](https://github.com/guipenedo/piston/tree/master/packages/cms_ioi/1.0.0) in the workers.
1. Launch a single worker:
```bash
slurm/piston/launch_piston_workers.sh 1
```
2. Assuming it's running on `ip-10-53-86-146:1234`, send the package install request:
For IOI:
```bash
curl -X POST http://ip-10-53-86-146:1234/api/v2/packages -H "Content-Type: application/json" -d '{"language": "cms_ioi", "version": "1.0.0"}'
```
For CodeForces:
```bash
curl -X POST http://ip-10-53-86-146:1234/api/v2/packages -H "Content-Type: application/json" -d '{"language": "codeforces", "version": "1.0.0"}'
```
3. You can now launch more workers and due to the shared mounted packages directory, they should already have the package installed.
To have the main script find the workers automatically, you can export the following environment variable:
```bash
export PISTON_ENDPOINTS=slurm
```
Alternatively your can add `PISTON_ENDPOINTS=slurm` to your .env file.
You can also change `PISTON_MAX_REQUESTS_PER_ENDPOINT`, which tries to limit how many simultaneous requests each worker will handle (1 by default). Keep in mind that this is a local limit and in distributed setups, as there is no global limit, workers might sometimes be overwhelmed when some processes hit the same worker.
If you would like to adapt the code to run without piston, please see the [ioi repo](https://github.com/huggingface/ioi).
For CodeForces, you should implement the [`run`](https://github.com/guipenedo/piston/blob/master/packages/codeforces/1.0.0/run) and [`compile`](https://github.com/guipenedo/piston/blob/master/packages/codeforces/1.0.0/compile) scripts.
# Piston workers (local docker)
This will launch a single worker in a docker container. Consider launching multiple workers for better scalability. Replace 2000 with the port you want to use.
Make sure to change `/path/to/local/packages` to the path you want to persist for package installs.
```bash
docker run -d \
--name piston_worker \
-v /path/to/local/packages:/piston/packages \
-e PORT=2000 \
-e PISTON_COMPILE_TIMEOUT=60000 \
-e PISTON_RUN_TIMEOUT=60000 \
-e PISTON_OUTPUT_MAX_SIZE=1000000000 \
-e PISTON_MAX_FILE_SIZE=1000000000 \
-e PISTON_DISABLE_NETWORKING=true \
-e PISTON_REPO_URL=https://github.com/guipenedo/piston/releases/download/pkgs/index \
-p 2000:2000 \
--entrypoint /bin/bash \
ghcr.io/engineer-man/piston@sha256:63b5654156a89c5a2ad281aface21416615d62ec056d88efe8fcd307ce73575a \
-c "sed -i '/app.use(body_parser.urlencoded/c\ app.use(body_parser.urlencoded({ extended: true, limit: \"512mb\" }));' src/index.js && \
sed -i '/app.use(body_parser.json/c\ app.use(body_parser.json({ limit: \"512mb\" }));' src/index.js && \
node src"
```
Install the package:
For IOI:
```bash
curl -X POST http://localhost:2000/api/v2/packages -H "Content-Type: application/json" -d '{"language": "cms_ioi", "version": "1.0.0"}'
```
For CodeForces:
```bash
curl -X POST http://localhost:2000/api/v2/packages -H "Content-Type: application/json" -d '{"language": "codeforces", "version": "1.0.0"}'
```
Remember to set `PISTON_ENDPOINTS`:
```bash
export PISTON_ENDPOINTS=http://localhost:2000/api/v2,http://localhost:2001/api/v2,http://localhost:2002/api/v2
```

View file

@ -1,16 +0,0 @@
#!/bin/bash
# this simple script will launch a bunch of piston workers on the HF science cluster
N_INSTANCES=${1:-5} # Default to 5 instances
for i in $(seq 1 $N_INSTANCES); do
# Find random (hopefully) available port
PORT=$(comm -23 <(seq 2000 10000 | sort) <(ss -tan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n1)
# the job name format is important for the code to then be able to get a list of workers. `piston-worker-<port>`
sbatch \
--job-name="piston-worker-$PORT" \
--export=ALL,PORT=$PORT \
slurm/piston/launch_single_piston.sh
done

View file

@ -1,31 +0,0 @@
#!/bin/bash
#SBATCH --job-name=piston_worker
#SBATCH --output=/fsx/open-r1/logs/piston/worker-logs/%x-%j.out
#SBATCH --error=/fsx/open-r1/logs/piston/worker-logs/%x-%j.out # Redirect error logs to .out
#SBATCH --cpus-per-task=2
#SBATCH --mem-per-cpu=1950M
#SBATCH --partition=hopper-cpu
#SBATCH --time=48:00:00
# sometimes if a bunch of workers start at the same time pyxis dies
sleep $(( RANDOM % 20 ))
# mounting the packages folder lets us not have to manually install the package on each instance
# we use 63b5654156a89c5a2ad281aface21416615d62ec056d88efe8fcd307ce73575a as the latest image requires isolate, which does not work on the HF science cluster (cgroups incompatibility)
# feel free try with the latest image
# the code you see below increases the very constrained piston default limits, and sets the repo url to the one hosting our IOI package
srun --container-mounts=/fsx/guilherme/ioi2024/piston_files/packages:/piston/packages --container-image "ghcr.io#engineer-man/piston:sha256:63b5654156a89c5a2ad281aface21416615d62ec056d88efe8fcd307ce73575a" \
bash -c "
export PISTON_COMPILE_TIMEOUT=60000
export PISTON_RUN_TIMEOUT=60000
export PISTON_OUTPUT_MAX_SIZE=1000000000
export PISTON_MAX_FILE_SIZE=1000000000
export PISTON_DISABLE_NETWORKING=true
export PISTON_REPO_URL=https://github.com/guipenedo/piston/releases/download/pkgs/index
sed -i '/app.use(body_parser.urlencoded/c\ app.use(body_parser.urlencoded({ extended: true, limit: \"512mb\" }));' src/index.js
sed -i '/app.use(body_parser.json/c\ app.use(body_parser.json({ limit: \"512mb\" }));' src/index.js
# Start server in background
node src
"

View file

@ -1,102 +1,52 @@
#!/bin/bash
#SBATCH --job-name=open_r1
#SBATCH --job-name=open-r1-sft
#SBATCH --ntasks-per-node=1
#SBATCH --exclusive
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod # Adjust this for your cluster
#SBATCH --output=./logs/%x-%j.out
#SBATCH --error=./logs/%x-%j.err
#SBATCH --err=./logs/%x-%j.err
#SBATCH --requeue
#SBATCH --time=3-00:00:00
if [[ "$*" == *"--help"* ]]; then
echo "Usage: sbatch slurm/train.slurm [options]"
echo "Options:"
echo " --model MODEL Model name"
echo " --task TASK Task name (e.g. sft, grpo)"
echo " --config SUFFIX Configuration suffix (e.g. demo, v00.00)"
echo " --accelerator CONFIG Accelerator configuration name (e.g. zero3)"
echo " --dp N Data parallelism for vLLM server (default: 1)"
echo " --tp N Tensor parallelism for vLLM server (default: 1)"
echo " --args \"ARGS\" Optional arguments to pass to the training script"
exit 0
fi
# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
module load cuda/12.4
set -x -e
source ~/.bashrc
source openr1/bin/activate
START_TIME=$(date +%s)
echo "START TIME: $(date)"
# Refresh Weka on h4 cache
echo "Refreshing Weka filesystem..."
find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
# Default values
MODEL=""
TASK=""
CONFIG_SUFFIX=""
ACCELERATOR=""
DP=1
TP=1
OPTIONAL_ARGS=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
--task)
TASK="$2"
shift 2
;;
--config)
CONFIG_SUFFIX="$2"
shift 2
;;
--accelerator)
ACCELERATOR="$2"
shift 2
;;
--dp)
DP="$2"
shift 2
;;
--tp)
TP="$2"
shift 2
;;
--args)
OPTIONAL_ARGS="$2"
shift 2
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Validate required arguments
if [[ -z "$MODEL" || -z "$TASK" || -z "$CONFIG_SUFFIX" || -z "$ACCELERATOR" ]]; then
echo "Error: Missing required arguments"
echo "Run with --help for usage information"
exit 1
fi
MODEL=$1
TASK=$2
CONFIG_SUFFIX=$3
ACCELERATOR=$4
OPTIONAL_ARGS=$5
# Training setup
NUM_NODES=$SLURM_NNODES
GPUS_PER_NODE=8
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
# Check if we are running vLLM during training to adjust the world size
if grep -q 'use_vllm:\s*true' "$CONFIG_FILE"; then
USE_VLLM="true"
else
USE_VLLM="false"
fi
if [[ "$USE_VLLM" == "true" ]]; then
WORLD_SIZE=$(($WORLD_SIZE - 1))
fi
# Split the string into individual arguments
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
for arg in "${ARGS[@]}"; do
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
@ -107,33 +57,27 @@ for arg in "${ARGS[@]}"; do
done
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
MODEL=$(grep 'model_name_or_path:' $CONFIG_FILE | awk '{print $2}')
REVISION=$(grep 'model_revision:' $CONFIG_FILE | head -n 1 | awk '{print $2}')
# Distributed configuration
NUM_NODES=$SLURM_NNODES
GPUS_PER_NODE=8
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
MASTER_ADDR=${NODELIST[0]} # First node for main process
# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
TRAIN_NODES=("${NODELIST[@]}")
USE_VLLM="false"
if [[ -f "$CONFIG_FILE" ]] && grep -qE '^\s*use_vllm:\s*true' "$CONFIG_FILE"; then
USE_VLLM="true"
fi
# if using vllm
if [[ "$USE_VLLM" == "true" ]]; then
TRAIN_NODES=("${NODELIST[@]:0:$((NUM_NODES - 1))}")
VLLM_NODE=${NODELIST[-1]} # Last node
WORLD_SIZE=$((WORLD_SIZE - GPUS_PER_NODE))
NUM_NODES=$((NUM_NODES - 1))
srun --nodes=1 --ntasks=1 --nodelist=$VLLM_NODE trl vllm-serve --model $MODEL --revision $REVISION --tensor_parallel_size $TP --data_parallel_size $DP &
export CMD=" \
src/open_r1/$TASK.py --config $CONFIG_FILE $OPTIONAL_ARGS
"
OPTIONAL_ARGS="$OPTIONAL_ARGS --vllm_server_host=$VLLM_NODE"
fi
export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
--gradient_accumulation_steps $GRAD_ACC_STEPS \
--num_machines $NUM_NODES \
--num_processes $WORLD_SIZE \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--machine_rank \$SLURM_PROCID \
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
--max_restarts 1 \
--role \$(hostname -s): \
--tee 3 \
"
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
@ -143,40 +87,14 @@ export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1
export CMD=" \
src/open_r1/$TASK.py --config $CONFIG_FILE $OPTIONAL_ARGS
"
export LAUNCHER="ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
--gradient_accumulation_steps $GRAD_ACC_STEPS \
--num_machines $NUM_NODES \
--num_processes $WORLD_SIZE \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--machine_rank $SLURM_PROCID \
--rdzv_backend=c10d \
--max_restarts 1 \
--tee 3 \
"
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
NODELIST=$(IFS=,; echo "${TRAIN_NODES[*]}")
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
--nodes=$NUM_NODES \
--ntasks=$NUM_NODES \
--nodelist=$NODELIST
"
srun $SRUN_ARGS bash -c "$LAUNCHER $CMD" 2>&1
END_TIME=$(date +%s)
echo "END TIME: $(date)"
ELAPSED_SECONDS=$((END_TIME - START_TIME))
HOURS=$((ELAPSED_SECONDS / 3600))
MINUTES=$(( (ELAPSED_SECONDS % 3600) / 60 ))
SECONDS=$((ELAPSED_SECONDS % 60))
echo "TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
echo "END TIME: $(date)"

93
slurm/train_remote.slurm Normal file
View file

@ -0,0 +1,93 @@
#!/bin/bash
#SBATCH --job-name=open-r1-sft
#SBATCH --ntasks-per-node=1
#SBATCH --exclusive
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod # Adjust this for your cluster
#SBATCH --output=./logs/%x-%j.out
#SBATCH --err=./logs/%x-%j.err
"""Usage:
sbatch --job-name=remote-grpo --nodes=1 slurm/train_remote.slurm Qwen2.5-1.5B-Instruct grpo remote zero3
"""
# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
module load cuda/12.4
set -x -e
source ~/.bashrc
source openr1/bin/activate
echo "START TIME: $(date)"
MODEL=$1
TASK=$2
CONFIG_SUFFIX=$3
ACCELERATOR=$4
OPTIONAL_ARGS=$5
# Training setup
NUM_NODES=$SLURM_NNODES
GPUS_PER_NODE=8
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
# Split the string into individual arguments
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
for arg in "${ARGS[@]}"; do
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
# Extract the value after the equals sign
GRAD_ACC_STEPS="${arg#*=}"
break # Exit the loop once we find the desired argument
fi
done
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
export CMD=" \
scripts/remote_grpo.py --config $CONFIG_FILE $OPTIONAL_ARGS
"
export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
--gradient_accumulation_steps $GRAD_ACC_STEPS \
--num_machines $NUM_NODES \
--num_processes $WORLD_SIZE \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--machine_rank \$SLURM_PROCID \
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
--max_restarts 1 \
--role \$(hostname -s): \
--tee 3 \
"
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
echo "END TIME: $(date)"

View file

@ -14,112 +14,11 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from typing import Optional
import trl
@dataclass
class DatasetConfig:
"""Configuration for a dataset in a mixture."""
id: str
config: Optional[str] = None
split: str = "train"
columns: Optional[list[str]] = None
weight: Optional[float] = None
@dataclass
class DatasetMixtureConfig:
"""Configuration for a mixture of datasets."""
datasets: list[DatasetConfig]
seed: int = 0
test_split_size: Optional[float] = None
@dataclass
class ScriptArguments(trl.ScriptArguments):
"""
Extended version of ScriptArguments with support for dataset mixtures.
Args:
dataset_mixture (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Configuration for creating dataset mixtures with advanced options.
Format:
dataset_mixture:
datasets:
- id: dataset_id1
config: config_name
columns:
- col1
- col2
weight: 0.5
- id: dataset_id2
config: config_name
columns:
- col1
- col2
weight: 0.5
seed: 42
test_split_size: 0.1
"""
# Override the dataset_name to make it optional
dataset_name: Optional[str] = field(
default=None, metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."}
)
dataset_mixture: Optional[dict[str, Any]] = field(
default=None,
metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."},
)
def __post_init__(self):
if self.dataset_name is None and self.dataset_mixture is None:
raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")
if self.dataset_mixture is not None:
if not isinstance(self.dataset_mixture, dict) or "datasets" not in self.dataset_mixture:
raise ValueError(
"dataset_mixture must be a dictionary with a 'datasets' key. "
"Expected format: {'datasets': [...], 'seed': int}"
)
datasets_list = []
datasets_data = self.dataset_mixture.get("datasets", [])
if isinstance(datasets_data, list):
for dataset_config in datasets_data:
datasets_list.append(
DatasetConfig(
id=dataset_config.get("id"),
config=dataset_config.get("config"),
split=dataset_config.get("split", "train"),
columns=dataset_config.get("columns"),
weight=dataset_config.get("weight", 1.0),
)
)
else:
raise ValueError("'datasets' must be a list of dataset configurations")
self.dataset_mixture = DatasetMixtureConfig(
datasets=datasets_list,
seed=self.dataset_mixture.get("seed", 0),
test_split_size=self.dataset_mixture.get("test_split_size", None),
)
# Check that column names are consistent across all dataset configs
columns_sets = [set(dataset.columns) for dataset in datasets_list if dataset.columns is not None]
if columns_sets:
first_columns = columns_sets[0]
if not all(columns == first_columns for columns in columns_sets):
raise ValueError(
"Column names must be consistent across all dataset configurations in a mixture. "
f"Found different column sets: {[list(cols) for cols in columns_sets]}"
)
# TODO: add the shared options with a mixin to reduce code duplication
@dataclass
class GRPOConfig(trl.GRPOConfig):
@ -128,30 +27,21 @@ class GRPOConfig(trl.GRPOConfig):
"""
benchmarks: list[str] = field(
default_factory=lambda: [],
metadata={"help": "The benchmarks to run after training."},
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [],
metadata={"help": "The callbacks to run during training."},
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
num_completions_to_print: int = field(default=0, metadata={"help": "Number of completions to print."})
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "The optional system prompt to use."},
)
wandb_log_unique_prompts: bool = field(
default=True,
metadata={
"help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.")
},
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
@ -160,10 +50,6 @@ class GRPOConfig(trl.GRPOConfig):
default=None,
metadata={"help": ("The project to store runs under.")},
)
wandb_run_group: Optional[str] = field(
default=None,
metadata={"help": ("The group to store runs under.")},
)
@dataclass
@ -173,12 +59,10 @@ class SFTConfig(trl.SFTConfig):
"""
benchmarks: list[str] = field(
default_factory=lambda: [],
metadata={"help": "The benchmarks to run after training."},
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [],
metadata={"help": "The callbacks to run during training."},
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
system_prompt: Optional[str] = field(
@ -199,133 +83,3 @@ class SFTConfig(trl.SFTConfig):
default=None,
metadata={"help": ("The project to store runs under.")},
)
wandb_run_group: Optional[str] = field(
default=None,
metadata={"help": ("The group to store runs under.")},
)
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', 'tag_count', 'code', 'ioi_code', 'code_format', 'soft_overlong_punishment'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
code_language (`str`):
Language for code format reward.
max_completion_len (`int`):
Maximum number of tokens in completion.
soft_punish_cache (`int`):
Minimum number of tokens in completion.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format", "tag_count"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
code_language: str = field(
default="python",
# '(?:python|cpp)'
metadata={
"help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages",
"choices": ["python", "javascript", "r", "java", "bash", "cpp"],
},
)
code_eval_test_batch_size: int = field(
default=1,
metadata={
"help": "for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions"
},
)
code_eval_scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = field(
default="weighted_sum",
metadata={"help": "use fraction of passed test cases as reward. If false, use 0/1 scoring."},
)
parallel_code_exec_per_proc: int = field(
default=2,
metadata={
"help": "Number of parallel E2B code executions per process. Default of 2 is suitable for the Free Hobby tier of E2B with 8 GPUs used for training."
},
)
dataset_prompt_column: str = field(
default="prompt",
metadata={"help": "Column to use as prompts for training."},
)
e2b_router_url: Optional[str] = field(
default=None,
metadata={"help": "URL for the E2B router. See scripts/e2b_router.py"},
)
morph_router_url: Optional[str] = field(
default=None,
metadata={"help": "URL for the MorphCloud router. See scripts/morph_router.py"},
)
code_provider: Optional[str] = field(
default="e2b",
metadata={
"help": "Provider for code execution. Options: 'e2b', 'local', 'morph'.",
"choices": ["e2b", "local", "morph"],
},
)
ioi_provider: Optional[str] = field(
default="piston",
metadata={
"help": "Provider for IOI code execution. Options: 'piston', 'morph'.",
"choices": ["piston", "morph"],
},
)
max_completion_len: int = field(
default=16384,
metadata={"help": "Maximum number of characters in completion."},
)
soft_punish_cache: int = field(
default=4096,
metadata={"help": "Minimum number of characters in completion."},
)

185
src/open_r1/evaluate.py Normal file
View file

@ -0,0 +1,185 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom evaluation tasks for LightEval."""
import random
from lighteval.metrics.dynamic_metrics import (
ExprExtractionConfig,
IndicesExtractionConfig,
LatexExtractionConfig,
multilingual_extractive_match_metric,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
# Prompt template adapted from
# - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17
# - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details
# Note that it is important to have the final answer in a box for math-verify to work correctly
MATH_QUERY_TEMPLATE = """
Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
{Question}
""".strip()
# Prompt template from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14
GPQA_QUERY_TEMPLATE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()
latex_gold_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(LatexExtractionConfig(),),
# Match boxed first before trying other regexes
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
aggregation_function=max,
)
expr_gold_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(ExprExtractionConfig(),),
# Match boxed first before trying other regexes
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
aggregation_function=max,
)
gpqa_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
precision=5,
)
def math_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["solution"]],
gold_index=0,
)
def aime_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["answer"]],
gold_index=0,
)
def gpqa_prompt_fn(line, task_name: str = None):
gold_index = random.randint(0, 3)
choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]]
choices.insert(gold_index, line["Correct Answer"])
query = GPQA_QUERY_TEMPLATE.format(
A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"]
)
return Doc(
task_name=task_name,
query=query,
choices=["A", "B", "C", "D"],
gold_index=gold_index,
instruction=query,
)
# Define tasks
aime24 = LightevalTaskConfig(
name="aime24",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="HuggingFaceH4/aime_2024",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[expr_gold_metric],
version=1,
)
aime25 = LightevalTaskConfig(
name="aime25",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="yentinglin/aime_2025",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[expr_gold_metric],
version=1,
)
math_500 = LightevalTaskConfig(
name="math_500",
suite=["custom"],
prompt_function=math_prompt_fn,
hf_repo="HuggingFaceH4/MATH-500",
hf_subset="default",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[latex_gold_metric],
version=1,
)
gpqa_diamond = LightevalTaskConfig(
name="gpqa:diamond",
suite=["custom"],
prompt_function=gpqa_prompt_fn,
hf_repo="Idavidrein/gpqa",
hf_subset="gpqa_diamond",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768, # needed for reasoning models like R1
metric=[gpqa_metric],
stop_sequence=[], # no stop sequence, will use eos token
trust_dataset=True,
version=1,
)
# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(aime25)
TASKS_TABLE.append(math_500)
TASKS_TABLE.append(gpqa_diamond)
# MODULE LOGIC
if __name__ == "__main__":
print([t["name"] for t in TASKS_TABLE])
print(len(TASKS_TABLE))

View file

@ -53,7 +53,7 @@ def build_distilabel_pipeline(
generation_kwargs=generation_kwargs,
),
template=prompt_template,
input_mappings=({"instruction": prompt_column} if prompt_column is not None else {}),
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
input_batch_size=input_batch_size,
num_generations=num_generations,
group_generations=True,

View file

@ -15,23 +15,101 @@
import logging
import os
import sys
from dataclasses import dataclass, field
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import GRPOConfig, GRPOScriptArguments
from open_r1.rewards import get_reward_funcs
from open_r1.utils import get_dataset, get_model, get_tokenizer
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
code_reward,
format_reward,
get_code_format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
tag_count_reward,
)
from open_r1.utils import get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, TrlParser, get_peft_config
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
logger = logging.getLogger(__name__)
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', 'tag_count', 'code', 'code_format'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
code_language (`str`):
Language for code format reward.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format", "tag_count"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
code_language: str = field(
default="python",
metadata={
"help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages",
"choices": ["python", "javascript", "r", "java", "bash"],
},
)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
@ -71,33 +149,44 @@ def main(script_args, training_args, model_args):
init_wandb_training(training_args)
# Load the dataset
dataset = get_dataset(script_args)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, training_args)
##############
# Load model #
##############
logger.info("*** Loading model ***")
model = get_model(model_args, training_args)
# Get reward functions from the registry
reward_funcs = get_reward_funcs(script_args)
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
"code": code_reward,
"code_format": get_code_format_reward(language=script_args.code_language),
"tag_count": tag_count_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
# Format into conversation
def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column):
def make_conversation(example):
prompt = []
if training_args.system_prompt is not None:
prompt.append({"role": "system", "content": training_args.system_prompt})
if prompt_column not in example:
raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")
prompt.append({"role": "user", "content": example[prompt_column]})
prompt.append({"role": "user", "content": example["problem"]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
@ -106,15 +195,28 @@ def main(script_args, training_args, model_args):
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
#############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
model=model,
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
processing_class=tokenizer,
@ -140,9 +242,6 @@ def main(script_args, training_args, model_args):
# Save model and create model card
##################################
logger.info("*** Save model ***")
# Align the model's generation config with the tokenizer's eos token
# to avoid unbounded generation in the transformers `pipeline()` function
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")

View file

@ -1,43 +1,27 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reward functions for GRPO training."""
import asyncio
import json
import math
import re
from functools import partial, update_wrapper
from typing import Callable, Dict, Literal, Optional
from typing import Dict
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from .utils.code_providers import get_provider
from .utils.competitive_programming import (
SubtaskResult,
add_includes,
get_morph_client_from_env,
get_piston_client_from_env,
)
from .utils.competitive_programming import patch_code as cf_patch_code
from .utils.competitive_programming import score_submission as cf_score_submission
from .utils.competitive_programming import score_subtask
from .utils import is_e2b_available
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
if is_e2b_available():
from dotenv import load_dotenv
from e2b_code_interpreter import AsyncSandbox
load_dotenv()
else:
AsyncSandbox = None
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
@ -45,6 +29,7 @@ def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str]
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
@ -67,15 +52,15 @@ def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str]
],
extraction_mode="first_match",
)
# Compute binary rewards if verifiable, `None` otherwise to skip this example
# Reward 1 if the content is the same as the ground truth, 0 otherwise
try:
reward = float(verify(gold_parsed, answer_parsed))
reward = float(verify(answer_parsed, gold_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
reward = None
reward = 0.0
else:
# If the gold solution is not parseable, we assign `None` to skip this example
reward = None
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
@ -132,7 +117,7 @@ def reasoning_steps_reward(completions, **kwargs):
def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599
Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
Args:
completions: List of model completions
@ -230,11 +215,7 @@ def get_cosine_scaled_reward(
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) == 0:
rewards.append(1.0) # Skip unparseable examples
print("Failed to parse gold solution: ", sol)
@ -282,41 +263,21 @@ def get_cosine_scaled_reward(
return cosine_scaled_reward
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, language: str = "en"):
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://huggingface.co/papers/2502.03373.
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
language: Language of the text, defaults to `en`. Used to choose the way to split the text into n-grams.
"""
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")
if language == "en":
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)]), words
elif language == "zh":
from transformers.utils.import_utils import _is_package_available
if not _is_package_available("jieba"):
raise ValueError("Please install jieba to use Chinese language")
def zipngram(text: str, ngram_size: int):
import jieba
seg_list = list(jieba.cut(text))
return zip(*[seg_list[i:] for i in range(ngram_size)]), seg_list
else:
raise ValueError(
f"Word splitting for language `{language}` is not yet implemented. Please implement your own zip-ngram function."
)
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
def repetition_penalty_reward(completions, **kwargs) -> float:
"""
@ -333,16 +294,13 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, language:
if completion == "":
rewards.append(0.0)
continue
ngrams = set()
total = 0
ngram_array, words = zipngram(completion, ngram_size)
if len(words) < ngram_size:
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue
for ng in ngram_array:
ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngrams.add(ng)
total += 1
@ -354,178 +312,26 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, language:
return repetition_penalty_reward
def _init_event_loop():
"""Initialize or get the current event loop."""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def ioi_code_reward(completions, test_batch_size: int = 1, provider_type: str = "piston", **kwargs) -> list[float]:
"""Reward function that evaluates IOI problems using a specified execution client.
Assumes the dataset has the same format as hf.co/datasets/open-r1/ioi
Args:
completions: List of model completions to evaluate
test_batch_size: Evaluate these many test cases in parallel, then check if any of them failed (0 score):
if so stop evaluating; otherwise continue with the next batch of test cases.
provider_type: The execution provider to use (default: "piston"). Supported values: "piston", "morph"
**kwargs: Additional arguments passed from the dataset
"""
# Get the appropriate client based on provider_type
if provider_type == "morph":
execution_client = get_morph_client_from_env()
else:
# for info on setting up piston workers, see slurm/piston/README.md
execution_client = get_piston_client_from_env()
code_snippets = [
# note: grading is automatically skipped if no code is extracted
add_includes(extract_code(completion[-1]["content"], "cpp"), problem_id)
for completion, problem_id in zip(completions, kwargs["id"])
]
async def run_catch_exceptions(task):
try:
return await task
except Exception as e:
print(f"Error from {provider_type} worker: {e}")
return SubtaskResult()
problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())]
loop = _init_event_loop()
evals = [
loop.create_task(
run_catch_exceptions(
score_subtask(
execution_client,
problem_data,
code,
test_batch_size=test_batch_size,
)
)
)
for problem_data, code in zip(problems_data, code_snippets)
]
results = loop.run_until_complete(asyncio.gather(*evals))
return [result.score for result in results]
def cf_code_reward(
completions,
test_batch_size: int = 1,
patch_code: bool = False,
scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = "weighted_sum",
**kwargs,
) -> list[float]:
"""Reward function that evaluates Codeforces problems using Piston+our CF package.
Assumes the dataset has the same format as hf.co/datasets/open-r1/codeforces (verifiable-prompts subset)
test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.
"""
# for info on setting up piston workers, see slurm/piston/README.md
piston_client = get_piston_client_from_env()
languages = kwargs["language"] if "language" in kwargs else [None] * len(completions)
code_snippets = [
# note: grading is automatically skipped if a problem has no tests
cf_patch_code(extract_code(completion[-1]["content"], language), language)
if patch_code
else extract_code(completion[-1]["content"], language)
for completion, language in zip(completions, languages)
]
async def run_catch_exceptions(task):
try:
return await task
except Exception as e:
print(f"Error from Piston worker: {e}")
return None
# load problem data. undo separating kwargs by column
problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())]
loop = _init_event_loop()
evals = [
loop.create_task(
run_catch_exceptions(
cf_score_submission(
piston_client,
problem_data,
code,
test_batch_size=test_batch_size,
scoring_mode=scoring_mode,
submission_language=problem_data.get("language", None),
)
)
)
for problem_data, code in zip(problems_data, code_snippets)
]
results = loop.run_until_complete(asyncio.gather(*evals))
return results
def extract_code(completion: str, language: str | None = "python") -> str:
if language is None:
return ""
pattern = re.compile(rf"```{language}\n(.*?)```", re.DOTALL)
def extract_code(completion: str) -> str:
pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
matches = pattern.findall(completion)
extracted_answer = matches[-1] if len(matches) >= 1 else ""
return extracted_answer
def binary_code_reward(
completions,
num_parallel: int = 2,
provider_type: str = "e2b",
enforce_same_language: bool = False,
**kwargs,
) -> list[float]:
rewards = code_reward(
completions,
num_parallel=num_parallel,
provider_type=provider_type,
enforce_same_language=enforce_same_language,
**kwargs,
)
BINARY_THRESHOLD = 0.99
output = []
for reward in rewards:
if reward is None:
output.append(None)
else:
output.append(1.0 if reward > BINARY_THRESHOLD else 0.0)
return output
def code_reward(
completions,
num_parallel: int = 2,
provider_type: str = "e2b",
enforce_same_language: bool = False,
**kwargs,
) -> list[float]:
"""Reward function that evaluates code snippets using a code execution provider.
def code_reward(completions, **kwargs) -> list[float]:
"""Reward function that evaluates code snippets using the E2B code interpreter.
Assumes the dataset contains a `verification_info` column with test cases.
Args:
completions: List of model completions to evaluate
num_parallel: Number of parallel code executions (default: 2)
provider_type: Which code execution provider to use (default: "e2b")
enforce_same_language: If True, verify all problems use the same language (default: False)
**kwargs: Additional arguments passed to the verification
"""
if not is_e2b_available():
raise ImportError(
"E2B is not available and required for this reward function. Please install E2B with "
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
)
# TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages
"""Returns a reward function that evaluates code snippets in a sandbox."""
evaluation_script_template = """
import subprocess
import json
@ -565,31 +371,25 @@ def code_reward(
evaluate_code(code_snippet, test_cases)
"""
code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
verification_info = kwargs["verification_info"]
template = evaluation_script_template
scripts = [
template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"])))
evaluation_script_template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"])))
for code, info in zip(code_snippets, verification_info)
]
language = verification_info[0]["language"]
if enforce_same_language:
all_same_language = all(v["language"] == language for v in verification_info)
if not all_same_language:
raise ValueError("All verification_info must have the same language", verification_info)
if not all(v["language"] == language for v in verification_info):
raise ValueError("All verification_info must have the same language", verification_info)
try:
rewards = run_async_from_sync(scripts, language)
execution_provider = get_provider(
provider_type=provider_type,
num_parallel=num_parallel,
**kwargs,
)
except Exception as e:
print(f"Error from E2B executor: {e}")
rewards = [0.0] * len(completions)
return execution_provider.execute_scripts(scripts, ["python"] * len(scripts))
return rewards
def get_code_format_reward(language: str = "python"):
@ -598,109 +398,52 @@ def get_code_format_reward(language: str = "python"):
Args:
language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages
"""
pattern = rf"^<think>\n.*?\n</think>\n<answer>\n.*?```{language}.*?```.*?\n</answer>$"
def code_format_reward(completions, **kwargs):
# if there is a language field, use it instead of the default language. This way we can have mixed language training.
languages = kwargs["language"] if "language" in kwargs else [language] * len(completions)
completion_contents = [completion[0]["content"] for completion in completions]
matches = [
re.match(
rf"^<think>\n.*?\n</think>\n<answer>\n.*?```{sample_language}.*?```.*?\n</answer>$",
content,
re.DOTALL | re.MULTILINE,
)
for content, sample_language in zip(completion_contents, languages)
]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
return code_format_reward
def get_soft_overlong_punishment(max_completion_len, soft_punish_cache):
"""
Reward function that penalizes overlong completions. It is used to penalize overlong completions,
but not to reward shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476)
def run_async_from_sync(scripts: list[str], language: str) -> list[float]:
"""Function wrapping the `run_async` function."""
# Create a new event loop and set it
try:
# Run the async function and get the result
rewards = asyncio.run(run_async(scripts, language))
except Exception as e:
print(f"Error from E2B executor async: {e}")
raise e
Args:
max_completion_len: Maximum length of the completion
soft_punish_cache: Minimum length of the completion. If set to 0, no minimum length is applied.
"""
def soft_overlong_punishment_reward(completion_ids: list[list[int]], **kwargs) -> list[float]:
"""Reward function that penalizes overlong completions."""
rewards = []
for ids in completion_ids:
completion_length = len(ids)
if completion_length <= max_completion_len - soft_punish_cache:
rewards.append(0.0)
elif max_completion_len - soft_punish_cache < completion_length <= max_completion_len:
rewards.append((max_completion_len - soft_punish_cache - completion_length) / soft_punish_cache)
else:
rewards.append(-1.0)
return rewards
return soft_overlong_punishment_reward
return rewards
def get_reward_funcs(script_args) -> list[Callable]:
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
"code": update_wrapper(
partial(
code_reward,
num_parallel=script_args.parallel_code_exec_per_proc,
provider_type=script_args.code_provider,
enforce_same_language=getattr(script_args, "enforce_same_language", False),
),
code_reward,
),
"binary_code": update_wrapper(
partial(
binary_code_reward,
num_parallel=script_args.parallel_code_exec_per_proc,
provider_type=script_args.code_provider,
enforce_same_language=getattr(script_args, "enforce_same_language", False),
),
binary_code_reward,
),
"ioi_code": update_wrapper(
partial(
ioi_code_reward,
test_batch_size=script_args.code_eval_test_batch_size,
provider_type=getattr(script_args, "ioi_provider", "piston"),
),
ioi_code_reward,
),
"cf_code": update_wrapper(
partial(
cf_code_reward,
test_batch_size=script_args.code_eval_test_batch_size,
scoring_mode=script_args.code_eval_scoring_mode,
),
cf_code_reward,
),
"code_format": get_code_format_reward(language=script_args.code_language),
"tag_count": tag_count_reward,
"soft_overlong_punishment": get_soft_overlong_punishment(
max_completion_len=script_args.max_completion_len,
soft_punish_cache=script_args.soft_punish_cache,
),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
async def run_async(scripts: list[str], language: str) -> list[float]:
# Create the sandbox by hand, currently there's no context manager for this version
sbx = await AsyncSandbox.create(timeout=60, request_timeout=5)
return reward_funcs
# Create a list of tasks for running scripts concurrently
MAX_TASKS_PER_PROCESS = 2 # E2B has a limit of 20 concurrent requests, assume 1 noe, 8 processes, this is 2 per process (20//8 = 2)
semaphore = asyncio.Semaphore(MAX_TASKS_PER_PROCESS)
tasks = [run_script(sbx, script, language, semaphore) for script in scripts]
# Wait for all tasks to complete and gather their results as they finish
results = await asyncio.gather(*tasks)
rewards = list(results) # collect results
# Kill the sandbox after all the tasks are complete
await sbx.kill()
return rewards
async def run_script(sbx: AsyncSandbox, script: str, language: str, semaphore) -> float:
async with semaphore: # Limit concurrency
execution = await sbx.run_code(script, language=language)
try:
return float(execution.text)
except (TypeError, ValueError):
return 0.0

View file

@ -19,18 +19,20 @@ Usage:
# One 1 node of 8 x H100s
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \
--dataset_name open-r1/Mixture-of-Thoughts \
--dataset_config all \
--eos_token '<|im_end|>' \
--learning_rate 4.0e-5 \
--num_train_epochs 5 \
--max_seq_length 32768 \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--max_seq_length 4096 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--bf16 \
--use_liger_kernel \
--output_dir data/OpenR1-Distill-7B
--logging_steps 5 \
--eval_strategy steps \
--eval_steps 100 \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
"""
import logging
@ -38,21 +40,32 @@ import os
import sys
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import ScriptArguments, SFTConfig
from open_r1.utils import get_dataset, get_model, get_tokenizer
from open_r1.configs import SFTConfig
from open_r1.utils import get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format
from trl import (
ModelConfig,
ScriptArguments,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
logger = logging.getLogger(__name__)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
@ -84,25 +97,44 @@ def main(script_args, training_args, model_args):
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
######################################
# Load dataset, tokenizer, and model #
######################################
dataset = get_dataset(script_args)
tokenizer = get_tokenizer(model_args, training_args)
model = get_model(model_args, training_args)
################
# Load datasets
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
if tokenizer.chat_template is None:
logger.info("No chat template provided, defaulting to ChatML.")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, training_args)
tokenizer.pad_token = tokenizer.eos_token
###################
# Model init kwargs
###################
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
############################
# Initialize the SFT Trainer
############################
trainer = SFTTrainer(
model=model,
model=model_args.model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
@ -128,9 +160,6 @@ def main(script_args, training_args, model_args):
# Save model and create model card
##################################
logger.info("*** Save model ***")
# Align the model's generation config with the tokenizer's eos token
# to avoid unbounded generation in the transformers `pipeline()` function
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")

View file

@ -0,0 +1,673 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# great reference: https://github.com/vllm-project/vllm/issues/11400
import contextlib
import functools
import gc
import math
import os
import tempfile
import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing import reduction
from typing import Callable, Optional, Union
from unittest.mock import patch
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorWithPadding,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_liger_kernel_available
import trl
from accelerate import Accelerator
from accelerate.utils import gather_object
from open_r1.trainers.job_launcher import SGLangSlurmJobLauncher
from open_r1.trainers.remote_model import RemoteModel
from trl.data_utils import is_conversational, maybe_apply_chat_template
from trl.trainer.utils import pad, selective_log_softmax
from vllm import LLM, SamplingParams
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
if is_wandb_available():
import wandb
@contextlib.contextmanager
def profiling_context(instance, name):
"""
A context manager function for profiling a block of code.
Can also be used as a decorator.
"""
start_time = time.perf_counter()
yield
end_time = time.perf_counter()
duration = end_time - start_time
if "wandb" in instance.args.report_to and wandb.run is not None and instance.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {instance.__class__.__name__}.{name}": duration})
def profiling_decorator(func):
"""
Decorator to profile a function and log execution time using profiling_context.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with profiling_context(self, func.__name__):
return func(self, *args, **kwargs)
return wrapper
from accelerate import Accelerator
if is_wandb_available():
import wandb
def exact_div(a, b, custom_error_message=""):
q = a // b
if a != q * b:
raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}")
return q
# TODO: add the shared options with a mixin to reduce code duplication
@dataclass
class FastGRPOConfig(trl.GRPOConfig):
"""
args for callbacks, benchmarks etc
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
system_prompt: Optional[str] = field(
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
)
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
remote_gen_model_url: str = field(
default="26.0.165.24",
)
remote_gen_model_port: str = field(
default="30010",
)
remote_gen_model_n_gpus: str = field(
default=8,
)
class FastGRPOTrainer(Trainer):
_tag_names = ["trl", "fast_grpo"]
def __init__(
self,
model: str, # only accept str for now
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: FastGRPOConfig,
train_dataset: Dataset,
processing_class: Optional[PreTrainedTokenizerBase] = None,
data_collator: Optional[DataCollatorWithPadding] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
) -> None:
self.args = args
self.reward_funcs = reward_funcs
# Reward weights (move this logic to post_init of config?)
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
self.reward_weights = args.reward_weights
else:
self.reward_weights = ([1.0] * len(reward_funcs),)
# start the remote model so it has time to warmup while we load the local model(s)
if self.args.remote_gen_model_url is None:
self.sglang_job_launcher = SGLangSlurmJobLauncher(
model, num_gpus=self.args.remote_gen_model_n_gpus, sglang_port=self.args.remote_gen_model_port
)
self.sglang_job_launcher.submit_job()
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
)
model_str = model
model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs)
# offload to cpu
ref_model = AutoModelForCausalLM.from_pretrained(model_str, **model_init_kwargs) # .to("cpu")
self.model = model
self.ref_model = ref_model
if self.args.use_liger_kernel:
if is_liger_kernel_available():
from liger_kernel.transformers import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=self.model)
_apply_liger_kernel_to_instance(model=self.ref_model)
else:
raise ImportError(
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
"Please install it with `pip install liger-kernel`"
)
# Processing class
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
self.processing_class = processing_class
self.train_dataset = train_dataset
if data_collator is not None:
raise ValueError("")
def data_collator(features): # No data collation is needed in GRPO
return features
self.data_collator = data_collator
local_dataloader_batch_size = exact_div(
args.per_device_train_batch_size * args.gradient_accumulation_steps,
args.num_generations,
"per_device_train_batch_size * gradient_accumulation_steps must >= num_generations to remain on policy",
)
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
self.accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
self.train_dataset_len = len(self.train_dataset)
num_total_samples = int(self.args.num_train_epochs * self.train_dataset_len)
self.total_steps_per_device = num_total_samples // (
local_dataloader_batch_size * self.accelerator.num_processes
)
self.create_optimizer_and_scheduler(num_training_steps=self.total_steps_per_device)
#########
### trainer specifics
#########
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = TrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# Create distant repo and output directory if needed
self.hub_model_id = None
if self.args.push_to_hub:
self.init_hf_repo()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
self.backup_model = None
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
#########
### setup dataloader
#########
self.dataloader = DataLoader(
self.train_dataset,
batch_size=local_dataloader_batch_size,
shuffle=True,
collate_fn=self.data_collator,
drop_last=True,
)
torch.manual_seed(args.seed)
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
self.model = self._enable_gradient_checkpointing(self.model, self.args)
self.model, self.optimizer, self.dataloader = self.accelerator.prepare(
self.model, self.optimizer, self.dataloader
)
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
# connect to a remote sglang model
if self.args.remote_gen_model_url is None:
self.sglang_job_launcher.wait_for_server()
self.args.remote_gen_model_url = self.sglang_job_launcher.get_remote_ip()
self.remote_model = RemoteModel(
self.args.remote_gen_model_url, self.args.remote_gen_model_port, self.processing_class.eos_token_id
)
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: FastGRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False
model.gradient_checkpointing_enable()
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
model.enable_input_require_grads()
return model
def print_gpu_memory_usage(self):
if torch.cuda.is_available():
gpu_memory_allocated = torch.cuda.memory_allocated()
gpu_memory_reserved = torch.cuda.memory_reserved()
print(f"GPU memory allocated: {gpu_memory_allocated / (1024**3):.2f} GB")
print(f"GPU memory reserved: {gpu_memory_reserved / (1024**3):.2f} GB")
else:
print("CUDA is not available.")
# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
@torch.no_grad()
@profiling_decorator
def _prepare_batch(self, batch):
"""
This will:
- generate k samples for each problem
- using internal reward model(s) to get rewards
"""
device = self.accelerator.device
prompts = [x["prompt"] for x in batch]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in batch]
prompt_inputs = self.processing_class(prompts_text)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
# add cuda clear cache here and a sleep
all_outputs = self.remote_model.generate(
prompt_ids,
max_new_tokens=self.args.max_completion_length,
temperature=self.args.temperature,
num_generations=self.args.num_generations,
)
# all_outputs = self.gen_vllm.generate(prompts_text, sampling_params=self.sampling_params, use_tqdm=True)
completion_ids = [example["completion_ids"] for example in all_outputs]
# completion_ids = []
# for outputs in all_outputs:
# for output in outputs.outputs:
# completion_ids.append(output.token_ids)
# Decode the generated completions
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
repeated_prompts = []
for prompt in prompts:
repeated_prompts.extend([prompt] * self.args.num_generations)
repeated_prompt_texts = []
for prompt in prompts_text:
repeated_prompt_texts.extend([prompt] * self.args.num_generations)
if is_conversational(batch[0]):
completions = []
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
for (
i,
reward_func,
) in enumerate(self.reward_funcs):
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in batch[0] if key not in ["prompt", "completion"]]
reward_kwargs = defaultdict(list)
for example in batch:
for key in keys:
reward_kwargs[key].extend([example[key]] * self.args.num_generations)
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions, **reward_kwargs)
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
# calculate the advantages, the prompt is all on the same device to no need to gather here
grouped_rewards = rewards.sum(-1).view(len(prompts), self.args.num_generations)
EPS = 1e-4
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
grouped_rewards.std(-1, keepdim=True) + EPS
)
advantages = grouped_advantages.flatten().tolist()
# build batch as list of dicts
examples = []
for i, prompt in enumerate(repeated_prompt_texts):
example = {
"prompt": prompt,
"prompt_ids": prompt_ids[i // self.args.num_generations],
"completion": completions_text[i],
"completion_ids": completion_ids[i],
"advantages": advantages[i],
"rewards": rewards[i],
}
examples.append(example)
return examples
@profiling_decorator
def _sync_weights(self):
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
start = time.time()
with tempfile.TemporaryDirectory(dir="/fsx/edward/work/open-r1/data/") as temp_dir_path:
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.save_pretrained(temp_dir_path)
self.remote_model.load_weights_from_path(temp_dir_path)
print("weight sync took: ", time.time() - start)
self.accelerator.wait_for_everyone()
def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
):
start_step = 1 # todo, set this when we resume + load model, opt state etc
if self.args.logging_steps is not None:
if self.args.logging_steps < 1:
self.state.logging_steps = math.ceil(self.state.max_steps * self.args.logging_steps)
else:
self.state.logging_steps = self.args.logging_steps
if self.args.save_steps is not None:
if self.args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * self.args.save_steps)
else:
self.state.save_steps = self.args.save_steps
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
self.state.global_step = 0
self.state.max_steps = self.total_steps_per_device
self.state.num_train_epochs = self.args.num_train_epochs
def repeat_generator():
while True:
yield from self.dataloader
iter_dataloader = iter(repeat_generator())
self.model.train()
@torch.no_grad()
def mini_batch_collator(examples):
device = self.accelerator.device
prompt_ids = [torch.LongTensor(example["prompt_ids"]) for example in examples]
completion_ids = [torch.LongTensor(example["completion_ids"]) for example in examples]
ref_per_token_logps = [torch.Tensor(example["ref_per_token_logps"]) for example in examples]
for logps, completion_id in zip(ref_per_token_logps, completion_ids):
assert len(logps) == len(completion_id), (
f"len(logps)={len(logps)} != len(completion_id)={len(completion_id)}"
)
pad_token_id = self.processing_class.pad_token_id
padded_prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
padded_completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
padd_ref_per_token_logps = pad(ref_per_token_logps, padding_value=0.0, padding_side="right")
if self.args.max_prompt_length is not None:
padded_prompt_ids = padded_prompt_ids[:, -self.args.max_prompt_length :]
# compute the masks
prompt_mask = (padded_prompt_ids != pad_token_id).long()
# Mask everything after the first EOS token
is_eos = padded_completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1)).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
advantages = torch.Tensor([example["advantages"] for example in examples])
return {
"prompt_ids": padded_prompt_ids.to(device),
"prompt_mask": prompt_mask.to(device),
"completion_ids": padded_completion_ids.to(device),
"completion_mask": completion_mask.to(device),
"advantages": advantages.to(device),
"ref_per_token_logps": padd_ref_per_token_logps.to(device),
}
device = self.accelerator.device
for step in range(start_step, self.total_steps_per_device + 1):
batch = next(iter_dataloader)
batch = self._prepare_batch(batch)
# TODO: log completions, rewards, etc
gen_dataset = Dataset.from_list(batch)
@torch.no_grad()
def compute_ref_logps(examples):
device = self.accelerator.device
prompt_ids = [torch.LongTensor(prompt_id) for prompt_id in examples["prompt_ids"]]
completion_ids = [torch.LongTensor(completion_id) for completion_id in examples["completion_ids"]]
completion_lengths = [len(c) for c in completion_ids]
pad_token_id = self.processing_class.pad_token_id
padded_prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
padded_completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
input_ids = torch.cat([padded_prompt_ids, padded_completion_ids], dim=1)
attention_mask = torch.cat(
[padded_prompt_ids != pad_token_id, padded_completion_ids != pad_token_id], dim=1
)
logits_to_keep = torch.tensor(completion_lengths).to(device)
logits_to_keep = padded_completion_ids.size(1)
with torch.inference_mode():
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, input_ids.to(device), attention_mask.to(device), logits_to_keep
)
ref_per_token_logps = ref_per_token_logps.to("cpu")
examples["ref_per_token_logps"] = [
logprobs[:length] for logprobs, length in zip(ref_per_token_logps, completion_lengths)
]
return examples
self.ref_model = self.ref_model.to(device)
# precompute the ref logprobs and offload the model to cpu
gen_dataset = gen_dataset.map(
compute_ref_logps, batched=True, batch_size=self.args.per_device_train_batch_size
)
self.ref_model = self.ref_model.to("cpu")
# we could add some optimizations here like sorting the dataset by length to improve throughput, but we will keep it simple for now
mini_batch_dataloader = DataLoader(
gen_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True, # we technically don#t need to shuffle due to grad acc, but we may move to clipped loss later
drop_last=True,
collate_fn=mini_batch_collator,
)
# optimization
# stats for logging
losses = []
kls = []
with profiling_context(self, "train_step"):
for mini_batch in mini_batch_dataloader:
loss_metric, kl_metric = self._optimization_step(mini_batch)
losses.append(loss_metric)
kls.append(kl_metric)
self.lr_scheduler.step()
self.state.global_step += 1
self.state.epoch = step / self.total_steps_per_device # TODO, this is not correct
# logging stats
metrics = {}
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["loss"] = self.accelerator.gather_for_metrics(torch.Tensor(losses).to(device)).mean().item()
metrics["kl"] = self.accelerator.gather_for_metrics(torch.Tensor(kls).to(device)).mean().item()
# completions stats
completion_lengths = [len(c) for c in gen_dataset["completion_ids"]]
gathered_completion_lengths = self.accelerator.gather_for_metrics(
torch.Tensor(completion_lengths).to(device)
)
metrics["mean_completion_lengths"] = gathered_completion_lengths.mean().item()
metrics["max_completion_lengths"] = gathered_completion_lengths.max().item()
metrics["min_completion_lengths"] = gathered_completion_lengths.min().item()
# reward stats
rewards = gen_dataset["rewards"]
gathered_rewards = self.accelerator.gather_for_metrics(torch.Tensor(rewards).to(device))
reward_per_func = gathered_rewards.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
reward_func_name = reward_func.__name__
metrics[f"rewards/{reward_func_name}"] = reward_per_func[i].item()
metrics["reward"] = reward_per_func.sum().item()
self.log(metrics)
if self.args.log_completions and "wandb" in self.args.report_to:
import pandas as pd
prompts = gather_object(gen_dataset["prompt"])
completions = gather_object(gen_dataset["completion"])
# For logging
table = {
"step": [str(self.state.global_step)] * len(prompts),
"prompts": prompts,
"completion": completions,
"reward": gathered_rewards.sum(1).tolist(),
}
df = pd.DataFrame(table)
if wandb.run is not None and self.accelerator.is_main_process:
wandb.log({"completions": wandb.Table(dataframe=df)})
# sync weights to remote server
self._sync_weights()
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(self.model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(self.model, trial=None, metrics=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def _optimization_step(self, mini_batch) -> tuple[float, float]:
prompt_ids, prompt_mask = mini_batch["prompt_ids"], mini_batch["prompt_mask"]
completion_ids, completion_mask = mini_batch["completion_ids"], mini_batch["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
ref_per_token_logps = mini_batch["ref_per_token_logps"]
with self.accelerator.accumulate(self.model):
per_token_logps = self._get_per_token_logps(self.model, input_ids, attention_mask, logits_to_keep)
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)
advantages = mini_batch["advantages"]
# TODO: convert to clipped loss so we can multiple GRPO epochs
per_token_loss = -torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = per_token_loss + self.args.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
self.accelerator.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
del per_token_logps, per_token_kl, per_token_loss, loss
# force garbage collection and empty cache
gc.collect()
torch.cuda.empty_cache()
return loss.detach().item(), per_token_kl.mean().item()

View file

@ -0,0 +1,171 @@
import atexit
import os
import re
import subprocess
import time
# We need a special environment setup to launch vLLM from within Slurm training jobs.
# - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105
# - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269
user_home_directory = os.path.expanduser("~")
SLURM_PREFIX = [
"env",
"-i",
"bash",
"-c",
f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch --qos=high --output=/fsx/h4/logs/%x-%j.out --err=/fsx/h4/logs/%x-%j.err ",
]
class SGLangSlurmJobLauncher:
def __init__(
self,
model_id_or_path,
model_revision="main",
num_gpus=1,
sglang_port=30010,
slurm_script="slurm/launch_sglang.slurm",
check_interval=5,
):
"""
Initialize the job launcher.
:param slurm_script: Path to the SLURM script.
:param check_interval: Time interval (seconds) to check job status.
"""
self.slurm_script = slurm_script
self.job_id = None
self.node_name = None
self.check_interval = check_interval
self.model_id_or_path = model_id_or_path
self.model_revision = model_revision
self.num_gpus = num_gpus
self.sglang_port = sglang_port
# Register cleanup function to cancel job on exit
atexit.register(self.cleanup)
def submit_job(self):
"""Submits the SLURM job and extracts the job ID."""
cmd = SLURM_PREFIX.copy()
cmd_args = [
f"--gres=gpu:{self.num_gpus}",
self.slurm_script,
self.model_id_or_path,
self.model_revision,
str(self.sglang_port),
]
cmd[-1] += " " + " ".join(cmd_args)
try:
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
match = re.search(r"Submitted batch job (\d+)", result.stdout)
if match:
self.job_id = match.group(1)
print(f"Job submitted with ID: {self.job_id}")
else:
raise RuntimeError("Failed to retrieve job ID from sbatch output.")
except subprocess.CalledProcessError as e:
print(f"Error submitting job: {e.stderr}")
raise
def get_job_status(self):
"""Checks the job status using squeue."""
if not self.job_id:
raise ValueError("Job ID is not set. Submit the job first.")
result = subprocess.run(
["squeue", "--job", self.job_id, "--noheader"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if not result.stdout.strip():
return None # Job is no longer in queue
status = result.stdout.split()[4] # Typically, state is the 5th column
return status
def wait_for_job_to_start(self):
"""Waits for the job to start running and fetches its node."""
print("Waiting for job to start...")
while True:
status = self.get_job_status()
if status is None:
raise RuntimeError("Job disappeared from queue, it may have failed.")
if status == "R": # Running
print("Job is running. Fetching node information...")
self.node_name = self.get_node_name()
return
time.sleep(self.check_interval)
def get_node_name(self):
"""Gets the node where the job is running."""
result = subprocess.run(
["squeue", "--job", self.job_id, "--noheader", "--format=%N"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.stdout.strip():
return result.stdout.strip()
else:
raise RuntimeError("Failed to retrieve node name.")
def get_node_ip(self):
"""Retrieves the IP address of the node running the job."""
if not self.node_name:
raise ValueError("Node name is not set. Wait for the job to start first.")
result = subprocess.run(
["scontrol", "show", "node", self.node_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
match = re.search(r"NodeAddr=(\S+)", result.stdout)
if match:
return match.group(1)
else:
raise RuntimeError("Failed to retrieve node IP address.")
def launch(self):
"""Launches the job, waits for it to start, and retrieves the node IP."""
self.submit_job()
self.wait_for_job_to_start()
ip_address = self.get_node_ip()
print(f"Job is running on {self.node_name} with IP: {ip_address}")
self.ip_address = ip_address
return ip_address
def cleanup(self):
"""Cancels the SLURM job if it is still running."""
if self.job_id is not None:
print(f"Cleaning up: Cancelling job {self.job_id}...")
subprocess.run(["scancel", self.job_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print("Job cancelled.")
def __del__(self):
"""Ensure job cleanup when the instance is destroyed."""
self.cleanup()
if __name__ == "__main__":
from open_r1.trainers.remote_model import RemoteModel
launcher = SGLangSlurmJobLauncher("HuggingFaceTB/SmolLM2-135M-Instruct")
ip_address = launcher.launch()
launcher.ip_address
time.sleep(15)
remote_model = RemoteModel(f"{ip_address}", 30010)
remote_model.wait_for_server()
result = remote_model.generate([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
assert 0
print(result)

View file

@ -0,0 +1,697 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Iterator, Optional, Union
from trl.trainer.utils import disable_dropout_in_model
import torch
import transformers
from datasets import Dataset, IterableDataset, disable_progress_bars, enable_progress_bars
from datasets.utils.logging import set_verbosity_error, set_verbosity_info
from packaging import version
from torch.utils.data import DataLoader, Sampler
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_liger_kernel_available
import deepspeed
import trl
from accelerate.utils import broadcast_object_list, gather_object, is_peft_model
from open_r1.trainers.job_launcher import SGLangSlurmJobLauncher
from open_r1.trainers.remote_model import RemoteModel
from trl.data_utils import is_conversational, maybe_apply_chat_template
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_rich_available
from trl.models import create_reference_model, prepare_deepspeed
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.utils import exact_div, pad, print_prompt_completions_sample, selective_log_softmax
if is_liger_kernel_available():
from liger_kernel.transformers import AutoLigerKernelForCausalLM
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
if is_wandb_available():
import wandb
class RepeatBatchRandomSampler(Sampler):
def __init__(
self,
data_source,
batch_size: int = 1,
num_processes: int = 1,
repeat_count: int = 1,
seed: Optional[int] = None,
):
self.data_source = data_source
self.batch_size = batch_size
self.num_processes = num_processes
self.repeat_count = repeat_count
self.num_samples = len(data_source)
self.seed = seed
self.generator = torch.Generator() # Create a local random generator
if seed is not None:
self.generator.manual_seed(seed)
def __iter__(self):
indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
all_process_batch_size = self.batch_size * self.num_processes
indices = [indices[i : i + all_process_batch_size] for i in range(0, len(indices), all_process_batch_size)]
indices = [chunk for chunk in indices if len(chunk) == all_process_batch_size]
for chunk in indices:
for _ in range(self.repeat_count):
for index in chunk:
yield index
def __len__(self) -> int:
return self.num_samples * self.repeat_count
@dataclass
class RemoteGRPOConfig(trl.GRPOConfig):
"""
args for callbacks, benchmarks etc
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
checkpoint_dir: Optional[str] = field(
default="/fsx/h4/tmp/", metadata={"help": "The directory to save temporary checkpoints to."}
)
system_prompt: Optional[str] = field(
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
)
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
remote_gen_model_url: Optional[str] = field(
default=None,
)
remote_gen_model_port: str = field(
default="30010",
)
remote_gen_model_n_gpus: str = field(
default=8,
)
use_liger: bool = field(
default=True,
metadata={"help": "Whether to use Liger kernel for training."},
)
class RemoteGRPOTrainer(Trainer):
_tag_names = ["trl", "grpo"]
def __init__(
self,
model: Union[str, PreTrainedModel],
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: Optional[RemoteGRPOConfig] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
):
self.args = args
# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self.log_completions = args.log_completions
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
model = self._create_model_from_path(model_id, args)
disable_dropout_in_model(model)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
)
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
# This acts as a flag to indicate that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Reference model
if self.args.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_deepspeed_zero3_enabled():
self.ref_model = self._create_model_from_path(model_id, args)
disable_dropout_in_model(self.ref_model)
elif is_peft_model(model):
raise NotImplementedError("Peft is not supported")
else:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
self.reward_funcs = reward_funcs
# Reward weights
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError("The number of reward processing classes must match the number of reward functions.")
# TODO: test RMS and also wrap them in deepspeed
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
def data_collator(features): # No data collation is needed in GRPO
return features
self.batch_buffer = []
super().__init__(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
data_collator=data_collator,
)
ip_address = self.args.remote_gen_model_url
if self.args.remote_gen_model_url is None and self.accelerator.is_main_process:
# we launch a job from here, get the ip on main process and broadcast to others
# it would be better to move this to the start so the server warms up which the local model is being loaded
model_revision = args.model_init_kwargs.get("revision", "main")
self.sglang_job_launcher = SGLangSlurmJobLauncher(
model_id,
model_revision,
num_gpus=self.args.remote_gen_model_n_gpus,
sglang_port=self.args.remote_gen_model_port,
)
ip_address = self.sglang_job_launcher.launch()
# get the ip from main process and broadcast to others
gather_ip_address = broadcast_object_list([ip_address], 0)
self.args.remote_gen_model_url = gather_ip_address[0]
self.remote_model = RemoteModel(
self.args.remote_gen_model_url, self.args.remote_gen_model_port, self.processing_class.eos_token_id
)
self.remote_model.wait_for_server()
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Add tags to the model
self.model.add_model_tags(self._tag_names)
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if args.sync_ref_model:
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
def _get_train_sampler(self) -> Sampler:
"""
Return the train sampler.
Returns:
Sampler: The train sampler.
"""
if self.args.dataloader_num_workers != 0:
raise ValueError("dataloader_num_workers should not be greater than 0 for remote training")
return RepeatBatchRandomSampler(
data_source=self.train_dataset,
batch_size=self._train_batch_size,
repeat_count=self.args.num_generations * self.args.num_iterations,
num_processes=self.accelerator.num_processes,
seed=self.args.seed,
)
def _create_model_from_path(self, model_path: str, args: RemoteGRPOConfig) -> PreTrainedModel:
"""Creates a model from a path or model identifier."""
model_init_kwargs = args.model_init_kwargs or {}
# Handle torch dtype
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
if args.gradient_checkpointing:
model_init_kwargs["use_cache"] = False
# Create model
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return model
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: RemoteGRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(model):
model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
else:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
model.enable_input_require_grads()
return model
def _generate_and_score_completions(
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
prompts_to_log = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(prompts_text)
prompt_ids = prompt_inputs["input_ids"]
# sync weights here?
self._sync_weights()
with profiling_context(self, "remote_generate"):
all_outputs = self.remote_model.generate(
prompt_ids,
max_new_tokens=self.args.max_completion_length,
temperature=self.args.temperature,
num_generations=self.args.num_generations,
)
completion_ids = [example["completion_ids"] for example in all_outputs]
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
repeated_prompts = []
for prompt in prompts_to_log:
repeated_prompts.extend([prompt] * self.args.num_generations)
repeated_prompt_texts = []
for prompt in prompts_text:
repeated_prompt_texts.extend([prompt] * self.args.num_generations)
if is_conversational(inputs[0]):
completions_to_log = []
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions_to_log.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions_to_log = completions_text
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
with profiling_context(self, "rewards"):
for i, reward_func in enumerate(self.reward_funcs):
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = defaultdict(list)
for example in inputs:
for key in keys:
reward_kwargs[key].extend([example[key]] * self.args.num_generations)
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions_to_log, **reward_kwargs)
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
# if i == 0 and self.accelerator.is_main_process: # dump generations to a text file for debugging
# with open("python_code_completions2.jsonl", "a") as f:
# for i,(p, c) in enumerate(zip(repeated_prompts, completions_to_log)):
# data = {
# "prompt": p,
# "completion": c,
# }
# for k in reward_kwargs.keys():
# data[k] = reward_kwargs[k][i]
# f.write(json.dumps(data) + "\n")
# calculate the advantages, the prompt is all on the same device to no need to gather here
grouped_rewards = rewards.sum(-1).view(len(prompts_to_log), self.args.num_generations)
EPS = 1e-4
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
grouped_rewards.std(-1, keepdim=True) + EPS
)
advantages = grouped_advantages.flatten().tolist()
examples = []
for i, prompt in enumerate(repeated_prompt_texts):
example = {
"prompt": prompt,
"prompt_ids": prompt_ids[i // self.args.num_generations],
"completion": completions_text[i],
"completion_ids": completion_ids[i],
"advantages": advantages[i],
"rewards": rewards[i],
}
examples.append(example)
# Instead of logging metrics here, collect them
mode = "eval" if getattr(self, "control", None) and self.control.should_evaluate else "train"
device = self.accelerator.device
# Collect completion length metrics
completion_lengths = [len(example["completion_ids"]) for example in examples]
gathered_completion_lengths = self.accelerator.gather_for_metrics(torch.Tensor(completion_lengths).to(device))
self._metrics[mode]["mean_completion_lengths"].append(gathered_completion_lengths.mean().item())
self._metrics[mode]["max_completion_lengths"].append(gathered_completion_lengths.max().item())
self._metrics[mode]["min_completion_lengths"].append(gathered_completion_lengths.min().item())
# Collect reward metrics
rewards = torch.stack(
[
example["rewards"].to(device)
if isinstance(example["rewards"], torch.Tensor)
else torch.tensor(example["rewards"], device=device)
for example in examples
]
)
gathered_rewards = self.accelerator.gather_for_metrics(rewards)
reward_per_func = gathered_rewards.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
reward_func_name = reward_func.__name__
self._metrics[mode][f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics[mode]["reward"].append(reward_per_func.sum().item())
if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
prompts_to_log = gather_object([example["prompt"] for example in examples])
completions_to_log = gather_object([example["completion"] for example in examples])
if self.accelerator.is_main_process:
# if is_rich_available():
# # TODO: enable num_samples in TRL to avoid clogging logs
# print_prompt_completions_sample(
# prompts_to_log[:5],
# completions_to_log[:5],
# gathered_rewards.sum(1).tolist()[:5],
# self.state.global_step,
# )
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
import pandas as pd
# For logging
table = {
"step": [str(self.state.global_step)] * len(prompts_to_log),
"prompts": prompts_to_log,
"completion": completions_to_log,
"reward": gathered_rewards.sum(1).tolist(),
}
df = pd.DataFrame(table)
if wandb.run is not None and self.accelerator.is_main_process:
wandb.log({"completions": wandb.Table(dataframe=df)})
return examples
@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
if len(self.batch_buffer) > 0:
return self.batch_buffer.pop(0)
inputs = self._generate_and_score_completions(inputs)
gen_dataset = Dataset.from_list(inputs)
exact_div(
len(gen_dataset), self.args.per_device_train_batch_size, "len(gen_dataset) is not divisible by batch size"
)
def get_logprobs(example, model, output_name):
# dict of lists to list of dicts
examples = [dict(zip(example.keys(), values)) for values in zip(*example.values())]
input_ids, attention_mask, completion_mask, completion_ids = self._get_padded_inputs_and_attn_mask(
examples
)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep).detach()
lengths = [len(example["completion_ids"]) for example in examples]
# Strip the completion padding
per_token_logps = per_token_logps.to("cpu").tolist()
per_token_logps = [logps[:length] for logps, length in zip(per_token_logps, lengths)]
example[output_name] = per_token_logps
return example
with torch.no_grad():
set_verbosity_error()
disable_progress_bars()
if self.ref_model is not None:
self.ref_model.eval()
gen_dataset = gen_dataset.map(
get_logprobs,
batched=True,
batch_size=self.args.per_device_train_batch_size*2,
fn_kwargs={"model": self.ref_model, "output_name": "ref_per_token_logps"},
)
self.model.eval()
gen_dataset = gen_dataset.map(
get_logprobs,
batched=True,
batch_size=self.args.per_device_train_batch_size*2 ,
fn_kwargs={"model": self.model, "output_name": "old_per_token_logps"},
)
self.model.train()
enable_progress_bars()
set_verbosity_info()
def mini_batch_collator(mini_batch):
return mini_batch
mini_batch_dataloader = DataLoader(
gen_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True, # we technically don't need to shuffle due to grad acc, but we may move to clipped loss later
drop_last=True,
collate_fn=mini_batch_collator,
)
for num_iters in range(self.args.num_iterations):
for mini_batch in mini_batch_dataloader:
self.batch_buffer.append(mini_batch)
return self.batch_buffer.pop(0)
@profiling_decorator
def _sync_weights(self):
if self.remote_model.is_mock:
return
self.accelerator.wait_for_everyone()
# if self.accelerator.is_main_process:
start = time.time()
# would be better if this was a ram disk + separate thread for writing
unwrapped_model = self.accelerator.unwrap_model(self.model)
if is_deepspeed_zero3_enabled():
state_dict = {}
for name, param in unwrapped_model.named_parameters():
if name in state_dict.keys():
# sometimes the embed table is duplicated so no need to regather it
continue
with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
state_dict[name] = param.cpu().detach().clone()
# if is_fsdp_managed_module(self.model):
# state_dict = self.model.state_dict()
# trainer.save_model(output_dir)
else:
state_dict = self.accelerator.get_state_dict(self.model)
# if self.accelerator.is_main_process:
# with tempfile.TemporaryDirectory(dir=self.args.checkpoint_dir) as temp_dir_path:
# self.save_model(temp_dir_path)
# state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
with tempfile.TemporaryDirectory(dir=self.args.checkpoint_dir) as temp_dir_path:
self._save(temp_dir_path, state_dict=state_dict)
self.remote_model.load_weights_from_path(temp_dir_path)
print(f"Weight sync took: {time.time() - start:.2f}s")
self.accelerator.wait_for_everyone()
# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
def _get_padded_inputs_and_attn_mask(self, inputs):
device = self.accelerator.device
prompt_ids = [torch.LongTensor(example["prompt_ids"]) for example in inputs]
completion_ids = [torch.LongTensor(example["completion_ids"]) for example in inputs]
pad_token_id = self.processing_class.pad_token_id
prompt_ids = pad(prompt_ids, padding_value=pad_token_id, padding_side="left")
completion_ids = pad(completion_ids, padding_value=pad_token_id, padding_side="right")
# padd_ref_per_token_logps = pad(ref_per_token_logps, padding_value=0.0, padding_side="right")
if self.args.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.args.max_prompt_length :]
# compute the masks
prompt_mask = (prompt_ids != pad_token_id).long()
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1)).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
input_ids = torch.cat([prompt_ids, completion_ids], dim=1).to(device)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1).to(device)
completion_mask = completion_mask.to(device)
return input_ids, attention_mask, completion_mask, completion_ids
@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
device = self.accelerator.device
advantages = torch.Tensor([example["advantages"] for example in inputs]).to(device)
input_ids, attention_mask, completion_mask, completion_ids = self._get_padded_inputs_and_attn_mask(inputs)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
old_per_token_logps = [torch.Tensor(example["old_per_token_logps"]) for example in inputs]
pad_token_id = self.processing_class.pad_token_id
# padd the ref and old logps
pad_old_per_token_logps = pad(old_per_token_logps, padding_value=pad_token_id, padding_side="right").to(device)
# model.eval()
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
if self.ref_model is not None:
ref_per_token_logps = [torch.Tensor(example["ref_per_token_logps"]) for example in inputs]
pad_ref_per_token_logps = pad(ref_per_token_logps, padding_value=pad_token_id, padding_side="right").to(device)
clamped_diff= torch.clamp(pad_ref_per_token_logps - per_token_logps,-10.0,10.0) # for numerical stability
per_token_kl = (
torch.exp(clamped_diff) - clamped_diff - 1
)
# del inputs, input_ids, attention_mask # free up memory
# clipped loss
coef_1 = torch.exp(per_token_logps - pad_old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.args.epsilon, 1 + self.args.epsilon)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.ref_model is not None:
per_token_loss = per_token_loss + self.args.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
torch.cuda.empty_cache()
return loss
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
mode = "eval" if self.control.should_evaluate else "train"
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if mode == "eval":
metrics = {f"eval_{key}": val for key, val in metrics.items()}
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics[mode].clear()

View file

@ -0,0 +1,162 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# great reference: https://github.com/vllm-project/vllm/issues/11400
import time
import random
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import requests
class RemoteModel:
"""
launch with:
export LD_LIBRARY_PATH=$(python -c "import site; print(site.getsitepackages()[0] + '/nvidia/nvjitlink/lib')"):$LD_LIBRARY_PATH
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port=30010 --skip-tokenizer-init --mem-fraction-static 0.4
python3 -m sglang.launch_server --model-path HuggingFaceTB/SmolLM2-135M-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.4 --host=0.0.0.0
# on a separate node
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
python3 -m sglang.launch_server --model-path HuggingFaceTB/SmolLM2-1.7B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.6 --host=0.0.0.0 --dp-size=8
python3 -m sglang.launch_server --model-path open-r1/Qwen2.5-Coder-7B-Instruct-SFT --revision v00.08-step-000001280 --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-1.5B-Instruct --port=30010 --skip-tokenizer-init --mem-fraction-static 0.7 --host=0.0.0.0 --dp-size=8
"""
def __init__(self, remote_model_url, remote_model_port, stop_token_id=None):
self.remote_model_url = remote_model_url
self.remote_model_port = remote_model_port
self.stop_token_id = stop_token_id
if self.remote_model_url == "mock":
print("Using mock remote model")
@property
def is_mock(self):
return self.remote_model_url == "mock"
def is_healthy(self, timeout=5):
if self.remote_model_url == "mock":
return True
"""Checks if the remote model server is up and running."""
try:
url = f"http://{self.remote_model_url}:{self.remote_model_port}/health"
response = requests.get(url, timeout=timeout)
return response.status_code == 200
except requests.RequestException:
return False
def wait_for_server(self, max_retries=120, delay=5):
"""Waits for the server to become available before proceeding."""
for attempt in range(max_retries):
if self.is_healthy():
print("Remote model server is healthy!")
return True
print(f"Waiting for server to start... (Attempt {attempt + 1}/{max_retries})")
time.sleep(delay)
raise RuntimeError("Remote model server did not start in time.")
def generate(
self, input_ids: list[list[int]], max_new_tokens=256, temperature=0.8, num_generations=2
) -> tuple[list[list[int]], list[list[int]]]:
# Prepare the request body
if self.remote_model_url == "mock":
examples = []
for prompt_ids in input_ids:
for j in range(num_generations):
example = {
"prompt_ids": prompt_ids,
"completion_ids": random.choices(range(10 ,1000), k=max_new_tokens),
# "prompt_log_probs": None, # TODO, not used for now
# "completion_log_probs": None,
}
examples.append(example)
return examples
request_body = {
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop_token_ids": [self.stop_token_id],
"n": num_generations,
},
"stream": False,
# "return_logprob": True, # disabled as we occasiosally see https://github.com/sgl-project/sglang/issues/4097
# "logprob_start_len": 0,
}
# Send the POST request to the server
# add a few retries?
response = requests.post(
f"http://{self.remote_model_url}:{self.remote_model_port}/generate", json=request_body
)
response_json = response.json()
examples = []
for i, result in enumerate(response_json):
prompt_index = i // num_generations
prompt_ids = input_ids[prompt_index]
completion_ids = result["output_ids"]
# prompt_log_probs = [prob[0] for prob in result["meta_info"]["input_token_logprobs"]]
# completion_log_probs = [prob[0] for prob in result["meta_info"]["output_token_logprobs"]]
example = {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
# "prompt_log_probs": prompt_log_probs,
# "completion_log_probs": completion_log_probs,
}
examples.append(example)
return examples
def load_weights_from_path(self, path: str):
if self.remote_model_url == "mock":
return
url = f"http://{self.remote_model_url}:{self.remote_model_port}/update_weights_from_disk"
data = {"model_path": path}
response = requests.post(url, json=data)
print(response.text)
assert response.json()["success"] is True
if __name__ == "__main__":
from datasets import load_dataset
url = "0.0.0.0"
port = 30010
MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
remote_model = RemoteModel(url, port, tokenizer.eos_token_id)
dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train")
dataloader = DataLoader(dataset, batch_size=4)
for i, batch in zip(range(2), dataloader):
problems = batch["problem"]
ids = tokenizer(problems)
new_ids, logprobs = remote_model.generate(ids["input_ids"])
print(new_ids)
print(logprobs)

View file

@ -0,0 +1,36 @@
from typing import Iterator
from torch.utils.data import RandomSampler
class RepeatBatchRandomSampler(RandomSampler):
def __init__(
self,
*args,
num_generations: int = 1,
batch_size: int = 3,
**kwargs,
) -> None:
self.num_generations = num_generations
self.batch_size = batch_size
super().__init__(*args, **kwargs)
def __len__(self) -> int:
return super().__len__() * self.num_generations
def __iter__(self) -> Iterator[int]:
batch_indices = []
for idx in super().__iter__():
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
batch_indices = batch_indices * self.num_generations
yield from batch_indices
batch_indices = []
if __name__ == "__main__":
sampler = RepeatBatchRandomSampler(num_generations=2, data_source=range(12), replacement=False)
# print(list(sampler))
for sample in sampler:
print(sample)

View file

@ -0,0 +1,148 @@
from collections import defaultdict
import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from open_r1.configs import GRPOConfig
from open_r1.trainers.remote_model import RemoteModel
from trl.data_utils import is_conversational, maybe_apply_chat_template
class RemoteGRPODataloader(DataLoader):
def __init__(
self, *args, config: GRPOConfig, remote_model=None, processing_class=None, reward_funcs=None, **kwargs
):
super().__init__(*args, **kwargs)
self.config = config
self.remote_model = remote_model
self.processing_class = processing_class
self.reward_funcs = reward_funcs
self.reward_weights = [1.0] * len(reward_funcs) # TODO: make this configurable
def __len__(self):
return super().__len__() * self.config.num_generations
def __iter__(self):
for batch in super().__iter__():
batch = self._prepare_batch(batch)
gen_dataset = Dataset.from_list(batch)
mini_batch_dataloader = DataLoader(
gen_dataset,
batch_size=self.config.per_device_train_batch_size,
shuffle=True, # we technically don#t need to shuffle due to grad acc, but we may move to clipped loss later
drop_last=True,
collate_fn=self.collate_fn,
)
for mini_batch in mini_batch_dataloader:
yield mini_batch
def _prepare_batch(self, batch):
prompts = [x["prompt"] for x in batch]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in batch]
prompt_inputs = self.processing_class(prompts_text)
prompt_ids = prompt_inputs["input_ids"]
# add cuda clear cache here and a sleep
all_outputs = self.remote_model.generate(
prompt_ids,
max_new_tokens=self.config.max_completion_length,
temperature=self.config.temperature,
num_generations=self.config.num_generations,
)
completion_ids = [example["completion_ids"] for example in all_outputs]
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
repeated_prompts = []
for prompt in prompts:
repeated_prompts.extend([prompt] * self.config.num_generations)
repeated_prompt_texts = []
for prompt in prompts_text:
repeated_prompt_texts.extend([prompt] * self.config.num_generations)
if is_conversational(batch[0]):
completions = []
for prompt, completion in zip(repeated_prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text
rewards = torch.zeros(len(repeated_prompts), len(self.reward_funcs))
for i, reward_func in enumerate(self.reward_funcs):
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in batch[0] if key not in ["prompt", "completion"]]
reward_kwargs = defaultdict(list)
for example in batch:
for key in keys:
reward_kwargs[key].extend([example[key]] * self.config.num_generations)
output_reward_func = reward_func(prompts=repeated_prompts, completions=completions, **reward_kwargs)
rewards[:, i] = torch.tensor(output_reward_func, dtype=torch.float32) * self.reward_weights[i]
grouped_rewards = rewards.sum(-1).view(len(prompts), self.config.num_generations)
EPS = 1e-4
grouped_advantages = (grouped_rewards - grouped_rewards.mean(-1, keepdim=True)) / (
grouped_rewards.std(-1, keepdim=True) + EPS
)
advantages = grouped_advantages.flatten().tolist()
# build batch as list of dicts
examples = []
for i, prompt in enumerate(repeated_prompt_texts):
example = {
"prompt": prompt,
"prompt_ids": prompt_ids[i // self.config.num_generations],
"completion": completions_text[i],
"completion_ids": completion_ids[i],
"advantages": advantages[i],
"rewards": rewards[i],
}
examples.append(example)
return examples
if __name__ == "__main__":
dataset = load_dataset("open-r1/OpenR1-Math-cn_k12-86k", split="train").select(range(32))
def make_conversation(example):
prompt = []
prompt.append({"role": "user", "content": example["problem"]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
def collate_fn(batch):
return batch
dataset = dataset.remove_columns("messages")
def reward_func(prompts, completions, **kwargs):
return [0.5] * len(prompts)
reward_funcs = [reward_func, reward_func]
MODEL = "HuggingFaceTB/SmolLM2-135M-Instruct"
processing_class = AutoTokenizer.from_pretrained(MODEL)
remote_model = RemoteModel("0.0.0.0", 30010, processing_class.eos_token_id)
config = GRPOConfig()
data_loader = RemoteGRPODataloader(
dataset,
remote_model=remote_model,
processing_class=processing_class,
reward_funcs=reward_funcs,
batch_size=2,
num_workers=0,
collate_fn=collate_fn,
config=config,
)
print(len(data_loader))
for i, batch in enumerate(data_loader):
print(i, len(batch))
print(batch)

View file

@ -1,6 +1,5 @@
from .data import get_dataset
from .import_utils import is_e2b_available, is_morph_available
from .model_utils import get_model, get_tokenizer
from .import_utils import is_e2b_available
from .model_utils import get_tokenizer
__all__ = ["get_tokenizer", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]
__all__ = ["get_tokenizer", "is_e2b_available"]

View file

@ -44,13 +44,7 @@ class PushToHubRevisionCallback(TrainerCallback):
def __init__(self, model_config) -> None:
self.model_config = model_config
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_world_process_zero:
global_step = state.global_step

View file

@ -1,366 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code execution providers for executing and evaluating code snippets."""
import abc
import asyncio
from typing import List, Optional
from ..utils import is_e2b_available, is_morph_available
if is_e2b_available():
from e2b_code_interpreter import AsyncSandbox
from e2b_code_interpreter.models import Execution
from .routed_sandbox import RoutedSandbox
else:
AsyncSandbox = None
Execution = None
RoutedSandbox = None
if is_morph_available():
from morphcloud.api import MorphCloudClient
from morphcloud.sandbox import Sandbox
from .routed_morph import RoutedMorphSandbox
else:
MorphCloudClient = None
Sandbox = None
RoutedMorphSandbox = None
class CodeExecutionProvider(abc.ABC):
"""Abstract base class for code execution providers."""
@abc.abstractmethod
def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:
"""Execute multiple scripts and return their reward values.
Args:
scripts: List of code scripts to execute
language: The programming language of the scripts
Returns:
List of float rewards (one per script)
"""
pass
class E2BProvider(CodeExecutionProvider):
"""Provider that executes code using E2B sandboxes."""
def __init__(self, num_parallel: int = 2, e2b_router_url: Optional[str] = None):
"""Initialize the E2B provider.
Args:
num_parallel: Number of parallel sandboxes to use
e2b_router_url: URL for the E2B router (if using router mode)
"""
if not is_e2b_available():
raise ImportError(
"E2B is not available and required for this provider. Please install E2B with "
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
)
self.num_parallel = num_parallel
self.e2b_router_url = e2b_router_url
def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:
"""Execute scripts using E2B sandboxes.
If e2b_router_url is provided, uses the RoutedSandbox for batch processing.
Otherwise, uses direct AsyncSandbox with parallelization.
"""
if self.e2b_router_url is not None:
routed_sandbox = RoutedSandbox(router_url=self.e2b_router_url)
executions = routed_sandbox.run_code(
scripts=scripts,
languages=languages,
timeout=30,
request_timeout=28,
)
rewards = []
for execution in executions:
try:
reward = float(execution.text)
rewards.append(reward)
except Exception:
rewards.append(None)
return rewards
try:
rewards = self._run_async_from_sync(scripts, languages, self.num_parallel)
except Exception as e:
print(f"Error from E2B executor: {e}")
rewards = [0.0] * len(scripts)
return rewards
def _run_async_from_sync(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:
"""Function wrapping the `_run_async` function."""
try:
rewards = asyncio.run(self._run_async(scripts, languages, num_parallel))
except Exception as e:
print(f"Error from E2B executor async: {e}")
raise e
return rewards
async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:
semaphore = asyncio.Semaphore(num_parallel)
tasks = [self._run_script(script, languages, semaphore) for script in scripts]
results = await asyncio.gather(*tasks)
rewards = list(results)
return rewards
async def _run_script(self, script: str, languages: List[str], semaphore: asyncio.Semaphore) -> float:
# We set a timeout margin, as the AsyncSandbox timeout does not seem to work
# These values are based on running 256 examples with the gold solution
# from open-r1/verifiable-coding-problems-python_decontaminated
# see scripts/benchmark_e2b.py
SANDBOX_TIMEOUT = 30
MARGIN = 2
REQUEST_TIMEOUT = SANDBOX_TIMEOUT - MARGIN
ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN
async with semaphore:
try:
sandbox = await AsyncSandbox.create(timeout=SANDBOX_TIMEOUT, request_timeout=REQUEST_TIMEOUT)
execution = await asyncio.wait_for(
sandbox.run_code(script, languages=languages),
timeout=ASYNCIO_TIMEOUT,
)
return float(execution.text)
except (TypeError, ValueError):
return 0.0
except asyncio.TimeoutError:
print("Operation timed out")
return 0.0
except Exception as e:
print(f"Error in `_run_script` from E2B sandbox ID {sandbox.sandbox_id} : {e}")
return 0.0
finally:
try:
await sandbox.kill()
except Exception as e:
print(f"Error from E2B executor kill with sandbox ID {sandbox.sandbox_id} : {e}")
class MorphProvider(CodeExecutionProvider):
"""Provider that executes code using MorphCloud's Sandbox API."""
def __init__(self, num_parallel: int = 2, morph_router_url: Optional[str] = None):
"""Initialize the Morph provider.
Args:
num_parallel: Number of parallel executions to use
morph_router_url: URL for the MorphCloud router (if using router mode)
"""
if not is_morph_available():
raise ImportError(
"MorphCloud is not available and required for this provider. Please install MorphCloud with "
"`pip install morphcloud` and add an API key to a `.env` file."
)
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
print("Warning: python-dotenv not installed. Environment variables must be set directly.")
self.num_parallel = num_parallel
self.morph_router_url = morph_router_url
if self.morph_router_url is not None:
self.routed_sandbox = RoutedMorphSandbox(router_url=self.morph_router_url)
return
import os
self.api_key = os.getenv("MORPH_API_KEY")
if not self.api_key:
raise ValueError("MorphCloud API key not found. Please set the MORPH_API_KEY environment variable.")
try:
self.client = MorphCloudClient(api_key=self.api_key)
self.Sandbox = Sandbox
except ImportError as e:
raise ImportError(f"Required MorphCloud dependencies not installed: {e}")
def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:
"""Execute scripts using MorphCloud Sandbox API.
Args:
scripts: List of Python scripts to execute
language: Programming language
Returns:
List of float rewards (one per script)
"""
if hasattr(self, "routed_sandbox"):
try:
results = self.routed_sandbox.run_code(
scripts=scripts,
languages=languages,
timeout=90,
request_timeout=96,
)
rewards = []
for result in results:
try:
reward = float(result.text)
rewards.append(reward)
except (ValueError, AttributeError):
rewards.append(0.0)
return rewards
except Exception as e:
print(f"Error from MorphCloud router: {e}")
return [0.0] * len(scripts)
import asyncio
try:
rewards = asyncio.run(self._run_async(scripts, languages, self.num_parallel))
except Exception as e:
print(f"Error from MorphCloud executor: {e}")
rewards = [0.0] * len(scripts)
return rewards
async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:
"""Run multiple scripts concurrently with limited parallelism.
Args:
scripts: List of scripts to execute
language: Programming language
num_parallel: Maximum number of concurrent executions
Returns:
List of rewards
"""
semaphore = asyncio.Semaphore(num_parallel)
tasks = [self._run_script(script, languages, semaphore) for script in scripts]
results = await asyncio.gather(*tasks)
return list(results)
async def _run_script(self, script: str, languages: List[str], semaphore: asyncio.Semaphore) -> float:
"""Execute a single script in a MorphCloud Sandbox.
Args:
script: The script to execute
language: Programming language
semaphore: Semaphore to limit concurrency
Returns:
Float reward from script execution
"""
SANDBOX_TIMEOUT = 90
MARGIN = 6
ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN
sandbox = None
async with semaphore:
try:
sandbox = await asyncio.to_thread(self.Sandbox.new, client=self.client, ttl_seconds=SANDBOX_TIMEOUT)
result = await asyncio.wait_for(
asyncio.to_thread(
sandbox.run_code,
script,
languages=languages,
timeout=SANDBOX_TIMEOUT,
),
timeout=ASYNCIO_TIMEOUT,
)
reward = 0.0
try:
if hasattr(result, "text") and result.text:
lines = result.text.strip().split("\n")
if lines:
try:
reward = float(lines[-1])
except ValueError:
try:
reward = float(result.text.strip())
except ValueError:
pass
elif hasattr(result, "stdout") and result.stdout:
lines = result.stdout.strip().split("\n")
if lines:
try:
reward = float(lines[-1])
except ValueError:
pass
except (ValueError, AttributeError):
pass
return reward
except asyncio.TimeoutError:
return 0.0
except Exception:
return 0.0
finally:
if sandbox:
try:
await asyncio.to_thread(sandbox.close)
await asyncio.to_thread(sandbox.shutdown)
except Exception:
pass
def get_provider(provider_type: str = "e2b", **kwargs) -> CodeExecutionProvider:
"""Factory function to get the appropriate code execution provider.
Args:
provider_type: Type of provider to use ("e2b", "morph")
**kwargs: Additional arguments to pass to the provider
Returns:
An instance of CodeExecutionProvider
"""
num_parallel = kwargs.pop("num_parallel", 2)
if provider_type == "e2b":
# Extract E2B-specific arguments
e2b_router_url = kwargs.pop("e2b_router_url", None)
return E2BProvider(
num_parallel=num_parallel,
e2b_router_url=e2b_router_url,
)
elif provider_type == "morph":
# Extract Morph-specific arguments
morph_router_url = kwargs.pop("morph_router_url", None)
return MorphProvider(
num_parallel=num_parallel,
morph_router_url=morph_router_url,
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")

View file

@ -1,19 +0,0 @@
from .cf_scoring import score_submission
from .code_patcher import patch_code
from .ioi_scoring import SubtaskResult, score_subtask, score_subtasks
from .ioi_utils import add_includes
from .morph_client import get_morph_client_from_env
from .piston_client import get_piston_client_from_env, get_slurm_piston_endpoints
__all__ = [
"get_piston_client_from_env",
"get_slurm_piston_endpoints",
"get_morph_client_from_env",
"patch_code",
"score_submission",
"score_subtask",
"score_subtasks",
"add_includes",
"SubtaskResult",
]

View file

@ -1,146 +0,0 @@
import asyncio
import os
from io import BytesIO
from typing import Literal
from async_lru import alru_cache
from .piston_client import PistonClient
from .utils import batched
async def score_single_test_case(
client: PistonClient,
problem_data: dict,
test_input: str,
test_output: str,
submission: str,
submission_language: str = "cpp",
) -> tuple[str, str]:
if submission_language not in ["python", "cpp"]:
raise ValueError(f"Invalid submission language: {submission_language}")
try:
result = await client.send_execute(
{
"files": [
{"name": f"main.{submission_language}", "content": submission},
*(
[{"name": "checker.py", "content": problem_data["generated_checker"]}]
if problem_data["generated_checker"]
else []
),
{"name": "input.txt", "content": test_input},
{"name": "correct_output.txt", "content": test_output},
{
"name": "grader_config",
"content": "\n".join(
f"{key}={value}"
for key, value in {
"TIME_LIMIT": problem_data["time_limit"],
"MEMORY_LIMIT": problem_data["memory_limit"],
"INPUT_MODE": problem_data["input_mode"],
}.items()
),
},
],
"run_timeout": (problem_data["time_limit"] + 10) * 1000,
# +10 seconds hard limit. time limits are handled by the codeforces script
},
language="cf_python3" if submission_language == "python" else "c++17",
)
except Exception as e:
print(f"Error scoring submission: {e}")
return False
return result
@alru_cache(maxsize=32) # TODO make this configurable
async def get_generated_contest_tests(contest_id: str) -> list[dict]:
import pandas as pd
import aiofiles
import aiofiles.os
tests_folder = os.environ.get("CF_TESTS_FOLDER", None)
if not tests_folder:
raise ValueError(
"CF_TESTS_FOLDER environment variable not set! Please download the codeforces generated tests and set CF_TESTS_FOLDER to the folder path. See https://huggingface.co/datasets/open-r1/codeforces for more information."
)
if not await aiofiles.os.path.exists(tests_folder):
raise ValueError(
f"CF_TESTS_FOLDER path '{tests_folder}' does not exist! Please download the codeforces generated tests and set CF_TESTS_FOLDER to the folder path. See https://huggingface.co/datasets/open-r1/codeforces for more information."
)
parquet_path = os.path.join(tests_folder, f"test_cases_{int(contest_id):04d}.parquet")
if not await aiofiles.os.path.exists(parquet_path):
return {}
# Read parquet file asynchronously
async with aiofiles.open(parquet_path, "rb") as f:
content = await f.read()
df = pd.read_parquet(BytesIO(content))
# Group by problem_id and convert to dictionary of lists
grouped_tests = df.groupby("problem_id").apply(lambda x: x[["input", "output"]].to_dict("records")).to_dict()
return grouped_tests
async def get_generated_tests(problem_id: str) -> list[dict]:
contest_id = problem_id.split("/")[0]
return (await get_generated_contest_tests(contest_id)).get(problem_id, [])
async def score_submission(
client: PistonClient,
problem_data: dict,
submission: str,
test_batch_size: int = 1,
scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = "weighted_sum",
no_compile_reward: float = -0.1,
no_submission_reward: float = -1.0,
submission_language: str = "cpp",
) -> float:
if submission_language not in ["python", "cpp"]:
raise ValueError(f"Invalid submission language: {submission_language}")
test_cases = problem_data["official_tests"] + (await get_generated_tests(problem_data["id"]))
# invalid/not a coding problem
if test_cases is None or len(test_cases) == 0:
return None
# no code extracted
if not submission:
return no_submission_reward
passed_test_cases = 0
# run one batch, check if any of them failed (0 score): if so stop evaluating (assuming non partial score); otherwise continue with the next batch of test cases.
for test_batch_to_run in batched(test_cases, test_batch_size) if test_batch_size >= 1 else [test_cases]:
results = await asyncio.gather(
*[
asyncio.create_task(
score_single_test_case(
client, problem_data, test_case["input"], test_case["output"], submission, submission_language
)
)
for test_case in test_batch_to_run
]
)
if any(result and result["compile"]["code"] != 0 for result in results):
return no_compile_reward
tests_passed_results = [
result and result["run"]["code"] == 0 and result["run"]["stdout"].strip() == "1" for result in results
]
if scoring_mode == "pass_fail" and any(not test_passed for test_passed in tests_passed_results):
break
passed_test_cases += sum(1 for test_passed in tests_passed_results if test_passed)
pass_fail_score = 1.0 if passed_test_cases == len(test_cases) else 0.0
if scoring_mode == "pass_fail":
return pass_fail_score
elif scoring_mode == "partial":
return passed_test_cases / len(test_cases)
elif scoring_mode == "weighted_sum":
return pass_fail_score + 0.1 * (passed_test_cases / len(test_cases))
else:
raise ValueError(f"Invalid scoring mode: {scoring_mode}")

View file

@ -1,123 +0,0 @@
import re
def fix_python3_imports(source_code):
"""
Fix common import and function changes between Python 3 versions
Args:
source_code (str): The Python source code to update
Returns:
str: The updated source code
"""
# Dictionary of patterns to replacements
replacements = [
# Fix collections.abc imports (changed in Python 3.3+)
(
r"from collections import (Mapping|Sequence|Set|Container|MutableMapping|MutableSet|MutableSequence)",
r"from collections.abc import \1",
),
# Fix imp module deprecation (deprecated in 3.4)
(r"import imp", r"import importlib"),
# Fix asyncio.async() to asyncio.ensure_future() (renamed in 3.4.4)
(r"asyncio\.async\(", r"asyncio.ensure_future("),
# Fix inspect.getargspec to inspect.getfullargspec (deprecated in 3.5)
(r"inspect\.getargspec", r"inspect.getfullargspec"),
# Fix array.array 'c' type code to 'b' (removed in 3.9)
(r"array\.array\('c'", r"array.array('b'"),
# Fix backslash line continuation with multiple newlines (Python-specific issue)
(r"\\(\r\n|\r|\n)+", "\\\n"),
# some solutions use getlogin() to check if they are debugging or on an actual submission
(r"(?:os\s*\.\s*)?getlogin\s*\(\s*\)", "False"),
# Fix usage of fractions.gcd (moved to math in 3.5)
# 1. Fix direct usage: fractions.gcd -> math.gcd
(r"\bfractions\.gcd\b", r"math.gcd"),
# 2. Fix 'from fractions import gcd, X' -> 'from fractions import X' (start/middle)
(r"(from\s+fractions\s+import\s+(?:\([^)]*)?)\bgcd\s*,\s*", r"\1"),
# 3. Fix 'from fractions import X, gcd' -> 'from fractions import X' (end)
(r"(from\s+fractions\s+import\s+.*?\S)\s*,\s*\bgcd(\s*\)?\s*(?:#.*)?)", r"\1\2"),
# 4. Fix standalone 'from fractions import gcd' -> 'from math import gcd'
(r"from\s+fractions\s+import\s+\(?\s*gcd\s*\)?", r""),
# --- End: Replacement for the faulty line ---
]
lines = source_code.splitlines()
last_import = max(
[
i
for i, line in enumerate(lines)
if line.strip().startswith("import") or (line.strip().startswith("from") and "import" in line)
],
default=0,
)
import_section = "\n".join(lines[: last_import + 1])
main_source = "\n".join(lines[last_import:])
if "fractions.gcd" in source_code and "import math" not in source_code:
import_section += "\nimport math"
elif "gcd" in source_code and "from math import gcd" not in source_code:
import_section += "\nfrom math import gcd"
if "set_int_max_str_digits" not in source_code:
import_section += "\nimport sys\nsys.set_int_max_str_digits(0)"
source_code = import_section + "\n" + main_source
# Apply each replacement
for pattern, replacement in replacements:
source_code = re.sub(pattern, replacement, source_code)
source_code = source_code.rstrip("\\")
return source_code
def fix_cpp_includes(source_code):
# has most of the useful functions
code_header = "#include <bits/stdc++.h>\n"
# use namespace std since models forget std:: often
if "using namespace std;" not in source_code and "std::" not in source_code:
code_header += "\nusing namespace std;\n\n"
return code_header + source_code
def is_patchable(lang):
return lang in ("python", "python3", "Python 3", "PyPy 3", "PyPy 3-64", "cpp") or "C++" in lang
def patch_code(text, lang):
if not text:
return text
if lang in ("python", "python3", "Python 3", "PyPy 3", "PyPy 3-64"):
return fix_python3_imports(text)
elif "cpp" in lang or "C++" in lang:
return fix_cpp_includes(text)
return text
tests = [
"""read = lambda: map(int, input().split())
n, m, z = read()
from fractions import gcd
ans = z // (n * m // gcd(n, m))
print(ans)""",
"""from fractions import Fraction,gcd
a,b,c,d = [int(x) for x in input().split()]
if a*d > b*c:
num = a*d-b*c
denom = a*d
else:
num = b*c-a*d
denom = b*c
div = gcd(num,denom)
print('%d/%d'%(num//div,denom//div))""",
]
if __name__ == "__main__":
for test in tests:
print("ORIGINAL:", test, sep="\n\n")
print("PATCHED:", patch_code(test, "Python 3"), sep="\n\n")
print("=" * 50)

View file

@ -1,335 +0,0 @@
import asyncio
from dataclasses import asdict, dataclass, field
from typing import Union
from .ioi_utils import load_ioi_tests
from .piston_client import PistonClient, PistonError
from .utils import batched
@dataclass
class TestResult:
"""
Represents the result of a single test case execution.
Attributes:
test_name: Name of the test case
score: Score achieved for this test (0.0 to 1.0)
status: Status code of the test result (e.g., 'AC', 'WA', 'TLE')
feedback: Detailed feedback message from the judge or an error message
"""
test_name: str
score: float = 0.0
status: str = "SKIPPED"
feedback: str = None
@dataclass
class SubtaskResult:
"""
Represents the result of a subtask containing multiple test cases.
Attributes:
problem: Problem identifier
subtask: Subtask identifier
points: Maximum points available for this subtask
score_precision: Number of decimal places for score rounding
test_results: List of individual test case results
"""
problem: str = None
subtask: str = None
points: float = 0.0
score_precision: int = 2
test_results: list[TestResult] = field(default_factory=list)
@property
def status(self):
"""
Determines the overall status of the subtask based on the worst status among test results.
Status priorities are ordered from worst to best.
Returns:
str: The status with the highest priority (lowest value)
"""
status_prios = {"CE": -1, "RE": 0, "WA": 1, "MLE": 2, "TLE": 3, "PA": 4, "AC": 5, "SKIPPED": 999}
return min([x.status for x in self.test_results], key=lambda x: status_prios[x])
@property
def score(self):
"""
Calculates the raw score for the subtask as the minimum score across all test results.
Returns:
float: The rounded minimum score
"""
return (
0
if not self.test_results
else round(min([test_result.score for test_result in self.test_results]), self.score_precision)
)
@property
def weighted_score(self):
"""
Calculates the weighted score by multiplying the raw score by the available points.
Returns:
float: The rounded weighted score
"""
return (
0
if not self.test_results
else round(
min([test_result.score for test_result in self.test_results]) * self.points, self.score_precision
)
)
def to_dict(self):
"""
Converts the SubtaskResult to a dictionary representation.
Returns:
dict: Dictionary containing all subtask result data
"""
return {
"problem": self.problem,
"subtask": self.subtask,
"score": self.score,
"weighted_score": self.weighted_score,
"points": self.points,
"score_precision": self.score_precision,
"status": self.status,
"test_results": [asdict(test_result) for test_result in self.test_results],
}
def _extract_single_status(score: float, feedback: str) -> str:
"""
Determines the status code based on the score and feedback message.
Args:
score: The numeric score (0.0 to 1.0)
feedback: The feedback message from the execution
Returns:
str: Status code ('CE', 'MLE', 'TLE', 'WA', 'RE', 'AC', or 'PA')
"""
if score == 0.0:
if "Compilation error" in feedback:
return "CE"
elif "Memory limit exceeded" in feedback:
return "MLE"
elif "Time limit exceeded" in feedback:
return "TLE"
elif "Output isn't correct" in feedback:
return "WA"
else:
return "RE"
elif score == 1.0:
return "AC"
else:
return "PA"
async def score_single_test_case(
client: PistonClient, subtask: dict, test_name: str, test_input: str, test_output: str, submission: str
) -> TestResult:
"""
Scores a single test case by running the submission against the provided input and output.
Args:
client: PistonClient instance for executing code
subtask: Dictionary containing subtask configuration
test_name: Name of the test case
test_input: Input data for the test case
test_output: Expected output for the test case
submission: Source code of the submission
Returns:
TestResult: Result of the test case execution
"""
# Run submission for this test case
score, feedback = await run_submission(client, subtask, test_input, submission, test_output)
score = float(score)
return TestResult(
test_name=test_name, score=score, status=_extract_single_status(score, feedback), feedback=feedback
)
async def score_subtask(
client: PistonClient,
subtask: dict,
submission: str,
test_case_run_cache: Union[dict, None] = None,
test_batch_size: int = 1,
) -> SubtaskResult:
"""
Scores all test cases in a subtask.
Args:
client: PistonClient instance for executing code
subtask: Dictionary containing subtask configuration
test_cases: Dictionary mapping test names to (input, output) tuples
submission: Source code of the submission
test_case_run_cache: Optional cache of previously run test cases
test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.
-1 to evaluate all test cases in parallel
Returns:
SubtaskResult: Result of the subtask evaluation
"""
subtask_result = SubtaskResult(
problem=subtask["id"],
subtask=subtask["subtask"],
points=subtask["score"],
score_precision=subtask["score_precision"],
test_results=[],
)
# tests that are not cached
tests_to_run = [
(ti, test_name)
for ti, test_name in enumerate(subtask["test_names"])
if test_case_run_cache is None or test_name not in test_case_run_cache
]
# initialize test results with cached results or empty (SKIPPED) TestResult objects
subtask_result.test_results = [
test_case_run_cache[test_name]
if test_case_run_cache is not None and test_name in test_case_run_cache
else TestResult(test_name=test_name)
for test_name in subtask["test_names"]
]
# we skip submissions where no code was extracted
# no need to do anything, as we have a failed cached result
if not submission or any(
test_result.status != "SKIPPED" and test_result.score == 0.0 for test_result in subtask_result.test_results
):
return subtask_result
if "test_cases" in subtask:
test_cases = subtask["test_cases"]
if isinstance(subtask["test_cases"], list):
test_cases = {test_name: test for test_name, test in zip(subtask["test_names"], subtask["test_cases"])}
else:
test_cases = load_ioi_tests(subtask["year"], subtask["id"])
# run one batch, check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.
for test_batch_to_run in batched(tests_to_run, test_batch_size):
results = await asyncio.gather(
*[
asyncio.create_task(
score_single_test_case(
client, subtask, test_name, test_cases[test_name][0], test_cases[test_name][1], submission
)
)
for _, test_name in test_batch_to_run
]
)
for (ti, test_name), test_result in zip(test_batch_to_run, results):
if test_case_run_cache is not None:
test_case_run_cache[test_name] = test_result
subtask_result.test_results[ti] = test_result
# Stop early if it failed
if any(test_result.score == 0.0 for test_result in results):
break
return subtask_result
async def score_subtasks(
client: PistonClient, subtasks: list[dict], submission: str, skip_mode: bool = True
) -> list[SubtaskResult]:
"""
Scores multiple subtasks for a submission.
Args:
client: PistonClient instance for executing code
subtasks: List of dictionaries containing subtask configurations
submission: Source code of the submission
skip_mode: If True, evaluates test by test and stops after the first failure. Otherwise, runs all tests in parallel. Should be True when evaluating a large number of submissions.
Returns:
list[SubtaskResult]: Results for all subtasks
"""
# avoid rerunning tests present in multiple subtasks
test_case_run_cache = {}
return [await score_subtask(client, subtask, submission, test_case_run_cache, skip_mode) for subtask in subtasks]
async def run_submission(
client: PistonClient, problem: dict, test_input: str, submission: str, test_output: str | None = None
) -> tuple[str, str]:
"""
Executes a submission against a test case using the Piston execution environment.
Args:
client: PistonClient instance for executing code
problem: Dictionary containing problem configuration
test_input: Input data for the test case
submission: Source code of the submission
test_output: Optional expected output for the test case
Returns:
tuple[str, str]: A tuple containing (score, feedback)
"""
data = {
"files": [
# the actual submission
{"name": f"graders/{problem['id'].lower()}.cpp", "content": submission},
# pass the input
{"name": "input.txt", "content": test_input},
# pass the expected output
*([{"name": "correct_output.txt", "content": test_output}] if test_output else []),
# grader files
*({"name": name, "content": content} for name, content in problem["grader_files"] if content),
],
"run_timeout": round(
(problem["time_limit"] + 3) * 1000
), # +3 seconds hard limit. time limits are handled by the ioi script
"run_memory_limit": problem["memory_limit"],
}
return await execute_ioi(client, data)
async def execute_ioi(client, data) -> tuple[str, str]:
"""
Requests to the IOI package return the score as a float in the stdout, as well as optional feedback/errors in stderr.
Returns a tuple of (score, feedback).
"""
response = await client.send_execute(data)
if "message" in response:
raise PistonError(response["message"])
if "compile" in response and response["compile"]["code"] != 0:
return "0", "Compilation error exit code " + str(response["compile"]["code"]) + "\n" + response["compile"][
"stderr"
]
if "run" not in response:
raise PistonError(response)
if response["run"]["code"] == 1 and "MemoryError" in response["run"]["stderr"]:
return "0", "Memory limit exceeded"
# successful result
if response["run"]["stdout"]:
return response["run"]["stdout"], response["run"]["stderr"]
if response["run"]["signal"] == "SIGKILL":
return "0", "Time limit exceeded"
# other issues
if response["run"]["code"] != 0:
raise PistonError(
f"language={response['language']}, version={response['version']}, exit code={response['run']['code']}, stderr={response['run']['stderr']}, signal={response['run']['signal']}"
)
return "0", "Unknown error"

View file

@ -1,41 +0,0 @@
from collections import defaultdict
from functools import lru_cache
from datasets import load_dataset
def add_includes(code: str, problem_id: str) -> str:
"""
Fix common compilation errors for IOI problems.
"""
if not code:
return code
# has most of the useful functions
code_header = "#include <bits/stdc++.h>\n"
# include the problem header
problem_header_include = f'#include "{problem_id}.h"'
if problem_header_include not in code:
code_header += problem_header_include + "\n"
# use namespace std since models forget std:: often
if "using namespace std;" not in code and "std::" not in code:
code_header += "\nusing namespace std;\n\n"
return code_header + code
@lru_cache
def load_ioi_tests_for_year(year: int) -> dict[str, dict[str, tuple[str, str]]]:
"""
Load IOI tests for a given year.
"""
tests_dataset = load_dataset("open-r1/ioi-test-cases", name=f"{year}", split="train")
test_cases = defaultdict(dict)
for test_case in tests_dataset:
test_cases[test_case["problem_id"]][test_case["test_name"]] = test_case["test_input"], test_case["test_output"]
return test_cases
def load_ioi_tests(year: int, problem_id: str) -> dict[str, tuple[str, str]]:
"""
Load IOI tests for a given year and problem id.
"""
return load_ioi_tests_for_year(year)[problem_id]

View file

@ -1,742 +0,0 @@
import asyncio
import json
import logging
import os
import tempfile
from typing import Any, Dict, Optional, Tuple
from dotenv import load_dotenv
from open_r1.utils.import_utils import is_morph_available
# Replace direct imports with conditional imports
if is_morph_available():
from morphcloud.api import Instance, InstanceExecResponse, MorphCloudClient
else:
Instance = None
InstanceExecResponse = None
MorphCloudClient = None
# Silence verbose logs from dependencies
logging.getLogger("paramiko").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
class MorphCloudError(Exception):
pass
class MorphCloudExecutionClient:
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
spans_log_path: Optional[str] = None,
):
"""
Initialize the MorphCloud execution client.
Args:
api_key: Optional API key for MorphCloud. If not provided, will use MORPH_API_KEY env var.
base_url: Optional base URL for MorphCloud API. If not provided, will use default.
spans_log_path: Path to log API call spans to. Defaults to 'logs/morph_api_spans.jsonl'.
"""
self.client = MorphCloudClient(api_key=api_key, base_url=base_url)
self._snapshot_lock = asyncio.Lock()
async def _prepare_instance(self, snapshot_id=None) -> Instance:
"""
Prepare and start a MorphCloud instance.
Args:
snapshot_id: Optional snapshot ID to use. If None, will get or create base snapshot.
Returns:
Instance: The ready-to-use MorphCloud instance
Raises:
TimeoutError: If instance fails to start or become ready
"""
if not snapshot_id:
snapshot = await self._get_or_create_base_snapshot()
snapshot_id = snapshot.id
try:
instance = await self.client.instances.astart(
snapshot_id, ttl_seconds=600
) # Auto-terminate after 10 minutes
await instance.await_until_ready(timeout=300)
return instance
except asyncio.TimeoutError as e:
print(f"Timeout while preparing instance: {str(e)}")
if instance:
try:
await instance.astop()
except Exception:
pass
raise
async def _prepare_files(self, data: Dict[str, Any], temp_dir: str) -> Tuple[str, Dict[str, Any], Dict[str, str]]:
"""
Process files, determine problem ID, and prepare configuration.
Args:
data: Dictionary containing file information
temp_dir: Local temporary directory for file operations
Returns:
tuple: (problem_id, grader_config, local_files)
Raises:
ValueError: If problem ID cannot be determined
"""
# Extract problem ID
problem_id = None
graders_files = []
for file in data["files"]:
if file["name"].startswith("graders/") and file["name"].endswith(".cpp"):
potential_id = os.path.basename(file["name"]).split(".")[0]
if potential_id not in ["grader", "manager", "stub"]:
problem_id = potential_id
if file["name"].startswith("graders/"):
graders_files.append(file)
if not problem_id:
raise ValueError("Could not determine problem ID from files")
grader_config = {
"task_type": "Batch",
"code": problem_id,
"time_limit": data["run_timeout"] / 1000,
"memory_limit": data["run_memory_limit"] * 1024 * 1024,
}
for file in graders_files:
if "manager.cpp" in file["name"]:
grader_config["task_type"] = "Communication"
grader_config["task_type_parameters_Communication_num_processes"] = 1
grader_config["task_type_parameters_Communication_user_io"] = "std_io"
break
config_path = os.path.join(temp_dir, "grader_config.json")
with open(config_path, "w") as f:
json.dump(grader_config, f)
local_files = {"grader_config.json": config_path}
for file in data["files"]:
local_path = os.path.join(temp_dir, os.path.basename(file["name"]))
with open(local_path, "w") as f:
f.write(file["content"])
local_files[file["name"]] = local_path
return problem_id, grader_config, local_files
async def _upload_files(self, instance: Instance, local_files: Dict[str, str]) -> bool:
"""
Upload all necessary files to the instance.
Args:
instance: The MorphCloud instance
local_files: Dictionary mapping remote paths to local file paths
Returns:
bool: True if all uploads were successful
Raises:
TimeoutError: If uploads time out
"""
for remote_name, local_path in local_files.items():
target_path = f"/workspace/{remote_name}"
dir_path = os.path.dirname(target_path)
if dir_path != "/workspace":
await instance.aexec(f"mkdir -p {dir_path}")
await instance.aupload(local_path, target_path)
await instance.aupload(local_files["grader_config.json"], "/workspace/graders/grader_config.json")
return True
async def _compile_code(self, instance: Instance) -> InstanceExecResponse:
"""
Compile the code on the instance.
Args:
instance: The MorphCloud instance
Returns:
InstanceExecResponse: Result of compilation
Raises:
RuntimeError: If compilation fails
"""
compile_result = await instance.aexec("cd /workspace && ./compile")
if compile_result.exit_code != 0:
raise RuntimeError(f"Compilation error exit code {compile_result.exit_code}\n{compile_result.stderr}")
return compile_result
async def _run_tests(self, instance: Instance, data: Dict[str, Any]) -> Tuple[str, str]:
"""
Run tests and evaluate results.
Args:
instance: The MorphCloud instance
data: Dictionary containing runtime parameters
Returns:
tuple: (score, feedback)
Raises:
TimeoutError: If test execution times out
"""
hard_timeout = data["run_timeout"] / 1000 + 3
run_command = f"cd /workspace && timeout {hard_timeout}s ./run"
run_result = await instance.aexec(run_command)
if run_result.exit_code == 124 or run_result.exit_code == 137 or run_result.exit_code == 143:
return "0", "Time limit exceeded"
if run_result.exit_code != 0 and "Memory limit exceeded" in run_result.stderr:
return "0", "Memory limit exceeded"
if run_result.stdout:
return run_result.stdout.strip(), run_result.stderr.strip()
if run_result.exit_code != 0:
return (
"0",
f"Runtime error with exit code {run_result.exit_code}\n{run_result.stderr}",
)
return "0", "Unknown error"
async def _execute_with_instance(self, instance: Instance, data: Dict[str, Any], temp_dir: str) -> Tuple[str, str]:
"""Execute code using a prepared instance.
Args:
instance: Ready MorphCloud instance
data: Execution data
temp_dir: Temporary directory for file operations
Returns:
Tuple of (score, feedback)
Raises:
Exception: Passes through exceptions for retry handling
"""
await instance.await_until_ready(timeout=300)
problem_id, grader_config, local_files = await self._prepare_files(data, temp_dir)
await self._upload_files(instance, local_files)
try:
await self._compile_code(instance)
except RuntimeError as e:
return "0", str(e)
score, feedback = await self._run_tests(instance, data)
return score, feedback
async def _execute(self, data: Dict[str, Any]) -> Tuple[str, str]:
"""
Internal implementation of execute with no retry logic.
Args:
data: Dictionary containing execution data
Returns:
Tuple of (score, feedback)
Raises:
Exception: If execution fails
"""
instance = None
# Set timeouts to ensure we don't block indefinitely
# INSTANCE_TIMEOUT = 300 # 5 minutes for instance operations
TOTAL_EXECUTION_TIMEOUT = 600 # 10 minutes total execution time
with tempfile.TemporaryDirectory(prefix="morph_exec_") as temp_dir:
snapshot = await self._get_or_create_base_snapshot()
instance = await self.client.instances.astart(
snapshot.id, ttl_seconds=600
) # Auto-terminate after 10 minutes
async with instance:
# Use asyncio.wait_for to add overall timeout to the execution process
return await asyncio.wait_for(
self._execute_with_instance(instance, data, temp_dir),
timeout=TOTAL_EXECUTION_TIMEOUT,
)
async def execute(self, data: Dict[str, Any]) -> Tuple[str, str]:
"""
Execute code on MorphCloud based on the provided data with enhanced debugging and recovery.
Orchestrates the following steps with proper error handling and retries:
1. Prepare an instance (with retry)
2. Set up workspace (with retry)
3. Prepare and upload files (with retry)
4. Compile code (with retry)
5. Run tests (with retry)
Args:
data: Dictionary containing:
- files: List of file objects with name and content fields
- run_timeout: Timeout in milliseconds
- run_memory_limit: Memory limit in MB
Returns:
Tuple of (score, feedback) where:
- score is a string representation of a float between 0.0 and 1.0
- feedback is a string with execution details
"""
# TODO: would be faster to pass info about the subtask as well to create a snapshot per subtask
# would cache the uploads of all files other than the submission: input.txt, correct_output.txt, grader files
# rather than reusing the snapshot that only has the compile/run scripts on it
# currently, run_submission -> client.execute(data) does not easily pass subtask info
# Retry configuration
max_retries = 4
base_delay = 1.0
# Try execution with retries and exponential backoff
for attempt in range(max_retries + 1):
try:
return await self._execute(data)
except asyncio.TimeoutError:
if attempt < max_retries:
print(f"Execution timed out, retrying ({attempt + 1}/{max_retries})")
else:
return "0", "Execution timed out after multiple retries"
except Exception as e:
# Calculate exponential backoff
if attempt < max_retries:
retry_delay = min(base_delay * (2**attempt), 30) # Exponential backoff, capped at 30 seconds
print(
f"Execution failed with {type(e).__name__}: {str(e)}, retrying in {retry_delay:.2f}s ({attempt + 1}/{max_retries})"
)
await asyncio.sleep(retry_delay)
else:
print(f"Execution failed after {max_retries} retries: {type(e).__name__}: {str(e)}")
return "0", f"Execution failed after multiple retries: {str(e)}"
async def _get_or_create_base_snapshot(self):
"""Get or create a snapshot with the necessary dependencies and scripts for evaluation."""
async with self._snapshot_lock:
base_snapshots = await self.client.snapshots.alist(digest="ioi-evaluation-morph")
if not base_snapshots:
print("Creating base snapshot with build-essential cmake and g++")
# Create base snapshot with minimal specs
base_snapshot = await self.client.snapshots.acreate(
vcpus=2,
memory=4096,
disk_size=10240,
metadata={"purpose": "ioi_evaluation"},
)
# Start a temporary instance from the base snapshot
temp_instance = await self.client.instances.astart(
base_snapshot.id, ttl_seconds=900
) # Auto-terminate after 15 minutes
try:
# Wait for the instance to be ready
await temp_instance.await_until_ready(timeout=300)
# Get script contents
compile_script = await self._get_compile_script()
run_script = await self._get_run_script()
# Use temporary directory to store scripts
with tempfile.TemporaryDirectory(prefix="morph_setup_") as temp_dir:
# Create paths for script files
compile_path = os.path.join(temp_dir, "compile.sh")
run_path = os.path.join(temp_dir, "run.sh")
# Write scripts to temp files
with open(compile_path, "w") as f:
f.write(compile_script)
with open(run_path, "w") as f:
f.write(run_script)
async with temp_instance:
# Install dependencies
await temp_instance.aexec("apt-get update && apt-get install -y build-essential cmake g++")
# Create workspace directory
await temp_instance.aexec(
"mkdir -p /workspace && mkdir -p /workspace/graders && chmod 777 /workspace"
)
# Upload scripts to instance
await temp_instance.aupload(compile_path, "/workspace/compile")
await temp_instance.aupload(run_path, "/workspace/run")
# Make scripts executable
await temp_instance.aexec("chmod +x /workspace/compile /workspace/run")
# Create snapshot from the prepared instance
final_snapshot = await temp_instance.asnapshot(digest="ioi-evaluation-morph")
except Exception as e:
# Ensure instance is stopped if anything fails
await temp_instance.astop()
raise e
else:
final_snapshot = base_snapshots[0]
return final_snapshot
async def _get_compile_script(self):
"""Get the compile script content."""
return """#!/bin/bash
manager_files=() # Array to store manager filenames
current_dir="$(pwd)"
# Checker compilation path
checker_dir="$current_dir/checker"
checker_src="$checker_dir/checker.cpp"
if [ -e "$checker_src" ]; then
echo "Compiling checker"
checker_exe="$checker_dir/checker"
g++ -x c++ -std=gnu++17 -O2 -o "$checker_exe" "$checker_src"
chmod +x "$checker_exe"
if [ $? -ne 0 ]; then
echo "Could not compile checker" >&2
exit 1
fi
echo "Compiled checker"
else
echo "No checker found at $checker_src"
fi
# Graders path
graders_dir="$current_dir/graders"
if [ ! -e "$graders_dir" ]; then
echo "Grader folder was not found" >&2
exit 1
fi
# Find and compile manager if it exists
manager_src="$graders_dir/manager.cpp"
if [ -e "$manager_src" ]; then
echo "Compiling manager"
manager_exe="$graders_dir/manager"
g++ -x c++ -std=gnu++17 -O2 -o "$manager_exe" "$manager_src"
chmod +x "$manager_exe"
if [ $? -ne 0 ]; then
echo "Could not compile manager" >&2
exit 1
fi
manager_files+=("manager")
fi
# Process other graders
graders_list=($(ls "$graders_dir" | grep -v 'manager.cpp'))
for grader_name in "${graders_list[@]}"; do
manager_files+=("$grader_name")
done
# Extract problem name and compile necessary files
problem_name='?'
for file in "${manager_files[@]}"; do
if [[ "$file" == *.h && "$file" != "testlib.h" ]]; then
problem_name="${file%.h}"
echo "Problem name: $problem_name"
break
fi
done
files_to_compile=("graders/$problem_name.cpp")
[ -e graders/grader.cpp ] && files_to_compile+=("graders/grader.cpp")
[ -e graders/stub.cpp ] && files_to_compile+=("graders/stub.cpp")
g++ -DEVAL -std=gnu++17 -O2 -pipe -s -o graders/"$problem_name" "${files_to_compile[@]}"
if [ $? -ne 0 ]; then
echo "Failed to compile $problem_name" >&2
exit 1
fi
chmod +x graders/"$problem_name"
echo "Compiled $problem_name from ${files_to_compile[@]} successfully"
echo "Manager files: ${manager_files[@]}"
"""
async def _get_run_script(self):
"""Get the run script content."""
return """#!/usr/bin/env bash
# disable stack limit so you don't get RE with recursion
ulimit -s unlimited
# some problems have 10MB+ input/output files in their test cases and you might get RE. uncomment if needed
# ulimit -f 2097152
# Check if grader_config.json exists
if [ ! -f "graders/grader_config.json" ]; then
echo "Error: graders/grader_config.json not found" >&2
echo "Current directory contents:" >&2
find . -type f -o -type d | sed -e 's/[^-][^\/]*\// |/g' -e 's/|\([^ ]\)/|-\1/' >&2
exit 1
fi
# Read task type, code, and time limit from grader_config.json using grep and sed
TASK_TYPE=$(grep -o '"task_type":[^,}]*' graders/grader_config.json | sed 's/"task_type":\\s*"\\([^"]*\\)"/\\1/')
TASK_NAME=$(grep -o '"code":[^,}]*' graders/grader_config.json | sed 's/"code":\\s*"\\([^"]*\\)"/\\1/')
TIME_LIMIT=$(grep -o '"time_limit":[^,}]*' graders/grader_config.json | sed 's/"time_limit":\\s*\\([^,}]*\\)/\\1/')
MEMORY_LIMIT=$(grep -o '"memory_limit":[^,}]*' graders/grader_config.json | sed 's/"memory_limit":\\s*\\([^,}]*\\)/\\1/')
TASK_EXECUTABLE="graders/$TASK_NAME"
# Set memory limit in KB (convert from bytes)
MEMORY_LIMIT_KB=0
if [ -n "$MEMORY_LIMIT" ]; then
MEMORY_LIMIT_KB=$(($MEMORY_LIMIT / 1024))
# Set the memory limit for the entire script and all child processes
ulimit -v $MEMORY_LIMIT_KB
fi
# "Securely" handle the correct output file
CORRECT_OUTPUT=""
if [ -f "correct_output.txt" ]; then
# Read the content and immediately remove the file
CORRECT_OUTPUT=$(cat correct_output.txt)
rm -f correct_output.txt
fi
# Create a temporary file for solution output
SOLUTION_OUTPUT=$(mktemp)
# Global variables for process tracking
declare -a ALL_PIDS
declare -a FIFO_DIRS
# Define cleanup function - simplified assuming timeout exists
function cleanup {
# Kill all tracked processes silently
exec 2>/dev/null
for pid in "${ALL_PIDS[@]:-}"; do
kill -9 "$pid" 2>/dev/null || true
done
# Clean up FIFO directories
for dir in "${FIFO_DIRS[@]:-}"; do
[ -d "$dir" ] && rm -rf "$dir"
done
# Clean up temporary files
rm -f "$SOLUTION_OUTPUT" || true
exec 2>&2
}
# Set up signal handling
trap cleanup EXIT INT TERM
# Function to handle exit codes consistently across task types
function handle_exit_code {
local exit_code=$1
# Check for known timeout exit codes:
# - 124: standard timeout exit code
# - 137: SIGKILL (128+9), used for hard timeouts
# - 143: SIGTERM (128+15), can also be used for timeouts
if [ $exit_code -eq 124 ] || [ $exit_code -eq 137 ] || [ $exit_code -eq 143 ]; then
echo "0"
echo "Time limit exceeded (${TIME_LIMIT}s)" >&2
return 124
# All other non-zero exit codes should be treated as runtime errors
elif [ $exit_code -ne 0 ]; then
echo "0"
echo "Runtime error with exit code $exit_code" >&2
return $exit_code
fi
# Success case - return 0
return 0
}
# Function to run a command with timeout (simplified assuming timeout exists)
function run_with_timeout {
local soft_limit=$1; shift
local command_to_run="$@"
timeout --preserve-status "$soft_limit" "$@"
return $?
}
case "$TASK_TYPE" in
"Batch")
# Simple batch execution with timeout
run_with_timeout "$TIME_LIMIT" ./$TASK_EXECUTABLE < input.txt > "$SOLUTION_OUTPUT"
exit_code=$?
# Handle non-zero exit codes
handle_exit_code $exit_code
if [ $? -ne 0 ]; then
exit $?
fi
# Check the output if we have a correct output
if [ -n "$CORRECT_OUTPUT" ]; then
# Restore the correct output file
echo "$CORRECT_OUTPUT" > correct_output.txt
# Check if there's a custom checker
if [ -f "checker/checker" ]; then
# Let the checker handle everything
./checker/checker input.txt correct_output.txt "$SOLUTION_OUTPUT"
exit $?
else
# Simple diff-based checking
if diff -bq <(echo "$CORRECT_OUTPUT") "$SOLUTION_OUTPUT" >/dev/null; then
echo "1"
echo "Output is correct (diff)" >&2
else
echo "0"
echo "Output isn't correct (diff)" >&2
exit 0
fi
fi
else
# If no correct output was provided, just output the solution's output
cat "$SOLUTION_OUTPUT"
fi
;;
"Communication")
# Read Communication-specific parameters
NUM_PROCESSES=$(grep -o '"task_type_parameters_Communication_num_processes":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*\\([0-9]*\\)/\\1/' || true)
if [ -z "$NUM_PROCESSES" ]; then
NUM_PROCESSES=1
fi
USER_IO=$(grep -o '"task_type_parameters_Communication_user_io":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*"\\([^"]*\\)"/\\1/' || echo "std_io")
# Read custom manager arguments if they exist
MANAGER_CUSTOM_ARGS=""
if grep -q '"task_type_parameters_Communication_manager_args"' graders/grader_config.json; then
MANAGER_CUSTOM_ARGS=$(grep -o '"task_type_parameters_Communication_manager_args":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*"\\([^"]*\\)"/\\1/')
fi
# Create temporary directories for FIFOs
for i in $(seq 0 $((NUM_PROCESSES-1))); do
FIFO_DIRS[$i]=$(mktemp -d)
# Create FIFOs for this process
mkfifo "${FIFO_DIRS[$i]}/u${i}_to_m"
mkfifo "${FIFO_DIRS[$i]}/m_to_u${i}"
chmod 755 "${FIFO_DIRS[$i]}"
chmod 666 "${FIFO_DIRS[$i]}/u${i}_to_m" "${FIFO_DIRS[$i]}/m_to_u${i}"
done
# Prepare manager arguments
MANAGER_ARGS=""
for i in $(seq 0 $((NUM_PROCESSES-1))); do
MANAGER_ARGS="$MANAGER_ARGS ${FIFO_DIRS[$i]}/u${i}_to_m ${FIFO_DIRS[$i]}/m_to_u${i}"
done
# Add custom manager arguments if specified
if [ -n "$MANAGER_CUSTOM_ARGS" ]; then
MANAGER_ARGS="$MANAGER_ARGS $MANAGER_CUSTOM_ARGS"
fi
# Start all user processes first
for i in $(seq 0 $((NUM_PROCESSES-1))); do
if [ "$USER_IO" = "fifo_io" ]; then
# Pass FIFOs as arguments
ARGS="${FIFO_DIRS[$i]}/m_to_u${i} ${FIFO_DIRS[$i]}/u${i}_to_m"
if [ "$NUM_PROCESSES" -ne 1 ]; then
ARGS="$ARGS $i"
fi
./$TASK_EXECUTABLE $ARGS &
ALL_PIDS+=($!)
else
# Use stdin/stdout redirection
if [ "$NUM_PROCESSES" -ne 1 ]; then
./$TASK_EXECUTABLE "$i" < "${FIFO_DIRS[$i]}/m_to_u${i}" > "${FIFO_DIRS[$i]}/u${i}_to_m" 2>/dev/null &
ALL_PIDS+=($!)
else
./$TASK_EXECUTABLE < "${FIFO_DIRS[$i]}/m_to_u${i}" > "${FIFO_DIRS[$i]}/u${i}_to_m" 2>/dev/null &
ALL_PIDS+=($!)
fi
fi
done
# Run the manager with timeout using direct pipe from input.txt
run_with_timeout "$TIME_LIMIT" ./graders/manager $MANAGER_ARGS < input.txt > "$SOLUTION_OUTPUT"
exit_code=$?
# Handle non-zero exit codes
handle_exit_code $exit_code
if [ $? -ne 0 ]; then
exit $?
fi
# Check the output if we have a correct output AND there's a checker (otherwise we assume the manager handles everything)
if [ -n "$CORRECT_OUTPUT" ] && [ -f "checker/checker" ]; then
# Restore the correct output file
echo "$CORRECT_OUTPUT" > correct_output.txt
# Let the checker handle it
./checker/checker input.txt correct_output.txt "$SOLUTION_OUTPUT"
exit $?
else
# we assume the manager handles it
cat "$SOLUTION_OUTPUT"
fi
;;
*)
echo "0"
echo "Unsupported task type \"$TASK_TYPE\"" >&2
exit 1
;;
esac
"""
def get_morph_client_from_env(session=None) -> MorphCloudExecutionClient:
"""
Creates a MorphCloudExecutionClient instance using environment variables.
Environment variables:
MORPH_API_KEY: API key for MorphCloud
Args:
session: Optional aiohttp.ClientSession to use for HTTP requests
Returns:
MorphCloudExecutionClient: A configured MorphCloud execution client
"""
if not is_morph_available():
raise ImportError(
"MorphCloud is not available and required for this function. Please install MorphCloud with "
"`pip install morphcloud` and add an API key to a `.env` file."
)
load_dotenv()
api_key = os.environ.get("MORPH_API_KEY")
if not api_key:
raise ValueError("MORPH_API_KEY environment variable is required")
return MorphCloudExecutionClient(api_key=api_key)
# noqa: W293

View file

@ -1,224 +0,0 @@
import asyncio
import os
import random
import re
import subprocess
from collections import Counter
from functools import lru_cache
import aiohttp
class PistonError(Exception):
pass
@lru_cache(maxsize=1)
def get_piston_client_from_env(session=None):
piston_endpoints = os.getenv("PISTON_ENDPOINTS")
if piston_endpoints is None:
raise ValueError(
"For IOI/CF problems Piston endpoints running our IOI package are required. Please add a list of valid Piston endpoints to a PISTON_ENDPOINTS variable in a `.env` file."
)
piston_endpoints = sorted(
piston_endpoints.split(",") if piston_endpoints != "slurm" else get_slurm_piston_endpoints()
)
gpu_nb = int(os.getenv("LOCAL_RANK", 0)) # perGPU index
world = int(os.getenv("WORLD_SIZE", 1)) # total GPUs
if world > 1:
print(f"Using a subset of piston endpoints for GPU#{gpu_nb}")
piston_endpoints = piston_endpoints[gpu_nb::world]
random.shuffle(piston_endpoints)
max_requests_per_endpoint = os.getenv("PISTON_MAX_REQUESTS_PER_ENDPOINT", "1")
return PistonClient(piston_endpoints, session, max_requests_per_endpoint=int(max_requests_per_endpoint))
class PistonClient:
"""
A client that will automatically load balance across multiple Piston (https://github.com/engineer-man/piston) workers.
This assumes piston is running our custom cms_ioi package: https://github.com/guipenedo/piston/releases/
We recommend starting the instances with the following script as otherwise some IOI problems will hit default limits:
```
export PISTON_COMPILE_TIMEOUT=60000
export PISTON_RUN_TIMEOUT=60000
export PISTON_OUTPUT_MAX_SIZE=1000000000
export PISTON_MAX_FILE_SIZE=1000000000
export PISTON_DISABLE_NETWORKING=true
export PISTON_REPO_URL=https://github.com/guipenedo/piston/releases/download/pkgs/index
mkdir /piston
sed -i '/app.use(body_parser.urlencoded/c\ app.use(body_parser.urlencoded({ extended: true, limit: \"512mb\" }));' src/index.js
sed -i '/app.use(body_parser.json/c\ app.use(body_parser.json({ limit: \"512mb\" }));' src/index.js
# Start server in background
node src```
Piston docs for API usage: https://piston.readthedocs.io/en/latest/api-v2/
"""
def __init__(
self,
base_endpoint: str | list[str] = "http://ip-10-53-80-65:3223/api/v2",
session=None,
max_requests_per_endpoint=1,
):
self.max_requests_per_endpoint = max_requests_per_endpoint
self.base_endpoints = [base_endpoint] if isinstance(base_endpoint, str) else base_endpoint
if len(self.base_endpoints) == 0:
raise ValueError("No Piston endpoints provided. Please check your PISTON_ENDPOINTS environment variable.")
self.endpoint_ids = {endpoint: i for i, endpoint in enumerate(self.base_endpoints)}
self._session = session
self.endpoint_tokens = asyncio.Queue(maxsize=max_requests_per_endpoint * len(self.base_endpoints))
for _ in range(max_requests_per_endpoint):
for base_endpoint in self.base_endpoints:
self.endpoint_tokens.put_nowait(base_endpoint)
self._endpoint_failures = Counter()
self._unhealthy_endpoints = set()
self._endpoint_failures_lock = asyncio.Lock()
@property
def session(self):
if self._session is None:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(sock_read=30),
connector=aiohttp.TCPConnector(
limit=self.max_requests_per_endpoint * len(self.base_endpoints),
ttl_dns_cache=300,
keepalive_timeout=5 * 60,
),
)
return self._session
async def _wait_for_endpoint(self):
endpoint = await self.endpoint_tokens.get()
return endpoint
async def _release_endpoint(self, endpoint):
await self.endpoint_tokens.put(endpoint)
async def _send_request(self, endpoint, route, data=None, method="post"):
async with self.session.request(
method, f"{endpoint.rstrip('/')}/{route}", json=data, headers={"Content-Type": "application/json"}
) as response:
return await response.json(content_type=None)
async def _send_to_all(self, route, data=None, method="post"):
return await asyncio.gather(
*[self._send_request(endpoint, route, data, method) for endpoint in self.base_endpoints]
)
async def _send_to_one(self, endpoint, route, data=None, method="post"):
return await self._send_request(endpoint, route, data, method)
async def install_package(self, language, version):
return await self._send_to_all("packages", {"language": language, "version": version}, method="post")
async def uninstall_package(self, language, version):
return await self._send_to_all("packages", {"language": language, "version": version}, method="delete")
async def get_supported_runtimes(self):
return await self._send_to_all("runtimes", method="get")
async def _check_failed_endpoint(self, endpoint):
async with self._endpoint_failures_lock:
if endpoint in self._unhealthy_endpoints:
return
try:
await asyncio.sleep(5)
await self.get_supported_runtimes()
except Exception as e:
print(f"Error checking endpoint {endpoint}, dropping it ({e})")
self._unhealthy_endpoints.add(endpoint)
if len(self._unhealthy_endpoints) >= len(self.base_endpoints):
raise PistonError("All endpoints are unhealthy. Please check your Piston workers.")
async def send_execute(self, data, language="cms_ioi", max_retries=5):
data = data | {
"language": language,
"version": "*",
}
base_delay = 1.0
status = None
endpoint = None
for attempt in range(max_retries + 1):
try:
endpoint = await self._wait_for_endpoint()
if attempt > 0:
await asyncio.sleep(1)
async with self.session.post(
f"{endpoint.rstrip('/')}/execute", json=data, headers={"Content-Type": "application/json"}
) as response:
status = response.status
res_json = await response.json(content_type=None)
if status != 200:
raise PistonError(f"Server error. status={status}. {res_json}")
if res_json is None:
raise PistonError(f"Empty response. status={status}")
# piston overloaded
if "run" in res_json and "Resource temporarily unavailable" in res_json["run"].get("stderr", ""):
raise PistonError(f"Piston overloaded: {res_json['run']['stderr']}")
return res_json
except (PistonError, asyncio.TimeoutError, aiohttp.ClientConnectionError, RuntimeError) as e:
# Only retry if we haven't reached max retries yet
if attempt < max_retries:
# Calculate backoff with jitter
delay = min(base_delay * (2**attempt), 10) # Exponential backoff, capped at 10 seconds
jitter = delay * 0.2 * (2 * asyncio.get_event_loop().time() % 1 - 0.5) # Add ±10% jitter
retry_delay = delay + jitter
print(f"Retrying in {retry_delay:.2f} seconds [{self.endpoint_ids[endpoint]}] {endpoint} - {e}")
# special case: worker died
if isinstance(e, aiohttp.ClientConnectionError) and "Connect call failed" in str(e):
await self._check_failed_endpoint(endpoint)
else:
# hopefully we won't get this one again
await self._release_endpoint(endpoint)
endpoint = None
await asyncio.sleep(retry_delay)
else:
await self._check_failed_endpoint(endpoint)
except Exception as e:
print(f"Propagating exception {type(e)}: {e}")
raise e
finally:
# Ensure endpoint is always released, even if an exception occurs
if endpoint is not None:
try:
await self._release_endpoint(endpoint)
except Exception as e:
print(f"Error releasing endpoint {endpoint}: {e}")
endpoint = None
def get_slurm_piston_endpoints():
"""Get list of active piston worker endpoints from squeue output"""
# Run squeue command to get job name, hostname and status, filtering for RUNNING state
result = subprocess.run(
["squeue", '--format="%j %N %T"', "--noheader", "--states=RUNNING"], capture_output=True, text=True
)
# Split output into lines and skip header
lines = result.stdout.strip().split("\n")
endpoints = []
for line in lines:
# Parse job name from squeue output
fields = line.split()
job_name = fields[0].strip('"') # Remove quotes
hostname = fields[1]
# Extract port if job name matches pattern
match = re.match(r"piston-worker-(\d+)", job_name)
if match:
port = match.group(1)
endpoints.append(f"http://{hostname}:{port}/api/v2")
return endpoints

View file

@ -1,11 +0,0 @@
from itertools import islice
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
return iterable
it = iter(iterable)
while batch := list(islice(it, n)):
yield batch

View file

@ -1,65 +0,0 @@
import logging
import datasets
from datasets import DatasetDict, concatenate_datasets
from ..configs import ScriptArguments
logger = logging.getLogger(__name__)
def get_dataset(args: ScriptArguments) -> DatasetDict:
"""Load a dataset or a mixture of datasets based on the configuration.
Args:
args (ScriptArguments): Script arguments containing dataset configuration.
Returns:
DatasetDict: The loaded datasets.
"""
if args.dataset_name and not args.dataset_mixture:
logger.info(f"Loading dataset: {args.dataset_name}")
return datasets.load_dataset(args.dataset_name, args.dataset_config)
elif args.dataset_mixture:
logger.info(f"Creating dataset mixture with {len(args.dataset_mixture.datasets)} datasets")
seed = args.dataset_mixture.seed
datasets_list = []
for dataset_config in args.dataset_mixture.datasets:
logger.info(f"Loading dataset for mixture: {dataset_config.id} (config: {dataset_config.config})")
ds = datasets.load_dataset(
dataset_config.id,
dataset_config.config,
split=dataset_config.split,
)
if dataset_config.columns is not None:
ds = ds.select_columns(dataset_config.columns)
if dataset_config.weight is not None:
ds = ds.shuffle(seed=seed).select(range(int(len(ds) * dataset_config.weight)))
logger.info(
f"Subsampled dataset '{dataset_config.id}' (config: {dataset_config.config}) with weight={dataset_config.weight} to {len(ds)} examples"
)
datasets_list.append(ds)
if datasets_list:
combined_dataset = concatenate_datasets(datasets_list)
combined_dataset = combined_dataset.shuffle(seed=seed)
logger.info(f"Created dataset mixture with {len(combined_dataset)} examples")
if args.dataset_mixture.test_split_size is not None:
combined_dataset = combined_dataset.train_test_split(
test_size=args.dataset_mixture.test_split_size, seed=seed
)
logger.info(
f"Split dataset into train and test sets with test size: {args.dataset_mixture.test_split_size}"
)
return combined_dataset
else:
return DatasetDict({"train": combined_dataset})
else:
raise ValueError("No datasets were loaded from the mixture configuration")
else:
raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")

View file

@ -7,7 +7,6 @@ from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id
if TYPE_CHECKING:
from trl import GRPOConfig, SFTConfig, ModelConfig
import base64
import os
@ -25,11 +24,7 @@ VLLM_SLURM_PREFIX = [
def register_lighteval_task(
configs: Dict[str, str],
eval_suite: str,
task_name: str,
task_list: str,
num_fewshot: int = 0,
configs: Dict[str, str], eval_suite: str, task_name: str, task_list: str, num_fewshot: int = 0
):
"""Registers a LightEval task configuration.
@ -51,10 +46,10 @@ def register_lighteval_task(
LIGHTEVAL_TASKS = {}
register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "math_500", "math_500", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "aime24", "aime24", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "aime25", "aime25", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "gpqa", "gpqa:diamond", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime25", "aime25", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb", "lcb:codegeneration", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb_v4", "lcb:codegeneration_v4", 0)
@ -67,9 +62,7 @@ SUPPORTED_BENCHMARKS = get_lighteval_tasks()
def run_lighteval_job(
benchmark: str,
training_args: Union["SFTConfig", "GRPOConfig"],
model_args: "ModelConfig",
benchmark: str, training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig"
) -> None:
task_list = LIGHTEVAL_TASKS[benchmark]
model_name = training_args.hub_model_id
@ -79,7 +72,7 @@ def run_lighteval_job(
if get_param_count_from_repo_id(model_name) >= 30_000_000_000:
tensor_parallel = True
else:
num_gpus = 2 # Hack while cluster is full
num_gpus = 8
tensor_parallel = False
cmd = VLLM_SLURM_PREFIX.copy()
@ -95,10 +88,7 @@ def run_lighteval_job(
f"{model_args.trust_remote_code}",
]
if training_args.system_prompt is not None:
# encode to base64 to avoid issues with special characters
# we decode in the sbatch script
prompt_encoded = base64.b64encode(training_args.system_prompt.encode()).decode()
cmd_args.append(prompt_encoded)
cmd_args.append(f"'{training_args.system_prompt}'")
cmd[-1] += " " + " ".join(cmd_args)
subprocess.run(cmd, check=True)

View file

@ -76,8 +76,7 @@ def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig):
# If the revision exists, we next check it has a README file
if training_args.hub_model_revision in revisions:
repo_files = list_repo_files(
repo_id=training_args.hub_model_id,
revision=training_args.hub_model_revision,
repo_id=training_args.hub_model_id, revision=training_args.hub_model_revision
)
if "README.md" in repo_files and training_args.overwrite_hub_revision is False:
raise ValueError(

View file

@ -21,10 +21,3 @@ _e2b_available = _is_package_available("e2b")
def is_e2b_available() -> bool:
return _e2b_available
_morph_available = _is_package_available("morphcloud")
def is_morph_available() -> bool:
return _morph_available

View file

@ -1,12 +1,16 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
from trl import ModelConfig
from ..configs import GRPOConfig, SFTConfig
def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer:
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
def get_tokenizer(
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig, auto_set_chat_template: bool = True
) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@ -16,27 +20,7 @@ def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
if training_args.chat_template is not None:
tokenizer.chat_template = training_args.chat_template
elif auto_set_chat_template and tokenizer.get_chat_template() is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return tokenizer
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
"""Get the model"""
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
return model

View file

@ -1,120 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import requests
class RoutedMorphSandbox:
"""
Client for the MorphCloud router service that mimics the API of MorphCloud's Sandbox.
This class provides a simple interface to execute code via a central MorphCloud router,
which manages sandbox creation and cleanup. It allows batch processing of multiple scripts
in a single request for improved efficiency.
Attributes:
router_url (str): The URL of the MorphCloud router service.
timeout (int): Execution timeout in seconds.
request_timeout (int): HTTP request timeout in seconds.
"""
def __init__(self, router_url: str, timeout: int = 300, request_timeout: int = 60):
"""
Initialize the routed MorphCloud sandbox client.
Args:
router_url: The URL of the MorphCloud router, including host and port.
timeout: Default execution timeout in seconds.
request_timeout: Default HTTP request timeout in seconds.
"""
self.router_url = router_url
self.timeout = timeout
self.request_timeout = request_timeout
def run_code(
self,
scripts: List[str],
languages: Optional[List[str]] = None,
timeout: Optional[int] = None,
request_timeout: Optional[int] = None,
) -> List:
"""
Execute multiple scripts using MorphCloud via the router.
Args:
scripts: List of code scripts to execute.
languages: List of programming languages for each script. If None, defaults to Python for all scripts.
timeout: Execution timeout in seconds. If None, uses the instance timeout.
request_timeout: HTTP request timeout in seconds. If None, uses the instance request_timeout.
Returns:
List of execution results with text and exception_str properties.
"""
actual_timeout = timeout if timeout is not None else self.timeout
actual_request_timeout = request_timeout if request_timeout is not None else self.request_timeout
# Default to Python for all scripts if languages is not provided
if languages is None:
languages = ["python"] * len(scripts)
payload = {
"scripts": scripts,
"languages": languages,
"timeout": actual_timeout,
"request_timeout": actual_request_timeout,
}
try:
endpoint = f"http://{self.router_url}/execute_batch"
response = requests.post(endpoint, json=payload, timeout=actual_request_timeout)
if response.status_code != 200:
error = f"Request to MorphCloud router failed with status code: {response.status_code}"
print(error)
results = []
for _ in scripts:
results.append(type("obj", (object,), {"text": None, "exception_str": error}))
return results
response_data = response.json()
results = []
for item in response_data:
# Log the response data to see what we're getting
# print(f"RoutedMorphSandbox: Got response item: {item}")
result = type(
"obj",
(object,),
{
"text": item.get("text"),
"exception_str": item.get("exception_str"),
},
)
results.append(result)
return results
except Exception as e:
error = f"Error communicating with MorphCloud router: {str(e)}"
print(error)
results = []
for _ in scripts:
results.append(type("obj", (object,), {"text": None, "exception_str": error}))
return results

View file

@ -1,109 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import requests
from e2b_code_interpreter.models import Execution, ExecutionError, Result
class RoutedSandbox:
"""
A sandbox environment that routes code execution requests to the E2B Router.
This class is designed for batched execution of scripts, primarily for Python code.
It mimics the usage of 'Sandbox' from 'e2b_code_interpreter', but adds support for batch processing.
Attributes:
router_url (str): The URL of the E2B Router to which code execution requests are sent.
"""
def __init__(self, router_url: str):
"""
Initializes the RoutedSandbox with the specified router URL.
Args:
router_url (str): The URL of the E2B Router.
"""
self.router_url = router_url
def run_code(
self,
scripts: list[str],
languages: Optional[List[str]] = None,
timeout: Optional[int] = None,
request_timeout: Optional[int] = None,
) -> list[Execution]:
"""
Executes a batch of scripts in the sandbox environment.
Args:
scripts (list[str]): A list of code scripts to execute.
languages (list[str], optional): List of programming languages for each script. If None, defaults to Python for all scripts.
timeout (Optional[int], optional): The maximum execution time for each script in seconds. Defaults to 300 seconds.
request_timeout (Optional[int], optional): The timeout for the HTTP request in seconds. Defaults to 30 seconds.
Returns:
list[Execution]: A list of Execution objects containing the results, logs, and errors (if any) for each script.
"""
# Set default values for timeouts if not provided
if timeout is None:
timeout = 300 # Default to 5 minutes
if request_timeout is None:
request_timeout = 30 # Default to 30 seconds
# Default to Python for all scripts if languages is not provided
if languages is None:
languages = ["python"] * len(scripts)
# Prepare the payload for the HTTP POST request
payload = {
"scripts": scripts,
"languages": languages,
"timeout": timeout,
"request_timeout": request_timeout,
}
# Send the request to the E2B Router
response = requests.post(f"http://{self.router_url}/execute_batch", json=payload)
if not response.ok:
print(f"Request failed with status code: {response.status_code}")
# Parse the response and construct Execution objects
results = response.json()
output = []
for result in results:
if result["execution"] is None:
# If execution is None, create an empty Execution object
# This can happen when a script times out or fails to execute
execution = Execution()
else:
execution = Execution(
results=[Result(**r) for r in result["execution"]["results"]],
logs=result["execution"]["logs"],
error=(ExecutionError(**result["execution"]["error"]) if result["execution"]["error"] else None),
execution_count=result["execution"]["execution_count"],
)
output.append(execution)
return output
if __name__ == "__main__":
# for local testing launch an E2B router with: python scripts/e2b_router.py
sbx = RoutedSandbox(router_url="0.0.0.0:8000")
codes = ["print('hello world')", "print('hello world)"]
executions = sbx.run_code(codes) # Execute Python inside the sandbox
print(executions)

View file

@ -9,5 +9,3 @@ def init_wandb_training(training_args):
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
if training_args.wandb_project is not None:
os.environ["WANDB_PROJECT"] = training_args.wandb_project
if training_args.wandb_run_group is not None:
os.environ["WANDB_RUN_GROUP"] = training_args.wandb_run_group

Some files were not shown because too many files have changed in this diff Show more