mirror of
https://github.com/huggingface/open-r1.git
synced 2026-06-24 01:54:06 +00:00
Compare commits
1 commit
main
...
data-agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0dfe13e563 |
4 changed files with 107 additions and 1 deletions
14
recipes/DataAgent-7B/README.md
Normal file
14
recipes/DataAgent-7B/README.md
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# How to train the DataAgent-7B model
|
||||
|
||||
|
||||
For the Qwen model
|
||||
```bash
|
||||
sbatch --job-name=train-data-agent-qwen --nodes=1 slurm/train.slurm --model DataAgent-Qwen-7B --task sft --config v00.00 --accelerator zero3
|
||||
```
|
||||
|
||||
For the Llama model
|
||||
```bash
|
||||
sbatch --job-name=train-data-agent-llama --nodes=1 slurm/train.slurm --model DataAgent-Llama-8B --task sft --config v00.00 --accelerator zero3
|
||||
```
|
||||
|
||||
|
||||
45
recipes/DataAgent-7B/sft/config_v00.00.yaml
Normal file
45
recipes/DataAgent-7B/sft/config_v00.00.yaml
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Config for 1 node of 8 H100s with DeepSpeed ZeRO-3
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: data-agents/jupyter-tulu-interleaved
|
||||
dataset_num_proc: 48
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'no'
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: data-agents/DataAgent-Qwen-7B
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
packing: false
|
||||
max_grad_norm: 0.2
|
||||
max_length: 32768
|
||||
max_steps: -1
|
||||
num_train_epochs: 10
|
||||
output_dir: data/DataAgent-Qwen-7B
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 1
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: epoch
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
use_liger_kernel: true
|
||||
warmup_ratio: 0.03
|
||||
45
recipes/DataAgent-Llama-8B/sft/config_v00.00.yaml
Normal file
45
recipes/DataAgent-Llama-8B/sft/config_v00.00.yaml
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Config for 1 node of 8 H100s with DeepSpeed ZeRO-3
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: data-agents/jupyter-tulu-interleaved
|
||||
dataset_num_proc: 48
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'no'
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: data-agents/DataAgent-Llama-8B
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
packing: false
|
||||
max_grad_norm: 0.2
|
||||
max_length: 32768
|
||||
max_steps: -1
|
||||
num_train_epochs: 10
|
||||
output_dir: data/DataAgent-Llama-8B
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 1
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: epoch
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
use_liger_kernel: true
|
||||
warmup_ratio: 0.03
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=open_r1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=5:00:00
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --partition=hopper-prod # Adjust this for your cluster
|
||||
#SBATCH --output=./logs/%x-%j.out
|
||||
#SBATCH --error=./logs/%x-%j.err
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue