mirror of
https://github.com/huggingface/open-r1.git
synced 2026-06-24 01:54:06 +00:00
Compare commits
88 commits
main
...
agent-trac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
18fea48144 | ||
|
|
e8a4c2bd08 | ||
|
|
6a63f2fed6 | ||
|
|
4c83688dba | ||
|
|
868d4a4011 | ||
|
|
1b50860c3e | ||
|
|
880a5853fb | ||
|
|
029dc608ba | ||
|
|
31cf3a2c53 | ||
|
|
b0d794cfaf | ||
|
|
f692c10b74 | ||
|
|
035f134bf4 | ||
|
|
22b84cf091 | ||
|
|
f6b8f7cafd | ||
|
|
2ba1c6580b | ||
|
|
b316210f28 | ||
|
|
5eadb0607a | ||
|
|
db30467a95 | ||
|
|
24ea1127b3 | ||
|
|
a658db9def | ||
|
|
933ea920a7 | ||
|
|
5fa7e51a5e | ||
|
|
a452f2fff6 | ||
|
|
3c3e954578 | ||
|
|
c4d41268e0 | ||
|
|
0ee52fc6f2 | ||
|
|
eb39096a66 | ||
|
|
81c64ac201 | ||
|
|
3b77977ce2 | ||
|
|
b7a700e119 | ||
|
|
7cb592c062 | ||
|
|
3aee6efd1e | ||
|
|
fbd987c3ad | ||
|
|
a675552585 | ||
|
|
984d63120b | ||
|
|
80f7ce833d | ||
|
|
a9b541195a | ||
|
|
a66a5e69a4 | ||
|
|
9347590c47 | ||
|
|
60472f6613 | ||
|
|
08a449c8f5 | ||
|
|
2030e166f3 | ||
|
|
2043be9ea0 | ||
|
|
1eaf1d15f5 | ||
|
|
2a08444b20 | ||
|
|
cae3c7c5aa | ||
|
|
d28d07b63d | ||
|
|
8a7951c0bc | ||
|
|
5647c262f8 | ||
|
|
de2b792dba | ||
|
|
49083cc87c | ||
|
|
2ddf70e6a8 | ||
|
|
b7522e3925 | ||
|
|
38efcfcbd5 | ||
|
|
2b1bc05ebc | ||
|
|
6961c3650c | ||
|
|
4a20ba4bfd | ||
|
|
4c2fce6bbf | ||
|
|
ddc1cddc25 | ||
|
|
a07cd54ab5 | ||
|
|
cf52433cdc | ||
|
|
2876d52491 | ||
|
|
8e70ca4f0d | ||
|
|
d87e3f3fa3 | ||
|
|
83a679fc68 | ||
|
|
e42b1cd606 | ||
|
|
d8cb19b616 | ||
|
|
a97eb27683 | ||
|
|
2a1ff761f1 | ||
|
|
9a2d16f62a | ||
|
|
cb2a2c2db7 | ||
|
|
f78b8651e3 | ||
|
|
b738e58ecf | ||
|
|
d2588cd290 | ||
|
|
dd15ad8b5c | ||
|
|
0cd0999b9c | ||
|
|
b47a4be6f4 | ||
|
|
e35800c5a1 | ||
|
|
cffa36268b | ||
|
|
cf13c2b63e | ||
|
|
6cffffe504 | ||
|
|
0af9e75346 | ||
|
|
143fcfa3da | ||
|
|
a00f0ee768 | ||
|
|
ad948c22a5 | ||
|
|
69b26518fc | ||
|
|
6c231d2130 | ||
|
|
352008b017 |
17 changed files with 1543 additions and 47 deletions
21
README_AGENTS.md
Normal file
21
README_AGENTS.md
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
Launch:
|
||||
```bash
|
||||
sbatch --nodes=1 slurm/train.slurm --model SmolLM2-1.7B-Instruct --task sft --config agent --accelerator zero3
|
||||
```
|
||||
Refers to the config recipes/SmolLM2-1.7B-Instruct/sft/config_agent.yaml
|
||||
zero3 is one of the accelerate configs in recipes/accelerate_configs
|
||||
|
||||
|
||||
### VLM training
|
||||
|
||||
Launch in multi GPU:
|
||||
```bash
|
||||
sbatch --qos=high --nodes=1 slurm/train.slurm --model Qwen2.5-VL-3B-Instruct --task sft --config agent --accelerator zero3
|
||||
```
|
||||
|
||||
🛑 For me the above fails because of NCCL issues, I launch it in single-GPU mode as follows:
|
||||
```bash
|
||||
sbatch slurm/trainsingle.slurm --model Qwen2.5-VL-3B-Instruct --task sft --config agent
|
||||
```
|
||||
|
||||
The config is located under recipes/Qwen2.5-VL-3B-Instruct/sft/config_agent.yaml
|
||||
46
recipes/Qwen2.5-3B-Instruct/sft/config.yaml
Normal file
46
recipes/Qwen2.5-3B-Instruct/sft/config.yaml
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-220k
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
|
||||
weight_decay: 0.0001
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: linear
|
||||
warmup_ratio: 0.1
|
||||
learning_rate: 5.0e-05
|
||||
gradient_accumulation_steps: 2
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 3
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'no'
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: OpenR1-Qwen-7B-SFT
|
||||
hub_strategy: every_save
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
packing: true
|
||||
output_dir: data/OpenR1-Qwen-7B-SFT
|
||||
overwrite_output_dir: true
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "steps"
|
||||
save_steps: 500
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
46
recipes/Qwen2.5-3B-Instruct/sft/config_agent.yaml
Normal file
46
recipes/Qwen2.5-3B-Instruct/sft/config_agent.yaml
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: Qwen/Qwen2.5-3B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: smolagents/training-traces
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
|
||||
weight_decay: 0.0001
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: linear
|
||||
warmup_ratio: 0.1
|
||||
learning_rate: 4.0e-05
|
||||
gradient_accumulation_steps: 1
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 2
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'no'
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: oR1-Qwen-3B-Agentic-e2-lr4e-b2
|
||||
hub_strategy: every_save
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
packing: true
|
||||
output_dir: data/oR1-Qwen-3B-Agentic-e2-lr4e-b2
|
||||
overwrite_output_dir: true
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "steps"
|
||||
save_steps: 500
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
76
recipes/Qwen2.5-VL-3B-Instruct/sft/config_agent.yaml
Normal file
76
recipes/Qwen2.5-VL-3B-Instruct/sft/config_agent.yaml
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
|
||||
vision_model: true
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: smolagents/aguvis-stage-2
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 32768
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_grad_norm: 0.2
|
||||
warmup_ratio: 0.03
|
||||
learning_rate: 1.0e-05
|
||||
gradient_accumulation_steps: 8
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
single_gpu: true
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'no'
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: smolagents/Qwen2.5-VL-3B-Instruct-Agentic
|
||||
hub_strategy: end
|
||||
push_to_hub: true
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
output_dir: data/smolagents-Qwen2.5-VL-3B-Instruct-Agentic
|
||||
overwrite_output_dir: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "epoch"
|
||||
save_steps: 1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
|
||||
dataset_mixture:
|
||||
datasets: # List of datasets to include in the mixture
|
||||
- id: smolagents/aguvis-stage-2 # Hub dataset ID
|
||||
config: mind2web # Name of the dataset config
|
||||
split: train # Split to use from the dataset
|
||||
columns: # Columns to keep
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: guiact-web-single
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: guiact-web-multi
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
seed: 42 # Seed for shuffling the combined dataset
|
||||
test_split_size: 0.1
|
||||
45
recipes/SmolLM2-1.7B-Instruct/sft/config_agent.yaml
Normal file
45
recipes/SmolLM2-1.7B-Instruct/sft/config_agent.yaml
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: HuggingFaceTB/SmolLM2-1.7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: smolagents/codeagent-traces
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 8192 # You can set this to 32768 if you change the rope, but you need to change the config.json file
|
||||
weight_decay: 0.0001
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: linear
|
||||
warmup_ratio: 0.1
|
||||
learning_rate: 5.0e-05
|
||||
gradient_accumulation_steps: 2
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'no'
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: OpenR1-SmolLM2-1.7B-Instruct-Agentic
|
||||
hub_strategy: every_save
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
output_dir: data/OpenR1-SmolLM2-1.7B-Instruct-Agentic
|
||||
overwrite_output_dir: true
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "steps"
|
||||
save_steps: 500
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
330
scripts/agents/generate_agent_traces.py
Normal file
330
scripts/agents/generate_agent_traces.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
import argparse
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import Set, Any, List
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import requests.adapters
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from smolagents import CodeAgent, Tool, HfApiModel
|
||||
from open_r1.rewards import run_tests
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
assert os.getenv("HF_TOKEN") is not None
|
||||
|
||||
# from huggingface_hub import login
|
||||
|
||||
# print("LOGIN:\n", login(token=os.getenv("HF_TOKEN")))
|
||||
|
||||
file_lock = Lock()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")
|
||||
|
||||
print("Launching generation")
|
||||
|
||||
def test_cases_on_function(function: Any, test_cases: List[dict]) -> str:
|
||||
source_code = f"```python\n{function.__source__}\n{function.__name__}()```"
|
||||
test_completions = [[{"content": source_code}]]
|
||||
test_kwargs = {"verification_info": [{"language": "python", "test_cases": test_cases}]}
|
||||
run_tests(test_completions, e2b_router_url="0.0.0.0:8000", test_mode=True, **test_kwargs)
|
||||
|
||||
|
||||
class ModifiedFinalAnswerTool(Tool):
|
||||
name = "final_answer"
|
||||
description = "Tests a function: if correct, returns it as the final answer; Else returns the test case that errored for solving."
|
||||
inputs = {'answer_function': {'type': 'any', 'description': 'The final function that solves the problem'}}
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, test_cases):
|
||||
self.is_initialized = False
|
||||
self.test_cases = test_cases
|
||||
super().__init__()
|
||||
|
||||
def forward(self, answer_function: Any) -> str:
|
||||
test_cases_on_function(answer_function, self.test_cases)
|
||||
return answer_function.__source__
|
||||
|
||||
|
||||
class ChatMessage:
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
def generate_completion_from_messages(session, messages, args, stop_sequences) -> str:
|
||||
retry_budget = 10
|
||||
while retry_budget > 0:
|
||||
try:
|
||||
formatted_chat = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
print("Input token count:", len(tokenizer.encode(formatted_chat)))
|
||||
# Add a small random delay to prevent overwhelming the API
|
||||
time.sleep(random.uniform(0.0, 0.1))
|
||||
response = session.post(
|
||||
f"http://{args.api_addr}/v1/chat/completions",
|
||||
json={
|
||||
"model": "default",
|
||||
"messages": messages,
|
||||
"max_tokens": args.max_tokens,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"stop": stop_sequences,
|
||||
},
|
||||
headers={"Authorization": "Bearer EMPTY"},
|
||||
timeout=2*60*60
|
||||
)
|
||||
|
||||
# Check status code and log error content if needed
|
||||
if response.status_code >= 400:
|
||||
print(f"HTTP Error {response.status_code}: {response.reason}")
|
||||
print(f"Response content: {response.text}")
|
||||
traceback.print_exc()
|
||||
retry_budget -= 1
|
||||
time.sleep(20)
|
||||
continue
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
output = response.json()["choices"][0]["message"]["content"]
|
||||
return output
|
||||
except ValueError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
print(f"Response content: {response.text}")
|
||||
traceback.print_exc()
|
||||
retry_budget -= 1
|
||||
time.sleep(20)
|
||||
continue
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"API request error (will retry): {e}")
|
||||
traceback.print_exc()
|
||||
retry_budget -= 1
|
||||
time.sleep(20)
|
||||
|
||||
raise Exception("Failed to get a valid response after multiple retries")
|
||||
|
||||
def get_agent_run(session, task, test_cases, args):
|
||||
# def model(messages, stop_sequences = None):
|
||||
# cleaned_messages = get_clean_message_list(messages, {"system": "user", "tool-call": "assistant", "tool-response": "user"}, flatten_messages_as_text=True)
|
||||
# result = generate_completion_from_messages(session, cleaned_messages, args, stop_sequences)
|
||||
# return ChatMessage(content=result)
|
||||
|
||||
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", provider="fireworks-ai", token=os.getenv("HF_TOKEN"))
|
||||
|
||||
agent = CodeAgent(
|
||||
model=model,
|
||||
tools=[ModifiedFinalAnswerTool(test_cases)],
|
||||
additional_authorized_imports=["numpy", "math"],
|
||||
max_steps=10,
|
||||
verbosity_level=2
|
||||
)
|
||||
|
||||
try:
|
||||
output = agent.run(task, additional_args={"test_cases": test_cases})
|
||||
return output, agent.write_memory_to_messages()
|
||||
except Exception as e:
|
||||
print(f"Error when generating agentic trace: {e}")
|
||||
return None
|
||||
|
||||
def process_example(example, session, args, output_file, pbar=None):
|
||||
prompt = f"""Here is a task to solve using a function:
|
||||
{example[args.prompt_column]}
|
||||
|
||||
Now write a function that solves the problem, then you can at once test and return it by using the tool final_answer(your_function).
|
||||
- The function should take the inputs described in the task above: use the input() function to get them.
|
||||
- As such your function will not take any arguments. IT SHOULD BE ONE SINGLE MONOLITHIC FUNCTION, no helper functions, all imports and variable defs should be inside the function.
|
||||
- And your function should give its output to stdout via print(). Returning anything is useless.
|
||||
- ALSO, DO NOT TRY TO CRAFT CUSTOM TEST FUNCTIONS or do not run your function: just test it using final_answer.
|
||||
- If you get this error: 'Forbidden function evaluation: 'input' is not among the explicitly allowed tools', it just means that you've tried to run your function: don't do that, just return it using final_answer
|
||||
"""
|
||||
try:
|
||||
agent_outputs, agent_memories = [], []
|
||||
for _ in range(args.num_generations):
|
||||
agent_output, agent_memory = get_agent_run(session, prompt, example["test_cases"], args)
|
||||
agent_outputs.append(agent_output)
|
||||
agent_memories.append(agent_memory)
|
||||
|
||||
if any(agent_output is None for agent_output in agent_outputs):
|
||||
print("Error processing example")
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return None
|
||||
|
||||
finish_reasons = []
|
||||
api_metadata = []
|
||||
|
||||
for agent_run in agent_output:
|
||||
finish_reasons.append(None)
|
||||
api_metadata.append(None)
|
||||
|
||||
# Convert agent_run to a serializable format
|
||||
serializable_generations = []
|
||||
for generation in agent_memories:
|
||||
if generation is not None:
|
||||
# Convert to a simple list of dictionaries if it's not already
|
||||
if isinstance(generation, list):
|
||||
serializable_generations.append([
|
||||
{k: v for k, v in msg.items() if isinstance(v, (str, int, float, bool, type(None), list, dict))}
|
||||
for msg in generation if isinstance(msg, dict)
|
||||
])
|
||||
else:
|
||||
# Handle other formats or provide a placeholder
|
||||
serializable_generations.append(str(generation))
|
||||
else:
|
||||
serializable_generations.append(None)
|
||||
|
||||
# Combine original dataset fields with generations
|
||||
result = {
|
||||
**example, # Preserve all original dataset fields
|
||||
"generations": serializable_generations,
|
||||
"final_outputs": agent_outputs,
|
||||
# "finish_reasons": finish_reasons,
|
||||
# "api_metadata": api_metadata,
|
||||
}
|
||||
|
||||
# Write to file with lock
|
||||
with file_lock:
|
||||
with open(output_file, mode="a") as f:
|
||||
try:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
f.flush()
|
||||
except TypeError as e:
|
||||
print(f"JSON serialization error: {e}")
|
||||
# Fallback: store with minimal information
|
||||
fallback_result = {
|
||||
**{k: v for k, v in example.items() if isinstance(v, (str, int, float, bool, type(None), list, dict))},
|
||||
"error": f"Failed to serialize full result: {e}"
|
||||
}
|
||||
f.write(json.dumps(fallback_result) + "\n")
|
||||
f.flush()
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error processing example: {e}")
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return None
|
||||
|
||||
def load_processed_uuids(output_file, uuid_column):
|
||||
processed_uuids = set()
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, mode="r") as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return processed_uuids
|
||||
|
||||
def process_example_wrapper(args_tuple):
|
||||
example, session, args, output_file, pbar = args_tuple
|
||||
return process_example(example, session, args, output_file, pbar)
|
||||
|
||||
def main():
|
||||
test_function = ModifiedFinalAnswerTool([{"input": "1 2 3", "output": 5}])
|
||||
|
||||
def add_numbers():
|
||||
numbers = input()
|
||||
print(sum([int(number) for number in numbers.split()]))
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
add_numbers.__source__ = dedent(inspect.getsource(add_numbers))
|
||||
# print(test_function(
|
||||
# add_numbers,
|
||||
# ))
|
||||
# quit()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("--dataset-name", type=str, required=True)
|
||||
parser.add_argument("--output-file", type=str, required=True)
|
||||
parser.add_argument("--prompt-column", type=str, required=True)
|
||||
parser.add_argument("--uuid-column", type=str, required=True)
|
||||
parser.add_argument("--api-addr", type=str, default="localhost:39876")
|
||||
parser.add_argument("--num-generations", type=int, default=4)
|
||||
parser.add_argument("--temperature", type=float, default=0.6)
|
||||
parser.add_argument("--top-p", type=float, default=0.95)
|
||||
parser.add_argument("--max-tokens", type=int, default=8096)
|
||||
parser.add_argument("--max-concurrent", type=int, default=1000)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = load_dataset(
|
||||
"open-r1/codeforces-test-cases",
|
||||
split="train",
|
||||
token=os.getenv("HF_TOKEN")
|
||||
).shuffle()
|
||||
dataset = dataset.filter(lambda x: x["full_test_set"])
|
||||
|
||||
processed_uuids = load_processed_uuids(args.output_file, args.uuid_column)
|
||||
if processed_uuids:
|
||||
print(f"Found {len(processed_uuids)} already processed examples, resuming from there...")
|
||||
|
||||
# Ensure the output directory exists
|
||||
output_path = Path(args.output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create the file if it doesn't exist
|
||||
if not output_path.exists():
|
||||
with open(args.output_file, mode="w") as f:
|
||||
f.write("")
|
||||
|
||||
# Create a session that will be shared among threads
|
||||
session = requests.Session()
|
||||
adapter = requests.adapters.HTTPAdapter(
|
||||
pool_connections=args.max_concurrent,
|
||||
pool_maxsize=args.max_concurrent,
|
||||
max_retries=3
|
||||
)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
|
||||
# Filter out already processed examples
|
||||
examples_to_process = []
|
||||
for example in dataset:
|
||||
uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()
|
||||
if uuid not in processed_uuids:
|
||||
examples_to_process.append(example)
|
||||
|
||||
print(f"Processing {len(examples_to_process)} examples with {args.max_concurrent} workers")
|
||||
|
||||
pbar = tqdm(
|
||||
total=len(examples_to_process),
|
||||
desc="Generating responses",
|
||||
unit="row",
|
||||
mininterval=2,
|
||||
smoothing=0.0001,
|
||||
)
|
||||
|
||||
# Prepare arguments for each example
|
||||
example_args = [(example, session, args, args.output_file, pbar) for example in examples_to_process]
|
||||
|
||||
# Use ThreadPoolExecutor to process examples concurrently
|
||||
with ThreadPoolExecutor(max_workers=args.max_concurrent) as executor:
|
||||
# Submit all tasks
|
||||
futures = [executor.submit(process_example_wrapper, arg) for arg in example_args]
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in futures:
|
||||
future.result() # This ensures exceptions are raised
|
||||
|
||||
pbar.close()
|
||||
print("All examples processed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
563
scripts/agents/get_aguvis_data.py
Normal file
563
scripts/agents/get_aguvis_data.py
Normal file
|
|
@ -0,0 +1,563 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to download, process, and upload the aguvis-stage2 dataset.
|
||||
Downloads from huggingface.co/datasets/xlangai/aguvis-stage2 and uploads to smolagents/aguvis-stage-2
|
||||
"""
|
||||
import re
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Generator, Callable
|
||||
from tqdm import tqdm
|
||||
from datasets import Dataset, load_dataset
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import HfApi, login, snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
SYSTEM_PROMPT= """You are a helpful GUI agent. You will be given a task and a screenshot of the screen. You need to perform a series of function calls in code to complete the task.
|
||||
|
||||
When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.
|
||||
|
||||
The following functions are exposed to the Python interpreter:
|
||||
<code>
|
||||
def final_answer(answer: any) -> any:
|
||||
\"\"\"
|
||||
Provides a final answer to the given problem.
|
||||
Args:
|
||||
answer: The final answer to the problem
|
||||
\"\"\"
|
||||
|
||||
def click(x: int, y: int) -> str:
|
||||
\"\"\"
|
||||
Performs a left-click at the specified coordinates
|
||||
Args:
|
||||
x: The x coordinate (horizontal position)
|
||||
y: The y coordinate (vertical position)
|
||||
\"\"\"
|
||||
|
||||
def right_click(x: int, y: int) -> str:
|
||||
\"\"\"
|
||||
Performs a right-click at the specified coordinates
|
||||
Args:
|
||||
x: The x coordinate (horizontal position)
|
||||
y: The y coordinate (vertical position)
|
||||
\"\"\"
|
||||
|
||||
def double_click(x: int, y: int) -> str:
|
||||
\"\"\"
|
||||
Performs a double-click at the specified coordinates
|
||||
Args:
|
||||
x: The x coordinate (horizontal position)
|
||||
y: The y coordinate (vertical position)
|
||||
\"\"\"
|
||||
|
||||
def write(text: str) -> str:
|
||||
\"\"\"
|
||||
Types the specified text at the current cursor position.
|
||||
Args:
|
||||
text: The text to type
|
||||
\"\"\"
|
||||
|
||||
def press_key(key: str) -> str:
|
||||
\"\"\"
|
||||
Presses a keyboard key
|
||||
Args:
|
||||
key: The key to press (e.g. "enter", "space", "backspace", etc.).
|
||||
\"\"\"
|
||||
|
||||
def go_back() -> str:
|
||||
\"\"\"
|
||||
Goes back to the previous page in the browser. If using this tool doesn't work, just click the button directly.
|
||||
Args:
|
||||
\"\"\"
|
||||
|
||||
def drag_and_drop(x1: int, y1: int, x2: int, y2: int) -> str:
|
||||
\"\"\"
|
||||
Clicks [x1, y1], drags mouse to [x2, y2], then release click.
|
||||
Args:
|
||||
x1: origin x coordinate
|
||||
y1: origin y coordinate
|
||||
x2: end x coordinate
|
||||
y2: end y coordinate
|
||||
\"\"\"
|
||||
|
||||
def scroll(x: int = None, y: int = None, direction: Literal["up", "down"] = "down", amount: int = 2) -> str:
|
||||
\"\"\"
|
||||
Moves the mouse to selected coordinates, then uses the scroll button: this could scroll the page or zoom, depending on the app. DO NOT use scroll to move through linux desktop menus.
|
||||
Args:
|
||||
x: The x coordinate (horizontal position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates
|
||||
y: The y coordinate (vertical position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates
|
||||
direction: The direction to scroll ("up" or "down"), defaults to "down". For zoom, "up" zooms in, "down" zooms out.
|
||||
amount: The amount to scroll. A good amount is 1 or 2.
|
||||
\"\"\"
|
||||
|
||||
def wait(seconds: float) -> str:
|
||||
\"\"\"
|
||||
Waits for the specified number of seconds. Very useful in case the prior order is still executing (for example starting very heavy applications like browsers or office apps)
|
||||
Args:
|
||||
seconds: Number of seconds to wait, generally 3 is enough.
|
||||
\"\"\"
|
||||
</code>
|
||||
|
||||
The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||
"""
|
||||
|
||||
|
||||
# TODO: some of the mappings above must be wrong because the conversion fails for some subsets
|
||||
config_dict = [{
|
||||
"json_path": "mind2web-l1.json",
|
||||
"images_folder": "mind2web/",
|
||||
"sampling_strategy": "all"
|
||||
}, {
|
||||
"json_path": "mind2web-l2.json",
|
||||
"images_folder": "mind2web/",
|
||||
"sampling_strategy": "all"}, {
|
||||
"json_path": "mind2web-l2.json",
|
||||
"images_folder": "mind2web/",
|
||||
"sampling_strategy": "all"}, {
|
||||
"json_path": "guiact-web-single.json",
|
||||
"images_folder": "guiact-web-single/images/",
|
||||
"sampling_strategy": "all"}, {
|
||||
"json_path": "guiact-web-multi-l1.json",
|
||||
"images_folder": "guiact-web-multi/images/",
|
||||
"sampling_strategy": "all"}, {
|
||||
"json_path": "guiact-web-multi-l2.json",
|
||||
"images_folder": "guiact-web-multi/images/",
|
||||
"sampling_strategy": "all"}, {
|
||||
"json_path": "miniwob-l1.json",
|
||||
"images_folder": "miniwob/images",
|
||||
"sampling_strategy": "all"}, {
|
||||
"json_path": "miniwob-l2.json",
|
||||
"images_folder": "miniwob/images/",
|
||||
"sampling_strategy": "all"},
|
||||
{
|
||||
"json_path": "coat.json",
|
||||
"images_folder": "coat/images/",
|
||||
"sampling_strategy": "all"},
|
||||
{
|
||||
"json_path": "android_control.json",
|
||||
"images_folder": "android_control/images/",
|
||||
"sampling_strategy": "all"},
|
||||
{
|
||||
"json_path": "gui-odyssey-l1.json",
|
||||
"images_folder": "gui-odyssey/images/",
|
||||
"sampling_strategy": "random:33%"}, {
|
||||
"json_path": "gui-odyssey-l2.json",
|
||||
"images_folder": "gui-odyssey/images/",
|
||||
"sampling_strategy": "random:33%"}, {
|
||||
"json_path": "gui-odyssey-l2.json",
|
||||
"images_folder": "gui-odyssey/images/",
|
||||
"sampling_strategy": "random:33%"}, {
|
||||
"json_path": "amex-l1.json",
|
||||
"images_folder": "amex/images/",
|
||||
"sampling_strategy": "random:33%"}, {
|
||||
"json_path": "amex-l2.json",
|
||||
"images_folder": "amex/images/",
|
||||
"sampling_strategy": "random:33%"}, {
|
||||
"json_path": "amex-l2.json",
|
||||
"images_folder": "amex/images/",
|
||||
"sampling_strategy": "random:33%"}, {
|
||||
"json_path": "aitw-l1.json",
|
||||
"images_folder": "aitw/images",
|
||||
"sampling_strategy": "all"},
|
||||
{
|
||||
"json_path": "aitw-l2.json",
|
||||
"images_folder": "aitw/images/",
|
||||
"sampling_strategy": "all"
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
||||
def authenticate_huggingface():
|
||||
"""Authenticate with HuggingFace Hub using token."""
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
if hf_token:
|
||||
print("Authenticating with HuggingFace Hub using token...")
|
||||
login(token=hf_token)
|
||||
else:
|
||||
raise ValueError("HF_TOKEN environment variable not set.")
|
||||
|
||||
|
||||
def discover_dataset_config(dataset_path: str) -> List[Dict[str, Any]]:
|
||||
"""Discover dataset configuration by scanning the data directory."""
|
||||
dataset_dir = Path(dataset_path)
|
||||
train_dir = dataset_dir
|
||||
|
||||
if not train_dir.exists():
|
||||
raise FileNotFoundError(f"Train directory not found: {train_dir}")
|
||||
|
||||
configs = []
|
||||
processed_splits = set()
|
||||
|
||||
# Find all JSON files in the train directory
|
||||
for config in config_dict:
|
||||
subset_name = config["json_path"].replace(".json", "").replace("-l1", "").replace("-l2", "")
|
||||
|
||||
# Skip if we already processed this split
|
||||
if subset_name in processed_splits:
|
||||
continue
|
||||
|
||||
config["subset_name"] = subset_name
|
||||
configs.append(config)
|
||||
processed_splits.add(subset_name)
|
||||
print(f"Discovered config: {config['subset_name']} -> {config['images_folder']}")
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def download_dataset(
|
||||
repo_id: str = "xlangai/aguvis-stage2", local_dir: str = "./aguvis_raw"
|
||||
) -> str:
|
||||
"""Download the dataset using snapshot_download."""
|
||||
print(f"Downloading dataset from {repo_id}...")
|
||||
local_path = snapshot_download(
|
||||
repo_id=repo_id, local_dir=local_dir, repo_type="dataset"
|
||||
)
|
||||
print(f"Dataset downloaded to: {local_path}")
|
||||
return local_path
|
||||
|
||||
|
||||
def extract_zip_files(dataset_path: str):
|
||||
"""Extract all zip files found in the dataset directory, but only if not already extracted."""
|
||||
print("Extracting zip files...")
|
||||
dataset_dir = Path(dataset_path)
|
||||
|
||||
for zip_file in dataset_dir.rglob("*.zip"):
|
||||
extract_dir = zip_file.parent / zip_file.stem
|
||||
if extract_dir.exists() and any(extract_dir.iterdir()):
|
||||
print(f"Skipping extraction for {zip_file} (already extracted at {extract_dir})")
|
||||
continue
|
||||
|
||||
print(f"Extracting: {zip_file}")
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_dir)
|
||||
print(f"Extracted to: {extract_dir}")
|
||||
|
||||
|
||||
def check_subset_exists(repo_id: str, subset_name: str) -> bool:
|
||||
"""Check if a subset already exists in the remote dataset."""
|
||||
try:
|
||||
# Try to get dataset info with specific subset
|
||||
from datasets import get_dataset_config_names
|
||||
config_names = get_dataset_config_names(repo_id)
|
||||
return subset_name in config_names
|
||||
except Exception as e:
|
||||
print(f"Could not check if subset exists: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_images_from_folder(
|
||||
images_folder: Path, image_paths: List[str]
|
||||
) -> List[Image.Image]:
|
||||
"""Load images from the specified folder."""
|
||||
images = []
|
||||
for img_path in image_paths:
|
||||
full_path = images_folder / img_path
|
||||
img = Image.open(full_path)
|
||||
images.append(img.copy())
|
||||
img.close()
|
||||
return images
|
||||
|
||||
|
||||
def convert_to_smolagents(messages: list[dict[str, Any]]):
|
||||
output_messages = [{
|
||||
"content": SYSTEM_PROMPT,
|
||||
"role": "system"
|
||||
}]
|
||||
previous_role = None
|
||||
for i in range(1, len(messages)):
|
||||
content = messages[i]["content"]
|
||||
|
||||
# Convert the format for content
|
||||
content = content.replace("answer(", "final_answer(")
|
||||
|
||||
if messages[i]["role"] == "assistant":
|
||||
if content.startswith("Action: "):
|
||||
content = content.replace("Action: ", "<think>\n").strip()
|
||||
content += "\n</think>\n"
|
||||
else:
|
||||
content = "<code>\n" + content.replace("pyautogui.", "").strip() + "\n</code>"
|
||||
|
||||
messages[i]["content"] = content
|
||||
|
||||
# Fuse subsequent messages if they are both assistants
|
||||
if messages[i]["role"] == "assistant" and messages[i-1]["role"] == "assistant":
|
||||
# Need to fuse both messages
|
||||
output_messages[-1]["content"] += messages[i]["content"]
|
||||
else:
|
||||
output_messages.append(messages[i])
|
||||
return output_messages
|
||||
|
||||
def test_conversion():
|
||||
origin = [
|
||||
{
|
||||
"content": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of actions to complete the task.\n\nYou have access to the following functions:\n- {\"name\": \"answer\", \"description\": \"Answer a question\", \"parameters\": {\"type\": \"object\", \"properties\": {\"answer\": {\"type\": \"string\", \"description\": \"The answer to the question\"}}, \"required\": [\"answer\"]}}\n",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "<image>\nPlease generate the next move according to the UI screenshot, instruction and previous actions.\n\nInstruction: What information does the site provide about Judith Lauand's career, works and exhibitions?\n\nPrevious actions:\nStep 1: Click on the link labeled 'Judith Lauand: Brazilian 1922-2022' to explore more about her career and exhibitions.\nStep 2: Click on the 'more' link below the overview text to access additional information about Judith Lauand's career and exhibitions.\nStep 3: Scroll down slightly to view additional information about Judith Lauand's career and exhibitions.",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Action: The answer is 'Judith Lauand was a Brazilian painter who was born in 1922.\\nNumerous key galleries and museums such as MASP, Museu de Arte de São Paulo have featured Judith Lauand's work in the past.\\nJudith Lauand's work has been offered at auction multiple times, with realized prices ranging from 515 USD to 87,500 USD, depending on the size and medium of the artwork. Since 2011 the record price for this artist at auction is 87,500 USD for Composition on Red Background, sold at Christie's New York in 2015.'\n",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "answer('Judith Lauand was a Brazilian painter who was born in 1922.\nNumerous key galleries and museums such as MASP, Museu de Arte de São Paulo have featured Judith Lauand's work in the past.\nJudith Lauand's work has been offered at auction multiple times, with realized prices ranging from 515 USD to 87,500 USD, depending on the size and medium of the artwork. Since 2011 the record price for this artist at auction is 87,500 USD for Composition on Red Background, sold at Christie's New York in 2015.')",
|
||||
"role": "assistant"
|
||||
}
|
||||
]
|
||||
converted = convert_to_smolagents(origin)
|
||||
print("CONVERTED:\n", converted)
|
||||
expected_messages = [
|
||||
{
|
||||
"content": SYSTEM_PROMPT,
|
||||
"role": "system"
|
||||
},
|
||||
{
|
||||
"content": "<image>\nPlease generate the next move according to the UI screenshot, instruction and previous actions.\n\nInstruction: What information does the site provide about Judith Lauand's career, works and exhibitions?\n\nPrevious actions:\nStep 1: Click on the link labeled 'Judith Lauand: Brazilian 1922-2022' to explore more about her career and exhibitions.\nStep 2: Click on the 'more' link below the overview text to access additional information about Judith Lauand's career and exhibitions.\nStep 3: Scroll down slightly to view additional information about Judith Lauand's career and exhibitions.",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "<think>The answer is 'Judith Lauand was a Brazilian painter who was born in 1922.\\nNumerous key galleries and museums such as MASP, Museu de Arte de São Paulo have featured Judith Lauand's work in the past.\\nJudith Lauand's work has been offered at auction multiple times, with realized prices ranging from 515 USD to 87,500 USD, depending on the size and medium of the artwork. Since 2011 the record price for this artist at auction is 87,500 USD for Composition on Red Background, sold at Christie's New York in 2015.'\n</think>\n<code>\nfinal_answer(\"The answer is 'Judith Lauand was a Brazilian painter who was born in 1922.\\nNumerous key galleries and museums such as MASP, Museu de Arte de São Paulo have featured Judith Lauand's work in the past.\\nJudith Lauand's work has been offered at auction multiple times, with realized prices ranging from 515 USD to 87,500 USD, depending on the size and medium of the artwork. Since 2011 the record price for this artist at auction is 87,500 USD for Composition on Red Background, sold at Christie's New York in 2015.'\n\")\n</code>",
|
||||
"role": "assistant"
|
||||
},
|
||||
]
|
||||
for i, message in enumerate(converted):
|
||||
if not message == expected_messages[i]:
|
||||
print(f"Message {i} is not equal to expected message")
|
||||
print(f"Expected: {expected_messages[i]}")
|
||||
print(f"Actual: {message}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def convert_to_chat_format(data_item: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Convert data item to chat template format."""
|
||||
# This is a placeholder - you'll need to adapt this based on the actual data structure
|
||||
# The exact conversion depends on how the original data is structured
|
||||
chat_messages = []
|
||||
|
||||
# Example conversion - adapt based on actual data structure
|
||||
if "conversations" in data_item:
|
||||
for conv in data_item["conversations"]:
|
||||
if "from" in conv and "value" in conv:
|
||||
role = "user" if conv["from"] == "human" else "assistant"
|
||||
message = {"role": role, "content": conv["value"]}
|
||||
chat_messages.append(message)
|
||||
elif "instruction" in data_item and "response" in data_item:
|
||||
chat_messages = [
|
||||
{"role": "user", "content": data_item["instruction"]},
|
||||
{"role": "assistant", "content": data_item["response"]},
|
||||
]
|
||||
|
||||
chat_messages = convert_to_smolagents(chat_messages)
|
||||
return chat_messages
|
||||
|
||||
|
||||
def process_subset(config: Dict[str, Any], dataset_path: str, destination_path: str, override_existing: bool = False) -> Callable:
|
||||
"""Process a single dataset subset."""
|
||||
subset_name = config['subset_name']
|
||||
repo_id = "smolagents/aguvis-stage-2"
|
||||
|
||||
# Check if the subset already exists in the remote dataset
|
||||
if check_subset_exists(repo_id, subset_name) and not override_existing:
|
||||
print(f"Subset '{subset_name}' already exists in {repo_id}, skipping processing.")
|
||||
return None
|
||||
|
||||
print(f"Processing split: {subset_name}")
|
||||
|
||||
dataset_dir = Path(dataset_path)
|
||||
images_folder = dataset_dir / config["subset_name"] / config["images_folder"]
|
||||
|
||||
# Find all JSON files that match this split (e.g., mind2web-l1.json, mind2web-l2.json)
|
||||
json_files = []
|
||||
for cfg in config_dict:
|
||||
cfg_split = cfg["json_path"].replace(".json", "").replace("-l1", "").replace("-l2", "")
|
||||
if cfg_split == subset_name:
|
||||
json_path = dataset_dir / cfg["json_path"]
|
||||
if json_path.exists():
|
||||
json_files.append(json_path)
|
||||
|
||||
# Load and merge JSON data from all matching files
|
||||
data = []
|
||||
for json_file in json_files:
|
||||
print(f"Loading data from: {json_file}")
|
||||
with open(json_file, "r") as f:
|
||||
file_data = json.load(f)
|
||||
data.extend(file_data)
|
||||
print(f" Added {len(file_data)} items")
|
||||
|
||||
def process_items() -> Generator[Dict[str, Any], None, None]:
|
||||
pbar = tqdm(data)
|
||||
for item in pbar:
|
||||
# Extract image paths from the data item
|
||||
try:
|
||||
image_paths = []
|
||||
if "images" in item:
|
||||
image_paths = (
|
||||
item["images"]
|
||||
if isinstance(item["images"], list)
|
||||
else [item["images"]]
|
||||
)
|
||||
elif "image" in item:
|
||||
image_paths = [item["image"]]
|
||||
|
||||
# Load images
|
||||
images = load_images_from_folder(images_folder, image_paths)
|
||||
|
||||
texts = convert_to_chat_format(item)
|
||||
|
||||
entry = {"images": images, "texts": texts}
|
||||
entry = convert_row_to_screenenv(entry)
|
||||
yield entry
|
||||
except Exception as e:
|
||||
print(f"Error processing item: {e}", item)
|
||||
continue
|
||||
return process_items
|
||||
|
||||
def convert_row_to_screenenv(example: dict[str, Image.Image | list[dict[str, Any]]]) -> dict[str, Image.Image | list[dict[str, Any]]]:
|
||||
"""
|
||||
Converts the dataset to the action space defined in ScreenEnv: https://github.com/huggingface/screenenv/blob/f8fb60d4e805e4c139f39855c04263f81e82155f/examples/desktop_agent.py#L114
|
||||
Also, converts the action space to absolute coordinates for qwen models.
|
||||
"""
|
||||
# example["texts"][0]["content"] = SYSTEM_PROMPT
|
||||
for i, message in enumerate(example["texts"]):
|
||||
if message["role"] == "assistant":
|
||||
if "click(" in message["content"] or "right_click(" in message["content"] or "double_click(" in message["content"]:
|
||||
# Regex that detects to consecutive floats between parentheses, also preceded by OPTIONAL x= and y=, like (x=1.0, y=2.028) or (1.0, 2.028)
|
||||
pattern = r"(click|right_click|double_click)\((?:x=)?(\d+\.\d+), (?:y=)?(\d+\.\d+)\)"
|
||||
matches = re.finditer(pattern, message["content"])
|
||||
for match in matches:
|
||||
name, x, y = match.groups()
|
||||
assert x is not None and y is not None
|
||||
image_size = example["images"][0].size
|
||||
x_absolute = round(float(x) * image_size[0])
|
||||
y_absolute = round(float(y) * image_size[1])
|
||||
message["content"] = message["content"].replace(match.group(0), f"{name}(x={x_absolute}, y={y_absolute})")
|
||||
|
||||
if "scroll(" in message["content"]:
|
||||
# Convert scroll(page=-0.33) to scroll(direction="down", amount=0.33)
|
||||
pattern = r"scroll\((?:page=)?(-?\d+\.\d+)\)"
|
||||
matches = re.finditer(pattern, message["content"])
|
||||
for match in matches:
|
||||
if float(match.group(1)) < 0:
|
||||
message["content"] = message["content"].replace(match.group(0), f"scroll(direction='up', amount={-1*float(match.group(1))})")
|
||||
else:
|
||||
message["content"] = message["content"].replace(match.group(0), f"scroll(direction='down', amount={float(match.group(1))})")
|
||||
|
||||
if "write(" in message["content"]:
|
||||
# Replace "write(message=...)" with "write(text=...)"
|
||||
message["content"] = message["content"].replace("write(message=", "write(text=")
|
||||
|
||||
if "press(" in message["content"]:
|
||||
message["content"] = message["content"].replace("press(keys=", "press(key=")
|
||||
|
||||
if i == len(example["texts"]) - 1 and not any(action in message["content"] for action in ["click", "right_click", "double_click", "scroll", "write", "press"]):
|
||||
# If no action is detected in the final assistant message, wrap the message in a final_answer call
|
||||
message_content = message["content"]
|
||||
message_content = message_content.replace("<code>", "").replace("</code>", "").strip()
|
||||
message["content"] = f"<code>\nfinal_answer({message_content})\n</code>"
|
||||
return example
|
||||
|
||||
test_sample = {
|
||||
"texts": [
|
||||
{
|
||||
"role": "system", # Should not be changed
|
||||
"content": "click(x=0.5, y=0.5)"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "click(x=0.5, y=0.5)\ndouble_click(x=0.5, y=0.597814)"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "scroll(page=0.33)\nscroll(page=-0.33)"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<code>\nThe answer is 12\n</code>"
|
||||
},
|
||||
],
|
||||
"images": [
|
||||
Image.new("RGB", (100, 100))
|
||||
]
|
||||
}
|
||||
|
||||
test_output = convert_row_to_screenenv(test_sample)
|
||||
assert test_output["texts"][1]["content"] == "click(x=50, y=50)\ndouble_click(x=50, y=60)", test_output["texts"][1]["content"]
|
||||
assert test_output["texts"][2]["content"] == "scroll(direction='down', amount=0.33)\nscroll(direction='up', amount=0.33)", test_output["texts"][2]["content"]
|
||||
assert test_output["texts"][3]["content"] == "<code>\nfinal_answer(The answer is 12)\n</code>", test_output["texts"][3]["content"]
|
||||
|
||||
|
||||
def make_dataset_from_original_data():
|
||||
"""Main function to orchestrate the entire process."""
|
||||
load_dotenv(override=True)
|
||||
|
||||
print("Starting aguvis-stage2 dataset processing...")
|
||||
|
||||
# Step 0: Authenticate with HuggingFace Hub
|
||||
authenticate_huggingface()
|
||||
|
||||
data_folder = Path("./aguvis_raw")
|
||||
|
||||
dataset_path = download_dataset("xlangai/aguvis-stage2", data_folder)
|
||||
|
||||
extract_zip_files(dataset_path)
|
||||
|
||||
dataset_configs = discover_dataset_config(dataset_path)
|
||||
converted_folder = "./aguvis_converted"
|
||||
|
||||
for config in dataset_configs:
|
||||
print(f"\n{'=' * 50}")
|
||||
print(config)
|
||||
process_items = process_subset(config, dataset_path, f"{config['subset_name']}", override_existing=True)
|
||||
|
||||
# Skip if process_subset returned None (subset already exists)
|
||||
if process_items is None:
|
||||
continue
|
||||
|
||||
print("Creating dataset...")
|
||||
data = Dataset.from_generator(process_items)
|
||||
print("Pushing to hub...")
|
||||
# Fix: Use config_name for subset name and split="train"
|
||||
data.push_to_hub(
|
||||
"smolagents/aguvis-stage-2",
|
||||
config_name=config['subset_name'], # This sets the subset name
|
||||
split="train", # This should be "train" not the subset name
|
||||
)
|
||||
|
||||
print(f"Processed and uploaded subset: {config['subset_name']}")
|
||||
|
||||
# Force garbage collection to manage memory
|
||||
gc.collect()
|
||||
|
||||
print(f"Subsets uploaded!")
|
||||
|
||||
# Cleanup
|
||||
print("\nCleaning up temporary files...")
|
||||
# shutil.rmtree(dataset_path, ignore_errors=True)
|
||||
|
||||
# api.upload_large_folder(folder_path=converted_folder, repo_id="smolagents/aguvis-stage-2", repo_type="dataset")
|
||||
|
||||
shutil.rmtree(converted_folder, ignore_errors=True)
|
||||
|
||||
print("All done!")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# for subset in ['guiact-web-single', 'mind2web']:
|
||||
# dataset = load_dataset("smolagents/aguvis-stage-2", subset, split="train", revision="cc2441320a990e930d20732d6375ee2f026d6d19")
|
||||
# print(dataset)
|
||||
|
||||
# dataset = dataset.map(change_coordinates, num_proc=32)
|
||||
|
||||
# dataset.push_to_hub("smolagents/aguvis-stage-2", subset, split="train")
|
||||
make_dataset_from_original_data()
|
||||
23
setup.py
23
setup.py
|
|
@ -21,7 +21,6 @@ from pathlib import Path
|
|||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
# Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
||||
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
|
||||
if stale_egg_info.exists():
|
||||
|
|
@ -42,7 +41,7 @@ if stale_egg_info.exists():
|
|||
# * If a dependency is fast-moving (e.g. trl), pin to the exact version
|
||||
_deps = [
|
||||
"accelerate==1.4.0",
|
||||
"bitsandbytes>=0.43.0",
|
||||
"bitsandbytes>=0.42.0",
|
||||
"datasets>=3.2.0",
|
||||
"deepspeed==0.16.8",
|
||||
"distilabel[vllm,ray,openai]>=1.5.2",
|
||||
|
|
@ -54,8 +53,8 @@ _deps = [
|
|||
"isort>=5.12.0",
|
||||
"jieba", # Needed for Chinese language support
|
||||
"langdetect", # Needed for LightEval's extended tasks
|
||||
"latex2sympy2_extended>=1.0.6",
|
||||
"liger-kernel>=0.5.10",
|
||||
# "latex2sympy2_extended>=1.0.6",
|
||||
# "liger-kernel>=0.5.10",
|
||||
"lighteval @ git+https://github.com/huggingface/lighteval.git@d3da6b9bbf38104c8b5e1acc86f83541f9a502d1", # Critical bug fix for tokenizer revisions: https://github.com/huggingface/lighteval/pull/721
|
||||
"math-verify==0.5.2", # Used for math verification in grpo
|
||||
"morphcloud==0.1.67",
|
||||
|
|
@ -82,7 +81,13 @@ _deps = [
|
|||
# packaging: "packaging"
|
||||
#
|
||||
# some of the values are versioned whereas others aren't.
|
||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
|
||||
deps = {
|
||||
b: a
|
||||
for a, b in (
|
||||
re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0]
|
||||
for x in _deps
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def deps_list(*pkgs):
|
||||
|
|
@ -93,7 +98,9 @@ extras = {}
|
|||
extras["tests"] = deps_list("pytest", "parameterized", "math-verify", "jieba")
|
||||
extras["torch"] = deps_list("torch")
|
||||
extras["quality"] = deps_list("ruff", "isort", "flake8")
|
||||
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv", "morphcloud", "jieba", "pandas", "aiofiles")
|
||||
extras["code"] = deps_list(
|
||||
"e2b-code-interpreter", "python-dotenv", "morphcloud", "jieba", "pandas", "aiofiles"
|
||||
)
|
||||
extras["eval"] = deps_list("lighteval", "math-verify")
|
||||
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["code"]
|
||||
|
||||
|
|
@ -107,9 +114,9 @@ install_requires = [
|
|||
deps["hf_transfer"],
|
||||
deps["huggingface-hub"],
|
||||
deps["langdetect"],
|
||||
deps["latex2sympy2_extended"],
|
||||
# deps["latex2sympy2_extended"],
|
||||
deps["math-verify"],
|
||||
deps["liger-kernel"],
|
||||
# deps["liger-kernel"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["safetensors"],
|
||||
deps["sentencepiece"],
|
||||
|
|
|
|||
23
slurm/agents/agentic_generation.slurm
Normal file
23
slurm/agents/agentic_generation.slurm
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=agentic-r1
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --output=./logs/%x_%j_%n.out
|
||||
#SBATCH --error=./logs/%x_%j_%n.err
|
||||
#SBATCH --time=7-00:00:00
|
||||
set -exuo pipefail
|
||||
|
||||
source ~/.bashrc
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate /fsx/aymeric/venv
|
||||
|
||||
python scripts/generate_agent_traces.py \
|
||||
--output-file "codeforces_agentic_generations.jsonl" \
|
||||
--prompt-column "prompt" \
|
||||
--uuid-column "contestId" \
|
||||
--api-addr "10.53.83.199:39876" \
|
||||
--num-generations 5 \
|
||||
--max-tokens 8096 \
|
||||
--max-concurrent 64
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=open_r1
|
||||
#SBATCH --job-name=agent-sft
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --partition=hopper-prod # Adjust this for your cluster
|
||||
#SBATCH --output=./logs/%x-%j.out
|
||||
|
|
@ -23,6 +23,7 @@ if [[ "$*" == *"--help"* ]]; then
|
|||
exit 0
|
||||
fi
|
||||
|
||||
|
||||
# Specific configuration optimized for the Hugging Face Compute Cluster
|
||||
module load cuda/12.4
|
||||
set -x -e
|
||||
|
|
@ -32,9 +33,9 @@ source openr1/bin/activate
|
|||
START_TIME=$(date +%s)
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
# Refresh Weka on h4 cache
|
||||
echo "Refreshing Weka filesystem..."
|
||||
find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
|
||||
# Refresh Weka on cache
|
||||
# echo "Refreshing Weka filesystem..."
|
||||
# find -L /fsx/aymeric/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
|
||||
|
||||
# Default values
|
||||
MODEL=""
|
||||
|
|
@ -45,6 +46,7 @@ DP=1
|
|||
TP=1
|
||||
OPTIONAL_ARGS=""
|
||||
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
|
|
@ -84,6 +86,9 @@ while [[ $# -gt 0 ]]; do
|
|||
esac
|
||||
done
|
||||
|
||||
export HF_HOME="/fsx/aymeric/.cache/"
|
||||
HF_HOME="/fsx/aymeric/.cache/"
|
||||
|
||||
# Validate required arguments
|
||||
if [[ -z "$MODEL" || -z "$TASK" || -z "$CONFIG_SUFFIX" || -z "$ACCELERATOR" ]]; then
|
||||
echo "Error: Missing required arguments"
|
||||
|
|
@ -106,8 +111,6 @@ for arg in "${ARGS[@]}"; do
|
|||
fi
|
||||
done
|
||||
|
||||
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
|
||||
|
||||
MODEL=$(grep 'model_name_or_path:' $CONFIG_FILE | awk '{print $2}')
|
||||
REVISION=$(grep 'model_revision:' $CONFIG_FILE | head -n 1 | awk '{print $2}')
|
||||
|
||||
|
|
@ -136,8 +139,8 @@ if [[ "$USE_VLLM" == "true" ]]; then
|
|||
fi
|
||||
|
||||
# force crashing on nccl issues like hanging broadcast
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1
|
||||
# export NCCL_DEBUG=INFO
|
||||
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
|
||||
export NCCL_DEBUG=INFO
|
||||
# export NCCL_DEBUG_SUBSYS=COLL
|
||||
# export NCCL_SOCKET_NTHREADS=1
|
||||
# export NCCL_NSOCKS_PERTHREAD=1
|
||||
|
|
|
|||
107
slurm/trainsingle.slurm
Normal file
107
slurm/trainsingle.slurm
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=agent-sft-single
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gres=gpu:1
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --output=./logs/%x-%j.out
|
||||
#SBATCH --error=./logs/%x-%j.err
|
||||
#SBATCH --requeue
|
||||
#SBATCH --time=3-00:00:00
|
||||
|
||||
if [[ "$*" == *"--help"* ]]; then
|
||||
echo "Usage: sbatch slurm/train_single.slurm [options]"
|
||||
echo "Options:"
|
||||
echo " --model MODEL Model name"
|
||||
echo " --task TASK Task name (e.g. sft, grpo)"
|
||||
echo " --config SUFFIX Configuration suffix (e.g. demo, v00.00)"
|
||||
echo " --args \"ARGS\" Optional arguments to pass to the training script"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Specific configuration optimized for the Hugging Face Compute Cluster
|
||||
module load cuda/12.4
|
||||
set -x -e
|
||||
|
||||
source ~/.bashrc
|
||||
source openr1/bin/activate
|
||||
START_TIME=$(date +%s)
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
# Default values
|
||||
MODEL=""
|
||||
TASK=""
|
||||
CONFIG_SUFFIX=""
|
||||
OPTIONAL_ARGS=""
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--model)
|
||||
MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--task)
|
||||
TASK="$2"
|
||||
shift 2
|
||||
;;
|
||||
--config)
|
||||
CONFIG_SUFFIX="$2"
|
||||
shift 2
|
||||
;;
|
||||
--args)
|
||||
OPTIONAL_ARGS="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
echo "Use --help for usage information"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
export HF_HOME="/fsx/aymeric/.cache/"
|
||||
|
||||
# Validate required arguments
|
||||
if [[ -z "$MODEL" || -z "$TASK" || -z "$CONFIG_SUFFIX" ]]; then
|
||||
echo "Error: Missing required arguments"
|
||||
echo "Run with --help for usage information"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
|
||||
|
||||
# Extract model info from config (for potential vLLM usage)
|
||||
MODEL_PATH=$(grep 'model_name_or_path:' $CONFIG_FILE | awk '{print $2}')
|
||||
REVISION=$(grep 'model_revision:' $CONFIG_FILE | head -n 1 | awk '{print $2}')
|
||||
|
||||
# Check if vLLM is needed (though unlikely for single GPU training)
|
||||
USE_VLLM="false"
|
||||
if [[ -f "$CONFIG_FILE" ]] && grep -qE '^\s*use_vllm:\s*true' "$CONFIG_FILE"; then
|
||||
USE_VLLM="true"
|
||||
echo "Warning: vLLM usage detected in single GPU setup. This may not be optimal."
|
||||
fi
|
||||
|
||||
# Set up environment for better debugging
|
||||
export CUDA_LAUNCH_BLOCKING=1
|
||||
export TRANSFORMERS_VERBOSITY=info
|
||||
|
||||
# Build command
|
||||
export CMD="src/open_r1/$TASK.py --config $CONFIG_FILE $OPTIONAL_ARGS"
|
||||
|
||||
echo "Running command: python $CMD"
|
||||
echo "Config file: $CONFIG_FILE"
|
||||
echo "Model: $MODEL_PATH"
|
||||
echo "Revision: $REVISION"
|
||||
|
||||
# Run training directly with python (no accelerate needed for single GPU)
|
||||
python $CMD 2>&1
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
echo "END TIME: $(date)"
|
||||
ELAPSED_SECONDS=$((END_TIME - START_TIME))
|
||||
HOURS=$((ELAPSED_SECONDS / 3600))
|
||||
MINUTES=$(( (ELAPSED_SECONDS % 3600) / 60 ))
|
||||
SECONDS=$((ELAPSED_SECONDS % 60))
|
||||
echo "TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)"
|
||||
|
|
@ -74,6 +74,10 @@ class ScriptArguments(trl.ScriptArguments):
|
|||
default=None,
|
||||
metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."},
|
||||
)
|
||||
single_gpu: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Force training on single GPU only, disabling distributed training."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.dataset_mixture is None:
|
||||
|
|
@ -185,6 +189,10 @@ class SFTConfig(trl.SFTConfig):
|
|||
default=None,
|
||||
metadata={"help": "The optional system prompt to use for benchmarking."},
|
||||
)
|
||||
vision_model: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether this is a vision-language model training."},
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The Hub model branch to push the model to."},
|
||||
|
|
|
|||
|
|
@ -25,6 +25,19 @@ from typing import Callable, Dict, Literal, Optional
|
|||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
<<<<<<< HEAD
|
||||
from .utils import is_e2b_available
|
||||
from .utils.ioi import SubtaskResult, add_includes, get_piston_client_from_env, score_subtask
|
||||
|
||||
|
||||
if is_e2b_available():
|
||||
from dotenv import load_dotenv
|
||||
from e2b_code_interpreter import AsyncSandbox, Sandbox
|
||||
|
||||
load_dotenv()
|
||||
else:
|
||||
AsyncSandbox = None
|
||||
=======
|
||||
from .utils.code_providers import get_provider
|
||||
from .utils.competitive_programming import (
|
||||
SubtaskResult,
|
||||
|
|
@ -35,6 +48,7 @@ from .utils.competitive_programming import (
|
|||
from .utils.competitive_programming import patch_code as cf_patch_code
|
||||
from .utils.competitive_programming import score_submission as cf_score_submission
|
||||
from .utils.competitive_programming import score_subtask
|
||||
>>>>>>> main
|
||||
|
||||
|
||||
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
|
||||
|
|
@ -592,6 +606,70 @@ def code_reward(
|
|||
return execution_provider.execute_scripts(scripts, ["python"] * len(scripts))
|
||||
|
||||
|
||||
def run_tests(completions, **kwargs) -> list[float]:
|
||||
"""Reward function that evaluates code snippets using the E2B code interpreter.
|
||||
|
||||
Assumes the dataset contains a `verification_info` column with test cases.
|
||||
"""
|
||||
if not is_e2b_available():
|
||||
raise ImportError(
|
||||
"E2B is not available and required for this reward function. Please install E2B with "
|
||||
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
|
||||
)
|
||||
|
||||
evaluation_script_template = """
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
def evaluate_code(code, test_cases):
|
||||
passed = 0
|
||||
total = len(test_cases)
|
||||
exec_timeout = 20
|
||||
|
||||
for case in test_cases:
|
||||
process = subprocess.run(
|
||||
["python3", "-c", code],
|
||||
input=case["input"],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=exec_timeout
|
||||
)
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = "Process exited with code" + process.returncode
|
||||
if process.stderr:
|
||||
error_msg += ": Error: " + process.stderr
|
||||
if process.stdout:
|
||||
error_msg += ": Output: " + process.stdout
|
||||
raise Exception(error_msg)
|
||||
|
||||
output = process.stdout.strip()
|
||||
|
||||
# TODO: implement a proper validator to compare against ground truth. For now we just check for exact string match on each line of stdout.
|
||||
for line1, line2 in zip(output.split('\\n'), str(case['output']).split('\\n')):
|
||||
if not line1.strip() == line2.strip():
|
||||
raise Exception("Function output did not match gold truth for test case "+ case['input'] + " : Got " + str(line1.strip()) + " instead of " + str(line2.strip()))
|
||||
|
||||
code_snippet = {code}
|
||||
test_cases = json.loads({test_cases})
|
||||
|
||||
evaluate_code(code_snippet, test_cases)
|
||||
"""
|
||||
code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
|
||||
verification_info = kwargs["verification_info"]
|
||||
scripts = [
|
||||
evaluation_script_template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"])))
|
||||
for code, info in zip(code_snippets, verification_info)
|
||||
]
|
||||
|
||||
language = verification_info[0]["language"]
|
||||
|
||||
if not all(v["language"] == language for v in verification_info):
|
||||
raise ValueError("All verification_info must have the same language", verification_info)
|
||||
|
||||
return [run_script_test(script, language) for script in scripts]
|
||||
|
||||
|
||||
def get_code_format_reward(language: str = "python"):
|
||||
"""Format reward function specifically for code responses.
|
||||
|
||||
|
|
@ -643,6 +721,14 @@ def get_soft_overlong_punishment(max_completion_len, soft_punish_cache):
|
|||
return soft_overlong_punishment_reward
|
||||
|
||||
|
||||
def run_script_test(script: str, language: str) -> float:
|
||||
sandbox = Sandbox(timeout=30, request_timeout=30)
|
||||
execution = sandbox.run_code(script, language=language)
|
||||
if execution.error:
|
||||
raise Exception(f"{execution.logs}\n{execution.error.name}: {execution.error.value}")
|
||||
|
||||
|
||||
|
||||
def get_reward_funcs(script_args) -> list[Callable]:
|
||||
REWARD_FUNCS_REGISTRY = {
|
||||
"accuracy": accuracy_reward,
|
||||
|
|
|
|||
|
|
@ -13,18 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Supervised fine-tuning script for decoder language models.
|
||||
Supervised fine-tuning script for decoder language models and vision-language models.
|
||||
|
||||
Usage:
|
||||
|
||||
# One 1 node of 8 x H100s
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \
|
||||
--dataset_name open-r1/Mixture-of-Thoughts \
|
||||
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
|
||||
--dataset_name smolagents/gaia-traces \
|
||||
--num_train_epochs 1 \
|
||||
--dataset_config all \
|
||||
--eos_token '<|im_end|>' \
|
||||
--learning_rate 4.0e-5 \
|
||||
--num_train_epochs 5 \
|
||||
--max_seq_length 32768 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_checkpointing \
|
||||
|
|
@ -39,20 +39,100 @@ import sys
|
|||
|
||||
import datasets
|
||||
import transformers
|
||||
from transformers import set_seed
|
||||
from transformers import set_seed, AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.configs import ScriptArguments, SFTConfig
|
||||
from open_r1.utils import get_dataset, get_model, get_tokenizer
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format
|
||||
|
||||
from open_r1.configs import ScriptArguments, SFTConfig
|
||||
from open_r1.utils import get_dataset, get_model, get_tokenizer, get_processor
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from transformers import Qwen2VLProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
def create_vlm_collate_fn(processor):
|
||||
"""Create a data collator for VLM training that handles images and text."""
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
def collate_fn(examples):
|
||||
# Convert dataset format to Qwen2.5-VL message format
|
||||
batch_messages = []
|
||||
|
||||
for example in examples:
|
||||
example_texts = example["texts"]
|
||||
example_images = example["images"]
|
||||
|
||||
# Convert to Qwen2.5-VL structured message format
|
||||
messages = []
|
||||
for i, msg in enumerate(example_texts):
|
||||
if msg["role"] == "user" and i == 0 and example_images:
|
||||
# First user message - add images
|
||||
content = []
|
||||
# Add images first
|
||||
for img in example_images:
|
||||
content.append({"type": "image", "image": img})
|
||||
# Then add text
|
||||
content.append({"type": "text", "text": msg["content"]})
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
# Regular text message
|
||||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
batch_messages.append(messages)
|
||||
|
||||
# Process each example
|
||||
texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
|
||||
for messages in batch_messages:
|
||||
# Apply chat template
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||
texts.append(text)
|
||||
|
||||
# Extract vision info
|
||||
image_inputs, _ = process_vision_info(messages)
|
||||
all_image_inputs.extend(image_inputs if image_inputs else [])
|
||||
|
||||
# Process the batch
|
||||
batch = processor(
|
||||
text=texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
if isinstance(processor, Qwen2VLProcessor):
|
||||
logger.info("DETECTED PROCESSOR")
|
||||
image_tokens = [151652,151653,151655]
|
||||
else:
|
||||
image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
|
||||
for image_token_id in image_tokens:
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
return collate_fn
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Force single GPU mode if requested
|
||||
# if hasattr(script_args, 'single_gpu') and script_args.single_gpu:
|
||||
# logger.info("Single GPU mode requested - setting CUDA_VISIBLE_DEVICES=0")
|
||||
# # Disable distributed training
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
# training_args.local_rank = -1
|
||||
# training_args.ddp_backend = None
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
|
|
@ -85,15 +165,40 @@ def main(script_args, training_args, model_args):
|
|||
init_wandb_training(training_args)
|
||||
|
||||
######################################
|
||||
# Load dataset, tokenizer, and model #
|
||||
# Load dataset, processor/tokenizer, and model #
|
||||
######################################
|
||||
dataset = get_dataset(script_args)
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
model = get_model(model_args, training_args)
|
||||
|
||||
if tokenizer.chat_template is None:
|
||||
logger.info("No chat template provided, defaulting to ChatML.")
|
||||
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
|
||||
if training_args.vision_model:
|
||||
logger.info("Setting up vision-language model training")
|
||||
|
||||
# Set VLM-specific training arguments (following TRL reference)
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
training_args.ddp_find_unused_parameters = True
|
||||
|
||||
# Load processor and model for VLM
|
||||
processor = get_processor(model_args, training_args)
|
||||
model = get_model(model_args, training_args) # This should return AutoModelForVision2Seq
|
||||
data_collator = create_vlm_collate_fn(processor)
|
||||
processing_class = processor.tokenizer
|
||||
model_tags = ["open-r1", "vision-language", "vlm"]
|
||||
|
||||
else:
|
||||
logger.info("Setting up text-only model training")
|
||||
|
||||
# Load tokenizer and model for text-only
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
model = get_model(model_args, training_args)
|
||||
|
||||
if tokenizer.chat_template is None:
|
||||
logger.info("No chat template provided, defaulting to ChatML.")
|
||||
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
|
||||
|
||||
data_collator = None # Use default
|
||||
processing_class = tokenizer
|
||||
model_tags = ["open-r1"]
|
||||
|
||||
############################
|
||||
# Initialize the SFT Trainer
|
||||
|
|
@ -101,9 +206,14 @@ def main(script_args, training_args, model_args):
|
|||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
|
||||
processing_class=tokenizer,
|
||||
eval_dataset=(
|
||||
dataset[script_args.dataset_test_split]
|
||||
if training_args.eval_strategy != "no"
|
||||
else None
|
||||
),
|
||||
processing_class=processing_class,
|
||||
peft_config=get_peft_config(model_args),
|
||||
callbacks=get_callbacks(training_args, model_args),
|
||||
)
|
||||
|
|
@ -128,16 +238,13 @@ def main(script_args, training_args, model_args):
|
|||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
# Align the model's generation config with the tokenizer's eos token
|
||||
# to avoid unbounded generation in the transformers `pipeline()` function
|
||||
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"dataset_name": script_args.dataset_name,
|
||||
"tags": ["open-r1"],
|
||||
"tags": model_tags,
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
|
@ -160,7 +267,10 @@ def main(script_args, training_args, model_args):
|
|||
#############
|
||||
if training_args.push_to_hub:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
trainer.push_to_hub(**kwargs, token=os.getenv("HF_TOKEN"))
|
||||
# Also push processor for VLM models
|
||||
if training_args.vision_model and trainer.accelerator.is_main_process:
|
||||
processor.push_to_hub(training_args.hub_model_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from .data import get_dataset
|
||||
from .import_utils import is_e2b_available, is_morph_available
|
||||
from .model_utils import get_model, get_tokenizer
|
||||
from .model_utils import get_model, get_tokenizer, get_processor
|
||||
|
||||
|
||||
__all__ = ["get_tokenizer", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]
|
||||
__all__ = ["get_tokenizer", "get_processor", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, AutoProcessor, AutoModelForVision2Seq
|
||||
|
||||
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
|
||||
|
||||
|
|
@ -20,8 +20,22 @@ def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
|
|||
return tokenizer
|
||||
|
||||
|
||||
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
|
||||
"""Get the model"""
|
||||
def get_processor(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoProcessor:
|
||||
"""Get the processor for VLM models."""
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
|
||||
if training_args.chat_template is not None:
|
||||
processor.chat_template = training_args.chat_template
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM | AutoModelForVision2Seq:
|
||||
"""Get the model - supports both text-only and vision-language models"""
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
|
|
@ -35,8 +49,19 @@ def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) ->
|
|||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Check if this is a VLM model using the explicit flag
|
||||
if hasattr(training_args, 'vision_model') and training_args.vision_model:
|
||||
# Load as vision-language model
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
# Load as text-only model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue