Compare commits

..

68 commits

Author SHA1 Message Date
aymeric-roucher
933ea920a7 Merge branch 'agent-traces' of github.com:huggingface/open-r1 into agent-traces 2025-06-26 07:33:50 +00:00
aymeric-roucher
5fa7e51a5e Deactivate multinodes 2025-06-26 07:33:39 +00:00
aymeric@huggingface.co
a452f2fff6 Impreove data collection script 2025-06-25 09:14:13 +00:00
aymeric-roucher
3c3e954578 Start adapting script for VLM training 2025-06-24 12:24:27 +00:00
aymeric-roucher
c4d41268e0 Working SFT for text model 2025-06-24 09:43:50 +00:00
aymeric@huggingface.co
0ee52fc6f2 Fix env 2025-06-23 20:09:02 +00:00
aymeric@huggingface.co
eb39096a66 Try edit 2025-06-23 20:02:31 +00:00
aymeric@huggingface.co
81c64ac201 Change weka path 2025-06-23 19:47:20 +00:00
aymeric@huggingface.co
3b77977ce2 Revert to new shitty script 2025-06-23 19:25:15 +00:00
aymeric@huggingface.co
b7a700e119 Revert training script to the good old time when it worked 2025-06-23 19:20:35 +00:00
aymeric@huggingface.co
7cb592c062 Remove parsing 2025-06-23 19:17:38 +00:00
aymeric@huggingface.co
3aee6efd1e Modify train slurm 2025-06-23 19:15:28 +00:00
aymeric@huggingface.co
fbd987c3ad Remove weka 2025-06-23 19:06:16 +00:00
aymeric@huggingface.co
a675552585 Fix env variables 2025-06-23 19:01:26 +00:00
aymeric@huggingface.co
984d63120b Add Readme for agents 2025-06-23 18:16:41 +00:00
aymeric@huggingface.co
80f7ce833d Improve collection script 2025-06-23 18:00:04 +00:00
Aymeric
a9b541195a Add aguvis download script 2025-06-23 11:06:05 +02:00
Aymeric
a66a5e69a4 Merge branch 'main' into agent-traces 2025-06-23 10:34:42 +02:00
Aksel
9347590c47 adding qwen 3b training setup 2025-04-15 12:55:39 +00:00
Aymeric
60472f6613 Increase epochs 2025-04-09 13:51:08 +02:00
Aymeric
08a449c8f5 Update dataset name 2025-04-09 13:43:27 +02:00
Aymeric
2030e166f3 Increase epochs 2025-04-08 21:10:56 +02:00
Aymeric
2043be9ea0 Change job name 2025-04-08 21:10:44 +02:00
Aymeric
1eaf1d15f5 Move script to proper file 2025-04-08 17:23:44 +02:00
Aymeric
2a08444b20 Switch to new venv 2025-04-08 11:50:04 +02:00
Aymeric
cae3c7c5aa Update train slurm 2025-04-04 17:52:26 +02:00
Aymeric
d28d07b63d Remove deepspeed config 2025-04-04 16:11:50 +02:00
Aymeric
8a7951c0bc Revert to zero3 config 2025-04-03 22:47:46 +02:00
Aymeric
5647c262f8 Add distributed type 2025-04-03 22:42:56 +02:00
Aymeric
de2b792dba Point to proper config file 2025-04-03 22:20:09 +02:00
Aymeric
49083cc87c Intervert sft training configs 2025-04-03 22:14:43 +02:00
Aymeric
2ddf70e6a8 Change job name 2025-04-03 21:59:12 +02:00
Aymeric
b7522e3925 Add training scripts for agents 2025-04-03 21:51:54 +02:00
Aymeric
38efcfcbd5 Working trace generation with auto verification by running test cases 2025-04-03 19:25:51 +02:00
Aymeric
2b1bc05ebc Merge branch 'main' into agent-traces 2025-04-03 15:51:00 +02:00
Aymeric
6961c3650c Remove some dependencies to work on mac 2025-04-03 15:26:41 +02:00
Aymeric
4a20ba4bfd Try Qwen Coder 32B 2025-04-02 21:35:33 +02:00
Aymeric
4c2fce6bbf Also store final outputs 2025-03-12 20:01:18 -07:00
Aymeric
ddc1cddc25 Improve explanations in prompt 2025-02-28 15:55:08 +01:00
Aymeric
a07cd54ab5 Prompt more around testing the function 2025-02-28 15:33:10 +01:00
Aymeric
cf52433cdc Flatten messages 2025-02-28 15:24:38 +01:00
Aymeric
2876d52491 Adjust 2025-02-28 15:11:25 +01:00
Aymeric
8e70ca4f0d Update timeouts 2025-02-28 15:02:27 +01:00
Aymeric
d87e3f3fa3 Running with gpt-4o 2025-02-28 14:55:13 +01:00
Aymeric
83a679fc68 Test 2025-02-28 14:46:58 +01:00
Aymeric
e42b1cd606 Add dummy completion 2025-02-28 14:42:05 +01:00
Aymeric
d8cb19b616 Fix message roles an add token counting 2025-02-28 14:31:59 +01:00
Aymeric
a97eb27683 Add token counting 2025-02-28 14:19:48 +01:00
Aymeric
2a1ff761f1 Reduce context length 2025-02-28 14:12:55 +01:00
Aymeric
9a2d16f62a Even more detailed request error logging 2025-02-28 14:06:41 +01:00
Aymeric
cb2a2c2db7 More detailed error logging 2025-02-28 14:05:42 +01:00
Aymeric
f78b8651e3 Small adapts to script 2025-02-28 13:58:06 +01:00
Aymeric
b738e58ecf Make synchronous 2025-02-28 13:44:19 +01:00
Aymeric
d2588cd290 Add port 2025-02-27 14:49:12 +01:00
Aymeric
dd15ad8b5c Add stop sequences 2025-02-26 18:04:42 +01:00
Aymeric
0cd0999b9c Try fixing async func 2025-02-26 17:59:04 +01:00
Aymeric
b47a4be6f4 Add await 2025-02-26 17:53:06 +01:00
Aymeric
e35800c5a1 Fix slurm script 2025-02-26 17:43:16 +01:00
Aymeric
cffa36268b Add conda init 2025-02-26 17:34:45 +01:00
Aymeric
cf13c2b63e Log 2025-02-26 17:33:52 +01:00
Aymeric
6cffffe504 128 concurrent 2025-02-26 17:26:52 +01:00
Aymeric
0af9e75346 Use local model 2025-02-26 17:23:52 +01:00
Aymeric
143fcfa3da Add conda activation 2025-02-26 17:23:10 +01:00
Aymeric
a00f0ee768 Update sbatch params 2025-02-26 16:05:30 +01:00
Aymeric
ad948c22a5 Increase concurrent requests 2025-02-26 16:02:53 +01:00
Aymeric
69b26518fc Update api addr 2025-02-26 15:30:13 +01:00
Aymeric
6c231d2130 Working local version with o1 2025-02-25 17:27:45 +01:00
Aymeric
352008b017 Start agent traces 2025-02-24 15:06:07 +01:00
18 changed files with 1144 additions and 49 deletions

View file

@ -16,9 +16,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
uses: actions/checkout@v4
- name: Setup Python environment
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
uses: actions/setup-python@v5
with:
python-version: 3.10.10
- name: Install dependencies

View file

@ -299,7 +299,6 @@ Make sure your dataset contains a `verification_info` column with the following
}
],
}
```
For example, to train a smol model on Python problems, start the vLLM server:
@ -799,7 +798,7 @@ If you find this project is useful in your own work, please consider citing as f
@misc{openr1,
title = {Open R1: A fully open reproduction of DeepSeek-R1},
url = {https://github.com/huggingface/open-r1},
author = {{Hugging Face}},
author = {Hugging Face},
month = {January},
year = {2025}
}

18
README_AGENTS.md Normal file
View file

@ -0,0 +1,18 @@
Launch:
```bash
sbatch --nodes=1 slurm/train.slurm --model SmolLM2-1.7B-Instruct --task sft --config agent --accelerator zero3
```
Refers to the config recipes/SmolLM2-1.7B-Instruct/sft/config_agent.yaml
zero3 is one of the accelerate configs in recipes/accelerate_configs
Launch VLM training:
```bash
sbatch --nodes=1 slurm/train.slurm --model Qwen2.5-VL-3B-Instruct --task sft --config agent --accelerator zero3
```
Simple mode
```bash
sbatch --nodes=1 slurm/train.slurm --model Qwen2.5-VL-3B-Instruct --task sft --config agent --accelerator ddp
```

View file

View file

@ -0,0 +1,46 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
dataset_num_proc: 48
#SFT hyperparam
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
# SFT trainer config
max_steps: -1
num_train_epochs: 3
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-Qwen-7B-SFT
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
output_dir: data/OpenR1-Qwen-7B-SFT
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -0,0 +1,46 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: Qwen/Qwen2.5-3B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: smolagents/training-traces
dataset_num_proc: 48
#SFT hyperparam
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 4.0e-05
gradient_accumulation_steps: 1
per_device_eval_batch_size: 4
per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS.
# SFT trainer config
max_steps: -1
num_train_epochs: 2
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: oR1-Qwen-3B-Agentic-e2-lr4e-b2
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
output_dir: data/oR1-Qwen-3B-Agentic-e2-lr4e-b2
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -0,0 +1,60 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
vision_model: true
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: smolagents/aguvis-stage-2
dataset_num_proc: 48
#SFT hyperparam
max_length: 32768
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
single_gpu: true
# SFT trainer config
max_steps: -1
num_train_epochs: 6
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: smolagents/Qwen2.5-VL-3B-Instruct-Agentic
hub_strategy: every_save
push_to_hub: false
log_level: info
logging_steps: 5
logging_strategy: steps
output_dir: data/smolagents-Qwen2.5-VL-3B-Instruct-Agentic
overwrite_output_dir: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42
dataset_mixture:
datasets: # List of datasets to include in the mixture
- id: smolagents/aguvis-stage-2 # Hub dataset ID
config: mind2web # Name of the dataset config
split: train # Split to use from the dataset
columns: # Columns to keep
- images
- texts
weight: 1. # Fraction of dataset to use
seed: 42 # Seed for shuffling the combined dataset
test_split_size: 0.1

View file

@ -0,0 +1,45 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: smolagents/codeagent-traces
dataset_num_proc: 48
#SFT hyperparam
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
weight_decay: 0.0001
optim: adamw_torch
lr_scheduler_type: linear
warmup_ratio: 0.1
learning_rate: 5.0e-05
gradient_accumulation_steps: 2
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
# SFT trainer config
max_steps: -1
num_train_epochs: 1
bf16: true
do_eval: false
eval_strategy: 'no'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: OpenR1-SmolLM2-1.7B-Instruct-Agentic
hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
output_dir: data/OpenR1-SmolLM2-1.7B-Instruct-Agentic
overwrite_output_dir: true
push_to_hub: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 500
save_total_limit: 1
seed: 42

View file

@ -0,0 +1,330 @@
import argparse
import hashlib
import inspect
import json
import os
import random
import time
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from typing import Set, Any, List
from pathlib import Path
import traceback
from datasets import load_dataset
from tqdm import tqdm
import requests
import requests.adapters
from transformers import AutoTokenizer
from smolagents import CodeAgent, Tool, HfApiModel
from open_r1.rewards import run_tests
from dotenv import load_dotenv
load_dotenv(override=True)
assert os.getenv("HF_TOKEN") is not None
# from huggingface_hub import login
# print("LOGIN:\n", login(token=os.getenv("HF_TOKEN")))
file_lock = Lock()
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")
print("Launching generation")
def test_cases_on_function(function: Any, test_cases: List[dict]) -> str:
source_code = f"```python\n{function.__source__}\n{function.__name__}()```"
test_completions = [[{"content": source_code}]]
test_kwargs = {"verification_info": [{"language": "python", "test_cases": test_cases}]}
run_tests(test_completions, e2b_router_url="0.0.0.0:8000", test_mode=True, **test_kwargs)
class ModifiedFinalAnswerTool(Tool):
name = "final_answer"
description = "Tests a function: if correct, returns it as the final answer; Else returns the test case that errored for solving."
inputs = {'answer_function': {'type': 'any', 'description': 'The final function that solves the problem'}}
output_type = "string"
def __init__(self, test_cases):
self.is_initialized = False
self.test_cases = test_cases
super().__init__()
def forward(self, answer_function: Any) -> str:
test_cases_on_function(answer_function, self.test_cases)
return answer_function.__source__
class ChatMessage:
def __init__(self, content):
self.content = content
def generate_completion_from_messages(session, messages, args, stop_sequences) -> str:
retry_budget = 10
while retry_budget > 0:
try:
formatted_chat = tokenizer.apply_chat_template(messages, tokenize=False)
print("Input token count:", len(tokenizer.encode(formatted_chat)))
# Add a small random delay to prevent overwhelming the API
time.sleep(random.uniform(0.0, 0.1))
response = session.post(
f"http://{args.api_addr}/v1/chat/completions",
json={
"model": "default",
"messages": messages,
"max_tokens": args.max_tokens,
"temperature": args.temperature,
"top_p": args.top_p,
"stop": stop_sequences,
},
headers={"Authorization": "Bearer EMPTY"},
timeout=2*60*60
)
# Check status code and log error content if needed
if response.status_code >= 400:
print(f"HTTP Error {response.status_code}: {response.reason}")
print(f"Response content: {response.text}")
traceback.print_exc()
retry_budget -= 1
time.sleep(20)
continue
# Parse JSON response
try:
output = response.json()["choices"][0]["message"]["content"]
return output
except ValueError as e:
print(f"JSON parsing error: {e}")
print(f"Response content: {response.text}")
traceback.print_exc()
retry_budget -= 1
time.sleep(20)
continue
except requests.exceptions.RequestException as e:
print(f"API request error (will retry): {e}")
traceback.print_exc()
retry_budget -= 1
time.sleep(20)
raise Exception("Failed to get a valid response after multiple retries")
def get_agent_run(session, task, test_cases, args):
# def model(messages, stop_sequences = None):
# cleaned_messages = get_clean_message_list(messages, {"system": "user", "tool-call": "assistant", "tool-response": "user"}, flatten_messages_as_text=True)
# result = generate_completion_from_messages(session, cleaned_messages, args, stop_sequences)
# return ChatMessage(content=result)
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", provider="fireworks-ai", token=os.getenv("HF_TOKEN"))
agent = CodeAgent(
model=model,
tools=[ModifiedFinalAnswerTool(test_cases)],
additional_authorized_imports=["numpy", "math"],
max_steps=10,
verbosity_level=2
)
try:
output = agent.run(task, additional_args={"test_cases": test_cases})
return output, agent.write_memory_to_messages()
except Exception as e:
print(f"Error when generating agentic trace: {e}")
return None
def process_example(example, session, args, output_file, pbar=None):
prompt = f"""Here is a task to solve using a function:
{example[args.prompt_column]}
Now write a function that solves the problem, then you can at once test and return it by using the tool final_answer(your_function).
- The function should take the inputs described in the task above: use the input() function to get them.
- As such your function will not take any arguments. IT SHOULD BE ONE SINGLE MONOLITHIC FUNCTION, no helper functions, all imports and variable defs should be inside the function.
- And your function should give its output to stdout via print(). Returning anything is useless.
- ALSO, DO NOT TRY TO CRAFT CUSTOM TEST FUNCTIONS or do not run your function: just test it using final_answer.
- If you get this error: 'Forbidden function evaluation: 'input' is not among the explicitly allowed tools', it just means that you've tried to run your function: don't do that, just return it using final_answer
"""
try:
agent_outputs, agent_memories = [], []
for _ in range(args.num_generations):
agent_output, agent_memory = get_agent_run(session, prompt, example["test_cases"], args)
agent_outputs.append(agent_output)
agent_memories.append(agent_memory)
if any(agent_output is None for agent_output in agent_outputs):
print("Error processing example")
if pbar:
pbar.update(1)
return None
finish_reasons = []
api_metadata = []
for agent_run in agent_output:
finish_reasons.append(None)
api_metadata.append(None)
# Convert agent_run to a serializable format
serializable_generations = []
for generation in agent_memories:
if generation is not None:
# Convert to a simple list of dictionaries if it's not already
if isinstance(generation, list):
serializable_generations.append([
{k: v for k, v in msg.items() if isinstance(v, (str, int, float, bool, type(None), list, dict))}
for msg in generation if isinstance(msg, dict)
])
else:
# Handle other formats or provide a placeholder
serializable_generations.append(str(generation))
else:
serializable_generations.append(None)
# Combine original dataset fields with generations
result = {
**example, # Preserve all original dataset fields
"generations": serializable_generations,
"final_outputs": agent_outputs,
# "finish_reasons": finish_reasons,
# "api_metadata": api_metadata,
}
# Write to file with lock
with file_lock:
with open(output_file, mode="a") as f:
try:
f.write(json.dumps(result) + "\n")
f.flush()
except TypeError as e:
print(f"JSON serialization error: {e}")
# Fallback: store with minimal information
fallback_result = {
**{k: v for k, v in example.items() if isinstance(v, (str, int, float, bool, type(None), list, dict))},
"error": f"Failed to serialize full result: {e}"
}
f.write(json.dumps(fallback_result) + "\n")
f.flush()
if pbar:
pbar.update(1)
return result
except Exception as e:
print(f"Error processing example: {e}")
if pbar:
pbar.update(1)
return None
def load_processed_uuids(output_file, uuid_column):
processed_uuids = set()
if os.path.exists(output_file):
with open(output_file, mode="r") as f:
for line in f:
try:
data = json.loads(line)
processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())
except json.JSONDecodeError:
continue
return processed_uuids
def process_example_wrapper(args_tuple):
example, session, args, output_file, pbar = args_tuple
return process_example(example, session, args, output_file, pbar)
def main():
test_function = ModifiedFinalAnswerTool([{"input": "1 2 3", "output": 5}])
def add_numbers():
numbers = input()
print(sum([int(number) for number in numbers.split()]))
from textwrap import dedent
add_numbers.__source__ = dedent(inspect.getsource(add_numbers))
# print(test_function(
# add_numbers,
# ))
# quit()
parser = argparse.ArgumentParser()
# parser.add_argument("--dataset-name", type=str, required=True)
parser.add_argument("--output-file", type=str, required=True)
parser.add_argument("--prompt-column", type=str, required=True)
parser.add_argument("--uuid-column", type=str, required=True)
parser.add_argument("--api-addr", type=str, default="localhost:39876")
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--temperature", type=float, default=0.6)
parser.add_argument("--top-p", type=float, default=0.95)
parser.add_argument("--max-tokens", type=int, default=8096)
parser.add_argument("--max-concurrent", type=int, default=1000)
args = parser.parse_args()
dataset = load_dataset(
"open-r1/codeforces-test-cases",
split="train",
token=os.getenv("HF_TOKEN")
).shuffle()
dataset = dataset.filter(lambda x: x["full_test_set"])
processed_uuids = load_processed_uuids(args.output_file, args.uuid_column)
if processed_uuids:
print(f"Found {len(processed_uuids)} already processed examples, resuming from there...")
# Ensure the output directory exists
output_path = Path(args.output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Create the file if it doesn't exist
if not output_path.exists():
with open(args.output_file, mode="w") as f:
f.write("")
# Create a session that will be shared among threads
session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
pool_connections=args.max_concurrent,
pool_maxsize=args.max_concurrent,
max_retries=3
)
session.mount('http://', adapter)
session.mount('https://', adapter)
# Filter out already processed examples
examples_to_process = []
for example in dataset:
uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()
if uuid not in processed_uuids:
examples_to_process.append(example)
print(f"Processing {len(examples_to_process)} examples with {args.max_concurrent} workers")
pbar = tqdm(
total=len(examples_to_process),
desc="Generating responses",
unit="row",
mininterval=2,
smoothing=0.0001,
)
# Prepare arguments for each example
example_args = [(example, session, args, args.output_file, pbar) for example in examples_to_process]
# Use ThreadPoolExecutor to process examples concurrently
with ThreadPoolExecutor(max_workers=args.max_concurrent) as executor:
# Submit all tasks
futures = [executor.submit(process_example_wrapper, arg) for arg in example_args]
# Wait for all futures to complete
for future in futures:
future.result() # This ensures exceptions are raised
pbar.close()
print("All examples processed!")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,327 @@
#!/usr/bin/env python3
"""
Script to download, process, and upload the aguvis-stage2 dataset.
Downloads from huggingface.co/datasets/xlangai/aguvis-stage2 and uploads to smolagents/aguvis-stage-2
"""
import gc
import json
import os
import shutil
import zipfile
from pathlib import Path
from typing import Any, Dict, List, Generator, Callable
from tqdm import tqdm
from datasets import Dataset
from dotenv import load_dotenv
from huggingface_hub import HfApi, login, snapshot_download
from PIL import Image
from huggingface_hub import upload_large_folder
load_dotenv(override=True)
api = HfApi()
config_dict = [{
"json_path": "mind2web-l1.json",
"images_folder": "mind2web/",
"sampling_strategy": "all"
}, {
"json_path": "mind2web-l2.json",
"images_folder": "mind2web/",
"sampling_strategy": "all"}, {
"json_path": "mind2web-l2.json",
"images_folder": "mind2web/",
"sampling_strategy": "all"}, {
"json_path": "guiact-web-single.json",
"images_folder": "guiact-web-single/images/",
"sampling_strategy": "all"}, {
"json_path": "guiact-web-multi-l1.json",
"images_folder": "guiact-web-multi/images/",
"sampling_strategy": "all"}, {
"json_path": "guiact-web-multi-l2.json",
"images_folder": "guiact-web-multi/images/",
"sampling_strategy": "all"}, {
"json_path": "miniwob-l1.json",
"images_folder": "miniwob/images",
"sampling_strategy": "all"}, {
"json_path": "miniwob-l2.json",
"images_folder": "miniwob/images/",
"sampling_strategy": "all"},
{
"json_path": "coat.json",
"images_folder": "coat/images/",
"sampling_strategy": "all"},
{
"json_path": "android_control.json",
"images_folder": "android_control/images/",
"sampling_strategy": "all"},
{
"json_path": "gui-odyssey-l1.json",
"images_folder": "gui-odyssey/images/",
"sampling_strategy": "random:33%"}, {
"json_path": "gui-odyssey-l2.json",
"images_folder": "gui-odyssey/images/",
"sampling_strategy": "random:33%"}, {
"json_path": "gui-odyssey-l2.json",
"images_folder": "gui-odyssey/images/",
"sampling_strategy": "random:33%"}, {
"json_path": "amex-l1.json",
"images_folder": "amex/images/",
"sampling_strategy": "random:33%"}, {
"json_path": "amex-l2.json",
"images_folder": "amex/images/",
"sampling_strategy": "random:33%"}, {
"json_path": "amex-l2.json",
"images_folder": "amex/images/",
"sampling_strategy": "random:33%"}, {
"json_path": "aitw-l1.json",
"images_folder": "aitw/images",
"sampling_strategy": "all"},
{
"json_path": "aitw-l2.json",
"images_folder": "aitw/images/",
"sampling_strategy": "all"
},
]
def discover_dataset_config(dataset_path: str) -> List[Dict[str, Any]]:
"""Discover dataset configuration by scanning the data directory."""
dataset_dir = Path(dataset_path)
train_dir = dataset_dir
if not train_dir.exists():
raise FileNotFoundError(f"Train directory not found: {train_dir}")
configs = []
processed_splits = set()
# Find all JSON files in the train directory
for config in config_dict:
subset_name = config["json_path"].replace(".json", "").replace("-l1", "").replace("-l2", "")
# Skip if we already processed this split
if subset_name in processed_splits:
continue
config["subset_name"] = subset_name
configs.append(config)
processed_splits.add(subset_name)
print(f"Discovered config: {config['subset_name']} -> {config['images_folder']}")
return configs
def download_dataset(
repo_id: str = "xlangai/aguvis-stage2", local_dir: str = "./aguvis_raw"
) -> str:
"""Download the dataset using snapshot_download."""
print(f"Downloading dataset from {repo_id}...")
local_path = snapshot_download(
repo_id=repo_id, local_dir=local_dir, repo_type="dataset"
)
print(f"Dataset downloaded to: {local_path}")
return local_path
def extract_zip_files(dataset_path: str):
"""Extract all zip files found in the dataset directory, but only if not already extracted."""
print("Extracting zip files...")
dataset_dir = Path(dataset_path)
for zip_file in dataset_dir.rglob("*.zip"):
extract_dir = zip_file.parent / zip_file.stem
if extract_dir.exists() and any(extract_dir.iterdir()):
print(f"Skipping extraction for {zip_file} (already extracted at {extract_dir})")
continue
print(f"Extracting: {zip_file}")
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(extract_dir)
print(f"Extracted to: {extract_dir}")
def check_subset_exists(repo_id: str, subset_name: str) -> bool:
"""Check if a subset already exists in the remote dataset."""
try:
# Try to get dataset info with specific subset
from datasets import get_dataset_config_names
config_names = get_dataset_config_names(repo_id)
return subset_name in config_names
except Exception as e:
print(f"Could not check if subset exists: {e}")
return False
def load_images_from_folder(
images_folder: Path, image_paths: List[str]
) -> List[Image.Image]:
"""Load images from the specified folder."""
images = []
for img_path in image_paths:
full_path = images_folder / img_path
img = Image.open(full_path)
images.append(img.copy())
img.close()
return images
def convert_to_chat_format(data_item: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Convert data item to chat template format."""
# This is a placeholder - you'll need to adapt this based on the actual data structure
# The exact conversion depends on how the original data is structured
chat_messages = []
# Example conversion - adapt based on actual data structure
if "conversations" in data_item:
for conv in data_item["conversations"]:
if "from" in conv and "value" in conv:
role = "user" if conv["from"] == "human" else "assistant"
message = {"role": role, "content": conv["value"]}
chat_messages.append(message)
elif "instruction" in data_item and "response" in data_item:
chat_messages = [
{"role": "user", "content": data_item["instruction"]},
{"role": "assistant", "content": data_item["response"]},
]
return chat_messages
def process_split(config: Dict[str, Any], dataset_path: str, destination_path: str) -> Callable:
"""Process a single dataset split."""
subset_name = config['subset_name']
repo_id = "smolagents/aguvis-stage-2"
# Check if the subset already exists in the remote dataset
if check_subset_exists(repo_id, subset_name):
print(f"Subset '{subset_name}' already exists in {repo_id}, skipping processing.")
return None
print(f"Processing split: {subset_name}")
dataset_dir = Path(dataset_path)
images_folder = dataset_dir / config["subset_name"] / config["images_folder"]
# Find all JSON files that match this split (e.g., mind2web-l1.json, mind2web-l2.json)
json_files = []
for cfg in config_dict:
cfg_split = cfg["json_path"].replace(".json", "").replace("-l1", "").replace("-l2", "")
if cfg_split == subset_name:
json_path = dataset_dir / cfg["json_path"]
if json_path.exists():
json_files.append(json_path)
# Load and merge JSON data from all matching files
data = []
for json_file in json_files:
print(f"Loading data from: {json_file}")
with open(json_file, "r") as f:
file_data = json.load(f)
data.extend(file_data)
print(f" Added {len(file_data)} items")
def get_images_total_weight(images_folder: Path, image_paths: list) -> int:
try:
return sum(os.path.getsize(images_folder / img_path) for img_path in image_paths)
except Exception as e:
print(f"Error getting image weight: {e}", images_folder, image_paths)
return 0
processed_data = []
current_weight = 0
shard_number = 0
MAX_WEIGHT = 1000 * 1024 * 1024
def process_items() -> Generator[Dict[str, Any], None, None]:
pbar = tqdm(data)
for item in pbar:
# Extract image paths from the data item
image_paths = []
if "images" in item:
image_paths = (
item["images"]
if isinstance(item["images"], list)
else [item["images"]]
)
elif "image" in item:
image_paths = [item["image"]]
# Load images
images = load_images_from_folder(images_folder, image_paths)
texts = convert_to_chat_format(item)
entry = {"images": images, "texts": texts}
yield entry
return process_items
def authenticate_huggingface():
"""Authenticate with HuggingFace Hub using token."""
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("Authenticating with HuggingFace Hub using token...")
login(token=hf_token)
else:
raise ValueError("HF_TOKEN environment variable not set.")
def main():
"""Main function to orchestrate the entire process."""
print("Starting aguvis-stage2 dataset processing...")
# Step 0: Authenticate with HuggingFace Hub
authenticate_huggingface()
data_folder = Path("./aguvis_raw")
dataset_path = download_dataset("xlangai/aguvis-stage2", data_folder)
extract_zip_files(dataset_path)
dataset_configs = discover_dataset_config(dataset_path)
converted_folder = "./aguvis_converted"
for config in dataset_configs:
print(f"\n{'=' * 50}")
print(config)
process_items = process_split(config, dataset_path, f"{config['subset_name']}")
# Skip if process_split returned None (subset already exists)
if process_items is None:
continue
print("Creating dataset...")
data = Dataset.from_generator(process_items)
print("Pushing to hub...")
# Fix: Use config_name for subset name and split="train"
data.push_to_hub(
"smolagents/aguvis-stage-2",
config_name=config['subset_name'], # This sets the subset name
split="train", # This should be "train" not the subset name
)
print(f"Processed and uploaded subset: {config['subset_name']}")
# Force garbage collection to manage memory
gc.collect()
print(f"Subsets uploaded!")
# Cleanup
print("\nCleaning up temporary files...")
# shutil.rmtree(dataset_path, ignore_errors=True)
# api.upload_large_folder(folder_path=converted_folder, repo_id="smolagents/aguvis-stage-2", repo_type="dataset")
shutil.rmtree(converted_folder, ignore_errors=True)
print("All done!")
if __name__ == "__main__":
main()

View file

@ -21,7 +21,6 @@ from pathlib import Path
from setuptools import find_packages, setup
# Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
if stale_egg_info.exists():
@ -42,7 +41,7 @@ if stale_egg_info.exists():
# * If a dependency is fast-moving (e.g. trl), pin to the exact version
_deps = [
"accelerate==1.4.0",
"bitsandbytes>=0.43.0",
"bitsandbytes>=0.42.0",
"datasets>=3.2.0",
"deepspeed==0.16.8",
"distilabel[vllm,ray,openai]>=1.5.2",
@ -54,8 +53,8 @@ _deps = [
"isort>=5.12.0",
"jieba", # Needed for Chinese language support
"langdetect", # Needed for LightEval's extended tasks
"latex2sympy2_extended>=1.0.6",
"liger-kernel>=0.5.10",
# "latex2sympy2_extended>=1.0.6",
# "liger-kernel>=0.5.10",
"lighteval @ git+https://github.com/huggingface/lighteval.git@d3da6b9bbf38104c8b5e1acc86f83541f9a502d1", # Critical bug fix for tokenizer revisions: https://github.com/huggingface/lighteval/pull/721
"math-verify==0.5.2", # Used for math verification in grpo
"morphcloud==0.1.67",
@ -82,7 +81,13 @@ _deps = [
# packaging: "packaging"
#
# some of the values are versioned whereas others aren't.
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
deps = {
b: a
for a, b in (
re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0]
for x in _deps
)
}
def deps_list(*pkgs):
@ -93,7 +98,9 @@ extras = {}
extras["tests"] = deps_list("pytest", "parameterized", "math-verify", "jieba")
extras["torch"] = deps_list("torch")
extras["quality"] = deps_list("ruff", "isort", "flake8")
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv", "morphcloud", "jieba", "pandas", "aiofiles")
extras["code"] = deps_list(
"e2b-code-interpreter", "python-dotenv", "morphcloud", "jieba", "pandas", "aiofiles"
)
extras["eval"] = deps_list("lighteval", "math-verify")
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["code"]
@ -107,9 +114,9 @@ install_requires = [
deps["hf_transfer"],
deps["huggingface-hub"],
deps["langdetect"],
deps["latex2sympy2_extended"],
# deps["latex2sympy2_extended"],
deps["math-verify"],
deps["liger-kernel"],
# deps["liger-kernel"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["safetensors"],
deps["sentencepiece"],

View file

@ -0,0 +1,23 @@
#!/bin/bash
#SBATCH --job-name=agentic-r1
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --nodes=1
#SBATCH --output=./logs/%x_%j_%n.out
#SBATCH --error=./logs/%x_%j_%n.err
#SBATCH --time=7-00:00:00
set -exuo pipefail
source ~/.bashrc
source $(conda info --base)/etc/profile.d/conda.sh
conda activate /fsx/aymeric/venv
python scripts/generate_agent_traces.py \
--output-file "codeforces_agentic_generations.jsonl" \
--prompt-column "prompt" \
--uuid-column "contestId" \
--api-addr "10.53.83.199:39876" \
--num-generations 5 \
--max-tokens 8096 \
--max-concurrent 64

View file

@ -1,5 +1,5 @@
#!/bin/bash
#SBATCH --job-name=open_r1
#SBATCH --job-name=agent-sft
#SBATCH --ntasks-per-node=1
#SBATCH --exclusive
#SBATCH --gres=gpu:8
@ -23,6 +23,7 @@ if [[ "$*" == *"--help"* ]]; then
exit 0
fi
# Specific configuration optimized for the Hugging Face Compute Cluster
module load cuda/12.4
set -x -e
@ -32,9 +33,9 @@ source openr1/bin/activate
START_TIME=$(date +%s)
echo "START TIME: $(date)"
# Refresh Weka on h4 cache
# Refresh Weka on cache
echo "Refreshing Weka filesystem..."
find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
# find -L /fsx/aymeric/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
# Default values
MODEL=""
@ -45,6 +46,7 @@ DP=1
TP=1
OPTIONAL_ARGS=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
@ -84,6 +86,9 @@ while [[ $# -gt 0 ]]; do
esac
done
export HF_HOME="/fsx/aymeric/.cache/"
HF_HOME="/fsx/aymeric/.cache/"
# Validate required arguments
if [[ -z "$MODEL" || -z "$TASK" || -z "$CONFIG_SUFFIX" || -z "$ACCELERATOR" ]]; then
echo "Error: Missing required arguments"
@ -106,14 +111,12 @@ for arg in "${ARGS[@]}"; do
fi
done
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
MODEL=$(grep 'model_name_or_path:' $CONFIG_FILE | awk '{print $2}')
REVISION=$(grep 'model_revision:' $CONFIG_FILE | head -n 1 | awk '{print $2}')
# Distributed configuration
NUM_NODES=$SLURM_NNODES
GPUS_PER_NODE=8
GPUS_PER_NODE=1
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
MASTER_ADDR=${NODELIST[0]} # First node for main process
@ -136,8 +139,8 @@ if [[ "$USE_VLLM" == "true" ]]; then
fi
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1

View file

@ -74,6 +74,10 @@ class ScriptArguments(trl.ScriptArguments):
default=None,
metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."},
)
single_gpu: bool = field(
default=False,
metadata={"help": "Force training on single GPU only, disabling distributed training."},
)
def __post_init__(self):
if self.dataset_name is None and self.dataset_mixture is None:
@ -185,6 +189,10 @@ class SFTConfig(trl.SFTConfig):
default=None,
metadata={"help": "The optional system prompt to use for benchmarking."},
)
vision_model: bool = field(
default=False,
metadata={"help": "Whether this is a vision-language model training."},
)
hub_model_revision: Optional[str] = field(
default="main",
metadata={"help": "The Hub model branch to push the model to."},

View file

@ -25,6 +25,19 @@ from typing import Callable, Dict, Literal, Optional
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
<<<<<<< HEAD
from .utils import is_e2b_available
from .utils.ioi import SubtaskResult, add_includes, get_piston_client_from_env, score_subtask
if is_e2b_available():
from dotenv import load_dotenv
from e2b_code_interpreter import AsyncSandbox, Sandbox
load_dotenv()
else:
AsyncSandbox = None
=======
from .utils.code_providers import get_provider
from .utils.competitive_programming import (
SubtaskResult,
@ -35,6 +48,7 @@ from .utils.competitive_programming import (
from .utils.competitive_programming import patch_code as cf_patch_code
from .utils.competitive_programming import score_submission as cf_score_submission
from .utils.competitive_programming import score_subtask
>>>>>>> main
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
@ -592,6 +606,70 @@ def code_reward(
return execution_provider.execute_scripts(scripts, ["python"] * len(scripts))
def run_tests(completions, **kwargs) -> list[float]:
"""Reward function that evaluates code snippets using the E2B code interpreter.
Assumes the dataset contains a `verification_info` column with test cases.
"""
if not is_e2b_available():
raise ImportError(
"E2B is not available and required for this reward function. Please install E2B with "
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
)
evaluation_script_template = """
import subprocess
import json
def evaluate_code(code, test_cases):
passed = 0
total = len(test_cases)
exec_timeout = 20
for case in test_cases:
process = subprocess.run(
["python3", "-c", code],
input=case["input"],
text=True,
capture_output=True,
timeout=exec_timeout
)
if process.returncode != 0:
error_msg = "Process exited with code" + process.returncode
if process.stderr:
error_msg += ": Error: " + process.stderr
if process.stdout:
error_msg += ": Output: " + process.stdout
raise Exception(error_msg)
output = process.stdout.strip()
# TODO: implement a proper validator to compare against ground truth. For now we just check for exact string match on each line of stdout.
for line1, line2 in zip(output.split('\\n'), str(case['output']).split('\\n')):
if not line1.strip() == line2.strip():
raise Exception("Function output did not match gold truth for test case "+ case['input'] + " : Got " + str(line1.strip()) + " instead of " + str(line2.strip()))
code_snippet = {code}
test_cases = json.loads({test_cases})
evaluate_code(code_snippet, test_cases)
"""
code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
verification_info = kwargs["verification_info"]
scripts = [
evaluation_script_template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"])))
for code, info in zip(code_snippets, verification_info)
]
language = verification_info[0]["language"]
if not all(v["language"] == language for v in verification_info):
raise ValueError("All verification_info must have the same language", verification_info)
return [run_script_test(script, language) for script in scripts]
def get_code_format_reward(language: str = "python"):
"""Format reward function specifically for code responses.
@ -643,6 +721,14 @@ def get_soft_overlong_punishment(max_completion_len, soft_punish_cache):
return soft_overlong_punishment_reward
def run_script_test(script: str, language: str) -> float:
sandbox = Sandbox(timeout=30, request_timeout=30)
execution = sandbox.run_code(script, language=language)
if execution.error:
raise Exception(f"{execution.logs}\n{execution.error.name}: {execution.error.value}")
def get_reward_funcs(script_args) -> list[Callable]:
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,

View file

@ -13,18 +13,18 @@
# limitations under the License.
"""
Supervised fine-tuning script for decoder language models.
Supervised fine-tuning script for decoder language models and vision-language models.
Usage:
# One 1 node of 8 x H100s
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \
--dataset_name open-r1/Mixture-of-Thoughts \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name smolagents/gaia-traces \
--num_train_epochs 1 \
--dataset_config all \
--eos_token '<|im_end|>' \
--learning_rate 4.0e-5 \
--num_train_epochs 5 \
--max_seq_length 32768 \
--per_device_train_batch_size 2 \
--gradient_checkpointing \
@ -39,20 +39,57 @@ import sys
import datasets
import transformers
from transformers import set_seed
from transformers import set_seed, AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import ScriptArguments, SFTConfig
from open_r1.utils import get_dataset, get_model, get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format
from open_r1.configs import ScriptArguments, SFTConfig
from open_r1.utils import get_dataset, get_model, get_tokenizer, get_processor
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
logger = logging.getLogger(__name__)
def create_vlm_collate_fn(processor):
"""Create a data collator for VLM training that handles images and text."""
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["texts"], tokenize=False) for example in examples]
images = [example["images"] for example in examples]
# Handle LLaVA 1.5 which doesn't support multiple images
# if isinstance(processor.model, LlavaForConditionalGeneration):
# images = [image[0] if image else None for image in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
# Ignore the image token index in the loss computation (model specific)
if hasattr(processor, 'image_token'):
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
return collate_fn
def main(script_args, training_args, model_args):
# Force single GPU mode if requested
# if hasattr(script_args, 'single_gpu') and script_args.single_gpu:
# logger.info("Single GPU mode requested - setting CUDA_VISIBLE_DEVICES=0")
# # Disable distributed training
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# training_args.local_rank = -1
# training_args.ddp_backend = None
set_seed(training_args.seed)
###############
@ -85,25 +122,60 @@ def main(script_args, training_args, model_args):
init_wandb_training(training_args)
######################################
# Load dataset, tokenizer, and model #
# Load dataset, processor/tokenizer, and model #
######################################
dataset = get_dataset(script_args)
tokenizer = get_tokenizer(model_args, training_args)
model = get_model(model_args, training_args)
if tokenizer.chat_template is None:
logger.info("No chat template provided, defaulting to ChatML.")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
if training_args.vision_model:
logger.info("Setting up vision-language model training")
# Set VLM-specific training arguments (following TRL reference)
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# Load processor and model for VLM
processor = get_processor(model_args, training_args)
model = get_model(model_args, training_args) # This should return AutoModelForVision2Seq
data_collator = create_vlm_collate_fn(processor)
processing_class = processor.tokenizer
model_tags = ["open-r1", "vision-language", "vlm"]
else:
logger.info("Setting up text-only model training")
# Load tokenizer and model for text-only
tokenizer = get_tokenizer(model_args, training_args)
model = get_model(model_args, training_args)
if tokenizer.chat_template is None:
logger.info("No chat template provided, defaulting to ChatML.")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
data_collator = None # Use default
processing_class = tokenizer
model_tags = ["open-r1"]
############################
# Initialize the SFT Trainer
############################
logger.info(f"""WHOLE DATASET INFO:
{model}
{training_args}
{script_args}
{model_args}
""")
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
processing_class=tokenizer,
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
processing_class=processing_class,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
@ -128,16 +200,13 @@ def main(script_args, training_args, model_args):
# Save model and create model card
##################################
logger.info("*** Save model ***")
# Align the model's generation config with the tokenizer's eos token
# to avoid unbounded generation in the transformers `pipeline()` function
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
"tags": model_tags,
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
@ -161,6 +230,9 @@ def main(script_args, training_args, model_args):
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
# Also push processor for VLM models
if training_args.vision_model and trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)
if __name__ == "__main__":

View file

@ -1,6 +1,6 @@
from .data import get_dataset
from .import_utils import is_e2b_available, is_morph_available
from .model_utils import get_model, get_tokenizer
from .model_utils import get_model, get_tokenizer, get_processor
__all__ = ["get_tokenizer", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]
__all__ = ["get_tokenizer", "get_processor", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]

View file

@ -1,5 +1,5 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, AutoProcessor, AutoModelForVision2Seq
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
@ -20,8 +20,22 @@ def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
return tokenizer
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
"""Get the model"""
def get_processor(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoProcessor:
"""Get the processor for VLM models."""
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
if training_args.chat_template is not None:
processor.chat_template = training_args.chat_template
return processor
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM | AutoModelForVision2Seq:
"""Get the model - supports both text-only and vision-language models"""
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
@ -35,8 +49,19 @@ def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) ->
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
# Check if this is a VLM model using the explicit flag
if hasattr(training_args, 'vision_model') and training_args.vision_model:
# Load as vision-language model
model = AutoModelForVision2Seq.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
else:
# Load as text-only model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
return model