Compare commits

...

49 commits

Author SHA1 Message Date
Pauline Bailly-Masson
1416fa0cf2
🔒 pin tests.yml actions to commit SHAs (#721) 2026-04-02 16:03:12 +02:00
Quentin Gallouédec
0e06249d1c
Update README.md 2025-07-17 13:20:00 -07:00
Quentin Gallouédec
7e700c6218
Update citation (#688) 2025-07-07 10:23:08 -07:00
lewtun
b806e1092a
Bump vLLM and TRL (#665)
* Bump vLLM and TRL

* Fix Makefile
2025-05-28 13:47:25 +02:00
lewtun
a6b4f668fb
Fix Weka refresh (#666)
* Fix Weka refresh

* Update evaluate.slurm
2025-05-28 13:45:48 +02:00
lewtun
01b4351c45
Set DP=2 for smol model evals (#664)
* Set DP=2 for smol model evals

Temporary hack while the HF cluster is at max capacity :)

* Style
2025-05-28 09:23:12 +02:00
lewtun
722f144d21
Refresh Weka on Slurm (#662)
* Refresh Weka on Slurm

* Include current working dir
2025-05-27 19:21:15 +02:00
lewtun
33f84def0d
Align EOS token ID between tokenizer and generation config (#663)
* Align EOS token ID between tokenizer and generation config

* Fix
2025-05-27 17:20:13 +02:00
lewtun
9eef995b4d
Bump deps (#656) 2025-05-27 15:38:21 +02:00
lewtun
5ac5971ea5
Add OpenR1-Distill recipe (#661) 2025-05-26 17:57:44 +02:00
lewtun
57e85b522f
Add better logging defaults for GRPO (#657) 2025-05-25 13:24:52 +02:00
Guilherme Penedo
c1e1192294
GRPO with codeforces problems (#627)
* add

* update

* updates

* updates #2

* weighted_sum and python fixes

* bugfix

* merging ioi/cf setups

* integrating the morph changes

* move morph_client

* run style

* small changes for mixed languages training

* revert grpo.py changes

* piston readme

* local test fetching

* bug fixes

* updated readme

* style fixes

* style fixes 2

* deps changes

* import sorting

* fix tests

* Update README.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update README.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-25 11:55:27 +02:00
lewtun
db2d9b011a
Bump lower bound on liger-kernel (#654)
Related to https://github.com/huggingface/open-r1/pull/653

(I forgot to include this in that PR)
2025-05-22 08:44:13 +02:00
lewtun
8067149e90
Bump DeepSpeed to 0.16.8 to fix OOM on Qwen3 (#653) 2025-05-21 22:25:57 +02:00
lewtun
9366aa2df3
Add dataset mixer (#647)
* Prototype

* Clean up

* Refactor

* Add tests

* Add doc and make scripts work

* Tune doc

* Up

* Tune

* Add column verification

* Fix types

* Fix YAML

* Fix types

* Fix doc

* f

* f
2025-05-20 11:40:42 +02:00
Quentin Gallouédec
5e0c210f9c
use hf papers (#646) 2025-05-19 13:48:14 +02:00
lewtun
ebd5913a85
Bump LightEval (#643) 2025-05-16 10:52:05 +02:00
Edward Beeching
ea5b7edf22
Add dataset filtering script (#637)
* add dataset filtering script

* remove subset selection

* save wip

* save wip

* update filter script

* refactor to run on chunks

* rename script

* cleanup

* update dapo filtering

* fixes

* dapo filt config

* udpate compute pass rate

* clean

* update readme and config

* add merging snippet
2025-05-16 10:26:49 +02:00
lewtun
4fc2a3ff82
Add time to Slurm (#639) 2025-05-09 19:19:51 +02:00
lewtun
c802f00512
Use pass@1 for all evals (#633)
* Use pass@1 for all evals

* Update scores
2025-05-09 17:42:36 +02:00
Edward Beeching
21b48fbe46
soft_overlong_punishment from DAPO paper (#638)
* soft_overlong_punishment_reward

* tests

* doc string updated

* style

* non-sensical import removed

* Update src/open_r1/rewards.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* max_completion_length set to 3.6

* style

* quality

* test case added for <max_com_len

* style

* max_len +cache len updated based on num chars

* max_len_completion docstring added in cofig

* Update configs.py

* refactor soft overlong penalty to use completion ids

* change decription to be tokens

---------

Co-authored-by: shirinyamani <yamani.shirin@ucalgary.ca>
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-09 17:26:34 +02:00
lewtun
6a0cd5c8ad
Fix style again :) (#636) 2025-05-08 16:29:01 +02:00
Andrei
af81114044
Code Execution using Morph Cloud (#614)
* initial commit for morphcloud sandbox support

* initial

* fixed prints in morph client for ioi

* updated import

* context manager

* removed unnecessary comments

* more intelligent instance/snapshot management

* update

* Add documentation for Morph integration

* Delete MORPH_INTEGRATION.md

* added retry and modularity to morph client

* updates to kwargs and setup.py

* Update setup.py

* added languages codepath + fixed slurm + added m
orph tests

* make quality formatting fixes

* conditional imports for morph

---------

Co-authored-by: arb8020 <arbeightytwenty@gmail.com>
2025-05-08 08:59:54 +02:00
lewtun
52520a6713
Fix style (#631)
* Fix style

* Fix

* Add jieba
2025-05-05 15:49:10 +02:00
Lewis Tunstall
c8b989109d Fix style 2025-05-02 14:45:17 +00:00
lewtun
9373ad3055
Update README.md 2025-04-30 22:16:18 +02:00
binary-husky
65211f4824
🦜Enhance repetition penalty reward for language that cannot be split by whitespace (#516)
* Update rewards.py

* add test for repetition reward with language

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-04-30 22:02:59 +02:00
lewtun
75c3999180
Bump LightEval to enable DP>1 (#629)
* Bump LightEval to enable DP>1

* Remove redundant arg

* Update eval scores

* Fix slurm
2025-04-30 22:02:20 +02:00
lewtun
50590a41b9
Enable data and tensor parallelism for GRPO (#626)
* Bump deps

* Fix SLurm

* Fix
2025-04-26 11:50:08 +02:00
Edward Beeching
715c8787fb
add back grad accumulations steps (#612) 2025-04-17 16:41:39 +02:00
lewtun
4c9b0f25d9
Fix TP once and for all :) (#613)
* Update evaluation.py

* Fix import
2025-04-17 15:25:59 +02:00
lewtun
14a81d2bd4
Update evaluation.py (#611) 2025-04-17 11:11:49 +02:00
lewtun
5112bfc401
Fix SFT for base models (#604)
* Fix pad token bug in SFT

* Add ChatML default

* Clean up

* Refactor grpo model load

* Add doc

* Bump deepspeed
2025-04-16 11:45:50 +02:00
lewtun
bcbb1da401
Update evaluation.py (#608) 2025-04-16 10:37:38 +02:00
lewtun
8eb1b7860a
Set DP=1 due to vLLM <> LightEval hanging (#600)
* Update evaluate.slurm

* Disable DP

* Fix
2025-04-16 10:24:33 +02:00
lewtun
8cf42663fd
Clean up recipes (#596) 2025-04-11 20:09:15 +02:00
Edward Beeching
068f13f236
Hotfix bin reward (#597)
* add WIP code GRPO configs

* hotfix bin reward

* remove unwanted files

* remote configs
2025-04-11 17:45:38 +02:00
lewtun
04dbf21989
Bump TRL and vLLM (#595)
* Bump TRL and vLLM

* Fix style

* Bump liger

* Add liger
2025-04-11 16:32:33 +02:00
Edward Beeching
c1eadaa097
E2B Router bug fixes (#592)
* fix eval system prompt

* style

* fix a rare issue where the execution is None

* fixes a bug in the e2b router
2025-04-11 14:04:59 +02:00
Edward Beeching
3a0e89678c
Fix eval system prompt (#591)
* fix eval system prompt

* style
2025-04-11 11:23:06 +02:00
Shenghang Tsai
2a7bb45f05
Update README.md (#590) 2025-04-10 13:11:35 +02:00
lewtun
bf08f56849
[WIP] Bump lighteval with proper pass@1 (#584)
* Bump lighteval with proper pass@1

* Bump lighteval

* Update AIME24
2025-04-08 20:53:34 +02:00
Edward Beeching
1b3bf043dc
Adds a E2B router server that executes batches of scripts (#561)
* adds a dedicated e2b server to handle batches of requests

* fix reward tests

* update slow reward

* style

* updates e2b router to be more generic

* refactor

* refactoring

* licence, cleanup

* update tests

* style

* fix import when e2b not present

* style

* rename sandbox file

* rename to RoutedSandbox

* update readme

* nits

* nits2

* unlimited max time

* update logs path
2025-04-07 21:01:06 +02:00
lewtun
2636a2130f
Add WandB groups to logging (#573) 2025-04-02 15:48:59 +02:00
lewtun
ca8664df1c
Fix missing prompt columns in recipes (#574) 2025-04-02 15:48:48 +02:00
lewtun
4f5b21e21d
Fix accuracy reward for math (#566)
* Fix accuracy reward for math

* Add typing

* Add unit test

* Return None for invalid samples

* Fix order of answers

* Fix type

* Use None for non-verifiable answers
2025-04-01 12:04:26 +02:00
Edward Beeching
9915e06f1e
Async code reward fixes (#546)
* expose num parallel code executions

* add e2b benchmarking script

* adds new parallel code execution with better execption handling

* style

* update default

* increase sandbox timeout

* Add pretty table and Sandbox IDs

* Add Sandbox ID

* fix merge

---------

Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2025-03-28 14:08:15 +01:00
Zhou Shao
1802bec75f
fix dataset parsing error (#540)
* fix dataset parsing error

support defined question field to fix errors when datasets' question field is not 'problem'

* add question field config

add script_args: question field

* refactor: datasets prompt column

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-03-28 13:17:04 +01:00
lewtun
4ec555b0c8
Restore single-node instructions to run GRPO (#549) 2025-03-27 10:29:07 +01:00
67 changed files with 4159 additions and 997 deletions

View file

@ -16,9 +16,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python environment
uses: actions/setup-python@v5
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version: 3.10.10
- name: Install dependencies

3
.gitignore vendored
View file

@ -177,4 +177,5 @@ logs/
eval_results/
results/
.vscode/
.vscode/
.python-version

View file

@ -8,10 +8,11 @@ check_dirs := src tests
# dev dependencies
install:
uv venv openr1 --python 3.11 && . openr1/bin/activate && uv pip install --upgrade pip
uv pip install vllm==0.7.2
uv pip install setuptools
uv pip install flash-attn --no-build-isolation
uv venv openr1 --python 3.11
. openr1/bin/activate && uv pip install --upgrade pip && \
uv pip install vllm==0.8.5.post1 && \
uv pip install setuptools && \
uv pip install flash-attn --no-build-isolation && \
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]"
style:
@ -46,8 +47,7 @@ evaluate:
--use-chat-template \
--output-dir data/evals/$(MODEL); \
else \
lighteval vllm $$MODEL_ARGS "custom|$(TASK)|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $$MODEL_ARGS "lighteval|$(TASK)|0|0" \
--use-chat-template \
--output-dir data/evals/$(MODEL); \
fi

458
README.md
View file

@ -21,10 +21,9 @@
The goal of this repo is to build the missing pieces of the R1 pipeline such that everybody can reproduce and build on top of it. The project is simple by design and mostly consists of:
- `src/open_r1`: contains the scripts to train and evaluate models as well as generate synthetic data:
- `src/open_r1`: contains the scripts to train models as well as generate synthetic data:
- `grpo.py`: trains a model with GRPO on a given dataset.
- `sft.py`: performs a simple SFT of a model on a dataset.
- `evaluate.py`: evaluates a model on the R1 benchmarks.
- `generate.py`: generates synthetic data from a model using [Distilabel](https://github.com/argilla-io/distilabel).
- `Makefile`: contains easy-to-run commands for each step in the R1 pipeline leveraging the scripts above.
@ -42,6 +41,7 @@ We will use the DeepSeek-R1 [tech report](https://github.com/deepseek-ai/DeepSee
## News 🗞️
* **🧑‍🍳 [2025/05/26] (Step 1 completed!)** We release [**Mixture-of-Thoughts**](https://huggingface.co/datasets/open-r1/Mixture-of-Thoughts)--a curated reasoning dataset of 350k verified traces distilled from R1. The dataset spans tasks in mathematics, coding, and science, and is designed to teach language models to reason step-by-step. We also provide a recipe to train [OpenR1-Distill-7B](https://huggingface.co/open-r1/OpenR1-Distill-7B), which replicates the reasoning capabilities of [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) and marks the completion of step 1 in the Open R1 project.
* **⚡️ [2025/03/11] [(update #3)](https://huggingface.co/blog/open-r1/update-3):** We release the [**CodeForces-CoTs**](https://huggingface.co/datasets/open-r1/codeforces-cots) dataset of 10k competitive programming problems and 100k solutions distilled from R1. We also release IOI24: a new benchmark of _very_ hard problems from international olympiads. A 7B Qwen model trained on CodeForces-CoTs can outperform Claude 3.7 Sonnet on IOI24, while a 32B model can outperform R1 itself.
* **∞ [2025/02/10] [(update #2)](https://huggingface.co/blog/open-r1/update-2):** We release the [**OpenR1-Math-220k**](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k) dataset of 220k traces distilled from R1 on a new version of NuminaMath. Models trained on this dataset match the performance of DeepSeek's distilled ones.
* **🔥 [2025/02/02] [(update #1)](https://huggingface.co/blog/open-r1/update-1):** We implement the first parts of the [training](https://github.com/huggingface/open-r1?tab=readme-ov-file#training-models), [inference](https://github.com/huggingface/open-r1?tab=readme-ov-file#data-generation), and [evaluation](https://github.com/huggingface/open-r1?tab=readme-ov-file#reproducing-deepseeks-evaluation-results) pipelines. Let's go!
@ -69,11 +69,11 @@ uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --u
Next, install vLLM and FlashAttention:
```shell
uv pip install vllm==0.7.2
uv pip install vllm==0.8.5.post1
uv pip install setuptools && uv pip install flash-attn --no-build-isolation
```
This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
This will also install PyTorch `v2.6.0` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
```shell
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]"
@ -100,25 +100,30 @@ sudo apt-get install git-lfs
## Training models
We support training models with either DDP or DeepSpeed (ZeRO-2 and ZeRO-3). For example, to run SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), run:
> [!NOTE]
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
We support training models with either DDP or DeepSpeed (ZeRO-2 and ZeRO-3). For example, to perform SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/Mixture-of-Thoughts](https://huggingface.co/datasets/open-r1/Mixture-of-Thoughts), run:
```shell
# Train via command line
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name open-r1/OpenR1-Math-220k \
--learning_rate 1.0e-5 \
--num_train_epochs 1 \
--packing \
--max_seq_length 16384 \
--per_device_train_batch_size 16 \
--model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \
--dataset_name open-r1/Mixture-of-Thoughts \
--dataset_config all \
--eos_token '<|im_end|>' \
--learning_rate 4.0e-5 \
--num_train_epochs 5 \
--max_seq_length 32768 \
--per_device_train_batch_size 2 \
--gradient_checkpointing \
--bf16 \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
--use_liger_kernel \
--output_dir data/OpenR1-Distill-7B
# Train via YAML config
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--config recipes/OpenR1-Distill-7B/sft/config_distill.yaml
```
Currently, the following tasks are supported:
@ -132,62 +137,160 @@ Currently, the following tasks are supported:
By default, these scripts will push each model to your Hugging Face Hub username, i.e. `{username}/{model_name}-{task}`. You can override the parameters in each YAML config by appending them to the command as follows:
```shell
# Change batch size, number of epochs etc
# Change the base model to a smaller variant
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--per_device_train_batch_size=1 --num_train_epochs=5
--config recipes/OpenR1-Distill-7B/sft/config_distill.yaml \
--model_name_or_path Qwen/Qwen3-0.6B-Base \
--hub_model_id OpenR1-Distill-0.6B \
--output_dir data/OpenR1-Distill-0.6B
```
If you also wish to override the Weights and Biases default settings, you can do so as follows:
```shell
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--config recipes/OpenR1-Distill-7B/sft/config_distill.yaml
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
```
> [!NOTE]
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
**🚨 WARNING 🚨**
### SFT
Most base models like `meta-llama/Llama-3.2-1B` do not have a chat template, so we set ChatML as the default during training. However, for Qwen base models like `Qwen/Qwen2.5-1.5B`, a chat template is pre-defined in the tokenizer, so the EOS token must be set accordingly, e.g.
To run SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), run:
```diff
# Align EOS token with chat template for Qwen base models
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B \
+ --eos_token '<|im_end|>'
--dataset_name open-r1/Mixture-of-Thoughts \
--dataset_config all \
--learning_rate 4.0e-5 \
--num_train_epochs 1 \
--max_seq_length 32768 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
--bf16 \
--use_liger_kernel \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
```
If you wish to use a custom chat template (e.g. Llama or Gemma), then the chat template and associated EOS token must be provided:
```diff
# Align EOS token with custom chat template
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path meta-llama/Llama-3.2-1B \
+ --chat_template "$(cat llama_chat_template.jinja)" \
+ --eos_token '<|eot_id|>' \
--dataset_name open-r1/Mixture-of-Thoughts \
--dataset_config all \
--learning_rate 4.0e-5 \
--num_train_epochs 1 \
--max_seq_length 32768 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
--bf16 \
--use_liger_kernel \
--output_dir data/Llama-3.2-1B-Open-R1-Distill
```
### SFT distillation
We provide a recipe to reproduce the reasoning capabilities of [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B), starting from the same base model. To do so, run:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
src/open_r1/sft.py \
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--config recipes/OpenR1-Distill-7B/sft/config_distill.yaml
```
The result will be a model like [open-r1/OpenR1-Distill-7B](https://huggingface.co/open-r1/OpenR1-Distill-7B), with the following downstream performance:
| Model | AIME 2024 | MATH-500 | GPQA Diamond | LiveCodeBench v5 |
|-----------------------------|-----------|----------|--------------|------------------|
| OpenR1-Distill-7B | 52.7 | 89.0 | 52.8 | 39.4 |
| DeepSeek-R1-Distill-Qwen-7B | 51.3 | 93.5 | 52.4 | 37.4 |
You can adjust the YAML config to train on a different base model or dataset.
### GRPO
We use TRL's new distributed vLLM server and GRPOTraining in order to scale to larger >7B models. We provide an example slurm script:
We use TRL's [vLLM backend](https://huggingface.co/docs/trl/speeding_up_training?vllm+examples=GRPO#vllm-for-fast-generation-in-online-methods) to scale training to large models across multiple nodes. For single-node training of smol models across 8 GPUs, use `vllm_mode="colocate"` to run vLLM in the same process as the training script:
```shell
sbatch --job-name=trl-Qwen2.5-Math-7B-config_simple_rl --nodes=2 slurm/train.slurm Qwen2.5-Math-7B grpo config_simple_rl zero3
ACCELERATE_LOG_LEVEL=info \
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
src/open_r1/grpo.py --config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml \
--vllm_mode colocate
```
You will need to adapt the `slurm/train.slurm` script to match your cluster.
> [!WARNING]
> The chat template used in the distilled DeepSeek models omits the contents of the reasoning block within the `<think>` and `</think>` tags. It also prefills the assistant response with `<think>` which interferes with the format reward function. To handle that, it is important to override the chat template as done in e.g. [recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml](./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml).
Our final [model](https://huggingface.co/Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on MATH-500, demonstrating a 17%+ improvement over the base model.
For multi-node training on N+1 nodes, with 1 node running the vLLM server and N nodes running training, we provide an example Slurm script. For example, to run the above example on 1+1 nodes with data parallelism, run:
```shell
sbatch --nodes=2 slurm/train.slurm --model Qwen2.5-1.5B-Instruct --task grpo --config demo --accelerator zero2 --dp 8 --tp 1
```
See the [Launching jobs on a Slurm cluster](#launching-jobs-on-a-slurm-cluster) section for more details.
#### GRPO dataset filtering
We provide support to filter datasets by generating and computing pass rate on veriable tasks, see this [README](scripts/pass_rate_filtering/README.md)
#### 👨‍💻 Training with a code interpreter
We provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we use [E2B](https://e2b.dev) sandboxes, which are fast and cheap to run. To use this reward function, first install the necessary dependencies:
We provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we support multiple sandbox providers:
1. [E2B](https://e2b.dev) - Fast, cloud-based sandboxes with focus on Python execution
2. [Morph](https://cloud.morph.so/web/) - Cloud-based sandboxes with broader language support - Python/JS/C++/Rust
To use the code reward function, first install the necessary dependencies:
```shell
uv pip install -e '.[code]'
```
Then create a `.env` file and place an API token from E2B within it:
##### E2B Provider
To use E2B sandboxes, create a `.env` file and add your E2B API token:
```
E2B_API_KEY="e2b_xxx"
```
Then make sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):
##### Morph Provider
To use Morph, first install the morphcloud package:
```shell
pip install morphcloud
```
Then add your Morph API token to the `.env` file:
```
MORPH_API_KEY="YOUR_MORPH_API_KEY"
```
To specify which provider to use, add the `provider_type` parameter in your configuration:
```yaml
# For E2B
provider_type: e2b
# For Morph
provider_type: morph
```
##### Dataset Requirements
Make sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):
```python
{
"language": "python",
"language": "python", # Morph supports more languages including C++, Java, etc.
"test_cases": [
{
"input": "4\n4\n0001\n1000\n0011\n0111\n3\n010\n101\n0\n2\n00000\n00001\n4\n01\n001\n0001\n00001\n",
@ -198,58 +301,94 @@ Then make sure your dataset contains a `verification_info` column with the follo
}
```
For example, to train a smol model on Python problems, run:
For example, to train a smol model on Python problems, start the vLLM server:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-1.5B-Instruct
```
#### IOI problems
Then run training with:
We provide a `ioi_code_reward` reward function for executing problems from [IOI](https://hf.co/datasets/open-r1/ioi) using [piston](https://github.com/engineer-man/piston).
```shell
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info \
accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes=7 \
src/open_r1/grpo.py --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
```
To get piston workers running, see [slurm/piston/README.md](./slurm/piston/README.md).
Set your environment variable `PISTON_ENDPOINTS` to `slurm` or to a list of piston worker endpoints.
##### Using Router Services
It is possible to be rate limited when too many scripts are executed on sandbox services. For both providers, we offer router scripts that can be launched on a CPU node:
For E2B:
```shell
sbatch slurm/e2b_router.slurm
```
For Morph:
```shell
sbatch slurm/morph_router.slurm
```
Then add the router URL in your training YAML config:
```yaml
# For E2B
e2b_router_url: 1.2.3.4:8000
# For Morph
morph_router_url: 1.2.3.4:8000
```
The port should match the one used when launching the router.
All training jobs can share the same router IP which will ensure parallel executions are properly managed.
#### Competitive Programming problems: IOI & CodeForces
We provide `ioi_code_reward` and `cf_code_reward` reward functions for executing problems from [IOI](https://hf.co/datasets/open-r1/ioi) and [CodeForces](https://huggingface.co/datasets/open-r1/codeforces), respectively. You can use either [piston](https://github.com/engineer-man/piston) or Morph (currently IOI only) as your execution provider.
##### Piston
To use Piston:
1. Get piston workers running, see [slurm/piston/README.md](./slurm/piston/README.md)
2. Set your environment variable `PISTON_ENDPOINTS` to `slurm` or to a list of piston worker endpoints
For IOI:
3. In your configuration, use `ioi_provider: "piston"`
For CodeForces:
3. Download the generated (hard) test cases:
```
# change PATH_TO_SAVE_TESTCASES. Increase --max-workers according to your machine's capacity
huggingface-cli download open-r1/codeforces --repo-type=dataset --include='generated_tests/*.parquet' --max-workers=8 --local-dir PATH_TO_SAVE_TESTCASES
```
4. Save the path in .env:
```
CF_TESTS_FOLDER=PATH_TO_SAVE_TESTCASES
```
##### Morph
Morph is a cloud-based solution that provides sandboxed environments for running code. To use it:
1. Install the Morph client: `pip install morphcloud`
2. Add your Morph API key to the `.env` file: `MORPH_API_KEY="your_key_here"`
3. In your configuration, use `ioi_provider: "morph"`
##### Example recipes
For IOI:
See the [example recipe](./recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml) for how to use the IOI reward function:
See the [example recipe](./recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml) for how to use the reward function:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml
```
#### Data decontamination
Following [s1: Simple test-time scaling](https://arxiv.org/abs/2501.19393) the data can be decontaminated using the script at: [scripts/decontaminate.py](./scripts/decontaminate.py), which decontaminates a dataset using 8-grams and deduplicate the data. Sample run:
For CodeForces:
```shell
python scripts/decontaminate.py \
--dataset "open-r1/verifiable-coding-problems-python" \
--problem_column problem \
--cleanup
```
It will decontaminate against the benchmark datasets, and remove the contaminated samples afterwards. If no argument `--new_dataset_name` is provided, the same dataset will be reused, adding a `_decontaminated`. It runs against the prompt, which for this dataset is the column `problem`, but a different one can be provided.
Arguments for the script:
```shell
usage: decontaminate.py [-h] --dataset DATASET [--split SPLIT] [--ngram_size NGRAM_SIZE] [--problem_column PROBLEM_COLUMN] [--cleanup] [--new_dataset_name NEW_DATASET_NAME]
options:
-h, --help show this help message and exit
--dataset DATASET Name of the dataset to check for contamination.
--split SPLIT Split to check for contamination, defaults to `train`.
--ngram_size NGRAM_SIZE
Size of n-grams to build, defaults to 8.
--problem_column PROBLEM_COLUMN
Name of the column containing the problem (prompt).
--cleanup Whether to remove the contaminated rows before pushing the dataset.
--new_dataset_name NEW_DATASET_NAME
New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.
sbatch --job-name=cf-grpo --nodes=2 slurm/train.slurm --model Qwen2.5-Coder-7B-Instruct --task grpo --config codeforces --accelerator zero3 --dp 8 --tp 1
```
### Launching jobs on a Slurm cluster
@ -257,48 +396,76 @@ options:
If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
```shell
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm {model_name} {task} {config_suffix} {accelerator}
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm --model {model_name} --task {task} --config {config_suffix} --accelerator {accelerator}
```
Here `{model_name}` and `{task}` are defined as above, while `{config_suffix}` refers to the specific config and `{accelerator}` refers to the choice of 🤗 Accelerate config in `recipes/accelerate_configs`. If you wish to override the default config parameters, you can provide them by appending a space-separated string like `'--arg1=value1 --arg2=value2'`. Here's a concrete example to run SFT on 1 node of 8 GPUs:
```shell
# Launch on Slurm and override default hyperparameters
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm Qwen2.5-1.5B-Instruct sft demo zero3 '--per_device_train_batch_size=1 --num_train_epochs=5'
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm --model OpenR1-Distill-7B --task sft --config distill --accelerator zero3
```
You can scale the number of nodes by increasing the `--nodes` flag.
For GRPO, we use 1 node for the vLLM server and N nodes for training. For example, to run GRPO on 1+1 nodes with mixed data and tensor parallelism, run:
```shell
sbatch --job-name=open_r1 --nodes=2 slurm/train.slurm --model Qwen2.5-1.5B-Instruct --task grpo --config demo --accelerator zero2 --dp 4 --tp 2
```
> [!NOTE]
> The configuration in `slurm/train.slurm` is optimised for the Hugging Face Compute Cluster and may require tweaking to be adapted to your own compute nodes.
### Customising the dataset mixture
To combine multiple datasets as a single training mixture, you can specify the `dataset_mixture` parameter in the YAML config file. Here's a template for how to do this:
```yaml
dataset_mixture:
datasets: # List of datasets to include in the mixture
- id: dataset_1 # Hub dataset ID
config: config_name_1 # Name of the dataset config
split: split_1 # Split to use from the dataset
columns: # Columns to keep
- column_1
- column_2
weight: 0.25 # Fraction of dataset to use
- id: dataset_2
config: config_name_2
split: split_2
columns:
- column_1
- column_2
weight: 0.5
seed: 42 # Seed for shuffling the combined dataset
test_split_size: 0.1 # Fraction of mixture to use for a test split
```
## Evaluating models
We use `lighteval` to evaluate models, with custom tasks defined in `src/open_r1/evaluate.py`. For models which fit on a single GPU, run:
We use `lighteval` to evaluate models. For models which fit on a single GPU, run:
```shell
export VLLM_WORKER_MULTIPROC_METHOD=spawn # Required for vLLM
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
# AIME 2024
TASK=aime24
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
--use-chat-template \
--output-dir $OUTPUT_DIR
# MATH-500
TASK=math_500
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
--use-chat-template \
--output-dir $OUTPUT_DIR
# GPQA Diamond
TASK=gpqa:diamond
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
--use-chat-template \
--output-dir $OUTPUT_DIR
@ -308,22 +475,18 @@ lighteval vllm $MODEL_ARGS "extended|lcb:codegeneration|0|0" \
--output-dir $OUTPUT_DIR
```
> [!IMPORTANT]
> You must set `max_model_length=32768` in the `vllm` command to align with the `max_new_tokens` we define per eval. Without this, `lighteval` will throw an error.
To increase throughput across multiple GPUs, use _data parallel_ as follows:
```shell
NUM_GPUS=8
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
TASK=aime24
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
--use-chat-template \
--output-dir $OUTPUT_DIR
--output-dir $OUTPUT_DIR
```
For large models which require sharding across GPUs, use _tensor parallel_ and run:
@ -331,15 +494,14 @@ For large models which require sharding across GPUs, use _tensor parallel_ and r
```shell
NUM_GPUS=8
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
TASK=aime24
OUTPUT_DIR=data/evals/$MODEL
export VLLM_WORKER_MULTIPROC_METHOD=spawn
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $MODEL_ARGS "lighteval|$TASK|0|0" \
--use-chat-template \
--output-dir $OUTPUT_DIR
--output-dir $OUTPUT_DIR
```
You can also launch an evaluation with `make evaluate`, specifying the model, task, and optionally the parallelism technique and number of GPUs.
@ -364,32 +526,40 @@ make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLE
## Reproducing Deepseek's evaluation results
> [!NOTE]
> The DeepSeek-R1 paper uses sampling with 64 responses per query to estimate `pass@1`. Below, we report the results from sampling 1 response per query, which likely explains the small 1-3σ discrepancies between our results and theirs.
The DeepSeek-R1 paper uses sampling with 4-64 responses per query to estimate `pass@1` accuracy, but does not specify the specific number of responses per benchmark. In the tables below, we estimate `pass@1` accuracy with the following number of responses per query:
| Benchmark | Number of responses per query |
|:-------------:|:-----------------------------:|
| AIME 2024 | 64 |
| MATH-500 | 4 |
| GPQA Diamond | 8 |
| LiveCodeBench | 16 |
Note that for benchmarks like AIME24, it is important to sample many responses as there are only 30 problems and this can introduce high variance across repeated runs. The choice of how many responses to sample per prompt likely explains the small differences between our evaluation results and those reported by DeepSeek.
### AIME 2024
We are able to reproduce Deepseek's reported results on the AIME 2024 benchmark within ~1-3 standard deviations:
| Model | AIME 2024 (🤗 LightEval) | AIME 2024 (DeepSeek Reported) |
|:------------------------------|:-----------------------:|:----------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 26.7 | 28.9 |
| DeepSeek-R1-Distill-Qwen-7B | 56.6 | 55.5 |
| DeepSeek-R1-Distill-Qwen-14B | 60.0 | 69.7 |
| DeepSeek-R1-Distill-Qwen-32B | 73.2 | 72.6 |
| DeepSeek-R1-Distill-Llama-8B | 43.3 | 50.4 |
| DeepSeek-R1-Distill-Llama-70B | 73.3 | 70.0 |
|:------------------------------|:------------------------:|:-----------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 30.7 | 28.9 |
| DeepSeek-R1-Distill-Qwen-7B | 50.8 | 55.5 |
| DeepSeek-R1-Distill-Qwen-14B | 65.9 | 69.7 |
| DeepSeek-R1-Distill-Qwen-32B | 69.7 | 72.6 |
| DeepSeek-R1-Distill-Llama-8B | 43.9 | 41.7 |
| DeepSeek-R1-Distill-Llama-70B | 63.0 | 70.0 |
To reproduce these results use the following command:
```shell
NUM_GPUS=1 # Set to 8 for 32B and 70B models
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "custom|aime24|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $MODEL_ARGS "lighteval|aime24|0|0" \
--use-chat-template \
--output-dir $OUTPUT_DIR
```
@ -406,23 +576,23 @@ We are able to reproduce Deepseek's reported results on the MATH-500 benchmark w
| Model | MATH-500 (🤗 LightEval) | MATH-500 (DeepSeek Reported) |
|:------------------------------|:-----------------------:|:----------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 84.6 | 83.9 |
| DeepSeek-R1-Distill-Qwen-7B | 93.0 | 92.8 |
| DeepSeek-R1-Distill-Qwen-14B | 95.0 | 93.9 |
| DeepSeek-R1-Distill-Qwen-32B | 96.6 | 94.3 |
| DeepSeek-R1-Distill-Qwen-1.5B | 83.1 | 83.9 |
| DeepSeek-R1-Distill-Qwen-7B | 94.5 | 92.8 |
| DeepSeek-R1-Distill-Qwen-14B | 94.1 | 93.9 |
| DeepSeek-R1-Distill-Qwen-32B | 95.6 | 94.3 |
| DeepSeek-R1-Distill-Llama-8B | 88.6 | 89.1 |
| DeepSeek-R1-Distill-Llama-70B | 96.4 | 94.5 |
| DeepSeek-R1-Distill-Llama-70B | 95.1 | 94.5 |
To reproduce these results use the following command:
```shell
export VLLM_WORKER_MULTIPROC_METHOD=spawn
NUM_GPUS=1 # Set to 8 for 32B and 70B models
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "custom|math_500|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $MODEL_ARGS "lighteval|math_500|0|0" \
--use-chat-template \
--output-dir $OUTPUT_DIR
```
@ -439,23 +609,23 @@ We are able to reproduce Deepseek's reported results on the GPQA Diamond benchma
| Model | GPQA Diamond (🤗 LightEval) | GPQA Diamond (DeepSeek Reported) |
|:------------------------------|:---------------------------:|:--------------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 34.3 | 33.8 |
| DeepSeek-R1-Distill-Qwen-1.5B | 35.8 | 33.8 |
| DeepSeek-R1-Distill-Qwen-7B | 50.5 | 49.1 |
| DeepSeek-R1-Distill-Qwen-14B | 59.6 | 59.1 |
| DeepSeek-R1-Distill-Qwen-32B | 63.6 | 62.1 |
| DeepSeek-R1-Distill-Llama-8B | 52.0 | 49.0 |
| DeepSeek-R1-Distill-Llama-70B | 67.2 | 65.2 |
| DeepSeek-R1-Distill-Qwen-14B | 61.5 | 59.1 |
| DeepSeek-R1-Distill-Qwen-32B | 63.1 | 62.1 |
| DeepSeek-R1-Distill-Llama-8B | 46.7 | 49.0 |
| DeepSeek-R1-Distill-Llama-70B | 67.4 | 65.2 |
To reproduce these results use the following command:
```shell
export VLLM_WORKER_MULTIPROC_METHOD=spawn
NUM_GPUS=1 # Set to 8 for 32B and 70B models
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "custom|gpqa:diamond|0|0" \
--custom-tasks src/open_r1/evaluate.py \
lighteval vllm $MODEL_ARGS "lighteval|gpqa:diamond|0|0" \
--use-chat-template \
--output-dir $OUTPUT_DIR
```
@ -469,20 +639,20 @@ python scripts/run_benchmarks.py --model-id {model_id} --benchmarks gpqa
We are able to reproduce Deepseek's reported results on the LiveCodeBench code generation benchmark within ~1-3 standard deviations:
| Model | LiveCodeBench (🤗 LightEval) | LiveCodeBench (DeepSeek Reported) |
|:------------------------------|:----------------------------:|:--------------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 16.3 | 16.9 |
| DeepSeek-R1-Distill-Qwen-7B | 36.6 | 37.6 |
| DeepSeek-R1-Distill-Qwen-14B | 51.5 | 53.1 |
| DeepSeek-R1-Distill-Qwen-32B | 56.6 | 57.2 |
| DeepSeek-R1-Distill-Llama-8B | 37.0 | 39.6 |
| DeepSeek-R1-Distill-Llama-70B | 54.5 | 57.5 |
|:------------------------------|:----------------------------:|:---------------------------------:|
| DeepSeek-R1-Distill-Qwen-1.5B | 16.1 | 16.9 |
| DeepSeek-R1-Distill-Qwen-7B | 37.4 | 37.6 |
| DeepSeek-R1-Distill-Qwen-14B | 51.3 | 53.1 |
| DeepSeek-R1-Distill-Qwen-32B | 56.0 | 57.2 |
| DeepSeek-R1-Distill-Llama-8B | 37.4 | 39.6 |
| DeepSeek-R1-Distill-Llama-70B | 55.9 | 57.5 |
To reproduce these results use the following command:
```shell
NUM_GPUS=1 # Set to 8 for 32B and 70B models, or data_parallel_size=8 with the smaller models for speed
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,data_parallel_size=$NUM_GPUS,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
OUTPUT_DIR=data/evals/$MODEL
lighteval vllm $MODEL_ARGS "extended|lcb:codegeneration|0|0" \
@ -581,6 +751,38 @@ sbatch slurm/generate.slurm \
> [!NOTE]
> While the job is running, you can setup an SSH tunnel through the cluster login node to access the Ray dashboard from your computer running `ssh -L 8265:ray_ip_head_node:8265 <login_node>`, then browsing `http://localhost:8265`
### Data decontamination
Following [s1: Simple test-time scaling](https://huggingface.co/papers/2501.19393) the data can be decontaminated using the script at: [scripts/decontaminate.py](./scripts/decontaminate.py), which decontaminates a dataset using 8-grams and deduplicate the data. Sample run:
```shell
python scripts/decontaminate.py \
--dataset "open-r1/verifiable-coding-problems-python" \
--problem_column problem \
--cleanup
```
It will decontaminate against the benchmark datasets, and remove the contaminated samples afterwards. If no argument `--new_dataset_name` is provided, the same dataset will be reused, adding a `_decontaminated`. It runs against the prompt, which for this dataset is the column `problem`, but a different one can be provided.
Arguments for the script:
```shell
usage: decontaminate.py [-h] --dataset DATASET [--split SPLIT] [--ngram_size NGRAM_SIZE] [--problem_column PROBLEM_COLUMN] [--cleanup] [--new_dataset_name NEW_DATASET_NAME]
options:
-h, --help show this help message and exit
--dataset DATASET Name of the dataset to check for contamination.
--split SPLIT Split to check for contamination, defaults to `train`.
--ngram_size NGRAM_SIZE
Size of n-grams to build, defaults to 8.
--problem_column PROBLEM_COLUMN
Name of the column containing the problem (prompt).
--cleanup Whether to remove the contaminated rows before pushing the dataset.
--new_dataset_name NEW_DATASET_NAME
New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.
```
## Contributing
Contributions are welcome. Please refer to https://github.com/huggingface/open-r1/issues/23.
@ -597,7 +799,7 @@ If you find this project is useful in your own work, please consider citing as f
@misc{openr1,
title = {Open R1: A fully open reproduction of DeepSeek-R1},
url = {https://github.com/huggingface/open-r1},
author = {Hugging Face},
author = {{Hugging Face}},
month = {January},
year = {2025}
}

View file

@ -8,6 +8,7 @@ attn_implementation: flash_attention_2
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}"
dataset_name: open-r1/OpenR1-Math-220k
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
@ -53,4 +54,5 @@ save_strategy: "epoch"
save_total_limit: 1
seed: 42
temperature: 0.7
use_liger_kernel: true
warmup_ratio: 0.1

View file

@ -1,42 +0,0 @@
# To start the training, run the following command:
# sbatch -N 4 --job-name=mistral_sft slurm/train.slurm Mistral-Small-24B-Instruct-2501 sft numina zero3
model_name_or_path: mistralai/Mistral-Small-24B-Instruct-2501
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
# dataset_name: yentinglin/s1K-1.1-trl-format
dataset_name: yentinglin/OpenR1-Math-220k-trl-format
preprocessing_num_workers: 8
# SFT trainer config
bf16: true
do_eval: true
eval_strategy: no
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Mistral-Small-24B-Instruct-2501-Open-R1-Distill
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
packing: true
max_length: 32768
max_steps: -1
num_train_epochs: 5
output_dir: data/Mistral-Small-24B-Instruct-2501-Open-R1-Distill
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: epoch
seed: 42
warmup_ratio: 0.1

View file

@ -45,5 +45,5 @@ save_only_model: true # needed to bypass FSDP errors with saving paged optimizer
save_strategy: epoch
save_total_limit: 1
seed: 42
use_liger: false # fails on multi-node
use_liger_kernel: false # fails on multi-node
warmup_ratio: 0.03

View file

@ -42,5 +42,5 @@ report_to:
save_strategy: epoch
save_total_limit: 1
seed: 42
use_liger: true
use_liger_kernel: true
warmup_ratio: 0.03

View file

@ -0,0 +1,48 @@
# Config for 1 node of 8 x H100s (80GB)
# Model arguments
model_name_or_path: open-r1/Qwen2.5-Math-7B-RoPE-300k
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
chat_template: "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Open-R1, a language model trained by Hugging Face to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Open-R1, a language model trained by Hugging Face to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
dataset_name: open-r1/Mixture-of-Thoughts
dataset_config: all
dataset_num_proc: 12
eos_token: <|im_end|>
# SFT trainer config
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Distill-7B
hub_strategy: every_save
learning_rate: 4.0e-05
log_level: info
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
packing: false
max_grad_norm: 0.2
max_length: 32768
max_steps: -1
num_train_epochs: 5
output_dir: data/OpenR1-Distill-7B
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 2
push_to_hub: true
report_to:
- wandb
save_strategy: epoch
save_total_limit: 1
seed: 42
use_liger_kernel: true
warmup_ratio: 0.03

View file

@ -1,48 +0,0 @@
# Model arguments
# You need to download the model and manually change the rope to 300k and max_position_embeddings to 32768
# the config file should match https://huggingface.co/open-r1/OpenR1-Qwen-7B/blob/main/config.json
model_name_or_path: Qwen/Qwen2.5-Math-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
#SFT hyperparam
max_length: 32768
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
# SFT trainer config
max_steps: -1
num_train_epochs: 3
bf16: true
do_eval: false
use_liger_kernel: true
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Qwen-7B-SFT
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
output_dir: data/OpenR1-Qwen-7B-SFT
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -6,6 +6,7 @@ attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config

View file

@ -6,6 +6,7 @@ attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python
dataset_prompt_column: problem_statement
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config

View file

@ -6,6 +6,7 @@ attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/ioi
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config

View file

@ -1,44 +0,0 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
# SFT trainer config
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-Distill
hub_strategy: every_save
learning_rate: 5.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
packing: true
max_length: 16384
max_steps: -1
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-Distill
overwrite_output_dir: true
per_device_eval_batch_size: 16
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
use_liger: true
warmup_ratio: 0.05

View file

@ -1,54 +0,0 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
do_eval: false
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-7B-Instruct-GRPO
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 16
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View file

@ -0,0 +1,80 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-Coder-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: open-r1/codeforces
dataset_prompt_column: prompt
dataset_config: verifiable-prompts
dataset_test_split: test
dataset_train_split: train
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# GRPO trainer config
callbacks:
- push_to_hub_revision
benchmarks:
- lcb_v4
beta: 0.0
loss_type: dr_grpo
scale_rewards: false
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
gradient_accumulation_steps: 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: open-r1/Qwen2.5-Coder-7B-Instruct-Codeforces-GRPO
hub_model_revision: v01.00
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 2000
max_completion_length: 8192
max_steps: -1
num_generations: 16
# aiming for 1k optimization steps
# total_samples_per_batch = num_gpus * grad_accumulation_steps * per_device_batch_size = 8 * 32 * 4 = 1024
# unique_prompts_per_batch = total_samples_per_batch / num_generations = 1024 / 16 = 64
# #dataset ~= 16k (8k * 2, for python and cpp)
# global_steps_per_epoch = #dataset / unique_prompts_per_batch = 16k / 64 ~= 250
# epochs_for_1k_steps = 1000/250 = 4 epochs
num_train_epochs: 4
output_dir: data/Qwen2.5-Coder-7B-Instruct-Codeforces-GRPO_v01.00
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- cf_code
- code_format
reward_weights:
- 1.0
- 0.1
save_strategy: "steps"
save_steps: 0.05
save_total_limit: 1
seed: 42
temperature: 0.7
wandb_entity: huggingface
wandb_project: open-r1
warmup_ratio: 0.1
mask_truncated_completions: true
# for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating
# otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions
code_eval_test_batch_size: -1
code_eval_scoring_mode: weighted_sum

View file

@ -1,51 +0,0 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-Math-7B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: DigitalLearningGmbH/MATH-lighteval
dataset_config: default
system_prompt: "You are a helpful AI Assistant, designed to provided well-reasoned and detailed responses. You FIRST think about the reasoning process as an internal monologue and then provide the user with the answer. The reasoning process MUST BE enclosed within <think> and </think> tags."
# GRPO trainer config
bf16: true
use_vllm: true
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen-2.5-7B-Simple-RL
hub_strategy: every_save
learning_rate: 3.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_generations: 7
num_train_epochs: 1
output_dir: data/Qwen-2.5-7B-Simple-RL
overwrite_output_dir: true
per_device_eval_batch_size: 16
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1

View file

@ -1,15 +1,23 @@
# Post-training recipes
## OpenR1 Distill 7B
To train the OpenR1 Distill 7B model, run:
```
sbatch --nodes=1 slurm/train.slurm --model OpenR1-Distill-7B --task sft --config distill --accelerator zero3
```
## OlympicCoder
To train the OlympicCoder models, run:
```
# 7B
sbatch --nodes=1 slurm/train.slurm OlympicCoder-7B sft v00.00 zero3
sbatch --nodes=1 slurm/train.slurm --model OlympicCoder-7B --task sft --config v00.00 --accelerator zero3
# 32B
sbatch --nodes=16 slurm/train.slurm OlympicCoder-32B sft v00.00 fsdp
sbatch --nodes=16 slurm/train.slurm --model OlympicCoder-32B --task sft --config v00.00 --accelerator fsdp
```
Note that we found it necessary to switch to FSDP1 and paged AdamW 8-bit for the 32B model in order to fit the largest possible context size.

View file

@ -1,46 +0,0 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
#SFT hyperparam
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
# SFT trainer config
max_steps: -1
num_train_epochs: 3
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Qwen-7B-SFT
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
output_dir: data/OpenR1-Qwen-7B-SFT
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -1,46 +0,0 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
#SFT hyperparam
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
# SFT trainer config
max_steps: -1
num_train_epochs: 3
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Qwen-7B-SFT
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
output_dir: data/OpenR1-Qwen-7B-SFT
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -0,0 +1,28 @@
# Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}"
dataset_name: open-r1/OpenR1-Math-220k
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# Generation arguments
max_completion_length: 2048
num_generations: 8
temperature: 0.7
top_p: 0.95
# Reward func arguments
reward_funcs:
- accuracy
reward_weights:
- 1.0
# Filtering arguments. Samples with a pass rate outside the interval `pass_rate_min < x < pass_rate_max` will be filtered.
pass_rate_min: 0.2
pass_rate_max: 0.8

View file

@ -0,0 +1,28 @@
# Model arguments
model_name_or_path: open-r1/R1-Distill-Qwen-Math-7B
model_revision: v03.00-step-000008190
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
dataset_name: open-r1/DAPO-Math-17k-Processed
dataset_config: all
dataset_split: train
# Generation arguments
max_completion_length: 32000
num_generations: 8
temperature: 1.0
# Reward func arguments
reward_funcs:
- accuracy
reward_weights:
- 1.0
# Filtering arguments. Samples with mean reward outside of low / high will be filtered
pass_rate_min: 0.1
pass_rate_max: 0.6
output_dataset_name: open-r1/DAPO-Math-17k-Processed-R1-Distill-Qwen-Math-7B-v03.00-step-000008190-filter

View file

@ -0,0 +1,26 @@
# Model arguments
model_name_or_path: open-r1/R1-Distill-Qwen-Math-7B-Merges
model_revision: v00.00-step-000003660_v01.00-step-000002600_weights-0.50-0.50
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
dataset_name: open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled
dataset_prompt_column: problem
# Generation arguments
max_completion_length: 16000
num_generations: 8
temperature: 0.7
# Reward func arguments
reward_funcs:
- binary_code
reward_weights:
- 1.0
e2b_router_url: ip-10-53-85-92:8000
# Filtering arguments. Samples with mean reward outside of low / high will be filtered
pass_rate_min: 0.1
pass_rate_max: 0.6

85
scripts/benchmark_e2b.py Normal file
View file

@ -0,0 +1,85 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Benchmark script for the code_reward function with E2B.
This script measures the performance of the code_reward function with varying numbers
of samples and parallelization levels.
Each sample is a CodeForces problem with a gold standard solution that is executed against a set of public test cases.
"""
from datasets import load_dataset
import time
from tqdm.auto import tqdm
from dotenv import load_dotenv
load_dotenv()
from open_r1.rewards import code_reward
def benchmark_code_reward(example):
start_time = time.time()
test_completions = [[{"content": example["gold_standard_solution"]}]]
reward_kwargs = {"verification_info": [example["verification_info"]]}
rewards = code_reward(test_completions, **reward_kwargs)
end_time = time.time()
example["test_reward"] = rewards[0]
example["reward_time"] = end_time - start_time
return example
if __name__ == "__main__":
parallel_dict = {
16:[1,4,16],
64:[4,16, 64],
256:[16, 64, 96], # cap at 96 as PRO account is limited to 100
}
# Store results for table formatting
results = []
for num_samples in tqdm([16, 64,256], desc="Benchmarking samples"):
for num_parallel in parallel_dict[num_samples]:
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated")
code_dataset = code_dataset["train"].shuffle(seed=42).select(range(num_samples))
test_completions = [[{"content": example["gold_standard_solution"]}] for example in code_dataset]
reward_kwargs = {"verification_info": [example["verification_info"] for example in code_dataset]}
start_time = time.time()
rewards = code_reward(test_completions, num_parallel=num_parallel, **reward_kwargs)
execution_time = time.time() - start_time
# Calculate some statistics about rewards
mean_reward = sum(rewards) / len(rewards)
min_reward = min(rewards)
max_reward = max(rewards)
# Store results
results.append({
"num_samples": num_samples,
"num_parallel": num_parallel,
"execution_time": execution_time,
"mean_reward": mean_reward,
"min_reward": min_reward,
"max_reward": max_reward
})
print("\n## Benchmark Results\n")
print("| Sample Size | Parallelization | Execution Time (s) | Mean Reward | Min Reward | Max Reward |")
print("|:-----------:|:---------------:|------------------:|:-----------:|:-----------:|:-----------:|")
for result in results:
print(f"| {result['num_samples']:^11} | {result['num_parallel']:^15} | {result['execution_time']:17.2f} | {result['mean_reward']:^11.4f} | {result['min_reward']:^11.4f} | {result['max_reward']:^11.4f} |")

View file

@ -15,7 +15,7 @@
# limitations under the License.
"""
This script is used to decontaminate a dataset by checking for n-gram overlap with other datasets.
It uses the same approach presented in https://arxiv.org/abs/2501.19393,
It uses the same approach presented in https://huggingface.co/papers/2501.19393,
as found in: https://github.com/simplescaling/s1/blob/main/data/decontaminate_util.py
Usage:

161
scripts/e2b_router.py Normal file
View file

@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
from fastapi import FastAPI
from pydantic import BaseModel, ConfigDict
from typing import Optional
from fastapi import FastAPI, Request
import argparse
import asyncio
from fastapi import FastAPI
import uvicorn
from e2b_code_interpreter.models import Execution
from dotenv import load_dotenv
from e2b_code_interpreter import AsyncSandbox
load_dotenv()
class BatchRequest(BaseModel):
"""
BatchRequest is a data model representing a batch processing request.
Attributes:
scripts (list[str]): A list of script names or paths to be executed.
languages (list[str]): The programming languages for each script in the list.
timeout (int): The maximum allowed execution time for each script in seconds.
request_timeout (int): The maximum allowed time for the entire batch request in seconds.
"""
scripts: list[str]
languages: list[str]
timeout: int
request_timeout: int
class ScriptResult(BaseModel):
"""
ScriptResult is a Pydantic model that represents the result of a script execution.
Attributes:
execution (Optional[Execution]): An optional instance of the `Execution` class
that contains details about the script's execution, such as status, output,
or any other relevant metadata.
exception_str (Optional[str]): An optional string that captures the exception
message or details if an error occurred during the script's execution.
model_config (ConfigDict): A configuration dictionary that allows arbitrary
types to be used within the Pydantic model. This is necessary to support
custom types like `Execution` within the model.
"""
execution: Optional[Execution]
exception_str: Optional[str]
# required to allow arbitrary types in pydantic models such as Execution
model_config = ConfigDict(arbitrary_types_allowed=True)
def create_app(args):
"""
Creates and configures a FastAPI application instance.
Args:
args: An object containing configuration parameters for the application.
- num_sandboxes (int): The maximum number of concurrent sandboxes allowed.
Returns:
FastAPI: A configured FastAPI application instance.
The application includes the following endpoints:
1. GET /health:
- Returns the health status of the application.
- Response: {"status": "ok"}
2. POST /execute_batch:
- Executes a batch of scripts in an isolated sandbox environment.
- Request Body: BatchRequest object containing:
- languages (list[str]): The programming languages of the scripts (python or javascript).
- timeout (int): The maximum execution time for each script.
- request_timeout (int): The timeout for the request itself.
- scripts (List[str]): A list of scripts to execute.
- Response: A list of ScriptResult objects for each script, containing:
- execution: The result of the script execution.
- exception_str: Any exception encountered during execution.
Notes:
- A semaphore is used to limit the number of concurrent sandboxes.
- Each script execution is wrapped in a timeout to prevent hanging.
- Sandboxes are cleaned up after execution, even in case of errors.
"""
app = FastAPI()
# Instantiate semaphore and attach it to app state
app.state.sandbox_semaphore = asyncio.Semaphore(args.max_num_sandboxes)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/execute_batch")
async def execute_batch(batch: BatchRequest, request: Request):
semaphore = request.app.state.sandbox_semaphore
languages = batch.languages
timeout = batch.timeout
request_timeout = batch.request_timeout
asyncio_timeout = batch.timeout + 1
async def run_script(script: str, language: str) -> ScriptResult:
async with semaphore:
try:
sandbox = await AsyncSandbox.create(
timeout=timeout,
request_timeout=request_timeout,
)
execution = await asyncio.wait_for(
sandbox.run_code(script, language=language),
timeout=asyncio_timeout,
)
return ScriptResult(execution=execution, exception_str=None)
except Exception as e:
return ScriptResult(execution=None, exception_str=str(e))
finally:
try:
await sandbox.kill()
except Exception:
pass
tasks = [run_script(script, lang) for script, lang in zip(batch.scripts, batch.languages)]
return await asyncio.gather(*tasks)
return app
def parse_args():
"""
Parse command-line arguments for the e2b_router script.
Arguments:
--host (str): The hostname or IP address to bind the server to. Defaults to "0.0.0.0" (binds to all interfaces).
--port (int): The port number on which the server will listen. Defaults to 8000.
--max_num_sandboxes (int): The maximum number of sandboxes that can be created or managed simultaneously. Defaults to 20.
Returns:
argparse.Namespace: Parsed command-line arguments as an object.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--max_num_sandboxes", type=int, default=20)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port)

173
scripts/morph_router.py Normal file
View file

@ -0,0 +1,173 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
from fastapi import FastAPI
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from fastapi import FastAPI, Request
import uvicorn
from dotenv import load_dotenv
import os
load_dotenv()
class BatchRequest(BaseModel):
"""
BatchRequest is a data model representing a batch processing request.
Attributes:
scripts (list[str]): A list of script names or paths to be executed.
languages (List[str]): The programming languages for each script in the list.
timeout (int): The maximum allowed execution time for each script in seconds.
request_timeout (int): The maximum allowed time for the entire batch request in seconds.
"""
scripts: List[str]
languages: List[str]
timeout: int
request_timeout: int
class ScriptResult(BaseModel):
"""
ScriptResult is a Pydantic model that represents the result of a script execution.
Attributes:
text (Optional[str]): The output text from the script execution.
exception_str (Optional[str]): An optional string that captures the exception
message or details if an error occurred during the script's execution.
model_config (ConfigDict): A configuration dictionary that allows arbitrary
types to be used within the Pydantic model.
"""
text: Optional[str]
exception_str: Optional[str]
model_config = ConfigDict(arbitrary_types_allowed=True)
def create_app(args):
"""
Creates and configures a FastAPI application instance for the MorphCloud router.
Args:
args: An object containing configuration parameters for the application.
- max_num_sandboxes (int): The maximum number of concurrent sandboxes allowed.
- api_key (str): The MorphCloud API key to use.
Returns:
FastAPI: A configured FastAPI application instance.
"""
app = FastAPI()
from morphcloud.api import MorphCloudClient
from morphcloud.sandbox import Sandbox
app.state.client = MorphCloudClient(api_key=args.api_key)
app.state.Sandbox = Sandbox
app.state.sandbox_semaphore = asyncio.Semaphore(args.max_num_sandboxes)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/execute_batch")
async def execute_batch(batch: BatchRequest, request: Request):
semaphore = request.app.state.sandbox_semaphore
client = request.app.state.client
Sandbox = request.app.state.Sandbox
languages = batch.languages
timeout = batch.timeout
request_timeout = batch.request_timeout
asyncio_timeout = batch.timeout + 1
async def run_script(script: str, language: str) -> ScriptResult:
sandbox = None
sandbox_id = "unknown"
async with semaphore:
try:
sandbox = await asyncio.to_thread(
Sandbox.new,
client=client,
ttl_seconds=timeout
)
sandbox_id = getattr(sandbox, 'id', None) or getattr(sandbox._instance, 'id', 'unknown')
execution = await asyncio.wait_for(
asyncio.to_thread(
sandbox.run_code,
script,
language=language,
timeout=timeout * 1000
),
timeout=asyncio_timeout,
)
if hasattr(execution, 'text') and execution.text:
return ScriptResult(text=execution.text, exception_str=None)
elif hasattr(execution, 'stdout') and execution.stdout:
return ScriptResult(text=execution.stdout, exception_str=None)
else:
return ScriptResult(text="", exception_str="No output from execution")
except Exception as e:
return ScriptResult(text=None, exception_str=str(e))
finally:
if sandbox:
try:
await asyncio.to_thread(sandbox.close)
await asyncio.to_thread(sandbox.shutdown)
except Exception:
pass
tasks = [run_script(script, lang) for script, lang in zip(batch.scripts, batch.languages)]
return await asyncio.gather(*tasks)
return app
def parse_args():
"""
Parse command-line arguments for the morph_router script.
Arguments:
--host (str): The hostname or IP address to bind the server to. Defaults to "0.0.0.0".
--port (int): The port number on which the server will listen. Defaults to 8001.
--max_num_sandboxes (int): The maximum number of sandboxes that can be created simultaneously. Defaults to 20.
--api_key (str): The MorphCloud API key. If not provided, it will be read from the MORPH_API_KEY environment variable.
Returns:
argparse.Namespace: Parsed command-line arguments as an object.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--max_num_sandboxes", type=int, default=20)
parser.add_argument("--api_key", default=os.getenv("MORPH_API_KEY"))
args = parser.parse_args()
if not args.api_key:
raise ValueError("MorphCloud API key not provided. Please set MORPH_API_KEY environment variable or use --api_key.")
return args
if __name__ == "__main__":
args = parse_args()
app = create_app(args)
print(f"Starting MorphCloud Router on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)

View file

@ -0,0 +1,36 @@
# Pass rate filtering
We provide support to filter datasets by generating and computing pass rate on veriable tasks
See `scripts/pass_rate_filtering/compute_pass_rate.py` and `scripts/pass_rate_filtering/launch_filtering.sh` (hardcoded for DAPO at the moment)
By default the script chunks the dataset, merge can be run using the following snippet (example for DAPO) :
from datasets import load_dataset, concatenate_datasets
name = "open-r1/DAPO-Math-17k-Processed-R1-Distill-Qwen-Math-7B-Merges-v00.02-v01.02-0.3-0.7-filter"
```python
gen_datasets = []
filt_datasets = []
for start in range(0,17400,200):
end = start + 200
if start == 17200:
end = 17398
gen_config_name = f"gen-{start}-{end}"
gen_dataset = load_dataset(name, gen_config_name, revision="gen", split="train")
gen_datasets.append(gen_dataset)
filt_config_name = f"filt-0.1-0.6-{start}-{end}"
filt_dataset = load_dataset(name, filt_config_name, revision="pass_rate", split="train")
filt_datasets.append(filt_dataset)
gen_dataset = concatenate_datasets(gen_datasets)
gen_dataset.push_to_hub(name, config_name="gen", split="train")
print(gen_dataset)
filt_dataset = concatenate_datasets(filt_datasets)
filt_dataset.push_to_hub(name, config_name="default", split="train")
print(filt_dataset)
```

View file

@ -0,0 +1,205 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# example usage python scripts/filter_dataset.py --config recipes/dataset_filtering/config_demo.yaml
import logging
from dataclasses import dataclass
from git import Optional
import torch
import sys
import datasets
import transformers
from datasets import load_dataset
from transformers import set_seed
from open_r1.configs import GRPOConfig, GRPOScriptArguments
from open_r1.rewards import get_reward_funcs
from open_r1.utils import get_tokenizer
from trl import ModelConfig, TrlParser
from trl.data_utils import apply_chat_template
from vllm import LLM, SamplingParams
logger = logging.getLogger(__name__)
@dataclass
class PassRateScriptArguments(GRPOScriptArguments):
# we can be lazy and just use the same script args as GRPO
output_dataset_name: Optional[str] = None
pass_rate_min: float = 0.1
pass_rate_max: float = 0.9
dataset_start_index: Optional[int] = None
dataset_end_index: Optional[int] = None
dataset_split: str = "train"
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_split)
if script_args.dataset_start_index is not None and script_args.dataset_end_index is not None:
dataset = dataset.select(range(script_args.dataset_start_index, script_args.dataset_end_index))
# Get reward functions from the registry
reward_funcs = get_reward_funcs(script_args)
# Format into conversation
def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column):
example["prompt_backup"] = example[prompt_column]
prompt = []
if training_args.system_prompt is not None:
prompt.append({"role": "system", "content": training_args.system_prompt})
if prompt_column not in example:
raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")
prompt.append({"role": "user", "content": example[prompt_column]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
tokenizer = get_tokenizer(model_args, training_args)
if "messages" in dataset.column_names:
dataset = dataset.remove_columns("messages")
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
llm = LLM(
model=model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
sampling_params=SamplingParams(
temperature=training_args.temperature,
top_p=training_args.top_p,
top_k=training_args.top_k,
n=training_args.num_generations,
max_tokens=training_args.max_completion_length,
)
def batch_score(examples):
prompts = examples["prompt"]
outputs = llm.generate(
prompts,
sampling_params=sampling_params,
use_tqdm=False,
)
repeated_prompts = []
reward_completions = []
grouped_completions = []
for output in outputs:
prompt = output.prompt
group = []
for completion in output.outputs:
text = completion.text
group.append(text)
message = [{"role": "assistant", "content": text}]
repeated_prompts.append(prompt)
reward_completions.append(message)
grouped_completions.append(group)
def repeat_each_element_k_times(list_to_repeat: list, k: int) -> list:
return [element for item in list_to_repeat for element in [item] * k]
rewards_per_func = torch.zeros(len(repeated_prompts), len(reward_funcs))
for i, reward_func in enumerate(reward_funcs):
keys = [key for key in examples.data.keys() if key not in ["prompt", "completion"]]
reward_kwargs = {key: repeat_each_element_k_times(examples[key], training_args.num_generations) for key in keys}
output_reward_func = reward_func(prompts=repeated_prompts, completions=reward_completions, **reward_kwargs)
# Convert None values to NaN
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32)
reshaped_rewards = rewards_per_func.view(-1, training_args.num_generations)
examples["pass_rate_generations"] = grouped_completions
examples["pass_rate_rewards"] = reshaped_rewards.tolist()
return examples
dataset = dataset.map(batch_score, batched=True, batch_size=64)
# we need to restore the prompt for the final dataset
def restore_prompt(example):
example["prompt"] = example["prompt_backup"]
return example
dataset = dataset.map(restore_prompt)
dataset = dataset.remove_columns("prompt_backup")
if script_args.output_dataset_name is not None:
output_dataset_name = script_args.output_dataset_name
else:
model_name = model_args.model_name_or_path
if "/" in model_name:
model_name = model_name.split("/")[-1]
model_revision = model_args.model_revision
output_dataset_name = f"{script_args.dataset_name}-{model_name}-{model_revision}-gen"
config_name="default"
filtered_config_name = f"filt-{script_args.pass_rate_min}-{script_args.pass_rate_max}"
if script_args.dataset_start_index is not None and script_args.dataset_end_index is not None:
config_name = f"gen-{script_args.dataset_start_index}-{script_args.dataset_end_index}"
filtered_config_name = f"{filtered_config_name}-{script_args.dataset_start_index}-{script_args.dataset_end_index}"
dataset.push_to_hub(output_dataset_name, config_name=config_name, revision="gen")
def filter_func(example):
rewards = example["pass_rate_rewards"]
# get the mean of the rewards that are not None
mean_reward = torch.nanmean(torch.tensor(rewards, dtype=torch.float32))
return script_args.pass_rate_min < mean_reward < script_args.pass_rate_max
logger.info(f"Filtering dataset with low reward threshold {script_args.pass_rate_min} and high reward threshold {script_args.pass_rate_max}")
logger.info(f"Dataset size before filtering: {dataset}")
dataset = dataset.filter(filter_func)
logger.info(f"Dataset size after filtering: {dataset}")
dataset.push_to_hub(output_dataset_name, config_name=filtered_config_name, revision="pass_rate")
if __name__ == "__main__":
parser = TrlParser((PassRateScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

View file

@ -0,0 +1,15 @@
# a bash foor loop from 0 to 17,400 in chunks of 200
for i in {0..17000..200}
do
START=$i
END=$((i + 200))
echo "Processing chunk from $START to $END"
# Submit the job to SLURM
sbatch slurm/compute_pass_rate.slurm recipes/dataset_filtering/filter_dapo.yaml $START $END
done
sbatch slurm/compute_pass_rate.slurm recipes/dataset_filtering/filter_dapo.yaml 17200 17398

View file

@ -44,19 +44,21 @@ _deps = [
"accelerate==1.4.0",
"bitsandbytes>=0.43.0",
"datasets>=3.2.0",
"deepspeed==0.15.4",
"deepspeed==0.16.8",
"distilabel[vllm,ray,openai]>=1.5.2",
"e2b-code-interpreter>=1.0.5",
"einops>=0.8.0",
"flake8>=6.0.0",
"hf_transfer>=0.1.4",
"huggingface-hub[cli]>=0.19.2,<1.0",
"huggingface-hub[cli,hf_xet]>=0.30.2,<1.0",
"isort>=5.12.0",
"jieba", # Needed for Chinese language support
"langdetect", # Needed for LightEval's extended tasks
"latex2sympy2_extended>=1.0.6",
"liger_kernel==0.5.3",
"lighteval @ git+https://github.com/huggingface/lighteval.git@ed084813e0bd12d82a06d9f913291fdbee774905",
"liger-kernel>=0.5.10",
"lighteval @ git+https://github.com/huggingface/lighteval.git@d3da6b9bbf38104c8b5e1acc86f83541f9a502d1", # Critical bug fix for tokenizer revisions: https://github.com/huggingface/lighteval/pull/721
"math-verify==0.5.2", # Used for math verification in grpo
"morphcloud==0.1.67",
"packaging>=23.0",
"parameterized>=0.9.0",
"peft>=0.14.0",
@ -65,11 +67,13 @@ _deps = [
"ruff>=0.9.0",
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
"torch==2.5.1",
"transformers==4.50.0",
"trl==0.16.0",
"vllm==0.7.2",
"torch==2.6.0",
"transformers==4.52.3",
"trl[vllm]==0.18.0",
"wandb>=0.19.1",
"async-lru>=2.0.5",
"aiofiles>=24.1.0",
"pandas>=2.2.3",
]
# this is a lookup table with items like:
@ -86,12 +90,12 @@ def deps_list(*pkgs):
extras = {}
extras["tests"] = deps_list("pytest", "parameterized", "math-verify")
extras["tests"] = deps_list("pytest", "parameterized", "math-verify", "jieba")
extras["torch"] = deps_list("torch")
extras["quality"] = deps_list("ruff", "isort", "flake8")
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv", "morphcloud", "jieba", "pandas", "aiofiles")
extras["eval"] = deps_list("lighteval", "math-verify")
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["code"]
# core dependencies shared across the whole project - keep this to a bare minimum :)
install_requires = [
@ -105,13 +109,14 @@ install_requires = [
deps["langdetect"],
deps["latex2sympy2_extended"],
deps["math-verify"],
deps["liger_kernel"],
deps["liger-kernel"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["safetensors"],
deps["sentencepiece"],
deps["transformers"],
deps["trl"],
deps["wandb"],
deps["async-lru"],
]
setup(

View file

@ -0,0 +1,20 @@
#!/bin/bash
#SBATCH --job-name=open-r1-compute-pass-rate
#SBATCH --partition=hopper-prod
#SBATCH --qos=normal
#SBATCH --nodes=1
#SBATCH --gpus-per-node=1
#SBATCH --output=./logs/%x-%j.out
#SBATCH --error=./logs/%x-%j.err
#SBATCH --time=01-00:00:00
#SBATCH --requeue
# example usage: sbatch slurm/dataset_filter.slurm recipes/dataset_filtering/filter_dapo.yaml 0 500
set -x -e
source ~/.bashrc
source openr1/bin/activate
python scripts/pass_rate_filtering/compute_pass_rate.py --config $1 --dataset_start_index $2 --dataset_end_index $3

17
slurm/e2b_router.slurm Normal file
View file

@ -0,0 +1,17 @@
#!/bin/bash
#SBATCH --partition=hopper-cpu
#SBATCH --mem=16g
#SBATCH --cpus-per-task=16
#SBATCH --output=/fsx/open-r1/logs/e2b_router/%x-%j.out
#SBATCH --error=/fsx/open-r1/logs/e2b_router/%x-%j.err
#SBATCH --requeue
#SBATCH --time=7-00:00:00
echo "Starting job"
set -x -e
source ~/.bashrc
source openr1/bin/activate
srun python scripts/e2b_router.py

View file

@ -3,13 +3,22 @@
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod
#SBATCH --output=./logs/%x-%j.out
#SBATCH --err=./logs/%x-%j.err
#SBATCH --error=./logs/%x-%j.err
#SBATCH --requeue
#SBATCH --time=1-00:00:00
# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
module load cuda/12.4
# Refresh Weka on h4 cache
echo "Refreshing Weka filesystem..."
find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
# Needed for vLLM
export VLLM_WORKER_MULTIPROC_METHOD=spawn
set -x -e
source ~/.bashrc
@ -25,14 +34,11 @@ MODEL_REVISION=$4
# $7 is reserved for system_prompt, see line 51
NUM_GPUS=$(nvidia-smi -L | wc -l)
# Set Whether to use tensor parallelism or data parallelism
# Use TP to shard model across GPUs
if [ "$TENSOR_PARALLEL" = "True" ]; then
# use TP to shard model across NUM_GPUS
export VLLM_WORKER_MULTIPROC_METHOD=spawn
# FIXME: lighteval now requires us to manually pass the generation params
MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="model_name=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
else
MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
MODEL_ARGS="model_name=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
fi
LM_EVAL_REPO_ID="open-r1/open-r1-eval-leaderboard"
@ -41,27 +47,14 @@ DETAILS_REPO_ID="open-r1/details-$MODEL_NAME"
OUTPUT_DIR="eval_results/$MODEL_ID/$MODEL_REVISION/$TASK_NAME"
# We need this flag since we run this script from training jobs that use DeepSpeed and the env vars get progated which causes errors during evaluation
ACCELERATE_USE_DEEPSPEED=false
# Enable fast downloads
HF_HUB_ENABLE_HF_TRANSFER=1
echo "Running lighteval script ..."
echo "Eval results will be saved to $OUTPUT_DIR"
# Check if "custom" is a substring of TASKS
if [[ $TASKS == *"custom"* ]]; then
echo "Custom task detected. Running custom task evaluation script ..."
lighteval vllm "$MODEL_ARGS" $TASKS \
--custom-tasks "src/open_r1/evaluate.py" \
lighteval vllm "$MODEL_ARGS" $TASKS \
--use-chat-template \
--output-dir $OUTPUT_DIR \
--save-details \
${7:+--system-prompt "$7"}
else
lighteval vllm "$MODEL_ARGS" $TASKS \
--use-chat-template \
--output-dir $OUTPUT_DIR \
--save-details \
${7:+--system-prompt "$7"}
fi
${7:+--system-prompt "$(echo "$7" | base64 --decode)"}
OUTPUT_FILEPATHS=$(find $OUTPUT_DIR/results/ -type f \( -name "*.json" \))
for filepath in $OUTPUT_FILEPATHS; do

View file

@ -6,7 +6,7 @@
#SBATCH --exclusive
#SBATCH --gpus-per-node=8
#SBATCH --output=./logs/%x-%j.out
#SBATCH --err=./logs/%x-%j.err
#SBATCH --error=./logs/%x-%j.err
#SBATCH --time=04-00:00:00
# Parse command line arguments

18
slurm/morph_router.slurm Normal file
View file

@ -0,0 +1,18 @@
#!/bin/bash
#SBATCH --partition=hopper-cpu
#SBATCH --mem=16g
#SBATCH --cpus-per-task=16
#SBATCH --output=/fsx/open-r1/logs/morph_router/%x-%j.out
#SBATCH --err=/fsx/open-r1/logs/morph_router/%x-%j.err
#SBATCH --requeue
#SBATCH --time=7-00:00:00
echo "Starting job"
set -x -e
source ~/.bashrc
source openr1/bin/activate
srun python scripts/morph_router.py --port 8001 --max_num_sandboxes 20

View file

@ -17,10 +17,17 @@ slurm/piston/launch_piston_workers.sh 1
```
2. Assuming it's running on `ip-10-53-86-146:1234`, send the package install request:
For IOI:
```bash
curl -X POST http://ip-10-53-86-146:1234/api/v2/packages -H "Content-Type: application/json" -d '{"language": "cms_ioi", "version": "1.0.0"}'
```
For CodeForces:
```bash
curl -X POST http://ip-10-53-86-146:1234/api/v2/packages -H "Content-Type: application/json" -d '{"language": "codeforces", "version": "1.0.0"}'
```
3. You can now launch more workers and due to the shared mounted packages directory, they should already have the package installed.
To have the main script find the workers automatically, you can export the following environment variable:
@ -32,6 +39,7 @@ Alternatively your can add `PISTON_ENDPOINTS=slurm` to your .env file.
You can also change `PISTON_MAX_REQUESTS_PER_ENDPOINT`, which tries to limit how many simultaneous requests each worker will handle (1 by default). Keep in mind that this is a local limit and in distributed setups, as there is no global limit, workers might sometimes be overwhelmed when some processes hit the same worker.
If you would like to adapt the code to run without piston, please see the [ioi repo](https://github.com/huggingface/ioi).
For CodeForces, you should implement the [`run`](https://github.com/guipenedo/piston/blob/master/packages/codeforces/1.0.0/run) and [`compile`](https://github.com/guipenedo/piston/blob/master/packages/codeforces/1.0.0/compile) scripts.
# Piston workers (local docker)
This will launch a single worker in a docker container. Consider launching multiple workers for better scalability. Replace 2000 with the port you want to use.
@ -57,10 +65,16 @@ docker run -d \
```
Install the package:
For IOI:
```bash
curl -X POST http://localhost:2000/api/v2/packages -H "Content-Type: application/json" -d '{"language": "cms_ioi", "version": "1.0.0"}'
```
For CodeForces:
```bash
curl -X POST http://localhost:2000/api/v2/packages -H "Content-Type: application/json" -d '{"language": "codeforces", "version": "1.0.0"}'
```
Remember to set `PISTON_ENDPOINTS`:
```bash
export PISTON_ENDPOINTS=http://localhost:2000/api/v2,http://localhost:2001/api/v2,http://localhost:2002/api/v2

View file

@ -1,12 +1,27 @@
#!/bin/bash
#SBATCH --job-name=open-r1-sft
#SBATCH --job-name=open_r1
#SBATCH --ntasks-per-node=1
#SBATCH --exclusive
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod # Adjust this for your cluster
#SBATCH --output=./logs/%x-%j.out
#SBATCH --err=./logs/%x-%j.err
#SBATCH --error=./logs/%x-%j.err
#SBATCH --requeue
#SBATCH --time=3-00:00:00
if [[ "$*" == *"--help"* ]]; then
echo "Usage: sbatch slurm/train.slurm [options]"
echo "Options:"
echo " --model MODEL Model name"
echo " --task TASK Task name (e.g. sft, grpo)"
echo " --config SUFFIX Configuration suffix (e.g. demo, v00.00)"
echo " --accelerator CONFIG Accelerator configuration name (e.g. zero3)"
echo " --dp N Data parallelism for vLLM server (default: 1)"
echo " --tp N Tensor parallelism for vLLM server (default: 1)"
echo " --args \"ARGS\" Optional arguments to pass to the training script"
exit 0
fi
# Specific configuration optimized for the Hugging Face Compute Cluster
module load cuda/12.4
@ -14,15 +29,85 @@ set -x -e
source ~/.bashrc
source openr1/bin/activate
START_TIME=$(date +%s)
echo "START TIME: $(date)"
MODEL=$1
TASK=$2
CONFIG_SUFFIX=$3
ACCELERATOR=$4
OPTIONAL_ARGS=$5
# Refresh Weka on h4 cache
echo "Refreshing Weka filesystem..."
find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
# Default values
MODEL=""
TASK=""
CONFIG_SUFFIX=""
ACCELERATOR=""
DP=1
TP=1
OPTIONAL_ARGS=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
--task)
TASK="$2"
shift 2
;;
--config)
CONFIG_SUFFIX="$2"
shift 2
;;
--accelerator)
ACCELERATOR="$2"
shift 2
;;
--dp)
DP="$2"
shift 2
;;
--tp)
TP="$2"
shift 2
;;
--args)
OPTIONAL_ARGS="$2"
shift 2
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Validate required arguments
if [[ -z "$MODEL" || -z "$TASK" || -z "$CONFIG_SUFFIX" || -z "$ACCELERATOR" ]]; then
echo "Error: Missing required arguments"
echo "Run with --help for usage information"
exit 1
fi
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
# Split the string into individual arguments
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
for arg in "${ARGS[@]}"; do
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
# Extract the value after the equals sign
GRAD_ACC_STEPS="${arg#*=}"
break # Exit the loop once we find the desired argument
fi
done
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
MODEL=$(grep 'model_name_or_path:' $CONFIG_FILE | awk '{print $2}')
REVISION=$(grep 'model_revision:' $CONFIG_FILE | head -n 1 | awk '{print $2}')
@ -43,10 +128,9 @@ fi
if [[ "$USE_VLLM" == "true" ]]; then
TRAIN_NODES=("${NODELIST[@]:0:$((NUM_NODES - 1))}")
VLLM_NODE=${NODELIST[-1]} # Last node
TP=$(python scripts/get_tensor_parallel_size.py --model_name $MODEL --revision $REVISION --default_tp $GPUS_PER_NODE)
WORLD_SIZE=$((WORLD_SIZE - GPUS_PER_NODE))
NUM_NODES=$((NUM_NODES - 1))
srun --nodes=1 --ntasks=1 --nodelist=$VLLM_NODE trl vllm-serve --model $MODEL --revision $REVISION --tensor_parallel_size $TP &
srun --nodes=1 --ntasks=1 --nodelist=$VLLM_NODE trl vllm-serve --model $MODEL --revision $REVISION --tensor_parallel_size $TP --data_parallel_size $DP &
OPTIONAL_ARGS="$OPTIONAL_ARGS --vllm_server_host=$VLLM_NODE"
fi
@ -63,7 +147,7 @@ export CMD=" \
src/open_r1/$TASK.py --config $CONFIG_FILE $OPTIONAL_ARGS
"
export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
export LAUNCHER="ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
--gradient_accumulation_steps $GRAD_ACC_STEPS \
--num_machines $NUM_NODES \
@ -73,19 +157,26 @@ export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORM
--machine_rank $SLURM_PROCID \
--rdzv_backend=c10d \
--max_restarts 1 \
--role \$(hostname -s): \
--tee 3 \
"
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
NODELIST=$(IFS=,; echo "${TRAIN_NODES[*]}")
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
--nodes=$NUM_NODES \
--ntasks=$NUM_NODES \
--nodelist=$TRAIN_NODES
--nodelist=$NODELIST
"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
srun $SRUN_ARGS bash -c "$LAUNCHER $CMD" 2>&1
echo "END TIME: $(date)"
END_TIME=$(date +%s)
echo "END TIME: $(date)"
ELAPSED_SECONDS=$((END_TIME - START_TIME))
HOURS=$((ELAPSED_SECONDS / 3600))
MINUTES=$(( (ELAPSED_SECONDS % 3600) / 60 ))
SECONDS=$((ELAPSED_SECONDS % 60))
echo "TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)"

View file

@ -14,11 +14,112 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from typing import Any, Literal, Optional
import trl
@dataclass
class DatasetConfig:
"""Configuration for a dataset in a mixture."""
id: str
config: Optional[str] = None
split: str = "train"
columns: Optional[list[str]] = None
weight: Optional[float] = None
@dataclass
class DatasetMixtureConfig:
"""Configuration for a mixture of datasets."""
datasets: list[DatasetConfig]
seed: int = 0
test_split_size: Optional[float] = None
@dataclass
class ScriptArguments(trl.ScriptArguments):
"""
Extended version of ScriptArguments with support for dataset mixtures.
Args:
dataset_mixture (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Configuration for creating dataset mixtures with advanced options.
Format:
dataset_mixture:
datasets:
- id: dataset_id1
config: config_name
columns:
- col1
- col2
weight: 0.5
- id: dataset_id2
config: config_name
columns:
- col1
- col2
weight: 0.5
seed: 42
test_split_size: 0.1
"""
# Override the dataset_name to make it optional
dataset_name: Optional[str] = field(
default=None, metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."}
)
dataset_mixture: Optional[dict[str, Any]] = field(
default=None,
metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."},
)
def __post_init__(self):
if self.dataset_name is None and self.dataset_mixture is None:
raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")
if self.dataset_mixture is not None:
if not isinstance(self.dataset_mixture, dict) or "datasets" not in self.dataset_mixture:
raise ValueError(
"dataset_mixture must be a dictionary with a 'datasets' key. "
"Expected format: {'datasets': [...], 'seed': int}"
)
datasets_list = []
datasets_data = self.dataset_mixture.get("datasets", [])
if isinstance(datasets_data, list):
for dataset_config in datasets_data:
datasets_list.append(
DatasetConfig(
id=dataset_config.get("id"),
config=dataset_config.get("config"),
split=dataset_config.get("split", "train"),
columns=dataset_config.get("columns"),
weight=dataset_config.get("weight", 1.0),
)
)
else:
raise ValueError("'datasets' must be a list of dataset configurations")
self.dataset_mixture = DatasetMixtureConfig(
datasets=datasets_list,
seed=self.dataset_mixture.get("seed", 0),
test_split_size=self.dataset_mixture.get("test_split_size", None),
)
# Check that column names are consistent across all dataset configs
columns_sets = [set(dataset.columns) for dataset in datasets_list if dataset.columns is not None]
if columns_sets:
first_columns = columns_sets[0]
if not all(columns == first_columns for columns in columns_sets):
raise ValueError(
"Column names must be consistent across all dataset configurations in a mixture. "
f"Found different column sets: {[list(cols) for cols in columns_sets]}"
)
# TODO: add the shared options with a mixin to reduce code duplication
@dataclass
class GRPOConfig(trl.GRPOConfig):
@ -27,21 +128,30 @@ class GRPOConfig(trl.GRPOConfig):
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
default_factory=lambda: [],
metadata={"help": "The benchmarks to run after training."},
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
default_factory=lambda: [],
metadata={"help": "The callbacks to run during training."},
)
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
num_completions_to_print: int = field(default=0, metadata={"help": "Number of completions to print."})
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "The optional system prompt to use."},
)
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
wandb_log_unique_prompts: bool = field(
default=True,
metadata={
"help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.")
},
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
@ -50,6 +160,10 @@ class GRPOConfig(trl.GRPOConfig):
default=None,
metadata={"help": ("The project to store runs under.")},
)
wandb_run_group: Optional[str] = field(
default=None,
metadata={"help": ("The group to store runs under.")},
)
@dataclass
@ -59,10 +173,12 @@ class SFTConfig(trl.SFTConfig):
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
default_factory=lambda: [],
metadata={"help": "The benchmarks to run after training."},
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
default_factory=lambda: [],
metadata={"help": "The callbacks to run during training."},
)
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
system_prompt: Optional[str] = field(
@ -83,16 +199,20 @@ class SFTConfig(trl.SFTConfig):
default=None,
metadata={"help": ("The project to store runs under.")},
)
wandb_run_group: Optional[str] = field(
default=None,
metadata={"help": ("The group to store runs under.")},
)
@dataclass
class GRPOScriptArguments(trl.ScriptArguments):
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', 'tag_count', 'code', 'ioi_code', 'code_format'.
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', 'tag_count', 'code', 'ioi_code', 'code_format', 'soft_overlong_punishment'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
@ -105,6 +225,10 @@ class GRPOScriptArguments(trl.ScriptArguments):
Maximum length for cosine scaling.
code_language (`str`):
Language for code format reward.
max_completion_len (`int`):
Maximum number of tokens in completion.
soft_punish_cache (`int`):
Minimum number of tokens in completion.
"""
reward_funcs: list[str] = field(
@ -143,6 +267,7 @@ class GRPOScriptArguments(trl.ScriptArguments):
)
code_language: str = field(
default="python",
# '(?:python|cpp)'
metadata={
"help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages",
"choices": ["python", "javascript", "r", "java", "bash", "cpp"],
@ -154,3 +279,53 @@ class GRPOScriptArguments(trl.ScriptArguments):
"help": "for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions"
},
)
code_eval_scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = field(
default="weighted_sum",
metadata={"help": "use fraction of passed test cases as reward. If false, use 0/1 scoring."},
)
parallel_code_exec_per_proc: int = field(
default=2,
metadata={
"help": "Number of parallel E2B code executions per process. Default of 2 is suitable for the Free Hobby tier of E2B with 8 GPUs used for training."
},
)
dataset_prompt_column: str = field(
default="prompt",
metadata={"help": "Column to use as prompts for training."},
)
e2b_router_url: Optional[str] = field(
default=None,
metadata={"help": "URL for the E2B router. See scripts/e2b_router.py"},
)
morph_router_url: Optional[str] = field(
default=None,
metadata={"help": "URL for the MorphCloud router. See scripts/morph_router.py"},
)
code_provider: Optional[str] = field(
default="e2b",
metadata={
"help": "Provider for code execution. Options: 'e2b', 'local', 'morph'.",
"choices": ["e2b", "local", "morph"],
},
)
ioi_provider: Optional[str] = field(
default="piston",
metadata={
"help": "Provider for IOI code execution. Options: 'piston', 'morph'.",
"choices": ["piston", "morph"],
},
)
max_completion_len: int = field(
default=16384,
metadata={"help": "Maximum number of characters in completion."},
)
soft_punish_cache: int = field(
default=4096,
metadata={"help": "Minimum number of characters in completion."},
)

View file

@ -1,185 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom evaluation tasks for LightEval."""
import random
from lighteval.metrics.dynamic_metrics import (
ExprExtractionConfig,
IndicesExtractionConfig,
LatexExtractionConfig,
multilingual_extractive_match_metric,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
# Prompt template adapted from
# - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17
# - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details
# Note that it is important to have the final answer in a box for math-verify to work correctly
MATH_QUERY_TEMPLATE = """
Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
{Question}
""".strip()
# Prompt template from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14
GPQA_QUERY_TEMPLATE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()
latex_gold_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(LatexExtractionConfig(),),
# Match boxed first before trying other regexes
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
aggregation_function=max,
)
expr_gold_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(ExprExtractionConfig(),),
# Match boxed first before trying other regexes
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
aggregation_function=max,
)
gpqa_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
precision=5,
)
def math_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["solution"]],
gold_index=0,
)
def aime_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["answer"]],
gold_index=0,
)
def gpqa_prompt_fn(line, task_name: str = None):
gold_index = random.randint(0, 3)
choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]]
choices.insert(gold_index, line["Correct Answer"])
query = GPQA_QUERY_TEMPLATE.format(
A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"]
)
return Doc(
task_name=task_name,
query=query,
choices=["A", "B", "C", "D"],
gold_index=gold_index,
instruction=query,
)
# Define tasks
aime24 = LightevalTaskConfig(
name="aime24",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="HuggingFaceH4/aime_2024",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[expr_gold_metric],
version=1,
)
aime25 = LightevalTaskConfig(
name="aime25",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="yentinglin/aime_2025",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[expr_gold_metric],
version=1,
)
math_500 = LightevalTaskConfig(
name="math_500",
suite=["custom"],
prompt_function=math_prompt_fn,
hf_repo="HuggingFaceH4/MATH-500",
hf_subset="default",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[latex_gold_metric],
version=1,
)
gpqa_diamond = LightevalTaskConfig(
name="gpqa:diamond",
suite=["custom"],
prompt_function=gpqa_prompt_fn,
hf_repo="Idavidrein/gpqa",
hf_subset="gpqa_diamond",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768, # needed for reasoning models like R1
metric=[gpqa_metric],
stop_sequence=[], # no stop sequence, will use eos token
trust_dataset=True,
version=1,
)
# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(aime25)
TASKS_TABLE.append(math_500)
TASKS_TABLE.append(gpqa_diamond)
# MODULE LOGIC
if __name__ == "__main__":
print([t["name"] for t in TASKS_TABLE])
print(len(TASKS_TABLE))

View file

@ -53,7 +53,7 @@ def build_distilabel_pipeline(
generation_kwargs=generation_kwargs,
),
template=prompt_template,
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
input_mappings=({"instruction": prompt_column} if prompt_column is not None else {}),
input_batch_size=input_batch_size,
num_generations=num_generations,
group_generations=True,

View file

@ -17,15 +17,13 @@ import os
import sys
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import GRPOConfig, GRPOScriptArguments
from open_r1.rewards import get_reward_funcs
from open_r1.utils import get_tokenizer
from open_r1.utils import get_dataset, get_model, get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, TrlParser, get_peft_config
@ -73,24 +71,33 @@ def main(script_args, training_args, model_args):
init_wandb_training(training_args)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
dataset = get_dataset(script_args)
################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, training_args)
##############
# Load model #
##############
logger.info("*** Loading model ***")
model = get_model(model_args, training_args)
# Get reward functions from the registry
reward_funcs = get_reward_funcs(script_args)
# Format into conversation
def make_conversation(example):
def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column):
prompt = []
if training_args.system_prompt is not None:
prompt.append({"role": "system", "content": training_args.system_prompt})
prompt.append({"role": "user", "content": example["problem"]})
if prompt_column not in example:
raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")
prompt.append({"role": "user", "content": example[prompt_column]})
return {"prompt": prompt}
dataset = dataset.map(make_conversation)
@ -99,28 +106,15 @@ def main(script_args, training_args, model_args):
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
#############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
model=model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
processing_class=tokenizer,
@ -146,6 +140,9 @@ def main(script_args, training_args, model_args):
# Save model and create model card
##################################
logger.info("*** Save model ***")
# Align the model's generation config with the tokenizer's eos token
# to avoid unbounded generation in the transformers `pipeline()` function
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")

View file

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reward functions for GRPO training."""
import asyncio
@ -5,25 +20,24 @@ import json
import math
import re
from functools import partial, update_wrapper
from typing import Callable, Dict
from typing import Callable, Dict, Literal, Optional
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from .utils import is_e2b_available
from .utils.ioi import SubtaskResult, add_includes, get_piston_client_from_env, score_subtask
from .utils.code_providers import get_provider
from .utils.competitive_programming import (
SubtaskResult,
add_includes,
get_morph_client_from_env,
get_piston_client_from_env,
)
from .utils.competitive_programming import patch_code as cf_patch_code
from .utils.competitive_programming import score_submission as cf_score_submission
from .utils.competitive_programming import score_subtask
if is_e2b_available():
from dotenv import load_dotenv
from e2b_code_interpreter import AsyncSandbox
load_dotenv()
else:
AsyncSandbox = None
def accuracy_reward(completions, solution, **kwargs):
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
@ -31,7 +45,6 @@ def accuracy_reward(completions, solution, **kwargs):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
@ -54,15 +67,15 @@ def accuracy_reward(completions, solution, **kwargs):
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
# Compute binary rewards if verifiable, `None` otherwise to skip this example
try:
reward = float(verify(answer_parsed, gold_parsed))
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
reward = 0.0
reward = None
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
# If the gold solution is not parseable, we assign `None` to skip this example
reward = None
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
@ -119,7 +132,7 @@ def reasoning_steps_reward(completions, **kwargs):
def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599
Args:
completions: List of model completions
@ -217,7 +230,11 @@ def get_cosine_scaled_reward(
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) == 0:
rewards.append(1.0) # Skip unparseable examples
print("Failed to parse gold solution: ", sol)
@ -265,21 +282,41 @@ def get_cosine_scaled_reward(
return cosine_scaled_reward
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, language: str = "en"):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Computes N-gram repetition penalty as described in Appendix C.2 of https://huggingface.co/papers/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
language: Language of the text, defaults to `en`. Used to choose the way to split the text into n-grams.
"""
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
if language == "en":
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)]), words
elif language == "zh":
from transformers.utils.import_utils import _is_package_available
if not _is_package_available("jieba"):
raise ValueError("Please install jieba to use Chinese language")
def zipngram(text: str, ngram_size: int):
import jieba
seg_list = list(jieba.cut(text))
return zip(*[seg_list[i:] for i in range(ngram_size)]), seg_list
else:
raise ValueError(
f"Word splitting for language `{language}` is not yet implemented. Please implement your own zip-ngram function."
)
def repetition_penalty_reward(completions, **kwargs) -> float:
"""
@ -296,13 +333,16 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue
ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngram_array, words = zipngram(completion, ngram_size)
if len(words) < ngram_size:
rewards.append(0.0)
continue
for ng in ngram_array:
ngrams.add(ng)
total += 1
@ -315,6 +355,7 @@ def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
def _init_event_loop():
"""Initialize or get the current event loop."""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
@ -323,15 +364,24 @@ def _init_event_loop():
return loop
def ioi_code_reward(completions, test_batch_size: int = 1, **kwargs) -> list[float]:
"""Reward function that evaluates IOI problems using Piston+our IOI package.
def ioi_code_reward(completions, test_batch_size: int = 1, provider_type: str = "piston", **kwargs) -> list[float]:
"""Reward function that evaluates IOI problems using a specified execution client.
Assumes the dataset has the same format as hf.co/datasets/open-r1/ioi
test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.
Args:
completions: List of model completions to evaluate
test_batch_size: Evaluate these many test cases in parallel, then check if any of them failed (0 score):
if so stop evaluating; otherwise continue with the next batch of test cases.
provider_type: The execution provider to use (default: "piston"). Supported values: "piston", "morph"
**kwargs: Additional arguments passed from the dataset
"""
# for info on setting up piston workers, see slurm/piston/README.md
piston_client = get_piston_client_from_env()
# Get the appropriate client based on provider_type
if provider_type == "morph":
execution_client = get_morph_client_from_env()
else:
# for info on setting up piston workers, see slurm/piston/README.md
execution_client = get_piston_client_from_env()
code_snippets = [
# note: grading is automatically skipped if no code is extracted
@ -343,16 +393,22 @@ def ioi_code_reward(completions, test_batch_size: int = 1, **kwargs) -> list[flo
try:
return await task
except Exception as e:
print(f"Error from Piston worker: {e}")
return SubtaskResult() # score 0.0
print(f"Error from {provider_type} worker: {e}")
return SubtaskResult()
# load problem data. undo separating kwargs by column
problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())]
loop = _init_event_loop()
evals = [
loop.create_task(
run_catch_exceptions(score_subtask(piston_client, problem_data, code, test_batch_size=test_batch_size))
run_catch_exceptions(
score_subtask(
execution_client,
problem_data,
code,
test_batch_size=test_batch_size,
)
)
)
for problem_data, code in zip(problems_data, code_snippets)
]
@ -361,32 +417,115 @@ def ioi_code_reward(completions, test_batch_size: int = 1, **kwargs) -> list[flo
return [result.score for result in results]
def extract_code(completion: str, language: str = "python") -> str:
def cf_code_reward(
completions,
test_batch_size: int = 1,
patch_code: bool = False,
scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = "weighted_sum",
**kwargs,
) -> list[float]:
"""Reward function that evaluates Codeforces problems using Piston+our CF package.
Assumes the dataset has the same format as hf.co/datasets/open-r1/codeforces (verifiable-prompts subset)
test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.
"""
# for info on setting up piston workers, see slurm/piston/README.md
piston_client = get_piston_client_from_env()
languages = kwargs["language"] if "language" in kwargs else [None] * len(completions)
code_snippets = [
# note: grading is automatically skipped if a problem has no tests
cf_patch_code(extract_code(completion[-1]["content"], language), language)
if patch_code
else extract_code(completion[-1]["content"], language)
for completion, language in zip(completions, languages)
]
async def run_catch_exceptions(task):
try:
return await task
except Exception as e:
print(f"Error from Piston worker: {e}")
return None
# load problem data. undo separating kwargs by column
problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())]
loop = _init_event_loop()
evals = [
loop.create_task(
run_catch_exceptions(
cf_score_submission(
piston_client,
problem_data,
code,
test_batch_size=test_batch_size,
scoring_mode=scoring_mode,
submission_language=problem_data.get("language", None),
)
)
)
for problem_data, code in zip(problems_data, code_snippets)
]
results = loop.run_until_complete(asyncio.gather(*evals))
return results
def extract_code(completion: str, language: str | None = "python") -> str:
if language is None:
return ""
pattern = re.compile(rf"```{language}\n(.*?)```", re.DOTALL)
matches = pattern.findall(completion)
extracted_answer = matches[-1] if len(matches) >= 1 else ""
return extracted_answer
def binary_code_reward(completions, **kwargs) -> list[float]:
rewards = code_reward(completions, **kwargs)
def binary_code_reward(
completions,
num_parallel: int = 2,
provider_type: str = "e2b",
enforce_same_language: bool = False,
**kwargs,
) -> list[float]:
rewards = code_reward(
completions,
num_parallel=num_parallel,
provider_type=provider_type,
enforce_same_language=enforce_same_language,
**kwargs,
)
BINARY_THRESHOLD = 0.99
return [1.0 if reward > BINARY_THRESHOLD else 0.0 for reward in rewards]
output = []
for reward in rewards:
if reward is None:
output.append(None)
else:
output.append(1.0 if reward > BINARY_THRESHOLD else 0.0)
return output
def code_reward(completions, **kwargs) -> list[float]:
"""Reward function that evaluates code snippets using the E2B code interpreter.
def code_reward(
completions,
num_parallel: int = 2,
provider_type: str = "e2b",
enforce_same_language: bool = False,
**kwargs,
) -> list[float]:
"""Reward function that evaluates code snippets using a code execution provider.
Assumes the dataset contains a `verification_info` column with test cases.
"""
if not is_e2b_available():
raise ImportError(
"E2B is not available and required for this reward function. Please install E2B with "
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
)
# TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages
"""Returns a reward function that evaluates code snippets in a sandbox."""
Args:
completions: List of model completions to evaluate
num_parallel: Number of parallel code executions (default: 2)
provider_type: Which code execution provider to use (default: "e2b")
enforce_same_language: If True, verify all problems use the same language (default: False)
**kwargs: Additional arguments passed to the verification
"""
evaluation_script_template = """
import subprocess
import json
@ -426,25 +565,31 @@ def code_reward(completions, **kwargs) -> list[float]:
evaluate_code(code_snippet, test_cases)
"""
code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
verification_info = kwargs["verification_info"]
template = evaluation_script_template
scripts = [
evaluation_script_template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"])))
template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"])))
for code, info in zip(code_snippets, verification_info)
]
language = verification_info[0]["language"]
if not all(v["language"] == language for v in verification_info):
raise ValueError("All verification_info must have the same language", verification_info)
try:
rewards = run_async_from_sync(scripts, language)
if enforce_same_language:
all_same_language = all(v["language"] == language for v in verification_info)
if not all_same_language:
raise ValueError("All verification_info must have the same language", verification_info)
except Exception as e:
print(f"Error from E2B executor: {e}")
rewards = [0.0] * len(completions)
execution_provider = get_provider(
provider_type=provider_type,
num_parallel=num_parallel,
**kwargs,
)
return rewards
return execution_provider.execute_scripts(scripts, ["python"] * len(scripts))
def get_code_format_reward(language: str = "python"):
@ -453,55 +598,49 @@ def get_code_format_reward(language: str = "python"):
Args:
language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages
"""
pattern = rf"^<think>\n.*?\n</think>\n<answer>\n.*?```{language}.*?```.*?\n</answer>$"
def code_format_reward(completions, **kwargs):
# if there is a language field, use it instead of the default language. This way we can have mixed language training.
languages = kwargs["language"] if "language" in kwargs else [language] * len(completions)
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
matches = [
re.match(
rf"^<think>\n.*?\n</think>\n<answer>\n.*?```{sample_language}.*?```.*?\n</answer>$",
content,
re.DOTALL | re.MULTILINE,
)
for content, sample_language in zip(completion_contents, languages)
]
return [1.0 if match else 0.0 for match in matches]
return code_format_reward
def run_async_from_sync(scripts: list[str], language: str) -> list[float]:
"""Function wrapping the `run_async` function."""
# Create a new event loop and set it
try:
# Run the async function and get the result
rewards = asyncio.run(run_async(scripts, language))
except Exception as e:
print(f"Error from E2B executor async: {e}")
raise e
def get_soft_overlong_punishment(max_completion_len, soft_punish_cache):
"""
Reward function that penalizes overlong completions. It is used to penalize overlong completions,
but not to reward shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476)
return rewards
Args:
max_completion_len: Maximum length of the completion
soft_punish_cache: Minimum length of the completion. If set to 0, no minimum length is applied.
"""
def soft_overlong_punishment_reward(completion_ids: list[list[int]], **kwargs) -> list[float]:
"""Reward function that penalizes overlong completions."""
rewards = []
for ids in completion_ids:
completion_length = len(ids)
if completion_length <= max_completion_len - soft_punish_cache:
rewards.append(0.0)
elif max_completion_len - soft_punish_cache < completion_length <= max_completion_len:
rewards.append((max_completion_len - soft_punish_cache - completion_length) / soft_punish_cache)
else:
rewards.append(-1.0)
return rewards
async def run_async(scripts: list[str], language: str) -> list[float]:
# Create the sandbox by hand, currently there's no context manager for this version
sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
# Create a list of tasks for running scripts concurrently
tasks = [run_script(sbx, script, language) for script in scripts]
# Wait for all tasks to complete and gather their results as they finish
results = await asyncio.gather(*tasks)
rewards = list(results) # collect results
# Kill the sandbox after all the tasks are complete
await sbx.kill()
return rewards
async def run_script(sbx: AsyncSandbox, script: str, language: str) -> float:
execution = await sbx.run_code(script, language=language)
try:
return float(execution.text)
except (TypeError, ValueError):
return 0.0
except Exception as e:
print(f"Error from E2B executor run_script: {e}")
return 0.0
return soft_overlong_punishment_reward
def get_reward_funcs(script_args) -> list[Callable]:
@ -521,13 +660,46 @@ def get_reward_funcs(script_args) -> list[Callable]:
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
"code": code_reward,
"binary_code": binary_code_reward,
"code": update_wrapper(
partial(
code_reward,
num_parallel=script_args.parallel_code_exec_per_proc,
provider_type=script_args.code_provider,
enforce_same_language=getattr(script_args, "enforce_same_language", False),
),
code_reward,
),
"binary_code": update_wrapper(
partial(
binary_code_reward,
num_parallel=script_args.parallel_code_exec_per_proc,
provider_type=script_args.code_provider,
enforce_same_language=getattr(script_args, "enforce_same_language", False),
),
binary_code_reward,
),
"ioi_code": update_wrapper(
partial(ioi_code_reward, test_batch_size=script_args.code_eval_test_batch_size), ioi_code_reward
partial(
ioi_code_reward,
test_batch_size=script_args.code_eval_test_batch_size,
provider_type=getattr(script_args, "ioi_provider", "piston"),
),
ioi_code_reward,
),
"cf_code": update_wrapper(
partial(
cf_code_reward,
test_batch_size=script_args.code_eval_test_batch_size,
scoring_mode=script_args.code_eval_scoring_mode,
),
cf_code_reward,
),
"code_format": get_code_format_reward(language=script_args.code_language),
"tag_count": tag_count_reward,
"soft_overlong_punishment": get_soft_overlong_punishment(
max_completion_len=script_args.max_completion_len,
soft_punish_cache=script_args.soft_punish_cache,
),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

View file

@ -19,20 +19,18 @@ Usage:
# One 1 node of 8 x H100s
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--max_seq_length 4096 \
--model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \
--dataset_name open-r1/Mixture-of-Thoughts \
--dataset_config all \
--eos_token '<|im_end|>' \
--learning_rate 4.0e-5 \
--num_train_epochs 5 \
--max_seq_length 32768 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--bf16 \
--logging_steps 5 \
--eval_strategy steps \
--eval_steps 100 \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
--use_liger_kernel \
--output_dir data/OpenR1-Distill-7B
"""
import logging
@ -40,32 +38,21 @@ import os
import sys
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import SFTConfig
from open_r1.utils import get_tokenizer
from open_r1.configs import ScriptArguments, SFTConfig
from open_r1.utils import get_dataset, get_model, get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import (
ModelConfig,
ScriptArguments,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format
logger = logging.getLogger(__name__)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
@ -97,44 +84,25 @@ def main(script_args, training_args, model_args):
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
################
# Load datasets
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Load tokenizer
################
######################################
# Load dataset, tokenizer, and model #
######################################
dataset = get_dataset(script_args)
tokenizer = get_tokenizer(model_args, training_args)
tokenizer.pad_token = tokenizer.eos_token
model = get_model(model_args, training_args)
###################
# Model init kwargs
###################
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
if tokenizer.chat_template is None:
logger.info("No chat template provided, defaulting to ChatML.")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
############################
# Initialize the SFT Trainer
############################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
model=model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
@ -160,6 +128,9 @@ def main(script_args, training_args, model_args):
# Save model and create model card
##################################
logger.info("*** Save model ***")
# Align the model's generation config with the tokenizer's eos token
# to avoid unbounded generation in the transformers `pipeline()` function
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")

View file

@ -1,5 +1,6 @@
from .import_utils import is_e2b_available
from .model_utils import get_tokenizer
from .data import get_dataset
from .import_utils import is_e2b_available, is_morph_available
from .model_utils import get_model, get_tokenizer
__all__ = ["get_tokenizer", "is_e2b_available"]
__all__ = ["get_tokenizer", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]

View file

@ -44,7 +44,13 @@ class PushToHubRevisionCallback(TrainerCallback):
def __init__(self, model_config) -> None:
self.model_config = model_config
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.is_world_process_zero:
global_step = state.global_step

View file

@ -0,0 +1,366 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code execution providers for executing and evaluating code snippets."""
import abc
import asyncio
from typing import List, Optional
from ..utils import is_e2b_available, is_morph_available
if is_e2b_available():
from e2b_code_interpreter import AsyncSandbox
from e2b_code_interpreter.models import Execution
from .routed_sandbox import RoutedSandbox
else:
AsyncSandbox = None
Execution = None
RoutedSandbox = None
if is_morph_available():
from morphcloud.api import MorphCloudClient
from morphcloud.sandbox import Sandbox
from .routed_morph import RoutedMorphSandbox
else:
MorphCloudClient = None
Sandbox = None
RoutedMorphSandbox = None
class CodeExecutionProvider(abc.ABC):
"""Abstract base class for code execution providers."""
@abc.abstractmethod
def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:
"""Execute multiple scripts and return their reward values.
Args:
scripts: List of code scripts to execute
language: The programming language of the scripts
Returns:
List of float rewards (one per script)
"""
pass
class E2BProvider(CodeExecutionProvider):
"""Provider that executes code using E2B sandboxes."""
def __init__(self, num_parallel: int = 2, e2b_router_url: Optional[str] = None):
"""Initialize the E2B provider.
Args:
num_parallel: Number of parallel sandboxes to use
e2b_router_url: URL for the E2B router (if using router mode)
"""
if not is_e2b_available():
raise ImportError(
"E2B is not available and required for this provider. Please install E2B with "
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
)
self.num_parallel = num_parallel
self.e2b_router_url = e2b_router_url
def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:
"""Execute scripts using E2B sandboxes.
If e2b_router_url is provided, uses the RoutedSandbox for batch processing.
Otherwise, uses direct AsyncSandbox with parallelization.
"""
if self.e2b_router_url is not None:
routed_sandbox = RoutedSandbox(router_url=self.e2b_router_url)
executions = routed_sandbox.run_code(
scripts=scripts,
languages=languages,
timeout=30,
request_timeout=28,
)
rewards = []
for execution in executions:
try:
reward = float(execution.text)
rewards.append(reward)
except Exception:
rewards.append(None)
return rewards
try:
rewards = self._run_async_from_sync(scripts, languages, self.num_parallel)
except Exception as e:
print(f"Error from E2B executor: {e}")
rewards = [0.0] * len(scripts)
return rewards
def _run_async_from_sync(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:
"""Function wrapping the `_run_async` function."""
try:
rewards = asyncio.run(self._run_async(scripts, languages, num_parallel))
except Exception as e:
print(f"Error from E2B executor async: {e}")
raise e
return rewards
async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:
semaphore = asyncio.Semaphore(num_parallel)
tasks = [self._run_script(script, languages, semaphore) for script in scripts]
results = await asyncio.gather(*tasks)
rewards = list(results)
return rewards
async def _run_script(self, script: str, languages: List[str], semaphore: asyncio.Semaphore) -> float:
# We set a timeout margin, as the AsyncSandbox timeout does not seem to work
# These values are based on running 256 examples with the gold solution
# from open-r1/verifiable-coding-problems-python_decontaminated
# see scripts/benchmark_e2b.py
SANDBOX_TIMEOUT = 30
MARGIN = 2
REQUEST_TIMEOUT = SANDBOX_TIMEOUT - MARGIN
ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN
async with semaphore:
try:
sandbox = await AsyncSandbox.create(timeout=SANDBOX_TIMEOUT, request_timeout=REQUEST_TIMEOUT)
execution = await asyncio.wait_for(
sandbox.run_code(script, languages=languages),
timeout=ASYNCIO_TIMEOUT,
)
return float(execution.text)
except (TypeError, ValueError):
return 0.0
except asyncio.TimeoutError:
print("Operation timed out")
return 0.0
except Exception as e:
print(f"Error in `_run_script` from E2B sandbox ID {sandbox.sandbox_id} : {e}")
return 0.0
finally:
try:
await sandbox.kill()
except Exception as e:
print(f"Error from E2B executor kill with sandbox ID {sandbox.sandbox_id} : {e}")
class MorphProvider(CodeExecutionProvider):
"""Provider that executes code using MorphCloud's Sandbox API."""
def __init__(self, num_parallel: int = 2, morph_router_url: Optional[str] = None):
"""Initialize the Morph provider.
Args:
num_parallel: Number of parallel executions to use
morph_router_url: URL for the MorphCloud router (if using router mode)
"""
if not is_morph_available():
raise ImportError(
"MorphCloud is not available and required for this provider. Please install MorphCloud with "
"`pip install morphcloud` and add an API key to a `.env` file."
)
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
print("Warning: python-dotenv not installed. Environment variables must be set directly.")
self.num_parallel = num_parallel
self.morph_router_url = morph_router_url
if self.morph_router_url is not None:
self.routed_sandbox = RoutedMorphSandbox(router_url=self.morph_router_url)
return
import os
self.api_key = os.getenv("MORPH_API_KEY")
if not self.api_key:
raise ValueError("MorphCloud API key not found. Please set the MORPH_API_KEY environment variable.")
try:
self.client = MorphCloudClient(api_key=self.api_key)
self.Sandbox = Sandbox
except ImportError as e:
raise ImportError(f"Required MorphCloud dependencies not installed: {e}")
def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:
"""Execute scripts using MorphCloud Sandbox API.
Args:
scripts: List of Python scripts to execute
language: Programming language
Returns:
List of float rewards (one per script)
"""
if hasattr(self, "routed_sandbox"):
try:
results = self.routed_sandbox.run_code(
scripts=scripts,
languages=languages,
timeout=90,
request_timeout=96,
)
rewards = []
for result in results:
try:
reward = float(result.text)
rewards.append(reward)
except (ValueError, AttributeError):
rewards.append(0.0)
return rewards
except Exception as e:
print(f"Error from MorphCloud router: {e}")
return [0.0] * len(scripts)
import asyncio
try:
rewards = asyncio.run(self._run_async(scripts, languages, self.num_parallel))
except Exception as e:
print(f"Error from MorphCloud executor: {e}")
rewards = [0.0] * len(scripts)
return rewards
async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:
"""Run multiple scripts concurrently with limited parallelism.
Args:
scripts: List of scripts to execute
language: Programming language
num_parallel: Maximum number of concurrent executions
Returns:
List of rewards
"""
semaphore = asyncio.Semaphore(num_parallel)
tasks = [self._run_script(script, languages, semaphore) for script in scripts]
results = await asyncio.gather(*tasks)
return list(results)
async def _run_script(self, script: str, languages: List[str], semaphore: asyncio.Semaphore) -> float:
"""Execute a single script in a MorphCloud Sandbox.
Args:
script: The script to execute
language: Programming language
semaphore: Semaphore to limit concurrency
Returns:
Float reward from script execution
"""
SANDBOX_TIMEOUT = 90
MARGIN = 6
ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN
sandbox = None
async with semaphore:
try:
sandbox = await asyncio.to_thread(self.Sandbox.new, client=self.client, ttl_seconds=SANDBOX_TIMEOUT)
result = await asyncio.wait_for(
asyncio.to_thread(
sandbox.run_code,
script,
languages=languages,
timeout=SANDBOX_TIMEOUT,
),
timeout=ASYNCIO_TIMEOUT,
)
reward = 0.0
try:
if hasattr(result, "text") and result.text:
lines = result.text.strip().split("\n")
if lines:
try:
reward = float(lines[-1])
except ValueError:
try:
reward = float(result.text.strip())
except ValueError:
pass
elif hasattr(result, "stdout") and result.stdout:
lines = result.stdout.strip().split("\n")
if lines:
try:
reward = float(lines[-1])
except ValueError:
pass
except (ValueError, AttributeError):
pass
return reward
except asyncio.TimeoutError:
return 0.0
except Exception:
return 0.0
finally:
if sandbox:
try:
await asyncio.to_thread(sandbox.close)
await asyncio.to_thread(sandbox.shutdown)
except Exception:
pass
def get_provider(provider_type: str = "e2b", **kwargs) -> CodeExecutionProvider:
"""Factory function to get the appropriate code execution provider.
Args:
provider_type: Type of provider to use ("e2b", "morph")
**kwargs: Additional arguments to pass to the provider
Returns:
An instance of CodeExecutionProvider
"""
num_parallel = kwargs.pop("num_parallel", 2)
if provider_type == "e2b":
# Extract E2B-specific arguments
e2b_router_url = kwargs.pop("e2b_router_url", None)
return E2BProvider(
num_parallel=num_parallel,
e2b_router_url=e2b_router_url,
)
elif provider_type == "morph":
# Extract Morph-specific arguments
morph_router_url = kwargs.pop("morph_router_url", None)
return MorphProvider(
num_parallel=num_parallel,
morph_router_url=morph_router_url,
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")

View file

@ -0,0 +1,19 @@
from .cf_scoring import score_submission
from .code_patcher import patch_code
from .ioi_scoring import SubtaskResult, score_subtask, score_subtasks
from .ioi_utils import add_includes
from .morph_client import get_morph_client_from_env
from .piston_client import get_piston_client_from_env, get_slurm_piston_endpoints
__all__ = [
"get_piston_client_from_env",
"get_slurm_piston_endpoints",
"get_morph_client_from_env",
"patch_code",
"score_submission",
"score_subtask",
"score_subtasks",
"add_includes",
"SubtaskResult",
]

View file

@ -0,0 +1,146 @@
import asyncio
import os
from io import BytesIO
from typing import Literal
from async_lru import alru_cache
from .piston_client import PistonClient
from .utils import batched
async def score_single_test_case(
client: PistonClient,
problem_data: dict,
test_input: str,
test_output: str,
submission: str,
submission_language: str = "cpp",
) -> tuple[str, str]:
if submission_language not in ["python", "cpp"]:
raise ValueError(f"Invalid submission language: {submission_language}")
try:
result = await client.send_execute(
{
"files": [
{"name": f"main.{submission_language}", "content": submission},
*(
[{"name": "checker.py", "content": problem_data["generated_checker"]}]
if problem_data["generated_checker"]
else []
),
{"name": "input.txt", "content": test_input},
{"name": "correct_output.txt", "content": test_output},
{
"name": "grader_config",
"content": "\n".join(
f"{key}={value}"
for key, value in {
"TIME_LIMIT": problem_data["time_limit"],
"MEMORY_LIMIT": problem_data["memory_limit"],
"INPUT_MODE": problem_data["input_mode"],
}.items()
),
},
],
"run_timeout": (problem_data["time_limit"] + 10) * 1000,
# +10 seconds hard limit. time limits are handled by the codeforces script
},
language="cf_python3" if submission_language == "python" else "c++17",
)
except Exception as e:
print(f"Error scoring submission: {e}")
return False
return result
@alru_cache(maxsize=32) # TODO make this configurable
async def get_generated_contest_tests(contest_id: str) -> list[dict]:
import pandas as pd
import aiofiles
import aiofiles.os
tests_folder = os.environ.get("CF_TESTS_FOLDER", None)
if not tests_folder:
raise ValueError(
"CF_TESTS_FOLDER environment variable not set! Please download the codeforces generated tests and set CF_TESTS_FOLDER to the folder path. See https://huggingface.co/datasets/open-r1/codeforces for more information."
)
if not await aiofiles.os.path.exists(tests_folder):
raise ValueError(
f"CF_TESTS_FOLDER path '{tests_folder}' does not exist! Please download the codeforces generated tests and set CF_TESTS_FOLDER to the folder path. See https://huggingface.co/datasets/open-r1/codeforces for more information."
)
parquet_path = os.path.join(tests_folder, f"test_cases_{int(contest_id):04d}.parquet")
if not await aiofiles.os.path.exists(parquet_path):
return {}
# Read parquet file asynchronously
async with aiofiles.open(parquet_path, "rb") as f:
content = await f.read()
df = pd.read_parquet(BytesIO(content))
# Group by problem_id and convert to dictionary of lists
grouped_tests = df.groupby("problem_id").apply(lambda x: x[["input", "output"]].to_dict("records")).to_dict()
return grouped_tests
async def get_generated_tests(problem_id: str) -> list[dict]:
contest_id = problem_id.split("/")[0]
return (await get_generated_contest_tests(contest_id)).get(problem_id, [])
async def score_submission(
client: PistonClient,
problem_data: dict,
submission: str,
test_batch_size: int = 1,
scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = "weighted_sum",
no_compile_reward: float = -0.1,
no_submission_reward: float = -1.0,
submission_language: str = "cpp",
) -> float:
if submission_language not in ["python", "cpp"]:
raise ValueError(f"Invalid submission language: {submission_language}")
test_cases = problem_data["official_tests"] + (await get_generated_tests(problem_data["id"]))
# invalid/not a coding problem
if test_cases is None or len(test_cases) == 0:
return None
# no code extracted
if not submission:
return no_submission_reward
passed_test_cases = 0
# run one batch, check if any of them failed (0 score): if so stop evaluating (assuming non partial score); otherwise continue with the next batch of test cases.
for test_batch_to_run in batched(test_cases, test_batch_size) if test_batch_size >= 1 else [test_cases]:
results = await asyncio.gather(
*[
asyncio.create_task(
score_single_test_case(
client, problem_data, test_case["input"], test_case["output"], submission, submission_language
)
)
for test_case in test_batch_to_run
]
)
if any(result and result["compile"]["code"] != 0 for result in results):
return no_compile_reward
tests_passed_results = [
result and result["run"]["code"] == 0 and result["run"]["stdout"].strip() == "1" for result in results
]
if scoring_mode == "pass_fail" and any(not test_passed for test_passed in tests_passed_results):
break
passed_test_cases += sum(1 for test_passed in tests_passed_results if test_passed)
pass_fail_score = 1.0 if passed_test_cases == len(test_cases) else 0.0
if scoring_mode == "pass_fail":
return pass_fail_score
elif scoring_mode == "partial":
return passed_test_cases / len(test_cases)
elif scoring_mode == "weighted_sum":
return pass_fail_score + 0.1 * (passed_test_cases / len(test_cases))
else:
raise ValueError(f"Invalid scoring mode: {scoring_mode}")

View file

@ -0,0 +1,123 @@
import re
def fix_python3_imports(source_code):
"""
Fix common import and function changes between Python 3 versions
Args:
source_code (str): The Python source code to update
Returns:
str: The updated source code
"""
# Dictionary of patterns to replacements
replacements = [
# Fix collections.abc imports (changed in Python 3.3+)
(
r"from collections import (Mapping|Sequence|Set|Container|MutableMapping|MutableSet|MutableSequence)",
r"from collections.abc import \1",
),
# Fix imp module deprecation (deprecated in 3.4)
(r"import imp", r"import importlib"),
# Fix asyncio.async() to asyncio.ensure_future() (renamed in 3.4.4)
(r"asyncio\.async\(", r"asyncio.ensure_future("),
# Fix inspect.getargspec to inspect.getfullargspec (deprecated in 3.5)
(r"inspect\.getargspec", r"inspect.getfullargspec"),
# Fix array.array 'c' type code to 'b' (removed in 3.9)
(r"array\.array\('c'", r"array.array('b'"),
# Fix backslash line continuation with multiple newlines (Python-specific issue)
(r"\\(\r\n|\r|\n)+", "\\\n"),
# some solutions use getlogin() to check if they are debugging or on an actual submission
(r"(?:os\s*\.\s*)?getlogin\s*\(\s*\)", "False"),
# Fix usage of fractions.gcd (moved to math in 3.5)
# 1. Fix direct usage: fractions.gcd -> math.gcd
(r"\bfractions\.gcd\b", r"math.gcd"),
# 2. Fix 'from fractions import gcd, X' -> 'from fractions import X' (start/middle)
(r"(from\s+fractions\s+import\s+(?:\([^)]*)?)\bgcd\s*,\s*", r"\1"),
# 3. Fix 'from fractions import X, gcd' -> 'from fractions import X' (end)
(r"(from\s+fractions\s+import\s+.*?\S)\s*,\s*\bgcd(\s*\)?\s*(?:#.*)?)", r"\1\2"),
# 4. Fix standalone 'from fractions import gcd' -> 'from math import gcd'
(r"from\s+fractions\s+import\s+\(?\s*gcd\s*\)?", r""),
# --- End: Replacement for the faulty line ---
]
lines = source_code.splitlines()
last_import = max(
[
i
for i, line in enumerate(lines)
if line.strip().startswith("import") or (line.strip().startswith("from") and "import" in line)
],
default=0,
)
import_section = "\n".join(lines[: last_import + 1])
main_source = "\n".join(lines[last_import:])
if "fractions.gcd" in source_code and "import math" not in source_code:
import_section += "\nimport math"
elif "gcd" in source_code and "from math import gcd" not in source_code:
import_section += "\nfrom math import gcd"
if "set_int_max_str_digits" not in source_code:
import_section += "\nimport sys\nsys.set_int_max_str_digits(0)"
source_code = import_section + "\n" + main_source
# Apply each replacement
for pattern, replacement in replacements:
source_code = re.sub(pattern, replacement, source_code)
source_code = source_code.rstrip("\\")
return source_code
def fix_cpp_includes(source_code):
# has most of the useful functions
code_header = "#include <bits/stdc++.h>\n"
# use namespace std since models forget std:: often
if "using namespace std;" not in source_code and "std::" not in source_code:
code_header += "\nusing namespace std;\n\n"
return code_header + source_code
def is_patchable(lang):
return lang in ("python", "python3", "Python 3", "PyPy 3", "PyPy 3-64", "cpp") or "C++" in lang
def patch_code(text, lang):
if not text:
return text
if lang in ("python", "python3", "Python 3", "PyPy 3", "PyPy 3-64"):
return fix_python3_imports(text)
elif "cpp" in lang or "C++" in lang:
return fix_cpp_includes(text)
return text
tests = [
"""read = lambda: map(int, input().split())
n, m, z = read()
from fractions import gcd
ans = z // (n * m // gcd(n, m))
print(ans)""",
"""from fractions import Fraction,gcd
a,b,c,d = [int(x) for x in input().split()]
if a*d > b*c:
num = a*d-b*c
denom = a*d
else:
num = b*c-a*d
denom = b*c
div = gcd(num,denom)
print('%d/%d'%(num//div,denom//div))""",
]
if __name__ == "__main__":
for test in tests:
print("ORIGINAL:", test, sep="\n\n")
print("PATCHED:", patch_code(test, "Python 3"), sep="\n\n")
print("=" * 50)

View file

@ -2,8 +2,9 @@ import asyncio
from dataclasses import asdict, dataclass, field
from typing import Union
from .piston_client import PistonClient
from .utils import batched, load_ioi_tests
from .ioi_utils import load_ioi_tests
from .piston_client import PistonClient, PistonError
from .utils import batched
@dataclass
@ -295,4 +296,40 @@ async def run_submission(
), # +3 seconds hard limit. time limits are handled by the ioi script
"run_memory_limit": problem["memory_limit"],
}
return await client.execute(data)
return await execute_ioi(client, data)
async def execute_ioi(client, data) -> tuple[str, str]:
"""
Requests to the IOI package return the score as a float in the stdout, as well as optional feedback/errors in stderr.
Returns a tuple of (score, feedback).
"""
response = await client.send_execute(data)
if "message" in response:
raise PistonError(response["message"])
if "compile" in response and response["compile"]["code"] != 0:
return "0", "Compilation error exit code " + str(response["compile"]["code"]) + "\n" + response["compile"][
"stderr"
]
if "run" not in response:
raise PistonError(response)
if response["run"]["code"] == 1 and "MemoryError" in response["run"]["stderr"]:
return "0", "Memory limit exceeded"
# successful result
if response["run"]["stdout"]:
return response["run"]["stdout"], response["run"]["stderr"]
if response["run"]["signal"] == "SIGKILL":
return "0", "Time limit exceeded"
# other issues
if response["run"]["code"] != 0:
raise PistonError(
f"language={response['language']}, version={response['version']}, exit code={response['run']['code']}, stderr={response['run']['stderr']}, signal={response['run']['signal']}"
)
return "0", "Unknown error"

View file

@ -1,6 +1,5 @@
from collections import defaultdict
from functools import lru_cache
from itertools import islice
from datasets import load_dataset
@ -40,13 +39,3 @@ def load_ioi_tests(year: int, problem_id: str) -> dict[str, tuple[str, str]]:
Load IOI tests for a given year and problem id.
"""
return load_ioi_tests_for_year(year)[problem_id]
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
return iterable
it = iter(iterable)
while batch := list(islice(it, n)):
yield batch

View file

@ -0,0 +1,742 @@
import asyncio
import json
import logging
import os
import tempfile
from typing import Any, Dict, Optional, Tuple
from dotenv import load_dotenv
from open_r1.utils.import_utils import is_morph_available
# Replace direct imports with conditional imports
if is_morph_available():
from morphcloud.api import Instance, InstanceExecResponse, MorphCloudClient
else:
Instance = None
InstanceExecResponse = None
MorphCloudClient = None
# Silence verbose logs from dependencies
logging.getLogger("paramiko").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
class MorphCloudError(Exception):
pass
class MorphCloudExecutionClient:
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
spans_log_path: Optional[str] = None,
):
"""
Initialize the MorphCloud execution client.
Args:
api_key: Optional API key for MorphCloud. If not provided, will use MORPH_API_KEY env var.
base_url: Optional base URL for MorphCloud API. If not provided, will use default.
spans_log_path: Path to log API call spans to. Defaults to 'logs/morph_api_spans.jsonl'.
"""
self.client = MorphCloudClient(api_key=api_key, base_url=base_url)
self._snapshot_lock = asyncio.Lock()
async def _prepare_instance(self, snapshot_id=None) -> Instance:
"""
Prepare and start a MorphCloud instance.
Args:
snapshot_id: Optional snapshot ID to use. If None, will get or create base snapshot.
Returns:
Instance: The ready-to-use MorphCloud instance
Raises:
TimeoutError: If instance fails to start or become ready
"""
if not snapshot_id:
snapshot = await self._get_or_create_base_snapshot()
snapshot_id = snapshot.id
try:
instance = await self.client.instances.astart(
snapshot_id, ttl_seconds=600
) # Auto-terminate after 10 minutes
await instance.await_until_ready(timeout=300)
return instance
except asyncio.TimeoutError as e:
print(f"Timeout while preparing instance: {str(e)}")
if instance:
try:
await instance.astop()
except Exception:
pass
raise
async def _prepare_files(self, data: Dict[str, Any], temp_dir: str) -> Tuple[str, Dict[str, Any], Dict[str, str]]:
"""
Process files, determine problem ID, and prepare configuration.
Args:
data: Dictionary containing file information
temp_dir: Local temporary directory for file operations
Returns:
tuple: (problem_id, grader_config, local_files)
Raises:
ValueError: If problem ID cannot be determined
"""
# Extract problem ID
problem_id = None
graders_files = []
for file in data["files"]:
if file["name"].startswith("graders/") and file["name"].endswith(".cpp"):
potential_id = os.path.basename(file["name"]).split(".")[0]
if potential_id not in ["grader", "manager", "stub"]:
problem_id = potential_id
if file["name"].startswith("graders/"):
graders_files.append(file)
if not problem_id:
raise ValueError("Could not determine problem ID from files")
grader_config = {
"task_type": "Batch",
"code": problem_id,
"time_limit": data["run_timeout"] / 1000,
"memory_limit": data["run_memory_limit"] * 1024 * 1024,
}
for file in graders_files:
if "manager.cpp" in file["name"]:
grader_config["task_type"] = "Communication"
grader_config["task_type_parameters_Communication_num_processes"] = 1
grader_config["task_type_parameters_Communication_user_io"] = "std_io"
break
config_path = os.path.join(temp_dir, "grader_config.json")
with open(config_path, "w") as f:
json.dump(grader_config, f)
local_files = {"grader_config.json": config_path}
for file in data["files"]:
local_path = os.path.join(temp_dir, os.path.basename(file["name"]))
with open(local_path, "w") as f:
f.write(file["content"])
local_files[file["name"]] = local_path
return problem_id, grader_config, local_files
async def _upload_files(self, instance: Instance, local_files: Dict[str, str]) -> bool:
"""
Upload all necessary files to the instance.
Args:
instance: The MorphCloud instance
local_files: Dictionary mapping remote paths to local file paths
Returns:
bool: True if all uploads were successful
Raises:
TimeoutError: If uploads time out
"""
for remote_name, local_path in local_files.items():
target_path = f"/workspace/{remote_name}"
dir_path = os.path.dirname(target_path)
if dir_path != "/workspace":
await instance.aexec(f"mkdir -p {dir_path}")
await instance.aupload(local_path, target_path)
await instance.aupload(local_files["grader_config.json"], "/workspace/graders/grader_config.json")
return True
async def _compile_code(self, instance: Instance) -> InstanceExecResponse:
"""
Compile the code on the instance.
Args:
instance: The MorphCloud instance
Returns:
InstanceExecResponse: Result of compilation
Raises:
RuntimeError: If compilation fails
"""
compile_result = await instance.aexec("cd /workspace && ./compile")
if compile_result.exit_code != 0:
raise RuntimeError(f"Compilation error exit code {compile_result.exit_code}\n{compile_result.stderr}")
return compile_result
async def _run_tests(self, instance: Instance, data: Dict[str, Any]) -> Tuple[str, str]:
"""
Run tests and evaluate results.
Args:
instance: The MorphCloud instance
data: Dictionary containing runtime parameters
Returns:
tuple: (score, feedback)
Raises:
TimeoutError: If test execution times out
"""
hard_timeout = data["run_timeout"] / 1000 + 3
run_command = f"cd /workspace && timeout {hard_timeout}s ./run"
run_result = await instance.aexec(run_command)
if run_result.exit_code == 124 or run_result.exit_code == 137 or run_result.exit_code == 143:
return "0", "Time limit exceeded"
if run_result.exit_code != 0 and "Memory limit exceeded" in run_result.stderr:
return "0", "Memory limit exceeded"
if run_result.stdout:
return run_result.stdout.strip(), run_result.stderr.strip()
if run_result.exit_code != 0:
return (
"0",
f"Runtime error with exit code {run_result.exit_code}\n{run_result.stderr}",
)
return "0", "Unknown error"
async def _execute_with_instance(self, instance: Instance, data: Dict[str, Any], temp_dir: str) -> Tuple[str, str]:
"""Execute code using a prepared instance.
Args:
instance: Ready MorphCloud instance
data: Execution data
temp_dir: Temporary directory for file operations
Returns:
Tuple of (score, feedback)
Raises:
Exception: Passes through exceptions for retry handling
"""
await instance.await_until_ready(timeout=300)
problem_id, grader_config, local_files = await self._prepare_files(data, temp_dir)
await self._upload_files(instance, local_files)
try:
await self._compile_code(instance)
except RuntimeError as e:
return "0", str(e)
score, feedback = await self._run_tests(instance, data)
return score, feedback
async def _execute(self, data: Dict[str, Any]) -> Tuple[str, str]:
"""
Internal implementation of execute with no retry logic.
Args:
data: Dictionary containing execution data
Returns:
Tuple of (score, feedback)
Raises:
Exception: If execution fails
"""
instance = None
# Set timeouts to ensure we don't block indefinitely
# INSTANCE_TIMEOUT = 300 # 5 minutes for instance operations
TOTAL_EXECUTION_TIMEOUT = 600 # 10 minutes total execution time
with tempfile.TemporaryDirectory(prefix="morph_exec_") as temp_dir:
snapshot = await self._get_or_create_base_snapshot()
instance = await self.client.instances.astart(
snapshot.id, ttl_seconds=600
) # Auto-terminate after 10 minutes
async with instance:
# Use asyncio.wait_for to add overall timeout to the execution process
return await asyncio.wait_for(
self._execute_with_instance(instance, data, temp_dir),
timeout=TOTAL_EXECUTION_TIMEOUT,
)
async def execute(self, data: Dict[str, Any]) -> Tuple[str, str]:
"""
Execute code on MorphCloud based on the provided data with enhanced debugging and recovery.
Orchestrates the following steps with proper error handling and retries:
1. Prepare an instance (with retry)
2. Set up workspace (with retry)
3. Prepare and upload files (with retry)
4. Compile code (with retry)
5. Run tests (with retry)
Args:
data: Dictionary containing:
- files: List of file objects with name and content fields
- run_timeout: Timeout in milliseconds
- run_memory_limit: Memory limit in MB
Returns:
Tuple of (score, feedback) where:
- score is a string representation of a float between 0.0 and 1.0
- feedback is a string with execution details
"""
# TODO: would be faster to pass info about the subtask as well to create a snapshot per subtask
# would cache the uploads of all files other than the submission: input.txt, correct_output.txt, grader files
# rather than reusing the snapshot that only has the compile/run scripts on it
# currently, run_submission -> client.execute(data) does not easily pass subtask info
# Retry configuration
max_retries = 4
base_delay = 1.0
# Try execution with retries and exponential backoff
for attempt in range(max_retries + 1):
try:
return await self._execute(data)
except asyncio.TimeoutError:
if attempt < max_retries:
print(f"Execution timed out, retrying ({attempt + 1}/{max_retries})")
else:
return "0", "Execution timed out after multiple retries"
except Exception as e:
# Calculate exponential backoff
if attempt < max_retries:
retry_delay = min(base_delay * (2**attempt), 30) # Exponential backoff, capped at 30 seconds
print(
f"Execution failed with {type(e).__name__}: {str(e)}, retrying in {retry_delay:.2f}s ({attempt + 1}/{max_retries})"
)
await asyncio.sleep(retry_delay)
else:
print(f"Execution failed after {max_retries} retries: {type(e).__name__}: {str(e)}")
return "0", f"Execution failed after multiple retries: {str(e)}"
async def _get_or_create_base_snapshot(self):
"""Get or create a snapshot with the necessary dependencies and scripts for evaluation."""
async with self._snapshot_lock:
base_snapshots = await self.client.snapshots.alist(digest="ioi-evaluation-morph")
if not base_snapshots:
print("Creating base snapshot with build-essential cmake and g++")
# Create base snapshot with minimal specs
base_snapshot = await self.client.snapshots.acreate(
vcpus=2,
memory=4096,
disk_size=10240,
metadata={"purpose": "ioi_evaluation"},
)
# Start a temporary instance from the base snapshot
temp_instance = await self.client.instances.astart(
base_snapshot.id, ttl_seconds=900
) # Auto-terminate after 15 minutes
try:
# Wait for the instance to be ready
await temp_instance.await_until_ready(timeout=300)
# Get script contents
compile_script = await self._get_compile_script()
run_script = await self._get_run_script()
# Use temporary directory to store scripts
with tempfile.TemporaryDirectory(prefix="morph_setup_") as temp_dir:
# Create paths for script files
compile_path = os.path.join(temp_dir, "compile.sh")
run_path = os.path.join(temp_dir, "run.sh")
# Write scripts to temp files
with open(compile_path, "w") as f:
f.write(compile_script)
with open(run_path, "w") as f:
f.write(run_script)
async with temp_instance:
# Install dependencies
await temp_instance.aexec("apt-get update && apt-get install -y build-essential cmake g++")
# Create workspace directory
await temp_instance.aexec(
"mkdir -p /workspace && mkdir -p /workspace/graders && chmod 777 /workspace"
)
# Upload scripts to instance
await temp_instance.aupload(compile_path, "/workspace/compile")
await temp_instance.aupload(run_path, "/workspace/run")
# Make scripts executable
await temp_instance.aexec("chmod +x /workspace/compile /workspace/run")
# Create snapshot from the prepared instance
final_snapshot = await temp_instance.asnapshot(digest="ioi-evaluation-morph")
except Exception as e:
# Ensure instance is stopped if anything fails
await temp_instance.astop()
raise e
else:
final_snapshot = base_snapshots[0]
return final_snapshot
async def _get_compile_script(self):
"""Get the compile script content."""
return """#!/bin/bash
manager_files=() # Array to store manager filenames
current_dir="$(pwd)"
# Checker compilation path
checker_dir="$current_dir/checker"
checker_src="$checker_dir/checker.cpp"
if [ -e "$checker_src" ]; then
echo "Compiling checker"
checker_exe="$checker_dir/checker"
g++ -x c++ -std=gnu++17 -O2 -o "$checker_exe" "$checker_src"
chmod +x "$checker_exe"
if [ $? -ne 0 ]; then
echo "Could not compile checker" >&2
exit 1
fi
echo "Compiled checker"
else
echo "No checker found at $checker_src"
fi
# Graders path
graders_dir="$current_dir/graders"
if [ ! -e "$graders_dir" ]; then
echo "Grader folder was not found" >&2
exit 1
fi
# Find and compile manager if it exists
manager_src="$graders_dir/manager.cpp"
if [ -e "$manager_src" ]; then
echo "Compiling manager"
manager_exe="$graders_dir/manager"
g++ -x c++ -std=gnu++17 -O2 -o "$manager_exe" "$manager_src"
chmod +x "$manager_exe"
if [ $? -ne 0 ]; then
echo "Could not compile manager" >&2
exit 1
fi
manager_files+=("manager")
fi
# Process other graders
graders_list=($(ls "$graders_dir" | grep -v 'manager.cpp'))
for grader_name in "${graders_list[@]}"; do
manager_files+=("$grader_name")
done
# Extract problem name and compile necessary files
problem_name='?'
for file in "${manager_files[@]}"; do
if [[ "$file" == *.h && "$file" != "testlib.h" ]]; then
problem_name="${file%.h}"
echo "Problem name: $problem_name"
break
fi
done
files_to_compile=("graders/$problem_name.cpp")
[ -e graders/grader.cpp ] && files_to_compile+=("graders/grader.cpp")
[ -e graders/stub.cpp ] && files_to_compile+=("graders/stub.cpp")
g++ -DEVAL -std=gnu++17 -O2 -pipe -s -o graders/"$problem_name" "${files_to_compile[@]}"
if [ $? -ne 0 ]; then
echo "Failed to compile $problem_name" >&2
exit 1
fi
chmod +x graders/"$problem_name"
echo "Compiled $problem_name from ${files_to_compile[@]} successfully"
echo "Manager files: ${manager_files[@]}"
"""
async def _get_run_script(self):
"""Get the run script content."""
return """#!/usr/bin/env bash
# disable stack limit so you don't get RE with recursion
ulimit -s unlimited
# some problems have 10MB+ input/output files in their test cases and you might get RE. uncomment if needed
# ulimit -f 2097152
# Check if grader_config.json exists
if [ ! -f "graders/grader_config.json" ]; then
echo "Error: graders/grader_config.json not found" >&2
echo "Current directory contents:" >&2
find . -type f -o -type d | sed -e 's/[^-][^\/]*\// |/g' -e 's/|\([^ ]\)/|-\1/' >&2
exit 1
fi
# Read task type, code, and time limit from grader_config.json using grep and sed
TASK_TYPE=$(grep -o '"task_type":[^,}]*' graders/grader_config.json | sed 's/"task_type":\\s*"\\([^"]*\\)"/\\1/')
TASK_NAME=$(grep -o '"code":[^,}]*' graders/grader_config.json | sed 's/"code":\\s*"\\([^"]*\\)"/\\1/')
TIME_LIMIT=$(grep -o '"time_limit":[^,}]*' graders/grader_config.json | sed 's/"time_limit":\\s*\\([^,}]*\\)/\\1/')
MEMORY_LIMIT=$(grep -o '"memory_limit":[^,}]*' graders/grader_config.json | sed 's/"memory_limit":\\s*\\([^,}]*\\)/\\1/')
TASK_EXECUTABLE="graders/$TASK_NAME"
# Set memory limit in KB (convert from bytes)
MEMORY_LIMIT_KB=0
if [ -n "$MEMORY_LIMIT" ]; then
MEMORY_LIMIT_KB=$(($MEMORY_LIMIT / 1024))
# Set the memory limit for the entire script and all child processes
ulimit -v $MEMORY_LIMIT_KB
fi
# "Securely" handle the correct output file
CORRECT_OUTPUT=""
if [ -f "correct_output.txt" ]; then
# Read the content and immediately remove the file
CORRECT_OUTPUT=$(cat correct_output.txt)
rm -f correct_output.txt
fi
# Create a temporary file for solution output
SOLUTION_OUTPUT=$(mktemp)
# Global variables for process tracking
declare -a ALL_PIDS
declare -a FIFO_DIRS
# Define cleanup function - simplified assuming timeout exists
function cleanup {
# Kill all tracked processes silently
exec 2>/dev/null
for pid in "${ALL_PIDS[@]:-}"; do
kill -9 "$pid" 2>/dev/null || true
done
# Clean up FIFO directories
for dir in "${FIFO_DIRS[@]:-}"; do
[ -d "$dir" ] && rm -rf "$dir"
done
# Clean up temporary files
rm -f "$SOLUTION_OUTPUT" || true
exec 2>&2
}
# Set up signal handling
trap cleanup EXIT INT TERM
# Function to handle exit codes consistently across task types
function handle_exit_code {
local exit_code=$1
# Check for known timeout exit codes:
# - 124: standard timeout exit code
# - 137: SIGKILL (128+9), used for hard timeouts
# - 143: SIGTERM (128+15), can also be used for timeouts
if [ $exit_code -eq 124 ] || [ $exit_code -eq 137 ] || [ $exit_code -eq 143 ]; then
echo "0"
echo "Time limit exceeded (${TIME_LIMIT}s)" >&2
return 124
# All other non-zero exit codes should be treated as runtime errors
elif [ $exit_code -ne 0 ]; then
echo "0"
echo "Runtime error with exit code $exit_code" >&2
return $exit_code
fi
# Success case - return 0
return 0
}
# Function to run a command with timeout (simplified assuming timeout exists)
function run_with_timeout {
local soft_limit=$1; shift
local command_to_run="$@"
timeout --preserve-status "$soft_limit" "$@"
return $?
}
case "$TASK_TYPE" in
"Batch")
# Simple batch execution with timeout
run_with_timeout "$TIME_LIMIT" ./$TASK_EXECUTABLE < input.txt > "$SOLUTION_OUTPUT"
exit_code=$?
# Handle non-zero exit codes
handle_exit_code $exit_code
if [ $? -ne 0 ]; then
exit $?
fi
# Check the output if we have a correct output
if [ -n "$CORRECT_OUTPUT" ]; then
# Restore the correct output file
echo "$CORRECT_OUTPUT" > correct_output.txt
# Check if there's a custom checker
if [ -f "checker/checker" ]; then
# Let the checker handle everything
./checker/checker input.txt correct_output.txt "$SOLUTION_OUTPUT"
exit $?
else
# Simple diff-based checking
if diff -bq <(echo "$CORRECT_OUTPUT") "$SOLUTION_OUTPUT" >/dev/null; then
echo "1"
echo "Output is correct (diff)" >&2
else
echo "0"
echo "Output isn't correct (diff)" >&2
exit 0
fi
fi
else
# If no correct output was provided, just output the solution's output
cat "$SOLUTION_OUTPUT"
fi
;;
"Communication")
# Read Communication-specific parameters
NUM_PROCESSES=$(grep -o '"task_type_parameters_Communication_num_processes":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*\\([0-9]*\\)/\\1/' || true)
if [ -z "$NUM_PROCESSES" ]; then
NUM_PROCESSES=1
fi
USER_IO=$(grep -o '"task_type_parameters_Communication_user_io":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*"\\([^"]*\\)"/\\1/' || echo "std_io")
# Read custom manager arguments if they exist
MANAGER_CUSTOM_ARGS=""
if grep -q '"task_type_parameters_Communication_manager_args"' graders/grader_config.json; then
MANAGER_CUSTOM_ARGS=$(grep -o '"task_type_parameters_Communication_manager_args":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*"\\([^"]*\\)"/\\1/')
fi
# Create temporary directories for FIFOs
for i in $(seq 0 $((NUM_PROCESSES-1))); do
FIFO_DIRS[$i]=$(mktemp -d)
# Create FIFOs for this process
mkfifo "${FIFO_DIRS[$i]}/u${i}_to_m"
mkfifo "${FIFO_DIRS[$i]}/m_to_u${i}"
chmod 755 "${FIFO_DIRS[$i]}"
chmod 666 "${FIFO_DIRS[$i]}/u${i}_to_m" "${FIFO_DIRS[$i]}/m_to_u${i}"
done
# Prepare manager arguments
MANAGER_ARGS=""
for i in $(seq 0 $((NUM_PROCESSES-1))); do
MANAGER_ARGS="$MANAGER_ARGS ${FIFO_DIRS[$i]}/u${i}_to_m ${FIFO_DIRS[$i]}/m_to_u${i}"
done
# Add custom manager arguments if specified
if [ -n "$MANAGER_CUSTOM_ARGS" ]; then
MANAGER_ARGS="$MANAGER_ARGS $MANAGER_CUSTOM_ARGS"
fi
# Start all user processes first
for i in $(seq 0 $((NUM_PROCESSES-1))); do
if [ "$USER_IO" = "fifo_io" ]; then
# Pass FIFOs as arguments
ARGS="${FIFO_DIRS[$i]}/m_to_u${i} ${FIFO_DIRS[$i]}/u${i}_to_m"
if [ "$NUM_PROCESSES" -ne 1 ]; then
ARGS="$ARGS $i"
fi
./$TASK_EXECUTABLE $ARGS &
ALL_PIDS+=($!)
else
# Use stdin/stdout redirection
if [ "$NUM_PROCESSES" -ne 1 ]; then
./$TASK_EXECUTABLE "$i" < "${FIFO_DIRS[$i]}/m_to_u${i}" > "${FIFO_DIRS[$i]}/u${i}_to_m" 2>/dev/null &
ALL_PIDS+=($!)
else
./$TASK_EXECUTABLE < "${FIFO_DIRS[$i]}/m_to_u${i}" > "${FIFO_DIRS[$i]}/u${i}_to_m" 2>/dev/null &
ALL_PIDS+=($!)
fi
fi
done
# Run the manager with timeout using direct pipe from input.txt
run_with_timeout "$TIME_LIMIT" ./graders/manager $MANAGER_ARGS < input.txt > "$SOLUTION_OUTPUT"
exit_code=$?
# Handle non-zero exit codes
handle_exit_code $exit_code
if [ $? -ne 0 ]; then
exit $?
fi
# Check the output if we have a correct output AND there's a checker (otherwise we assume the manager handles everything)
if [ -n "$CORRECT_OUTPUT" ] && [ -f "checker/checker" ]; then
# Restore the correct output file
echo "$CORRECT_OUTPUT" > correct_output.txt
# Let the checker handle it
./checker/checker input.txt correct_output.txt "$SOLUTION_OUTPUT"
exit $?
else
# we assume the manager handles it
cat "$SOLUTION_OUTPUT"
fi
;;
*)
echo "0"
echo "Unsupported task type \"$TASK_TYPE\"" >&2
exit 1
;;
esac
"""
def get_morph_client_from_env(session=None) -> MorphCloudExecutionClient:
"""
Creates a MorphCloudExecutionClient instance using environment variables.
Environment variables:
MORPH_API_KEY: API key for MorphCloud
Args:
session: Optional aiohttp.ClientSession to use for HTTP requests
Returns:
MorphCloudExecutionClient: A configured MorphCloud execution client
"""
if not is_morph_available():
raise ImportError(
"MorphCloud is not available and required for this function. Please install MorphCloud with "
"`pip install morphcloud` and add an API key to a `.env` file."
)
load_dotenv()
api_key = os.environ.get("MORPH_API_KEY")
if not api_key:
raise ValueError("MORPH_API_KEY environment variable is required")
return MorphCloudExecutionClient(api_key=api_key)
# noqa: W293

View file

@ -14,16 +14,23 @@ class PistonError(Exception):
@lru_cache(maxsize=1)
def get_piston_client_from_env():
def get_piston_client_from_env(session=None):
piston_endpoints = os.getenv("PISTON_ENDPOINTS")
if piston_endpoints is None:
raise ValueError(
"For IOI problems Piston endpoints running our IOI package are required. Please add a list of valid Piston endpoints to a PISTON_ENDPOINTS varialbe in a `.env` file."
"For IOI/CF problems Piston endpoints running our IOI package are required. Please add a list of valid Piston endpoints to a PISTON_ENDPOINTS variable in a `.env` file."
)
piston_endpoints = piston_endpoints.split(",") if piston_endpoints != "slurm" else get_slurm_piston_endpoints()
piston_endpoints = sorted(
piston_endpoints.split(",") if piston_endpoints != "slurm" else get_slurm_piston_endpoints()
)
gpu_nb = int(os.getenv("LOCAL_RANK", 0)) # perGPU index
world = int(os.getenv("WORLD_SIZE", 1)) # total GPUs
if world > 1:
print(f"Using a subset of piston endpoints for GPU#{gpu_nb}")
piston_endpoints = piston_endpoints[gpu_nb::world]
random.shuffle(piston_endpoints)
max_requests_per_endpoint = os.getenv("PISTON_MAX_REQUESTS_PER_ENDPOINT", "1")
return PistonClient(piston_endpoints, max_requests_per_endpoint=int(max_requests_per_endpoint))
return PistonClient(piston_endpoints, session, max_requests_per_endpoint=int(max_requests_per_endpoint))
class PistonClient:
@ -57,6 +64,8 @@ class PistonClient:
):
self.max_requests_per_endpoint = max_requests_per_endpoint
self.base_endpoints = [base_endpoint] if isinstance(base_endpoint, str) else base_endpoint
if len(self.base_endpoints) == 0:
raise ValueError("No Piston endpoints provided. Please check your PISTON_ENDPOINTS environment variable.")
self.endpoint_ids = {endpoint: i for i, endpoint in enumerate(self.base_endpoints)}
self._session = session
@ -73,7 +82,7 @@ class PistonClient:
def session(self):
if self._session is None:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(sock_read=10),
timeout=aiohttp.ClientTimeout(sock_read=30),
connector=aiohttp.TCPConnector(
limit=self.max_requests_per_endpoint * len(self.base_endpoints),
ttl_dns_cache=300,
@ -112,41 +121,6 @@ class PistonClient:
async def get_supported_runtimes(self):
return await self._send_to_all("runtimes", method="get")
async def execute(self, data) -> tuple[str, str]:
"""
Requests to the IOI package return the score as a float in the stdout, as well as optional feedback/errors in stderr.
Returns a tuple of (score, feedback).
"""
response = await self._send_execute(data)
if "message" in response:
raise PistonError(response["message"])
if "compile" in response and response["compile"]["code"] != 0:
return "0", "Compilation error exit code " + str(response["compile"]["code"]) + "\n" + response["compile"][
"stderr"
]
if "run" not in response:
raise PistonError(response)
if response["run"]["code"] == 1 and "MemoryError" in response["run"]["stderr"]:
return "0", "Memory limit exceeded"
# successful result
if response["run"]["stdout"]:
return response["run"]["stdout"], response["run"]["stderr"]
if response["run"]["signal"] == "SIGKILL":
return "0", "Time limit exceeded"
# other issues
if response["run"]["code"] != 0:
raise PistonError(
f"language={response['language']}, version={response['version']}, exit code={response['run']['code']}, stderr={response['run']['stderr']}, signal={response['run']['signal']}"
)
return "0", "Unknown error"
async def _check_failed_endpoint(self, endpoint):
async with self._endpoint_failures_lock:
if endpoint in self._unhealthy_endpoints:
@ -157,14 +131,15 @@ class PistonClient:
except Exception as e:
print(f"Error checking endpoint {endpoint}, dropping it ({e})")
self._unhealthy_endpoints.add(endpoint)
if len(self._unhealthy_endpoints) >= len(self.base_endpoints):
raise PistonError("All endpoints are unhealthy. Please check your Piston workers.")
async def _send_execute(self, data):
async def send_execute(self, data, language="cms_ioi", max_retries=5):
data = data | {
"language": "cms_ioi",
"language": language,
"version": "*",
}
max_retries = 5
base_delay = 1.0
status = None
@ -182,7 +157,7 @@ class PistonClient:
res_json = await response.json(content_type=None)
if status != 200:
raise PistonError(f"Server error. status={status}")
raise PistonError(f"Server error. status={status}. {res_json}")
if res_json is None:
raise PistonError(f"Empty response. status={status}")
# piston overloaded
@ -197,7 +172,7 @@ class PistonClient:
delay = min(base_delay * (2**attempt), 10) # Exponential backoff, capped at 10 seconds
jitter = delay * 0.2 * (2 * asyncio.get_event_loop().time() % 1 - 0.5) # Add ±10% jitter
retry_delay = delay + jitter
print(f"Retrying in {retry_delay} seconds [{self.endpoint_ids[endpoint]}] {endpoint}")
print(f"Retrying in {retry_delay:.2f} seconds [{self.endpoint_ids[endpoint]}] {endpoint} - {e}")
# special case: worker died
if isinstance(e, aiohttp.ClientConnectionError) and "Connect call failed" in str(e):
@ -209,8 +184,7 @@ class PistonClient:
await asyncio.sleep(retry_delay)
else:
print(f"Giving up on retries. {e}")
raise e
await self._check_failed_endpoint(endpoint)
except Exception as e:
print(f"Propagating exception {type(e)}: {e}")
raise e

View file

@ -0,0 +1,11 @@
from itertools import islice
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
return iterable
it = iter(iterable)
while batch := list(islice(it, n)):
yield batch

65
src/open_r1/utils/data.py Normal file
View file

@ -0,0 +1,65 @@
import logging
import datasets
from datasets import DatasetDict, concatenate_datasets
from ..configs import ScriptArguments
logger = logging.getLogger(__name__)
def get_dataset(args: ScriptArguments) -> DatasetDict:
"""Load a dataset or a mixture of datasets based on the configuration.
Args:
args (ScriptArguments): Script arguments containing dataset configuration.
Returns:
DatasetDict: The loaded datasets.
"""
if args.dataset_name and not args.dataset_mixture:
logger.info(f"Loading dataset: {args.dataset_name}")
return datasets.load_dataset(args.dataset_name, args.dataset_config)
elif args.dataset_mixture:
logger.info(f"Creating dataset mixture with {len(args.dataset_mixture.datasets)} datasets")
seed = args.dataset_mixture.seed
datasets_list = []
for dataset_config in args.dataset_mixture.datasets:
logger.info(f"Loading dataset for mixture: {dataset_config.id} (config: {dataset_config.config})")
ds = datasets.load_dataset(
dataset_config.id,
dataset_config.config,
split=dataset_config.split,
)
if dataset_config.columns is not None:
ds = ds.select_columns(dataset_config.columns)
if dataset_config.weight is not None:
ds = ds.shuffle(seed=seed).select(range(int(len(ds) * dataset_config.weight)))
logger.info(
f"Subsampled dataset '{dataset_config.id}' (config: {dataset_config.config}) with weight={dataset_config.weight} to {len(ds)} examples"
)
datasets_list.append(ds)
if datasets_list:
combined_dataset = concatenate_datasets(datasets_list)
combined_dataset = combined_dataset.shuffle(seed=seed)
logger.info(f"Created dataset mixture with {len(combined_dataset)} examples")
if args.dataset_mixture.test_split_size is not None:
combined_dataset = combined_dataset.train_test_split(
test_size=args.dataset_mixture.test_split_size, seed=seed
)
logger.info(
f"Split dataset into train and test sets with test size: {args.dataset_mixture.test_split_size}"
)
return combined_dataset
else:
return DatasetDict({"train": combined_dataset})
else:
raise ValueError("No datasets were loaded from the mixture configuration")
else:
raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")

View file

@ -7,6 +7,7 @@ from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id
if TYPE_CHECKING:
from trl import GRPOConfig, SFTConfig, ModelConfig
import base64
import os
@ -24,7 +25,11 @@ VLLM_SLURM_PREFIX = [
def register_lighteval_task(
configs: Dict[str, str], eval_suite: str, task_name: str, task_list: str, num_fewshot: int = 0
configs: Dict[str, str],
eval_suite: str,
task_name: str,
task_list: str,
num_fewshot: int = 0,
):
"""Registers a LightEval task configuration.
@ -46,10 +51,10 @@ def register_lighteval_task(
LIGHTEVAL_TASKS = {}
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime25", "aime25", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "math_500", "math_500", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "aime24", "aime24", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "aime25", "aime25", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "gpqa", "gpqa:diamond", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb", "lcb:codegeneration", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb_v4", "lcb:codegeneration_v4", 0)
@ -62,7 +67,9 @@ SUPPORTED_BENCHMARKS = get_lighteval_tasks()
def run_lighteval_job(
benchmark: str, training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig"
benchmark: str,
training_args: Union["SFTConfig", "GRPOConfig"],
model_args: "ModelConfig",
) -> None:
task_list = LIGHTEVAL_TASKS[benchmark]
model_name = training_args.hub_model_id
@ -72,7 +79,7 @@ def run_lighteval_job(
if get_param_count_from_repo_id(model_name) >= 30_000_000_000:
tensor_parallel = True
else:
num_gpus = 8
num_gpus = 2 # Hack while cluster is full
tensor_parallel = False
cmd = VLLM_SLURM_PREFIX.copy()
@ -88,7 +95,10 @@ def run_lighteval_job(
f"{model_args.trust_remote_code}",
]
if training_args.system_prompt is not None:
cmd_args.append(f"--system_prompt={training_args.system_prompt}")
# encode to base64 to avoid issues with special characters
# we decode in the sbatch script
prompt_encoded = base64.b64encode(training_args.system_prompt.encode()).decode()
cmd_args.append(prompt_encoded)
cmd[-1] += " " + " ".join(cmd_args)
subprocess.run(cmd, check=True)

View file

@ -76,7 +76,8 @@ def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig):
# If the revision exists, we next check it has a README file
if training_args.hub_model_revision in revisions:
repo_files = list_repo_files(
repo_id=training_args.hub_model_id, revision=training_args.hub_model_revision
repo_id=training_args.hub_model_id,
revision=training_args.hub_model_revision,
)
if "README.md" in repo_files and training_args.overwrite_hub_revision is False:
raise ValueError(

View file

@ -21,3 +21,10 @@ _e2b_available = _is_package_available("e2b")
def is_e2b_available() -> bool:
return _e2b_available
_morph_available = _is_package_available("morphcloud")
def is_morph_available() -> bool:
return _morph_available

View file

@ -1,12 +0,0 @@
from .piston_client import get_piston_client_from_env, get_slurm_piston_endpoints
from .scoring import SubtaskResult, score_subtask
from .utils import add_includes
__all__ = [
"get_piston_client_from_env",
"get_slurm_piston_endpoints",
"score_subtask",
"add_includes",
"SubtaskResult",
]

View file

@ -1,16 +1,12 @@
from transformers import AutoTokenizer, PreTrainedTokenizer
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from trl import ModelConfig
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
from ..configs import GRPOConfig, SFTConfig
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
def get_tokenizer(
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig, auto_set_chat_template: bool = True
) -> PreTrainedTokenizer:
def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@ -20,7 +16,27 @@ def get_tokenizer(
if training_args.chat_template is not None:
tokenizer.chat_template = training_args.chat_template
elif auto_set_chat_template and tokenizer.get_chat_template() is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return tokenizer
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
"""Get the model"""
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
return model

View file

@ -0,0 +1,120 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import requests
class RoutedMorphSandbox:
"""
Client for the MorphCloud router service that mimics the API of MorphCloud's Sandbox.
This class provides a simple interface to execute code via a central MorphCloud router,
which manages sandbox creation and cleanup. It allows batch processing of multiple scripts
in a single request for improved efficiency.
Attributes:
router_url (str): The URL of the MorphCloud router service.
timeout (int): Execution timeout in seconds.
request_timeout (int): HTTP request timeout in seconds.
"""
def __init__(self, router_url: str, timeout: int = 300, request_timeout: int = 60):
"""
Initialize the routed MorphCloud sandbox client.
Args:
router_url: The URL of the MorphCloud router, including host and port.
timeout: Default execution timeout in seconds.
request_timeout: Default HTTP request timeout in seconds.
"""
self.router_url = router_url
self.timeout = timeout
self.request_timeout = request_timeout
def run_code(
self,
scripts: List[str],
languages: Optional[List[str]] = None,
timeout: Optional[int] = None,
request_timeout: Optional[int] = None,
) -> List:
"""
Execute multiple scripts using MorphCloud via the router.
Args:
scripts: List of code scripts to execute.
languages: List of programming languages for each script. If None, defaults to Python for all scripts.
timeout: Execution timeout in seconds. If None, uses the instance timeout.
request_timeout: HTTP request timeout in seconds. If None, uses the instance request_timeout.
Returns:
List of execution results with text and exception_str properties.
"""
actual_timeout = timeout if timeout is not None else self.timeout
actual_request_timeout = request_timeout if request_timeout is not None else self.request_timeout
# Default to Python for all scripts if languages is not provided
if languages is None:
languages = ["python"] * len(scripts)
payload = {
"scripts": scripts,
"languages": languages,
"timeout": actual_timeout,
"request_timeout": actual_request_timeout,
}
try:
endpoint = f"http://{self.router_url}/execute_batch"
response = requests.post(endpoint, json=payload, timeout=actual_request_timeout)
if response.status_code != 200:
error = f"Request to MorphCloud router failed with status code: {response.status_code}"
print(error)
results = []
for _ in scripts:
results.append(type("obj", (object,), {"text": None, "exception_str": error}))
return results
response_data = response.json()
results = []
for item in response_data:
# Log the response data to see what we're getting
# print(f"RoutedMorphSandbox: Got response item: {item}")
result = type(
"obj",
(object,),
{
"text": item.get("text"),
"exception_str": item.get("exception_str"),
},
)
results.append(result)
return results
except Exception as e:
error = f"Error communicating with MorphCloud router: {str(e)}"
print(error)
results = []
for _ in scripts:
results.append(type("obj", (object,), {"text": None, "exception_str": error}))
return results

View file

@ -0,0 +1,109 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import requests
from e2b_code_interpreter.models import Execution, ExecutionError, Result
class RoutedSandbox:
"""
A sandbox environment that routes code execution requests to the E2B Router.
This class is designed for batched execution of scripts, primarily for Python code.
It mimics the usage of 'Sandbox' from 'e2b_code_interpreter', but adds support for batch processing.
Attributes:
router_url (str): The URL of the E2B Router to which code execution requests are sent.
"""
def __init__(self, router_url: str):
"""
Initializes the RoutedSandbox with the specified router URL.
Args:
router_url (str): The URL of the E2B Router.
"""
self.router_url = router_url
def run_code(
self,
scripts: list[str],
languages: Optional[List[str]] = None,
timeout: Optional[int] = None,
request_timeout: Optional[int] = None,
) -> list[Execution]:
"""
Executes a batch of scripts in the sandbox environment.
Args:
scripts (list[str]): A list of code scripts to execute.
languages (list[str], optional): List of programming languages for each script. If None, defaults to Python for all scripts.
timeout (Optional[int], optional): The maximum execution time for each script in seconds. Defaults to 300 seconds.
request_timeout (Optional[int], optional): The timeout for the HTTP request in seconds. Defaults to 30 seconds.
Returns:
list[Execution]: A list of Execution objects containing the results, logs, and errors (if any) for each script.
"""
# Set default values for timeouts if not provided
if timeout is None:
timeout = 300 # Default to 5 minutes
if request_timeout is None:
request_timeout = 30 # Default to 30 seconds
# Default to Python for all scripts if languages is not provided
if languages is None:
languages = ["python"] * len(scripts)
# Prepare the payload for the HTTP POST request
payload = {
"scripts": scripts,
"languages": languages,
"timeout": timeout,
"request_timeout": request_timeout,
}
# Send the request to the E2B Router
response = requests.post(f"http://{self.router_url}/execute_batch", json=payload)
if not response.ok:
print(f"Request failed with status code: {response.status_code}")
# Parse the response and construct Execution objects
results = response.json()
output = []
for result in results:
if result["execution"] is None:
# If execution is None, create an empty Execution object
# This can happen when a script times out or fails to execute
execution = Execution()
else:
execution = Execution(
results=[Result(**r) for r in result["execution"]["results"]],
logs=result["execution"]["logs"],
error=(ExecutionError(**result["execution"]["error"]) if result["execution"]["error"] else None),
execution_count=result["execution"]["execution_count"],
)
output.append(execution)
return output
if __name__ == "__main__":
# for local testing launch an E2B router with: python scripts/e2b_router.py
sbx = RoutedSandbox(router_url="0.0.0.0:8000")
codes = ["print('hello world')", "print('hello world)"]
executions = sbx.run_code(codes) # Execute Python inside the sandbox
print(executions)

View file

@ -9,3 +9,5 @@ def init_wandb_training(training_args):
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
if training_args.wandb_project is not None:
os.environ["WANDB_PROJECT"] = training_args.wandb_project
if training_args.wandb_run_group is not None:
os.environ["WANDB_RUN_GROUP"] = training_args.wandb_run_group

View file

@ -17,13 +17,16 @@ import unittest
from datasets import load_dataset
from e2b_code_interpreter.models import Execution, ExecutionError
from open_r1.rewards import code_reward, ioi_code_reward
from open_r1.utils.routed_morph import RoutedMorphSandbox
from open_r1.utils.routed_sandbox import RoutedSandbox
class TestCodeRewards(unittest.TestCase):
def test_python_code_reward(self):
# requires E2B, see the README.md file
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested")
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
NUM_SAMPLES = 20
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
@ -32,6 +35,42 @@ class TestCodeRewards(unittest.TestCase):
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_e2b_router(self):
# run router locally: python scripts/e2b_router.py
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
NUM_SAMPLES = 128
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
reward_kwargs = {"verification_info": [sample["verification_info"] for sample in samples]}
rewards = code_reward(test_completions, e2b_router_url="0.0.0.0:8000", **reward_kwargs)
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_e2b_router_parallel(self):
# run router locally: python scripts/e2b_router.py
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
BATCH_SIZE = 32
NUM_SAMPLES = 256
def batch_code_reward(examples):
test_completions = [[{"content": solution}] for solution in examples["gold_standard_solution"]]
reward_kwargs = {
"verification_info": [verification_info for verification_info in examples["verification_info"]]
}
rewards = code_reward(test_completions, e2b_router_url="0.0.0.0:8000", **reward_kwargs)
assert rewards == [1.0] * BATCH_SIZE
return examples
code_dataset = code_dataset["train"].select(range(NUM_SAMPLES))
code_dataset = code_dataset.map(
batch_code_reward,
batched=True,
batch_size=BATCH_SIZE,
num_proc=4,
load_from_cache_file=False,
)
def test_ioi_code_reward(self):
# This slow test case requires spinning up a bunch (I tested with ~64) of piston workers, see docs here
# slurm/piston/README.md
@ -45,6 +84,136 @@ class TestCodeRewards(unittest.TestCase):
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_e2b_router_run_code_success(self):
# run router locally: python scripts/e2b_router.py
routed_sandbox = RoutedSandbox(router_url="localhost:8000")
scripts = [
"print('hello from integration test')",
"result = 2 + 2\nprint(result)",
]
results = routed_sandbox.run_code(scripts)
assert len(results) == 2
for result in results:
assert isinstance(result, Execution)
# assert result.exit_code == 0
assert result.error is None
assert "hello" in result.logs["stdout"][0] or "4" in result.logs["stdout"][0]
def test_e2b_router_run_code_with_error(self):
# run router locally: python scripts/e2b_router.py
routed_sandbox = RoutedSandbox(router_url="localhost:8000")
scripts = ["print('this is fine')", "print('unterminated string"]
results = routed_sandbox.run_code(scripts)
assert len(results) == 2
# First one should be okay
# assert results[0].exit_code == 0 # Execution object has no attribute 'exit_code'
assert results[0].error is None
assert "this is fine" in results[0].logs["stdout"][0]
# Second one should have a syntax error
# assert results[1].exit_code != 0 # Execution object has no attribute 'exit_code'
assert results[1].error is not None
assert isinstance(results[1].error, ExecutionError)
assert "SyntaxError" in results[1].error.name
def test_python_code_reward_morph(self):
# requires MorphCloud, see the README.md file
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
NUM_SAMPLES = 20
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
reward_kwargs = {
"verification_info": [sample["verification_info"] for sample in samples],
"provider_type": "morph",
}
rewards = code_reward(test_completions, **reward_kwargs)
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_morph_router(self):
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
NUM_SAMPLES = 32
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
reward_kwargs = {
"verification_info": [sample["verification_info"] for sample in samples],
"provider_type": "morph",
"morph_router_url": "0.0.0.0:8001",
}
rewards = code_reward(test_completions, **reward_kwargs)
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_morph_router_parallel(self):
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
BATCH_SIZE = 32
NUM_SAMPLES = 256
def batch_code_reward(examples):
test_completions = [[{"content": solution}] for solution in examples["gold_standard_solution"]]
reward_kwargs = {
"verification_info": [verification_info for verification_info in examples["verification_info"]],
"provider_type": "morph",
"morph_router_url": "0.0.0.0:8001",
}
rewards = code_reward(test_completions, **reward_kwargs)
assert rewards == [1.0] * BATCH_SIZE
return examples
code_dataset = code_dataset["train"].select(range(NUM_SAMPLES))
code_dataset = code_dataset.map(
batch_code_reward,
batched=True,
batch_size=BATCH_SIZE,
num_proc=4,
load_from_cache_file=False,
)
def test_morph_router_run_code_success(self):
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
routed_sandbox = RoutedMorphSandbox(router_url="localhost:8001")
scripts = [
"print('hello from morph integration test')",
"result = 2 + 2\nprint(result)",
]
results = routed_sandbox.run_code(scripts)
assert len(results) == 2
for result in results:
assert result.exception_str is None
assert "hello" in result.text or "4" in result.text
def test_morph_router_run_code_with_error(self):
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
routed_sandbox = RoutedMorphSandbox(router_url="localhost:8001")
scripts = ["print('this is fine with morph')", "print('unterminated string"]
results = routed_sandbox.run_code(scripts)
assert len(results) == 2
# First one should be okay
assert results[0].exception_str is None
assert "this is fine with morph" in results[0].text
# Second one should have a syntax error
assert "SyntaxError" in results[1].text
if __name__ == "__main__":
unittest.main()

View file

@ -15,6 +15,7 @@
import unittest
from dotenv import load_dotenv
from open_r1.configs import GRPOScriptArguments
from open_r1.rewards import (
accuracy_reward,
@ -23,12 +24,16 @@ from open_r1.rewards import (
get_cosine_scaled_reward,
get_repetition_penalty_reward,
get_reward_funcs,
get_soft_overlong_punishment,
len_reward,
reasoning_steps_reward,
tag_count_reward,
)
load_dotenv()
class TestGetRewardFuncs(unittest.TestCase):
def test_get_reward_funcs(self):
"""Test get_reward_funcs with various reward functions."""
@ -82,7 +87,13 @@ class TestRewards(unittest.TestCase):
"""Test accuracy_reward with an incorrect answer."""
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
solution = [r"\frac{63}{400}"]
rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 0.0)
def test_accuracy_reward_wrong_answer_no_latex(self):
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
completion = [[{"content": r"\boxed{3}"}]]
solution = ["6"]
rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 0.0)
@ -127,7 +138,10 @@ class TestRewards(unittest.TestCase):
def test_multiple_completions(self):
"""Test handling multiple completions at once."""
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}],
[{"content": r"\boxed{\frac{64}{400}}"}],
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = accuracy_reward(completions, solutions)
@ -148,11 +162,31 @@ class TestRewards(unittest.TestCase):
test_cases = [
# Correct answers with different lengths
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 20, 0.943), # Short correct answer
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 80, 0.547), # Long correct answer
(
r"\boxed{\frac{63}{400}}",
r"\frac{63}{400}",
20,
0.943,
), # Short correct answer
(
r"\boxed{\frac{63}{400}}",
r"\frac{63}{400}",
80,
0.547,
), # Long correct answer
# Wrong answers with different lengths
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 20, -0.942), # Short wrong answer
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 80, -0.547), # Long wrong answer
(
r"\boxed{\frac{64}{400}}",
r"\frac{63}{400}",
20,
-0.942,
), # Short wrong answer
(
r"\boxed{\frac{64}{400}}",
r"\frac{63}{400}",
80,
-0.547,
), # Long wrong answer
]
for content, solution, content_len, expected_reward in test_cases:
@ -172,7 +206,10 @@ class TestRewards(unittest.TestCase):
def test_same_length_responses(self):
"""Test len_reward when all responses have the same length."""
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}],
[{"content": r"\boxed{\frac{64}{400}}"}],
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = len_reward(completions, solutions)
@ -232,7 +269,10 @@ class TestRewards(unittest.TestCase):
def test_unparseable_solution(self):
"""Test len_reward with unparseable solution."""
completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]]
completions = [
[{"content": r"\boxed{answer}"}],
[{"content": r"\boxed{answer} " + "x" * 10}],
]
solutions = ["unparseable_latex", "unparseable_latex"]
rewards = len_reward(completions, solutions)
@ -407,6 +447,40 @@ class TestRepetitionPenaltyReward(unittest.TestCase):
rewards = tag_count_reward(completion)
self.assertEqual(rewards[0], 0.0)
def test_full_repetition_with_language(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="en")
completions = [[{"content": "that that that that that"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.75])
# begin test for zh language
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="zh")
completions = [[{"content": "这个这个这个这个这个"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.75])
def test_soft_overlong_punishment_short_completion(self):
"""Test soft overlong punishment reward function with a short completion."""
# length 50, with max=100 and soft cache=20, reward should be 0.
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 50] # 50 <= 80
rewards = reward_fn(completion_ids=completion_ids)
self.assertEqual(rewards, [0])
def test_soft_overlong_punishment_long_completion(self):
"""Test soft overlong punishment reward function with a longer than max completion."""
# 110 > 100, reward should be -1.
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 110]
rewards = reward_fn(completion_ids)
self.assertEqual(rewards, [-1])
def test_soft_overlong_punishment_intermediate_completion(self):
"""Test soft overlong punishment reward function for intermediate length completion."""
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 90] # 90 is between 80 and 100
rewards = reward_fn(completion_ids)
self.assertAlmostEqual(rewards[0], -0.5, places=4)
class TestCodeFormat(unittest.TestCase):
def test_correct_python_format(self):

129
tests/utils/test_data.py Normal file
View file

@ -0,0 +1,129 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from dataclasses import asdict
from datasets import DatasetDict, load_dataset
from open_r1.configs import DatasetConfig, DatasetMixtureConfig, ScriptArguments
from open_r1.utils.data import get_dataset
class TestGetDataset(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dataset_name = "trl-internal-testing/zen"
cls.dataset_config = "conversational_preference"
cls.ref_dataset = load_dataset(cls.dataset_name, cls.dataset_config)
def test_dataset_and_config_name(self):
args = ScriptArguments(dataset_name=self.dataset_name, dataset_config=self.dataset_config)
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertEqual(len(dataset["train"]), len(self.ref_dataset["train"]))
def test_unweighted_mixture(self):
"""Mix train and test splits of the same dataset."""
dataset_configs = [
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="train", columns=None, weight=None),
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="test", columns=None, weight=None),
]
dataset_mixture = DatasetMixtureConfig(
datasets=dataset_configs,
)
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertEqual(len(dataset["train"]), len(self.ref_dataset["train"]) + len(self.ref_dataset["test"]))
def test_weighted_mixture(self):
"""Test loading a dataset mixture with weights."""
dataset_configs = [
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="train", columns=None, weight=0.25),
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="test", columns=None, weight=0.5),
]
dataset_mixture = DatasetMixtureConfig(
datasets=dataset_configs,
)
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertEqual(
len(dataset["train"]), len(self.ref_dataset["train"]) // 4 + len(self.ref_dataset["test"]) // 2
)
def test_mixture_and_test_split(self):
"""Test loading a dataset mixture with test split."""
dataset_configs = [
DatasetConfig(
id=self.dataset_name, config=self.dataset_config, split="train[:10]", columns=None, weight=None
),
]
dataset_mixture = DatasetMixtureConfig(datasets=dataset_configs, test_split_size=0.2)
args = ScriptArguments(dataset_name=None, dataset_mixture=asdict(dataset_mixture))
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertIn("test", dataset)
self.assertEqual(len(dataset["train"]), 8)
self.assertEqual(len(dataset["test"]), 2)
def test_mixture_column_selection(self):
"""Test loading a dataset mixture with column selection."""
dataset_configs = [
DatasetConfig(
id=self.dataset_name,
config=self.dataset_config,
split="train",
columns=["prompt", "chosen"],
weight=None,
),
]
dataset_mixture = DatasetMixtureConfig(
datasets=dataset_configs,
)
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertIn("prompt", dataset["train"].column_names)
self.assertIn("chosen", dataset["train"].column_names)
def test_mixture_with_mismatched_columns(self):
dataset_configs = [
DatasetConfig(
id=self.dataset_name, config=self.dataset_config, split="train", columns=["prompt"], weight=None
),
DatasetConfig(
id=self.dataset_name, config=self.dataset_config, split="train", columns=["chosen"], weight=None
),
]
dataset_mixture = DatasetMixtureConfig(
datasets=dataset_configs,
)
with self.assertRaises(ValueError) as context:
_ = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
self.assertIn("Column names must be consistent", str(context.exception))
def test_no_dataset_name_or_mixture(self):
with self.assertRaises(ValueError) as context:
_ = ScriptArguments(dataset_name=None, dataset_mixture=None)
self.assertIn("Either `dataset_name` or `dataset_mixture` must be provided", str(context.exception))
if __name__ == "__main__":
unittest.main()