Compare commits

..

16 commits

Author SHA1 Message Date
amir.mahla@huggingface.co
31a99af5bb CHG collator 2025-08-10 12:03:15 +00:00
amir.mahla@huggingface.co
76c56724ee CHG collator 2025-08-10 10:23:19 +00:00
amir.mahla@huggingface.co
b5a27167f1 CHG transform_messages 2025-08-08 16:38:05 +00:00
amir.mahla@huggingface.co
dca3a06ada PUSH last config 2025-08-08 15:35:50 +00:00
amir.mahla@huggingface.co
803c468507 ADD aguvis-stage-1 dataprocessing 2025-08-04 16:17:40 +00:00
amir.mahla@huggingface.co
afbd97b1ec ADD new config dataset 2025-07-30 12:46:11 +00:00
amir.mahla@huggingface.co
55c49d66c3 CHG recipe - add eval step 2025-07-24 16:29:53 +00:00
amir.mahla@huggingface.co
4c89c85fff CHG recipe 2025-07-24 16:02:14 +00:00
amir.mahla@huggingface.co
2ef6b50ccd FIX collator 2025-07-24 15:54:58 +00:00
amir.mahla@huggingface.co
7852ddefc8 CHG qwenvl collator 2025-07-24 13:58:04 +00:00
amir.mahla@huggingface.co
648a523325 Training script for qwenvl 2025-07-24 13:55:13 +00:00
amir.mahla@huggingface.co
342f8f7856 Training script for qwenvl 2025-07-24 13:38:17 +00:00
amir.mahla@huggingface.co
02819cf0ab Deleted action 2025-07-21 12:43:26 +00:00
amir.mahla@huggingface.co
4a55c49641 Aguvis dataset transform 2025-07-21 12:41:02 +00:00
Amir Mahla
dfcaecc92c Aguvis Data Pipeline Done 2025-07-21 10:41:59 +02:00
amir.mahla@huggingface.co
c13574e28a ADD function_parser 2025-07-17 22:43:57 +00:00
28 changed files with 3136 additions and 969 deletions

View file

@ -16,9 +16,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
uses: actions/checkout@v4
- name: Setup Python environment
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
uses: actions/setup-python@v5
with:
python-version: 3.10.10
- name: Install dependencies

View file

@ -299,7 +299,6 @@ Make sure your dataset contains a `verification_info` column with the following
}
],
}
```
For example, to train a smol model on Python problems, start the vLLM server:

View file

@ -0,0 +1,61 @@
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from datasets import load_dataset
# default: Load the model on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"smolagents/Qwen2.5-VL-3B-Instruct-Agentic", torch_dtype="auto", device_map="auto"
)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# "Qwen/Qwen2.5-VL-3B-Instruct",
# torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
# default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
# The default range for the number of visual tokens per image in the model is 4-16384.
# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
dataset = load_dataset("smolagents/aguvis-stage-2", "mind2web", split="train")
for example in dataset:
messages = [
{"role": "system", "content": example["system"]},
{"role": "user", "content": [
{"type": "image", "image": example["image"]},
{"type": "text", "text": example["user"]}
]},
]
break
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=4096)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

View file

@ -0,0 +1,142 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
vision_model: true
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: smolagents/aguvis-stage-2
dataset_num_proc: 48
#SFT hyperparam
max_length: 4096
optim: adamw_torch
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_grad_norm: 0.2
warmup_ratio: 0.03
learning_rate: 2.0e-05
gradient_accumulation_steps: 16
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
# Image resize arguments
image_resize:
factor: 28
min_pixels: 200704
max_pixels: 1003520
# SFT trainer config
max_steps: -1
num_train_epochs: 1
bf16: true
do_eval: true
eval_strategy: 'steps'
eval_steps: 100
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: A-Mahla/Qwen2.5-VL-3B-Instruct-Agentic-GUI
hub_strategy: end
push_to_hub: true
log_level: info
logging_steps: 5
logging_strategy: steps
output_dir: /fsx/amir_mahla/smolagents-Qwen2.5-VL-3B-Instruct-Agentic
overwrite_output_dir: true
report_to:
- wandb
wandb_project: smolagents
save_strategy: "epoch"
save_steps: 1
save_total_limit: 1
seed: 42
dataset_mixture:
datasets: # List of datasets to include in the mixture
- id: smolagents/aguvis-stage-2 # Hub dataset ID
config: mind2web # Name of the dataset config
split: train # Split to use from the dataset
columns: # Columns to keep
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: guiact-web-single
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: guiact-web-multi
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: miniwob
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: coat
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: android_control
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: gui-odyssey
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: amex
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: aitw
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
seed: 42 # Seed for shuffling the combined dataset
test_split_size: 0.01

View file

@ -0,0 +1,116 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct
vision_model: true
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: smolagents/aguvis-stage-2
dataset_num_proc: 48
#SFT hyperparam
max_length: 4096
optim: adamw_torch
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_grad_norm: 0.2
warmup_ratio: 0.03
learning_rate: 2.0e-05
gradient_accumulation_steps: 32
per_device_eval_batch_size: 2
per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS.
image_resize:
resolution_max_side: 1152
to_pixel_coordinates: true
# SFT trainer config
max_steps: -1
num_train_epochs: 1
bf16: true
do_eval: false
eval_strategy: 'steps'
eval_steps: 100
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI
hub_model_revision: main
hub_strategy: end
push_to_hub: false
log_level: info
logging_steps: 5
logging_strategy: steps
output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-1152-pixel-coordinates
overwrite_output_dir: true
report_to:
- wandb
wandb_project: smolagents
save_strategy: steps
save_steps: 800
save_total_limit: 1
seed: 42
dataset_mixture:
datasets: # List of datasets to include in the mixture
- id: smolagents/aguvis-stage-1 # Hub dataset ID
config: guienv # Name of the dataset config
split: train # Split to use from the dataset
columns: # Columns to keep
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: omniact
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ricoig16k
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ricosca
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: seeclick
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ui_refexp
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: webui350k
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: widget_captioning
split: train
columns:
- images
- texts
weight: 1.
seed: 42 # Seed for shuffling the combined dataset
test_split_size: 0.007

View file

@ -0,0 +1,116 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct
vision_model: true
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: smolagents/aguvis-stage-2
dataset_num_proc: 48
#SFT hyperparam
max_length: 4096
optim: adamw_torch
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_grad_norm: 0.2
warmup_ratio: 0.03
learning_rate: 2.0e-05
gradient_accumulation_steps: 32
per_device_eval_batch_size: 2
per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS.
image_resize:
resolution_max_side: 384
to_pixel_coordinates: true
# SFT trainer config
max_steps: -1
num_train_epochs: 1
bf16: true
do_eval: false
eval_strategy: 'steps'
eval_steps: 100
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI
hub_model_revision: main
hub_strategy: end
push_to_hub: false
log_level: info
logging_steps: 5
logging_strategy: steps
output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-384-pixel-coordinates
overwrite_output_dir: true
report_to:
- wandb
wandb_project: smolagents
save_strategy: steps
save_steps: 800
save_total_limit: 1
seed: 42
dataset_mixture:
datasets: # List of datasets to include in the mixture
- id: smolagents/aguvis-stage-1 # Hub dataset ID
config: guienv # Name of the dataset config
split: train # Split to use from the dataset
columns: # Columns to keep
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: omniact
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ricoig16k
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ricosca
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: seeclick
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ui_refexp
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: webui350k
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: widget_captioning
split: train
columns:
- images
- texts
weight: 1.
seed: 42 # Seed for shuffling the combined dataset
test_split_size: 0.007

View file

@ -0,0 +1,116 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct
vision_model: true
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: smolagents/aguvis-stage-2
dataset_num_proc: 48
#SFT hyperparam
max_length: 4096
optim: adamw_torch
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_grad_norm: 0.2
warmup_ratio: 0.03
learning_rate: 2.0e-05
gradient_accumulation_steps: 32
per_device_eval_batch_size: 2
per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS.
image_resize:
resolution_max_side: 764
to_pixel_coordinates: true
# SFT trainer config
max_steps: -1
num_train_epochs: 1
bf16: true
do_eval: false
eval_strategy: 'steps'
eval_steps: 100
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI
hub_model_revision: main
hub_strategy: end
push_to_hub: false
log_level: info
logging_steps: 5
logging_strategy: steps
output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-764-pixel-coordinates
overwrite_output_dir: true
report_to:
- wandb
wandb_project: smolagents
save_strategy: steps
save_steps: 800
save_total_limit: 1
seed: 42
dataset_mixture:
datasets: # List of datasets to include in the mixture
- id: smolagents/aguvis-stage-1 # Hub dataset ID
config: guienv # Name of the dataset config
split: train # Split to use from the dataset
columns: # Columns to keep
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: omniact
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ricoig16k
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ricosca
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: seeclick
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: ui_refexp
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: webui350k
split: train
columns:
- images
- texts
weight: 1.
- id: smolagents/aguvis-stage-1
config: widget_captioning
split: train
columns:
- images
- texts
weight: 1.
seed: 42 # Seed for shuffling the combined dataset
test_split_size: 0.007

View file

@ -0,0 +1,137 @@
# Model arguments
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct
vision_model: true
model_revision: main
torch_dtype: bfloat16
attn_implementation: sdpa
# Data training arguments
dataset_name: smolagents/aguvis-stage-2
dataset_num_proc: 48
#SFT hyperparam
max_length: 4096
optim: adamw_torch
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_grad_norm: 0.2
warmup_ratio: 0.03
learning_rate: 2.0e-05
gradient_accumulation_steps: 16
per_device_eval_batch_size: 4
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
# SFT trainer config
max_steps: -1
num_train_epochs: 1
bf16: true
do_eval: true
eval_strategy: 'steps'
eval_steps: 100
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI
hub_model_revision: max-resolution-1152-without-system
hub_strategy: end
push_to_hub: true
log_level: info
logging_steps: 5
logging_strategy: steps
output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-max-resolution-1152-without-system
overwrite_output_dir: true
report_to:
- wandb
wandb_project: smolagents
save_strategy: "epoch"
save_steps: 1
save_total_limit: 1
seed: 42
dataset_mixture:
datasets: # List of datasets to include in the mixture
- id: smolagents/aguvis-stage-2 # Hub dataset ID
config: mind2web # Name of the dataset config
split: train # Split to use from the dataset
columns: # Columns to keep
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: guiact-web-single
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: guiact-web-multi
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: miniwob
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: coat
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: android_control
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: gui-odyssey
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: amex
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
- id: smolagents/aguvis-stage-2
config: aitw
split: train
columns:
- system
- user
- assistant
- image
weight: 1.
seed: 42 # Seed for shuffling the combined dataset
test_split_size: 0.01

View file

@ -0,0 +1,194 @@
from function_parser import FunctionCall
from copy import deepcopy
# from aguvis aguvis action space to custom action space:
# mobile.home() -> navigate_home()
# mobile.open_app(app_name='drupe') -> open_app(app_name: str) -> str:
# mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518]) -> swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])
# mobile.back() -> navigate_back()
# mobile.long_press(x=0.799, y=0.911) -> long_press(x, y)
# mobile.terminate(status='success') -> final_answer(answer: str)
# answer('text') -> final_answer('text') OK
# mobile.wait(seconds=3) -> wait(seconds=3) OK
# pyautogui.hscroll(page=-0.1)
# ?
# pyautogui.scroll(page=-0.1) or pyautogui.scroll(0.13) OK
# -> negative: scroll(direction: Literal["up", "down"] = "up", amount: int = abs(page * 10))
# -> positive: scroll(direction: Literal["up", "down"] = "down", amount: int = abs(page * 10))
# pyautogui.click(x=0.8102, y=0.9463) -> click(x: int, y: int) OK
# pyautogui.doubleClick() -> double_click() OK
# pyautogui.hotkey(keys=['ctrl', 'c']) -> press(keys: str | list) OK
# pyautogui.press(keys='enter') or pyautogui.press(keys=['enter']) -> press(keys: str | list) OK
# pyautogui.moveTo(x=0.04, y=0.405) -> move_mouse(x: int, y: int) OK
# pyautogui.write(message='bread buns') -> type(text: str) OK
# pyautogui.dragTo(x=0.8102, y=0.9463) -> drag(x1, y1, x2, y2) OK but to recheck formatage in official dataset
def convert_to_pixel_coordinates(action: FunctionCall, resolution: tuple[int, int]) -> None:
if "arg_0" in action.parameters:
if isinstance(action.parameters["arg_0"], (list, tuple)):
action.parameters["from_coord"] = (int(action.parameters["arg_0"][0] * resolution[0]), int(action.parameters["arg_0"][1] * resolution[1]))
else:
action.parameters["x"] = int(action.parameters["arg_0"] * resolution[0])
del action.parameters["arg_0"]
if "arg_1" in action.parameters:
if isinstance(action.parameters["arg_1"], (list, tuple)):
action.parameters["to_coord"] = (int(action.parameters["arg_1"][0] * resolution[0]), int(action.parameters["arg_1"][1] * resolution[1]))
else:
action.parameters["y"] = int(action.parameters["arg_1"] * resolution[1])
del action.parameters["arg_1"]
def change_argument_name(action: FunctionCall) -> None:
if "arg_0" in action.parameters:
if isinstance(action.parameters["arg_0"], (list, tuple)):
action.parameters["from_coord"] = (float(action.parameters["arg_0"][0]), float(action.parameters["arg_0"][1]))
else:
action.parameters["x"] = float(action.parameters["arg_0"])
del action.parameters["arg_0"]
if "arg_1" in action.parameters:
if isinstance(action.parameters["arg_1"], (list, tuple)):
action.parameters["to_coord"] = (float(action.parameters["arg_1"][0]), float(action.parameters["arg_1"][1]))
else:
action.parameters["y"] = float(action.parameters["arg_1"])
del action.parameters["arg_1"]
def rename_parameters(action: FunctionCall) -> None:
"""
Reorder FunctionCall parameters to use arg_0, arg_1, arg_2, etc. as keys.
Preserves the order of the original parameters.
Args:
action: FunctionCall object to reorder parameters for
"""
if not action.parameters:
return
for i, (key, value) in enumerate(deepcopy(action.parameters).items()):
tmp = value
del action.parameters[key]
action.parameters[f"arg_{i}"] = tmp
def action_conversion(
actions: list[FunctionCall], resolution: tuple[int, int]
) -> list[FunctionCall]:
for i, action in enumerate(actions):
rename_parameters(action)
# MOBILE ACTIONS
if action.function_name == "mobile.home":
actions[i].function_name = "navigate_home"
elif action.function_name == "mobile.open_app":
actions[i].function_name = "open_app"
elif action.function_name == "mobile.swipe":
actions[i].function_name = "swipe"
change_argument_name(actions[i])
elif action.function_name == "mobile.back":
actions[i].function_name = "navigate_back"
elif action.function_name == "mobile.long_press":
actions[i].function_name = "long_press"
change_argument_name(actions[i])
elif action.function_name in ["mobile.terminate", "answer"]:
actions[i].function_name = "final_answer"
elif action.function_name == "mobile.wait":
actions[i].function_name = "wait"
if "arg_0" in actions[i].parameters:
actions[i].parameters["seconds"] = int(actions[i].parameters["arg_0"])
del actions[i].parameters["arg_0"]
# OS ACTION
elif action.function_name == "pyautogui.click":
actions[i].function_name = "click"
change_argument_name(actions[i])
elif action.function_name == "pyautogui.doubleClick":
actions[i].function_name = "double_click"
change_argument_name(actions[i])
elif action.function_name == "pyautogui.rightClick":
actions[i].function_name = "right_click"
change_argument_name(actions[i])
elif action.function_name in ["pyautogui.hotkey", "pyautogui.press"]:
actions[i].function_name = "press"
if "arg_0" in actions[i].parameters:
actions[i].parameters["keys"] = actions[i].parameters["arg_0"]
del actions[i].parameters["arg_0"]
elif action.function_name == "pyautogui.moveTo":
actions[i].function_name = "move_mouse"
change_argument_name(actions[i])
elif action.function_name == "pyautogui.write":
actions[i].function_name = "type"
elif action.function_name in ["pyautogui.scroll", "pyautogui.hscroll"]:
arg_value = actions[i].parameters["arg_0"]
if arg_value < 0:
if action.function_name == "pyautogui.hscroll":
actions[i].parameters["direction"] = "left"
else:
actions[i].parameters["direction"] = "up"
else:
if action.function_name == "pyautogui.hscroll":
actions[i].parameters["direction"] = "right"
else:
actions[i].parameters["direction"] = "down"
del actions[i].parameters["arg_0"]
actions[i].function_name = "scroll"
actions[i].parameters["amount"] = int(abs(arg_value * 100))
elif action.function_name == "pyautogui.dragTo":
actions[i].function_name = "drag"
change_argument_name(actions[i])
else:
ValueError("Error FonctionCall Formatting")
actions[i].original_string = actions[i].to_string()
return actions
if __name__ == "__main__":
from function_parser import FunctionCall
# Example actions for all function types
actions = [
# MOBILE ACTIONS
FunctionCall("mobile.home", {}, "mobile.home()"),
FunctionCall("mobile.open_app", {"app_name": "drupe"}, "mobile.open_app(app_name='drupe')"),
FunctionCall("mobile.swipe", {"from_coord": [0.581, 0.898], "to_coord": [0.601, 0.518]}, "mobile.swipe(from_coord=[0.581,0.898],to_coord=[0.601,0.518])"),
FunctionCall("mobile.back", {}, "mobile.back()"),
FunctionCall("mobile.long_press", {"x": 0.799, "y": 0.911}, "mobile.long_press(x=0.799, y=0.911)"),
FunctionCall("mobile.terminate", {"status": "success"}, "mobile.terminate(status='success')"),
FunctionCall("answer", {"arg_0": "text"}, "answer('text')"),
FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)"),
# OS ACTIONS
FunctionCall("pyautogui.hscroll", {"page": -0.1}, "pyautogui.hscroll(page=-0.1)"),
FunctionCall("pyautogui.scroll", {"page": 0.13}, "pyautogui.scroll(page=0.13)"),
FunctionCall("pyautogui.click", {"x": 0.8102, "y": 0.9463}, "pyautogui.click(x=0.8102, y=0.9463)"),
FunctionCall("pyautogui.doubleClick", {}, "pyautogui.doubleClick()"),
FunctionCall("pyautogui.hotkey", {"keys": ["ctrl", "c"]}, "pyautogui.hotkey(keys=['ctrl','c'])"),
FunctionCall("pyautogui.press", {"keys": "enter"}, "pyautogui.press(keys='enter')"),
FunctionCall("pyautogui.moveTo", {"x": 0.04, "y": 0.405}, "pyautogui.moveTo(x=0.04, y=0.405)"),
FunctionCall("pyautogui.write", {"message": "bread buns"}, "pyautogui.write(message='bread buns')"),
FunctionCall("pyautogui.dragTo", {"from_coord": [0.87, 0.423], "to_coord": [0.8102, 0.9463]}, "pyautogui.dragTo(from_coord=[0.87, 0.423], to_coord=[0.8102, 0.9463])"),
]
resolution = (1080, 1920)
print("Before conversion:")
for action in actions:
print(action)
print("\nAfter conversion:")
converted = action_conversion(actions, resolution)
for action in converted:
print(action)

231
scripts/agents/config.py Normal file
View file

@ -0,0 +1,231 @@
# aguvis json file with mobile action space
MOBILE_FILE = [
"android_control.json",
"gui-odyssey-l1.json",
"aitw-l3.json",
"coat.jsonamex-l2.json",
"amex-l1.json",
"amex-l3.json",
"gui-odyssey-l3.json",
"aitw-l1.json",
"aitw-l2.json",
"gui-odyssey-l2.json",
]
# Processing: guienv
# Max conversations by image: 5 conversations
# Duplicates: 0
# Duplicate images in guienv.json. difference: 257578
# len(images_path): 327972
# len(images_set_path): 70394
# user/assistant by image: 3.6590902633747193
#
# Processing: omniact
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in omniact.json
# len(images_path): 6720
# len(images_set_path): 6720
# user/assistant by image: 0.0
#
# Processing: ricoig16k
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in ricoig16k.json
# len(images_path): 16133
# len(images_set_path): 16133
# user/assistant by image: 0.0
#
# Processing: ricosca
# Max conversations by image: 20 conversations
# Duplicates: 0
# Duplicate images in ricosca.json. difference: 155066
# len(images_path): 173212
# len(images_set_path): 18146
# user/assistant by image: 8.54546456519343
#
# Processing: seeclick
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in seeclick.json
# len(images_path): 271121
# len(images_set_path): 271121
# user/assistant by image: 0.0
#
# Processing: webui350k
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in webui350k.json
# len(images_path): 57389
# len(images_set_path): 57389
# user/assistant by image: 0.0
#
# Processing: ui_refexp
# Max conversations by image: 15 conversations
# Duplicates: 32
# Duplicate images in ui_refexp.json. difference: 10978
# len(images_path): 15624
# len(images_set_path): 4646
# user/assistant by image: 2.3628928110202323
#
# Processing: widget_captioning
# Max conversations by image: 161 conversations
# Duplicates: 4877
# Duplicate images in widget_captioning.json. difference: 87017
# len(images_path): 101426
# len(images_set_path): 14409
# user/assistant by image: 6.039072801721146
#
# total_samples = 458958
config_dict_stage_1 = [
{
"json_path": "guienv.json",
"images_folder": "guienvs/images/",
},
{
"json_path": "omniact.json",
"images_folder": "omniact/images/",
},
{
"json_path": "ricoig16k.json",
"images_folder": "ricoig16k/images/",
},
{
"json_path": "ricosca.json",
"images_folder": "ricosca/images/",
},
{
"json_path": "seeclick.json",
"images_folder": "seeclick/seeclick_web_imgs/",
},
{
"json_path": "webui350k.json",
"images_folder": "webui350k/images/",
},
{
"json_path": "ui_refexp.json",
"images_folder": "ui_refexp/images/",
},
{
"json_path": "widget_captioning.json",
"images_folder": "widget_captioning/images/",
},
]
# Processing: mind2web-l3
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in mind2web-l3.json
# len(images_path): 7591
# len(images_set_path): 7591
# user/assistant by image: 0.0
#
# Processing: guiact-web-single
# Max conversations by image: 12 conversations
# Duplicates: 0
# Duplicate images in guiact-web-single.json. difference: 54134
# len(images_path): 67396
# len(images_set_path): 13262
# user/assistant by image: 4.081888101342181
#
# Processing: guiact-web-multi-l3
# Max conversations by image: 2 conversations
# Duplicates: 0
# Duplicate images in guiact-web-multi-l3.json. difference: 24
# len(images_path): 16704
# len(images_set_path): 16680
# user/assistant by image: 0.0014388489208633094
#
# Processing: miniwob-l3
# Max conversations by image: 6 conversations
# Duplicates: 0
# Duplicate images in miniwob-l3.json. difference: 161
# len(images_path): 9826
# len(images_set_path): 9665
# user/assistant by image: 0.016658044490429385
#
# Processing: coat
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in coat.json
# len(images_path): 11921
# len(images_set_path): 11921
# user/assistant by image: 0.0
#
# Processing: android_control
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in android_control.json
# len(images_path): 74714
# len(images_set_path): 74714
# user/assistant by image: 0.0
#
# Processing: gui-odyssey-l3
# Max conversations by image: 2 conversations
# Duplicates: 0
# Duplicate images in gui-odyssey-l3.json. difference: 24
# len(images_path): 118282
# len(images_set_path): 118258
# user/assistant by image: 0.0002029461008980365
#
# Processing: amex-l3
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in amex-l3.json
# len(images_path): 38469
# len(images_set_path): 38469
# user/assistant by image: 0.0
#
# Processing: aitw-l3
# Max conversations by image: 0 conversations
# Duplicates: 0
# No duplicate images in aitw-l3.json
# len(images_path): 18992
# len(images_set_path): 18992
# user/assistant by image: 0.0
#
# Total samples: 309552
config_dict_stage_2 = [
{
"json_path": "mind2web-l3.json",
"images_folder": "mind2web/",
},
{
"json_path": "guiact-web-single.json",
"images_folder": "guiact-web-single/images/",
},
{
"json_path": "guiact-web-multi-l3.json",
"images_folder": "guiact-web-multi-v2/images",
},
{
"json_path": "miniwob-l3.json",
"images_folder": "images",
},
{
"json_path": "coat.json",
"images_folder": "coat/images/",
},
{
"json_path": "android_control.json",
"images_folder": "android_control/images/",
},
{
"json_path": "gui-odyssey-l3.json",
"images_folder": "gui-odyssey/images/",
},
{
"json_path": "amex-l3.json",
"images_folder": "amex/images/",
},
{
"json_path": "aitw-l3.json",
"images_folder": "aitw-v1/images/",
},
]

View file

@ -0,0 +1,547 @@
#!/usr/bin/env python3
"""
Function parser for extracting function names, parameter names, and values from string function calls.
Supports both mobile and pyautogui function patterns.
"""
import re
from typing import Dict, List, Tuple, Any, Union
from dataclasses import dataclass
from collections import OrderedDict
@dataclass
class FunctionCall:
"""Represents a parsed function call with its parameters."""
function_name: str
parameters: Dict[str, Any]
original_string: str
description: str = ""
def to_string(self) -> str:
"""
Reconstruct the function call string from the parsed data.
Returns:
String representation of the function call
Examples:
>>> call = FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)")
>>> call.to_string()
"mobile.wait(seconds=3)"
>>> call = FunctionCall("function", {"arg_0": 1, "arg_1": 2, "x": 0.5}, "function(1, 2, x=0.5)")
>>> call.to_string()
"function(1, 2, x=0.5)"
"""
if not self.parameters:
return f"{self.function_name}()"
# Separate positional and named arguments
positional_args = []
named_args = []
for name, value in self.parameters.items():
if name.startswith("arg_"):
# Positional argument
positional_args.append((int(name.split("_")[1]), value))
else:
# kwargs
named_args.append((name, value))
# Sort positional arguments by index
positional_args.sort(key=lambda x: x[0])
# Build parameter string
param_parts = []
# Add positional arguments
for _, value in positional_args:
param_parts.append(self._value_to_string(value))
# Add named arguments
for name, value in named_args:
param_parts.append(f"{name}={self._value_to_string(value)}")
return f"{self.function_name}({', '.join(param_parts)})"
def _value_to_string(self, value: Any) -> str:
"""
Convert a value to its string representation for function calls.
Args:
value: The value to convert
Returns:
String representation of the value
"""
if isinstance(value, str):
# Quote strings
return f"'{value}'"
elif isinstance(value, (list, tuple)):
# Convert lists/tuples to string representation
items = [self._value_to_string(item) for item in value]
return f"[{', '.join(items)}]"
elif isinstance(value, dict):
# Convert dictionaries to string representation
items = [f"'{k}': {self._value_to_string(v)}" for k, v in value.items()]
return f"{{{', '.join(items)}}}"
elif isinstance(value, bool):
# Convert booleans to lowercase
return str(value).lower()
elif value is None:
return "None"
else:
# Numbers and other types
return str(value)
def parse_function_call(function_string: str, pattern_to_match: list[str] = []) -> List[FunctionCall]:
"""
Parse a function call string and extract all function calls found.
Args:
function_string: String representation of function calls
Returns:
List of FunctionCall objects with parsed information
Examples:
>>> parse_function_call("mobile.wait(seconds=3)")
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
>>> parse_function_call("mobile. wait(seconds=3)")
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
>>> parse_function_call("mobile.wait(seconds=3) mobile.home()")
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...), FunctionCall(function_name='home', parameters={}, ...)]
"""
# Remove any leading/trailing whitespace
function_string = function_string.strip()
# Pattern to match function calls with parameters
# Matches: function_name(param1=value1, param2=value2, ...)
# Can have any characters before the function call, extracts just the function name
pattern = r'.*?([a-zA-Z_][a-zA-Z0-9_.]*)\(([^)]*)\)'
matches = re.findall(pattern, function_string)
if not matches:
# No valid function calls found in: {function_string}
return []
results = []
for match in matches:
function_name = match[0]
params_string = match[1]
if pattern_to_match and all(pattern not in function_name for pattern in pattern_to_match):
continue
# Parse parameters
parameters = parse_parameters(params_string)
# Create the original string for this specific function call
original_string = f"{function_name}({params_string})"
results.append(FunctionCall(
function_name=function_name,
parameters=parameters,
original_string=original_string
))
return results
def parse_parameters(params_string: str) -> Dict[str, Any]:
"""
Parse parameter string and extract parameter names and values.
Args:
params_string: String containing parameters (e.g., "x=0.5, y=0.6, text='hello'")
Returns:
Dictionary mapping parameter names to their values
Examples:
>>> parse_parameters("x=0.5, y=0.6")
{'x': 0.5, 'y': 0.6}
>>> parse_parameters("app_name='drupe'")
{'app_name': 'drupe'}
>>> parse_parameters("'text'")
{'arg_0': 'text'}
>>> parse_parameters("1, 3, 4")
{'arg_0': 1, 'arg_1': 3, 'arg_2': 4}
>>> parse_parameters("arg1, arg2, x=0.5")
{'arg_0': 'arg1', 'arg_1': 'arg2', 'x': 0.5}
"""
if not params_string.strip():
return {}
parameters = OrderedDict()
# Split by commas, but be careful with commas inside quotes or brackets
param_parts = split_parameters(params_string)
positional_index = 0
for part in param_parts:
part = part.strip()
if not part:
continue
# Parse individual parameter
name, value = parse_single_parameter(part)
# For positional arguments, use index-based naming
if name.startswith("arg_"):
name = f"arg_{positional_index}"
positional_index += 1
parameters[name] = value
return parameters
def split_parameters(params_string: str) -> List[str]:
"""
Split parameter string by commas, respecting quotes and brackets.
Args:
params_string: String containing parameters
Returns:
List of individual parameter strings
"""
parts = []
current_part = ""
paren_count = 0
bracket_count = 0
brace_count = 0
in_quotes = False
quote_char = None
for char in params_string:
if char in ['"', "'"] and (not in_quotes or char == quote_char):
if not in_quotes:
in_quotes = True
quote_char = char
else:
in_quotes = False
quote_char = None
elif not in_quotes:
if char == '(':
paren_count += 1
elif char == ')':
paren_count -= 1
elif char == '[':
bracket_count += 1
elif char == ']':
bracket_count -= 1
elif char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
elif char == ',' and paren_count == 0 and bracket_count == 0 and brace_count == 0:
parts.append(current_part.strip())
current_part = ""
continue
current_part += char
if current_part.strip():
parts.append(current_part.strip())
return parts
def parse_single_parameter(param_string: str) -> Tuple[str, Any]:
"""
Parse a single parameter string into name and value.
Args:
param_string: String like "x=0.5" or "app_name='drupe'" or just "value"
Returns:
Tuple of (parameter_name, parameter_value)
Examples:
>>> parse_single_parameter("x=0.5")
('x', 0.5)
>>> parse_single_parameter("app_name='drupe'")
('app_name', 'drupe')
>>> parse_single_parameter("'text'")
('arg_0', 'text')
>>> parse_single_parameter("123")
('arg_0', 123)
>>> parse_single_parameter("3")
('arg_0', 3)
"""
# Pattern to match parameter name and value
pattern = r'^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$'
match = re.match(pattern, param_string)
if match:
# Named parameter
param_name = match.group(1)
param_value_str = match.group(2).strip()
param_value = parse_value(param_value_str)
return param_name, param_value
else:
# Positional parameter - treat as unnamed argument
param_value = parse_value(param_string)
return "arg_0", param_value
def parse_value(value_string: str) -> Any:
"""
Parse a value string into appropriate Python type.
Args:
value_string: String representation of a value
Returns:
Parsed value (int, float, str, list, etc.)
Examples:
>>> parse_value("3")
3
>>> parse_value("3.14")
3.14
>>> parse_value("'hello'")
'hello'
>>> parse_value("[0.581, 0.898]")
[0.581, 0.898]
"""
value_string = value_string.strip()
# String values (quoted)
if (value_string.startswith("'") and value_string.endswith("'")) or \
(value_string.startswith('"') and value_string.endswith('"')):
return value_string[1:-1]
# List values
if value_string.startswith('[') and value_string.endswith(']'):
return parse_list(value_string)
# Dictionary values
if value_string.startswith('{') and value_string.endswith('}'):
return parse_dict(value_string)
# Boolean values
if value_string.lower() in ['true', 'false']:
return value_string.lower() == 'true'
# None value
if value_string.lower() == 'none':
return None
# Numeric values
try:
# Try integer first
if '.' not in value_string:
return int(value_string)
else:
return float(value_string)
except ValueError:
# If it's not a number, return as string (remove quotes if present)
if value_string.startswith("'") and value_string.endswith("'"):
return value_string[1:-1]
elif value_string.startswith('"') and value_string.endswith('"'):
return value_string[1:-1]
else:
return value_string
def parse_list(list_string: str) -> List[Any]:
"""
Parse a list string into a Python list.
Args:
list_string: String like "[0.581, 0.898]"
Returns:
List of parsed values
Examples:
>>> parse_list("[0.581, 0.898]")
[0.581, 0.898]
"""
# Remove outer brackets
content = list_string[1:-1].strip()
if not content:
return []
# Split by commas, respecting nested structures
parts = split_parameters(content)
return [parse_value(part.strip()) for part in parts]
def parse_dict(dict_string: str) -> Dict[str, Any]:
"""
Parse a dictionary string into a Python dict.
Args:
dict_string: String like "{'key': 'value'}"
Returns:
Dictionary of parsed key-value pairs
"""
# Remove outer braces
content = dict_string[1:-1].strip()
if not content:
return {}
# Split by commas, respecting nested structures
parts = split_parameters(content)
result = {}
for part in parts:
part = part.strip()
if ':' in part:
key_str, value_str = part.split(':', 1)
key = parse_value(key_str.strip())
value = parse_value(value_str.strip())
result[key] = value
return result
def parse_multiple_functions(function_strings: List[str]) -> List[FunctionCall]:
"""
Parse multiple function call strings.
Args:
function_strings: List of function call strings
Returns:
List of FunctionCall objects
"""
results = []
for func_str in function_strings:
try:
result_list = parse_function_call(func_str)
results.extend(result_list)
except Exception as e:
print(f"Warning: Could not parse function call '{func_str}': {e}")
continue
return results
def extract_function_calls_from_text(text: str) -> List[FunctionCall]:
"""
Extract and parse function calls from a text block.
Args:
text: Text containing function calls
Returns:
List of FunctionCall objects
"""
# Pattern to find function calls in text
# Matches: function_name(param1=value1, param2=value2)
pattern = r'[a-zA-Z_][a-zA-Z0-9_.]*\([^)]*\)'
matches = re.findall(pattern, text)
return parse_multiple_functions(matches)
# Example usage and testing
if __name__ == "__main__":
test_cases = [
"mobile.home()",
"mobile.open_app(app_name='drupe')",
"mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
"mobile.back()",
"mobile.long_press(x=0.799, y=0.911)",
"mobile.terminate(status='success')",
"answer('text')",
"pyautogui.hscroll(page=-0.1)",
"pyautogui.scroll(page=-0.1)",
"pyautogui.scroll(0.13)",
"pyautogui.click(x=0.8102, y=0.9463)",
"pyautogui.hotkey(keys=['ctrl', 'c'])",
"pyautogui.doubleClick()",
"pyautogui.press(keys='enter')",
"pyautogui.press(keys=['enter'])",
"pyautogui.moveTo(x=0.04, y=0.405)",
"pyautogui.write(message='bread buns')",
"pyautogui.dragTo(x=0.8102, y=0.9463)",
"mobile.wait(seconds=3)\nmobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
# Additional test cases for multiple positional arguments
"function(arg1, arg2, arg3)",
"function('hello', 123, x=0.5)",
"function(arg1, arg2, named_param='value')",
"function(1, 2, 3, 4, 5)",
"function('a', 'b', 'c', x=1, y=2)",
]
print("Testing function parser:")
print("=" * 50)
for test_case in test_cases:
try:
results = parse_function_call(test_case)
print(f"{test_case}")
for result in results:
print(f" Function: {result.function_name}")
print(f" Parameters: {result.parameters}")
print()
except Exception as e:
print(f"{test_case}")
print(f" Error: {e}")
print()
# Test extracting from text
print("Testing text extraction:")
print("=" * 50)
sample_text = """
mobile.wait(seconds=3)
mobile.open_app(app_name='drupe')
pyautogui.click(x=0.8102, y=0.9463)
pyautogui.write(message='bread buns')
"""
extracted = extract_function_calls_from_text(sample_text)
for func_call in extracted:
print(f"Found: {func_call.function_name} with params: {func_call.parameters}")
# Test reconstruction
print("\nTesting function call reconstruction:")
print("=" * 50)
reconstruction_tests = [
"mobile.wait(seconds=3)",
"mobile.home()",
"mobile.open_app(app_name='drupe')",
"mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
"answer('text')",
"pyautogui.scroll(0.13)",
"pyautogui.click(x=0.8102, y=0.9463)",
"pyautogui.hotkey(keys=['ctrl', 'c'])",
"function(1, 2, 3)",
"function('hello', 123, x=0.5, y=0.8)",
"function([1, 3], 'arg2', named_param='value')",
]
for test_case in reconstruction_tests:
parsed_list = parse_function_call(test_case)
for parsed in parsed_list:
reconstructed = parsed.to_string()
print(f"Original: {test_case}")
print(f"Reconstructed: {reconstructed}")
print(f"Match: {test_case == reconstructed}")
assert test_case == reconstructed
print()

View file

@ -0,0 +1,519 @@
#!/usr/bin/env python3
"""
Script to download, process, and upload the aguvis-stage2 dataset.
Downloads from huggingface.co/datasets/xlangai/aguvis-stage2 and uploads to smolagents/aguvis-stage-2
"""
import re
import gc
import sys
import json
import os
import shutil
import zipfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Generator, Callable, Literal
from tqdm import tqdm
from datasets import Dataset, load_dataset, concatenate_datasets
from dotenv import load_dotenv
from huggingface_hub import HfApi, login, snapshot_download
from collections import defaultdict
from PIL import Image
import tarfile
from itertools import islice
import multiprocessing as mp
from multiprocessing import Pool, Manager
from prompts import OS_SYSTEM_PROMPT, MOBILE_SYSTEM_PROMPT
from models import ConversationDataList, ConversationData, ChatMessage, DataRow
from function_parser import parse_function_call
from action_conversion import action_conversion
from pydantic import BaseModel
from config import config_dict_stage_1, config_dict_stage_2, MOBILE_FILE
api = HfApi()
def authenticate_huggingface():
"""Authenticate with HuggingFace Hub using token."""
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("Authenticating with HuggingFace Hub using token...")
login(token=hf_token)
else:
raise ValueError("HF_TOKEN environment variable not set.")
def discover_dataset_config(dataset_path: str, config_dict: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Discover dataset configuration by scanning the data directory."""
dataset_dir = Path(dataset_path)
train_dir = dataset_dir
if not train_dir.exists():
raise FileNotFoundError(f"Train directory not found: {train_dir}")
configs = []
processed_splits = set()
# Find all JSON files in the train directory
for config in config_dict:
subset_name = (
config["json_path"]
.replace(".json", "")
.replace("-l1", "")
.replace("-l2", "")
.replace("-l3", "")
)
# Skip if we already processed this split
if subset_name in processed_splits:
continue
config["subset_name"] = subset_name
configs.append(config)
processed_splits.add(subset_name)
print(
f"Discovered config: {config['subset_name']} -> {config['images_folder']}"
)
return configs
def download_dataset(
repo_id: str = "xlangai/aguvis-stage2", local_dir: str = "./aguvis_raw"
) -> str:
"""Download the dataset using snapshot_download."""
print(f"Downloading dataset from {repo_id}...")
local_path = snapshot_download(
repo_id=repo_id, local_dir=local_dir, repo_type="dataset"
)
print(f"Dataset downloaded to: {local_path}")
return local_path
def extract_zip_files(dataset_path: str):
"""Extract all zip files found in the dataset directory, but only if not already extracted."""
print("Extracting zip files...")
dataset_dir = Path(dataset_path)
for zip_file in dataset_dir.rglob("*.zip"):
extract_dir = zip_file.parent / zip_file.stem
if extract_dir.exists() and any(extract_dir.iterdir()):
print(
f"Skipping extraction for {zip_file} (already extracted at {extract_dir})"
)
continue
print(f"Extracting: {zip_file}")
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(extract_dir)
print(f"Extracted to: {extract_dir}")
def extract_tar_parts_grouped(dataset_path: str):
"""
Finds all .tar.gz.part_* groups, merges them, and extracts them into directories
named after their common prefix.
"""
dataset_dir = Path(dataset_path)
part_files = list(dataset_dir.glob("*.tar.gz.part_*"))
if not part_files:
print("No split .tar.gz.part_* files found.")
return
# Group part files by prefix
groups = defaultdict(list)
for part in part_files:
prefix = part.name.split(".tar.gz.part_")[0]
groups[prefix].append(part)
for prefix, parts in groups.items():
parts = sorted(parts) # Ensure correct order
merged_tar_path = dataset_dir / f"{prefix}.tar.gz"
extract_dir = dataset_dir / prefix
if extract_dir.exists() and any(extract_dir.iterdir()):
print(
f"Skipping extraction for '{prefix}' (already extracted at {extract_dir})"
)
continue
# Merge parts
CHUNK_SIZE = 1024 * 1024
print(f"Merging parts for '{prefix}'...")
with open(merged_tar_path, "wb") as outfile:
for part in parts:
print(f" Adding: {part.name}")
with open(part, "rb") as infile:
while chunk := infile.read(CHUNK_SIZE):
outfile.write(chunk)
print(f"Merged to: {merged_tar_path}")
# Extract
print(f"Extracting to: {extract_dir}")
with tarfile.open(merged_tar_path, "r:gz") as tar:
tar.extractall(path=extract_dir)
print(f"Done extracting '{prefix}'\n")
def check_subset_exists(repo_id: str, subset_name: str) -> bool:
"""Check if a subset already exists in the remote dataset."""
try:
# Try to get dataset info with specific subset
from datasets import get_dataset_config_names
config_names = get_dataset_config_names(repo_id)
return subset_name in config_names
except Exception as e:
print(f"Could not check if subset exists: {e}")
return False
def load_image_from_folder(images_folder: Path, img_path: str) -> Image.Image:
"""Load images from the specified folder."""
full_path = images_folder / img_path
img = Image.open(full_path)
new_img = img.copy()
img.close()
return new_img
def convert_to_code_agent_format(messages: list[ChatMessage], json_path: str, reasoning: bool):
for i, message in enumerate(messages):
content = message.content
if message.role == "system":
if json_path in MOBILE_FILE:
content = MOBILE_SYSTEM_PROMPT
else:
content = OS_SYSTEM_PROMPT
if message.role == "user":
content = content.replace("<image>\n", "").replace("<image>", "")
elif message.role == "assistant":
content = (
content.replace("Action: ", "")
.replace("Observation: ", "")
.replace("Thought: ", "")
)
if reasoning and i == len(messages) - 1:
content = (
"<code>\n" + content.strip() + "\n</code>"
)
elif reasoning:
# TODO: Check if there is always only 2 assistants
content = (
"<think>\n"
+ content.strip()
+ "\n</think>\n"
)
else:
content = content.strip()
messages[i].content = content
# Fuse subsequent messages have the same role, merge it
if i > 0 and messages[i].role == messages[i - 1].role:
# Need to fuse both messages
if reasoning:
messages[i - 1].content += messages[i].content
else:
messages[i - 1].content += "\n" + messages[i].content
messages.pop(i)
return messages
def convert_to_chat_format(
data: ConversationData, json_path: str, reasoning: bool
) -> list[ChatMessage]:
"""Convert data item to chat template format."""
# This is a placeholder - you'll need to adapt this based on the actual data structure
# The exact conversion depends on how the original data is structured
chat_messages = data.to_chat_messages()
# mobile = json_path in open("mobile_files.txt", "r").read()
# os = json_path in open("os_files.txt", "r").read()
# if not mobile and not os:
# for message in chat_messages:
# if mobile and os:
# break
# if message.role == "assistant":
# if not mobile and "mobile" in message.content:
# with open("mobile_files.txt", "a") as mobile_files:
# mobile_files.write(json_path + "\n")
# mobile = True
# if not os and "pyautogui" in message.content:
# with open("os_files.txt", "a") as os_files:
# os_files.write(json_path + "\n")
# os = True
# Aguvis stage 1
chat_messages = convert_to_code_agent_format(chat_messages, json_path, reasoning)
return chat_messages
def convert_to_new_action_space(
messages: list[ChatMessage], resolution: tuple[int, int], code_format: bool = True
) -> list[ChatMessage]:
regex_match: re.Match | str | None = None
index = -1
regex = r"<code>\n(.*?)\n</code>"
assistant_msg = [(i, message) for i, message in enumerate(messages) if message.role == "assistant"]
if assistant_msg:
for index, msg in assistant_msg:
if code_format:
regex_match = re.search(regex, msg.content, re.DOTALL)
else:
regex_match = msg.content
if regex_match is not None:
function_calls = parse_function_call(
regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match,
pattern_to_match=["pyautogui", "mobile", "terminate", "answer"],
)
if len(function_calls) > 0:
for i, function_call in enumerate(deepcopy(function_calls)):
if function_call.function_name == "pyautogui.dragTo" and not isinstance(list(function_calls[i].parameters.values())[0], (list, tuple)):
x1, y1 = islice(function_calls[i-1].parameters.values(), 2)
x2, y2 = islice(function_calls[i].parameters.values(), 2)
function_calls[i].parameters = {"from_coord": (x1, y1), "to_coord": (x2, y2)}
function_calls[i].original_string = function_calls[i].to_string()
function_calls.pop(i-1)
function_calls = action_conversion(function_calls, resolution=resolution)
new_action_string = "\n".join(
[function_call.to_string() for function_call in function_calls]
)
messages[index].content = messages[index].content.replace(
regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match, new_action_string
)
return messages
def process_subset(
config: Dict[str, Any],
dataset_path: str,
) -> tuple[ConversationDataList, Path]:
"""Process a single dataset subset."""
subset_name = config["subset_name"]
print(f"Processing split: {subset_name}")
dataset_dir = Path(dataset_path)
images_folder = dataset_dir / config["subset_name"] / config["images_folder"]
if not images_folder.exists():
print(f"Images folder not found: {images_folder}")
else:
print(f"Images folder: {images_folder}")
json_config_path = dataset_dir / config["json_path"]
with open(json_config_path, "r") as f:
data = ConversationDataList.model_validate_json(f.read())
# data = f.read()
print(f"Added '{json_config_path}'")
return data, images_folder
def row_generator(
data: ConversationDataList, images_folder: Path, json_path: str, reasoning: bool
) -> Generator[Dict[str, Any], None, None]:
conversations: list[ConversationData] = data.root
for item in tqdm(conversations):
# Extract image paths from the data item
try:
# Load images
image = load_image_from_folder(images_folder, item.image)
chat_message = convert_to_chat_format(item, json_path, reasoning)
chat_message = convert_to_new_action_space(chat_message, image.size, code_format=reasoning)
if len(chat_message) == 0:
continue
row = DataRow.from_chat_messages(chat_message, image, source=json_path.split("/")[-1].split(".")[0])
yield row.model_dump(exclude_none=True)
del image
except Exception as e:
import traceback
traceback.print_exc()
print(f"Error processing item: {e}", item)
continue
class DatasetConfig(BaseModel):
huggingface_repo_id: str
local_path: str
config_dict: List[Dict[str, Any]]
smolagents_repo_id: str
reasoning: bool
def process_single_config(config: Dict[str, Any], dataset_path: str, smolagents_repo_id: str, reasoning: bool) -> bool:
"""Process a single config in a separate process."""
try:
# Authenticate in this process
authenticate_huggingface()
print(f"\n{'=' * 50}")
print(f"Processing config: {config}")
# Check if the subset already exists in the remote dataset
subset_name = config["subset_name"]
# if check_subset_exists(smolagents_repo_id, subset_name):
# print(
# f"Subset '{subset_name}' already exists in {smolagents_repo_id}, skipping processing."
# )
# return True
json_path = config["json_path"]
data, image_folder = process_subset(config, dataset_path)
# Collect all rows first
rows = []
datasets = []
for row in row_generator(data, image_folder, json_path, reasoning):
rows.append(row)
if len(rows) > 20000:
print("Creating batch dataset")
dataset = Dataset.from_list(rows)
datasets.append(dataset)
rows = []
gc.collect()
if len(rows) > 0:
# Create dataset from collected data
dataset = Dataset.from_list(rows)
datasets.append(dataset)
rows = []
dataset_to_push = concatenate_datasets(datasets)
# Push to hub
dataset_to_push.push_to_hub(
smolagents_repo_id,
# config_name=subset_name, # This sets the subset name
split="train", # This should be "train" not the subset name
)
print(f"Processed and uploaded subset: {config['subset_name']}")
# Force garbage collection to manage memory
gc.collect()
return True
except Exception as e:
print(f"Error processing config {config.get('subset_name', 'unknown')}: {e}")
import traceback
traceback.print_exc()
return False
def make_dataset_from_original_data(dataset_config: DatasetConfig, max_processes: int | None = None):
"""Main function to orchestrate the entire process."""
load_dotenv(override=True)
print(f"Starting {dataset_config.smolagents_repo_id} dataset processing...")
# Step 0: Authenticate with HuggingFace Hub
authenticate_huggingface()
dataset_path = download_dataset(
dataset_config.huggingface_repo_id, dataset_config.local_path
)
# extract_zip_files(dataset_path)
# extract_tar_parts_grouped(dataset_path)
dataset_configs = discover_dataset_config(dataset_path, dataset_config.config_dict)
converted_repo_id = dataset_config.smolagents_repo_id
reasoning = dataset_config.reasoning
# Use multiprocessing to process configs in parallel
available_cpus = mp.cpu_count()
if max_processes is None:
max_processes = available_cpus
num_processes = min(max_processes, len(dataset_configs))
print(f"Using {num_processes} processes (out of {available_cpus} available CPUs) to process {len(dataset_configs)} configs")
# Prepare arguments for multiprocessing
process_args = [
(config, dataset_path, converted_repo_id, reasoning)
for config in dataset_configs if config["subset_name"] if config["subset_name"] in ["guiact-web-single"]
]
# Process configs in parallel with progress tracking
print(f"Starting parallel processing of {len(dataset_configs)} configs...")
try:
with Pool(processes=num_processes) as pool:
results = []
for i, result in enumerate(pool.starmap(process_single_config, process_args)):
results.append(result)
print(f"Completed {i+1}/{len(dataset_configs)} configs")
except Exception as e:
print(f"Multiprocessing failed: {e}")
print("Falling back to sequential processing...")
results = []
for i, args in enumerate(process_args):
result = process_single_config(*args)
results.append(result)
print(f"Completed {i+1}/{len(dataset_configs)} configs (sequential)")
# Check results
successful = sum(results)
total = len(dataset_configs)
print(f"\nProcessing complete: {successful}/{total} configs processed successfully")
if successful < total:
failed_count = total - successful
print(f"Warning: {failed_count} configs failed to process. Check the logs above for details.")
else:
print("All configs processed successfully!")
# # Cleanup
# print("\nCleaning up temporary files...")
# # shutil.rmtree(dataset_path, ignore_errors=True)
#
# # api.upload_large_folder(folder_path=converted_folder, repo_id="smolagents/aguvis-stage-2", repo_type="dataset")
#
# shutil.rmtree(converted_folder, ignore_errors=True)
#
# print("All done!")
if __name__ == "__main__":
# dataset_config_1 = DatasetConfig(
# huggingface_repo_id="xlangai/aguvis-stage1",
# local_path="/fsx/amir_mahla/aguvis_raw_stage_1",
# config_dict=config_dict_stage_1,
# smolagents_repo_id="smolagents/aguvis-stage-1",
# reasoning=False,
# )
# dataset_config_2 = DatasetConfig(
# huggingface_repo_id="xlangai/aguvis-stage2",
# local_path="/fsx/amir_mahla/aguvis_raw_stage_2",
# config_dict=config_dict_stage_2,
# smolagents_repo_id="smolagents/aguvis-stage-2",
# reasoning=True,
# )
dataset_config_3 = DatasetConfig(
huggingface_repo_id="xlangai/aguvis-stage2",
local_path="/fsx/amir_mahla/aguvis_raw_stage_2",
config_dict=config_dict_stage_2,
smolagents_repo_id="smolagents/guiact-web-single",
reasoning=True,
)
# You can specify max_processes to limit the number of parallel processes
# make_dataset_from_original_data(dataset_config_1, max_processes=4)
make_dataset_from_original_data(dataset_config_3, 1)

154
scripts/agents/models.py Normal file
View file

@ -0,0 +1,154 @@
from typing import List, Optional, Literal
from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
from copy import deepcopy
from PIL import Image
from collections import OrderedDict
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
@staticmethod
def from_conversation_list(data: list[dict[str, str]]) -> list["ChatMessage"]:
messages = []
system_added = False
for item in data:
if item["from"] == "system":
if not system_added:
role: Literal["user", "assistant", "system"] = "system"
messages.append(ChatMessage(role=role, content=item["value"]))
system_added = True
elif item["from"] == "human":
role = "user"
messages.append(ChatMessage(role=role, content=item["value"]))
else:
role = "assistant"
messages.append(ChatMessage(role=role, content=item["value"]))
return messages
class ConversationEntry(BaseModel):
from_: Literal["system", "human", "gpt"] = Field(alias="from")
value: str
recipient: Optional[str] = None
end_turn: Optional[bool] = None
def to_chat_message(self) -> ChatMessage:
if self.from_ == "system":
role: Literal["user", "assistant", "system"] = "system"
elif self.from_ == "human":
role = "user"
else:
role = "assistant"
return ChatMessage(role=role, content=self.value)
class ConversationData(BaseModel):
image: str
conversations: List[ConversationEntry]
recipient: Optional[str] = None
end_turn: Optional[bool] = None
@field_validator("image", mode="before")
def validate_image(cls, v):
if isinstance(v, list):
if len(v) == 1:
return v[0]
elif len(v) == 2:
return v[1]
else:
raise ValueError("Expected 1 or 2 images, got multiple")
return v
def to_chat_messages(self) -> list[ChatMessage]:
return [conversation.to_chat_message() for conversation in self.conversations]
class ConversationDataList(RootModel[List[ConversationData]]):
@model_validator(mode="after")
def validate_conversation(self):
new_conversations: dict[str, List[ConversationData]] = {}
# merge image duplicates
for conversation in self.root:
if conversation.image not in new_conversations:
new_conversations[conversation.image] = [conversation]
else:
new_conversations[conversation.image].append(conversation)
# delete text duplicates
duplicates = 0
for data in new_conversations.values():
if isinstance(data, list):
index_to_pop = set()
for i in range(len(data) - 1):
for j in range(i + 1, len(data)):
if [c1.model_dump() for c1 in data[i].conversations] == [c2.model_dump() for c2 in data[j].conversations]:
if j not in index_to_pop:
duplicates += 1
index_to_pop.add(j)
for index in sorted(index_to_pop, reverse=True):
data.pop(index)
# delete text duplicates
new_data = []
for data in new_conversations.values():
for i in range(len(data)):
if i == 0:
new_data.append(data[i])
else:
new_data[-1].conversations.extend(data[i].conversations)
self.root = new_data
return self
class DataRow(BaseModel):
images: list[Image.Image]
texts: list[OrderedDict[str, str]]
source: str
class Config:
arbitrary_types_allowed = True
@classmethod
def from_chat_messages(cls, messages: list[ChatMessage], image: Image.Image, source: str) -> "DataRow":
system, user, assistant = None, None, None
have_system = any(message.role == "system" for message in messages)
texts: list[OrderedDict[str, str]] = []
images = [image]
chat_messages: OrderedDict[str, str] = OrderedDict()
for message in messages:
if message.role == "system":
system = message.content
elif message.role == "user":
user = message.content
elif message.role == "assistant":
assistant = message.content
if have_system and user is not None and assistant is not None and system is not None:
chat_messages["system"] = system
chat_messages["user"] = user
chat_messages["assistant"] = assistant
texts.append(chat_messages)
chat_messages = OrderedDict()
user, assistant = None, None
elif not have_system and user is not None and assistant is not None:
chat_messages["user"] = user
chat_messages["assistant"] = assistant
texts.append(chat_messages)
chat_messages = OrderedDict()
user, assistant = None, None
return cls(images=images, texts=texts, source=source)
def to_model_dump(self) -> dict:
return {
"images": self.images,
"texts": self.texts,
"source": self.source,
}

145
scripts/agents/prompts.py Normal file
View file

@ -0,0 +1,145 @@
from typing import Literal
OS_ACTIONS = """
def final_answer(answer: any) -> any:
\"\"\"
Provides a final answer to the given problem.
Args:
answer: The final answer to the problem
\"\"\"
def move_mouse(self, x: float, y: float) -> str:
\"\"\"
Moves the mouse cursor to the specified coordinates
Args:
x: The x coordinate (horizontal position)
y: The y coordinate (vertical position)
\"\"\"
def click(x: Optional[float] = None, y: Optional[float] = None) -> str:
\"\"\"
Performs a left-click at the specified normalized coordinates
Args:
x: The x coordinate (horizontal position)
y: The y coordinate (vertical position)
\"\"\"
def double_click(x: Optional[float] = None, y: Optional[float] = None) -> str:
\"\"\"
Performs a double-click at the specified normalized coordinates
Args:
x: The x coordinate (horizontal position)
y: The y coordinate (vertical position)
\"\"\"
def type(text: str) -> str:
\"\"\"
Types the specified text at the current cursor position.
Args:
text: The text to type
\"\"\"
def press(keys: str | list[str]) -> str:
\"\"\"
Presses a keyboard key
Args:
keys: The key or list of keys to press (e.g. "enter", "space", "backspace", "ctrl", etc.).
\"\"\"
def navigate_back() -> str:
\"\"\"
Goes back to the previous page in the browser. If using this tool doesn't work, just click the button directly.
\"\"\"
def drag(from_coord: list[float], to_coord: list[float]) -> str:
\"\"\"
Clicks [x1, y1], drags mouse to [x2, y2], then release click.
Args:
x1: origin x coordinate
y1: origin y coordinate
x2: end x coordinate
y2: end y coordinate
\"\"\"
def scroll(direction: Literal["up", "down"] = "down", amount: int = 1) -> str:
\"\"\"
Moves the mouse to selected coordinates, then uses the scroll button: this could scroll the page or zoom, depending on the app. DO NOT use scroll to move through linux desktop menus.
Args:
x: The x coordinate (horizontal position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates
y: The y coordinate (vertical position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates
direction: The direction to scroll ("up" or "down"), defaults to "down". For zoom, "up" zooms in, "down" zooms out.
amount: The amount to scroll. A good amount is 1 or 2.
\"\"\"
def wait(seconds: float) -> str:
\"\"\"
Waits for the specified number of seconds. Very useful in case the prior order is still executing (for example starting very heavy applications like browsers or office apps)
Args:
seconds: Number of seconds to wait, generally 2 is enough.
\"\"\"
"""
MOBILE_ACTIONS = """
def navigate_back() -> str:
\"\"\"
Return to home page
\"\"\"
def open_app(app_name: str) -> str:
\"\"\"
Launches the specified application.
Args:
app_name: the name of the application to launch
\"\"\"
def swipe(from_coord: list[str], to_coord: list[str]) -> str:
\"\"\"
swipe from 'from_coord' to 'to_coord'
Args:
from_coord: origin coordinates
to_coord: end coordinates
\"\"\"
def long_press(x: int, y: int) -> str:
\"\"\"
Performs a long-press at the specified coordinates
Args:
x: The x coordinate (horizontal position)
y: The y coordinate (vertical position)
\"\"\"
"""
OS_SYSTEM_PROMPT = f"""You are a helpful GUI agent. Youll be given a task and a screenshot of the screen. Complete the task using Python function calls.
For each step:
First, <think></think> to express the thought process guiding your next action and the reasoning behind it.
Then, use <code></code> to perform the action. it will be executed in a stateful environment.
The following functions are exposed to the Python interpreter:
<code>
{OS_ACTIONS}
</code>
The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
"""
MOBILE_SYSTEM_PROMPT = f"""You are a helpful GUI agent. Youll be given a task and a screenshot of the screen. Complete the task using Python function calls.
For each step:
First, <think></think> to express the thought process guiding your next action and the reasoning behind it.
Then, use <code></code> to perform the action. it will be executed in a stateful environment.
The following functions are exposed to the Python interpreter:
<code>
# OS ACTIONS
{OS_ACTIONS}
# MOBILE ACTIONS
{MOBILE_ACTIONS}
</code>
The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
"""

View file

@ -0,0 +1,153 @@
from PIL import Image
from scripts.agents.function_parser import parse_function_call
from qwen_vl_utils import smart_resize
def resize_images_in_messages(batch_messages, script_args) -> list[Image.Image]:
min_pixels = script_args.image_resize["min_pixels"]
max_pixels = script_args.image_resize["max_pixels"]
factor = script_args.image_resize["factor"]
all_image_inputs = []
for messages in batch_messages:
old_image = messages[1]["content"][0]["image"]
resized_height, resized_width = smart_resize(
old_image.height,
old_image.width,
factor=factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
new_image = old_image.resize((resized_width, resized_height))
messages[1]["content"][0]["image"] = new_image
function_calls = parse_function_call(messages[2]["content"])
old_function_call_strings = [
function_call.to_string() for function_call in function_calls
]
for function_call, old_function_call_string in zip(function_calls, old_function_call_strings):
if function_call.function_name in [
"click",
"long_press",
"double_click",
"move_mouse",
]:
function_call.parameters["arg_0"] = (
int(function_call.parameters["arg_0"]
/ old_image.width
* new_image.width)
)
function_call.parameters["arg_1"] = (
int(function_call.parameters["arg_1"]
/ old_image.height
* new_image.height)
)
elif function_call.function_name in ["swipe", "drag"]:
function_call.parameters["arg_0"] = (
int(function_call.parameters["arg_0"][0]
/ old_image.width
* new_image.width),
int(function_call.parameters["arg_0"][1]
/ old_image.height
* new_image.height)
)
function_call.parameters["arg_1"] = (
int(function_call.parameters["arg_1"][0]
/ old_image.width
* new_image.width),
int(function_call.parameters["arg_1"][1]
/ old_image.height
* new_image.height)
)
messages[2]["content"] = messages[2]["content"].replace(old_function_call_string, function_call.to_string())
all_image_inputs.append([new_image])
return all_image_inputs
def create_vlm_collate_fn(processor, script_args):
"""Optimized collate function for VLM training that masks system prompt tokens."""
def collate_fn(examples: list[dict[str, str | Image.Image]]):
batch_messages = []
system_prompts = []
user_prompts = []
for example in examples:
system = example["system"]
user = example["user"]
assistant = example["assistant"]
image = example["image"]
system_prompts.append(system)
user_prompts.append(user)
batch_messages.append(
[
{"role": "system", "content": system},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": user},
],
},
{"role": "assistant", "content": assistant},
]
)
all_image_inputs = []
if script_args.image_resize is not None:
all_image_inputs = resize_images_in_messages(batch_messages, script_args)
texts = [
processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
for messages in batch_messages
]
batch = processor(
text=texts,
images=all_image_inputs if all_image_inputs else None,
padding=True,
return_tensors="pt",
max_length=4096,
)
input_ids = batch["input_ids"]
labels = input_ids.clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
if hasattr(processor, "image_token"):
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
if image_token_id is not None:
labels[labels == image_token_id] = -100
else:
raise ValueError("Processor does not have image_token")
system_encodings = processor.tokenizer(
system_prompts, add_special_tokens=False, padding=False
)["input_ids"]
user_encodings = processor.tokenizer(
user_prompts, add_special_tokens=False, padding=False
)["input_ids"]
for encodings in [system_encodings, user_encodings]:
for i, system_ids in enumerate(encodings):
if input_ids[i, : len(system_ids)].tolist() == system_ids:
labels[i, : len(system_ids)] = -100
else:
seq = input_ids[i].tolist()
for j in range(len(seq) - len(system_ids) + 1):
if seq[j : j + len(system_ids)] == system_ids:
labels[i, j : j + len(system_ids)] = -100
break # early exit
batch["labels"] = labels
return batch
return collate_fn

View file

@ -0,0 +1,224 @@
from PIL import Image
from scripts.agents.function_parser import parse_function_call
import numpy as np
from transformers.models.smolvlm.image_processing_smolvlm import (
get_resize_output_image_size,
)
from transformers.image_utils import ChannelDimension
def transform_messages(
batch_messages,
image_resize: dict[str, int | bool],
) -> list[list[Image.Image]]:
resolution_max_side = image_resize["resolution_max_side"] if "resolution_max_side" in image_resize else None
to_pixel_coordinates = image_resize["to_pixel_coordinates"] if "to_pixel_coordinates" in image_resize else False
if not to_pixel_coordinates and resolution_max_side is None:
return batch_messages
all_image_inputs: list[list[Image.Image]] = []
for messages in batch_messages:
new_image = None
for i in range(len(messages)):
if "image" in messages[i]["content"][0]:
old_image = messages[i]["content"][0]["image"]
if resolution_max_side is not None:
resized_height, resized_width = get_resize_output_image_size(
np.array(old_image),
resolution_max_side=resolution_max_side,
input_data_format=ChannelDimension.LAST,
)
new_image = old_image.resize((resized_width, resized_height))
else:
resized_height, resized_width = old_image.height, old_image.width
new_image = old_image
messages[i]["content"][0]["image"] = new_image
all_image_inputs.append([new_image])
if messages[i]["role"] == "assistant" and to_pixel_coordinates:
assert new_image is not None, "new_image is None"
function_calls = parse_function_call(messages[i]["content"][0]["text"])
old_function_call_strings = [
function_call.to_string() for function_call in function_calls
]
for function_call, old_function_call_string in zip(
function_calls, old_function_call_strings
):
if function_call.function_name in [
"click",
"long_press",
"double_click",
"move_mouse",
]:
function_call.parameters["x"] = int(
function_call.parameters["x"] * new_image.width
)
function_call.parameters["y"] = int(
function_call.parameters["y"] * new_image.height
)
elif function_call.function_name in ["swipe", "drag"]:
function_call.parameters["from_coord"] = (
int(function_call.parameters["from_coord"][0] * new_image.width),
int(
function_call.parameters["from_coord"][1] * new_image.height
),
)
function_call.parameters["to_coord"] = (
int(function_call.parameters["to_coord"][0] * new_image.width),
int(
function_call.parameters["to_coord"][1] * new_image.height
),
)
messages[i]["content"][0]["text"] = messages[i]["content"][0][
"text"
].replace(old_function_call_string, function_call.to_string())
return all_image_inputs
def create_vlm_collate_fn(processor, training_args, script_args):
"""Optimized collate function for VLM training that masks system prompt tokens."""
def collate_fn(examples: list[dict[str, list | str | Image.Image]]):
batch_messages: list[list[dict[str, list | str | Image.Image]]] = []
assistant_messages: list[list[str]] = []
all_image_inputs: list[list[Image.Image]] = []
for example in examples:
images: list[Image.Image] = example["images"]
is_first_user = True
sample: list[dict[str, list | str | Image.Image]] = []
assistant: list[str] = []
for text in example["texts"]:
if "system" in text.keys():
sample.append(
{
"role": "system",
"content": [{"type": "text", "text": text["system"]}],
}
)
if is_first_user:
sample.append(
{
"role": "user",
"content": [
{"type": "image", "image": images[0]},
{"type": "text", "text": text["user"]},
],
}
)
is_first_user = False
else:
sample.append(
{
"role": "user",
"content": [
{"type": "text", "text": text["user"]},
],
}
)
sample.append(
{
"role": "assistant",
"content": [{"type": "text", "text": "\n" + text["assistant"]}],
}
)
assistant.append(text["assistant"] + "<end_of_utterance>")
batch_messages.append(sample)
assistant_messages.append(assistant)
all_image_inputs.append(images)
if script_args.image_resize is not None and "to_pixel_coordinates" in script_args.image_resize and script_args.image_resize["to_pixel_coordinates"]:
all_image_inputs = transform_messages(
batch_messages,
image_resize=script_args.image_resize,
)
texts = [
processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
for messages in batch_messages
]
batch = processor(
text=texts,
images=all_image_inputs if all_image_inputs else None,
max_length=training_args.max_length,
truncation=True,
padding=True,
return_tensors="pt",
)
input_ids = batch["input_ids"]
labels = input_ids.clone()
assistant_encodings = [
processor.tokenizer(
assistant_message, add_special_tokens=False, padding=False
)["input_ids"]
for assistant_message in assistant_messages
]
# Mask out all except the assistant messages
for i, assistant_ids_list in enumerate(assistant_encodings):
seq = input_ids[i].tolist()
assistant_positions: list[int] = []
for ids in assistant_ids_list:
start_pos = 0
while start_pos < len(seq) - len(ids) + 1:
found = False
for j in range(start_pos, len(seq) - len(ids) + 1):
if seq[j : j + len(ids)] == ids:
assistant_positions.extend(range(j, j + len(ids)))
start_pos = j + len(ids)
found = True
break
if not found:
break
for pos in range(len(seq)):
if pos not in assistant_positions:
labels[i, pos] = -100
batch["labels"] = labels
return batch
return collate_fn
if __name__ == "__main__":
from transformers import AutoProcessor
from datasets import load_dataset
class ScriptArguments:
image_resize = None
processor = AutoProcessor.from_pretrained(
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
)
processor.image_processor.size = {"longest_edge": 384}
collate_fn = create_vlm_collate_fn(processor, script_args=ScriptArguments)
max_length = []
for dataset_name in ['ricosca']:
dataset_max_length = 0
data = load_dataset("smolagents/aguvis-stage-1", dataset_name, split="train")
print("processing", dataset_name)
for example in data:
batch = collate_fn([example])
dataset_max_length = max(dataset_max_length, batch["input_ids"].shape[1])
print("dataset_max_length", dataset_name, dataset_max_length)
max_length.append(dataset_max_length)
print(max_length)
print("max_length", max(max_length))
open("max_length_384_phase_1.txt", "a").write(str(max(max_length)))

View file

@ -0,0 +1,54 @@
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.models.smolvlm.image_processing_smolvlm import SmolVLMImageProcessor
class TransformersModel:
def __init__(self, model_id: str, to_device: str = "cuda"):
self.model_id = model_id
self.processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
self.processor.image_processor.size = {"longest_edge": 3 * 384}
self.model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(to_device)
def generate(self, messages: list[dict], **kwargs):
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device, dtype=torch.bfloat16)
generated_ids = self.model.generate(**inputs, **kwargs)
return self.processor.batch_decode(
generated_ids[:, len(inputs["input_ids"][0]) :], skip_special_tokens=True
)[0]
if __name__ == "__main__":
from PIL import Image
model = TransformersModel(
model_id="/fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-1152/checkpoint-800",
to_device="cuda:0",
)
image = Image.open("/admin/home/amir_mahla/screensuite/examples/sample_image.png")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": "Given the screenshot, and the instruction, output a click that completes the instruction or targets the given element (always target the center of the element).\n\nOutput the click position as follows:\n\n<think>(thought process)</think><code>click(x, y)</code>\nWith x the number of pixels from the left edge and y the number of pixels from the top edge.\n\nNow write the click needed to complete the instruction:\nInstruction: view more information about bomber\n",
},
],
},
]
print(model.generate(messages, max_new_tokens=128))

View file

@ -74,6 +74,8 @@ _deps = [
"async-lru>=2.0.5",
"aiofiles>=24.1.0",
"pandas>=2.2.3",
"qwen-vl-utils>=0.1.0",
"setuptools>=80.9.0",
]
# this is a lookup table with items like:

View file

@ -33,8 +33,8 @@ START_TIME=$(date +%s)
echo "START TIME: $(date)"
# Refresh Weka on h4 cache
echo "Refreshing Weka filesystem..."
find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
# echo "Refreshing Weka filesystem..."
# find -L /fsx/${USER}/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
# Default values
MODEL=""

View file

@ -68,19 +68,51 @@ class ScriptArguments(trl.ScriptArguments):
# Override the dataset_name to make it optional
dataset_name: Optional[str] = field(
default=None, metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."}
default=None,
metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."},
)
dataset_mixture: Optional[dict[str, Any]] = field(
default=None,
metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."},
metadata={
"help": "Configuration for creating dataset mixtures with advanced options like shuffling."
},
)
single_gpu: bool = field(
default=False,
metadata={
"help": "Force training on single GPU only, disabling distributed training."
},
)
image_resize: Optional[dict[str, int]] = field(
default=None,
metadata={"help": "Resize the image to the given minimum and maximum pixels."},
)
def __post_init__(self):
if self.dataset_name is None and self.dataset_mixture is None:
raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")
raise ValueError(
"Either `dataset_name` or `dataset_mixture` must be provided"
)
if self.image_resize is not None:
if (
not isinstance(self.image_resize, dict)
# or "min_pixels" not in self.image_resize
# or "max_pixels" not in self.image_resize
# or "factor" not in self.image_resize
or "to_pixel_coordinates" not in self.image_resize
or "resolution_max_side" not in self.image_resize
):
raise ValueError(
f"image_resize must be a dictionary with a 'min_pixels', 'max_pixels' and 'factor' key. {self.image_resize}"
)
if self.dataset_mixture is not None:
if not isinstance(self.dataset_mixture, dict) or "datasets" not in self.dataset_mixture:
if (
not isinstance(self.dataset_mixture, dict)
or "datasets" not in self.dataset_mixture
):
raise ValueError(
"dataset_mixture must be a dictionary with a 'datasets' key. "
"Expected format: {'datasets': [...], 'seed': int}"
@ -110,7 +142,11 @@ class ScriptArguments(trl.ScriptArguments):
)
# Check that column names are consistent across all dataset configs
columns_sets = [set(dataset.columns) for dataset in datasets_list if dataset.columns is not None]
columns_sets = [
set(dataset.columns)
for dataset in datasets_list
if dataset.columns is not None
]
if columns_sets:
first_columns = columns_sets[0]
if not all(columns == first_columns for columns in columns_sets):
@ -135,13 +171,21 @@ class GRPOConfig(trl.GRPOConfig):
default_factory=lambda: [],
metadata={"help": "The callbacks to run during training."},
)
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
chat_template: Optional[str] = field(
default=None, metadata={"help": "The chat template to use."}
)
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
num_completions_to_print: int = field(default=0, metadata={"help": "Number of completions to print."})
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
num_completions_to_print: int = field(
default=0, metadata={"help": "Number of completions to print."}
)
overwrite_hub_revision: bool = field(
default=False, metadata={"help": "Whether to overwrite the Hub revision."}
)
push_to_hub_revision: bool = field(
default=False, metadata={"help": "Whether to push to a Hub revision/branch."}
)
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "The optional system prompt to use."},
@ -149,7 +193,9 @@ class GRPOConfig(trl.GRPOConfig):
wandb_log_unique_prompts: bool = field(
default=True,
metadata={
"help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.")
"help": (
"Whether to log the unique prompts to wandb. This will create a new run for each unique prompt."
)
},
)
wandb_entity: Optional[str] = field(
@ -180,17 +226,27 @@ class SFTConfig(trl.SFTConfig):
default_factory=lambda: [],
metadata={"help": "The callbacks to run during training."},
)
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
chat_template: Optional[str] = field(
default=None, metadata={"help": "The chat template to use."}
)
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "The optional system prompt to use for benchmarking."},
)
vision_model: bool = field(
default=False,
metadata={"help": "Whether this is a vision-language model training."},
)
hub_model_revision: Optional[str] = field(
default="main",
metadata={"help": "The Hub model branch to push the model to."},
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
overwrite_hub_revision: bool = field(
default=False, metadata={"help": "Whether to overwrite the Hub revision."}
)
push_to_hub_revision: bool = field(
default=False, metadata={"help": "Whether to push to a Hub revision/branch."}
)
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
@ -263,7 +319,9 @@ class GRPOScriptArguments(ScriptArguments):
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
metadata={
"help": "Maximum (negative) penalty for for repetition penalty reward"
},
)
code_language: str = field(
default="python",
@ -281,7 +339,9 @@ class GRPOScriptArguments(ScriptArguments):
)
code_eval_scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = field(
default="weighted_sum",
metadata={"help": "use fraction of passed test cases as reward. If false, use 0/1 scoring."},
metadata={
"help": "use fraction of passed test cases as reward. If false, use 0/1 scoring."
},
)
parallel_code_exec_per_proc: int = field(
default=2,

View file

@ -13,18 +13,18 @@
# limitations under the License.
"""
Supervised fine-tuning script for decoder language models.
Supervised fine-tuning script for decoder language models and vision-language models.
Usage:
# One 1 node of 8 x H100s
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \
--dataset_name open-r1/Mixture-of-Thoughts \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name smolagents/gaia-traces \
--num_train_epochs 1 \
--dataset_config all \
--eos_token '<|im_end|>' \
--learning_rate 4.0e-5 \
--num_train_epochs 5 \
--max_seq_length 32768 \
--per_device_train_batch_size 2 \
--gradient_checkpointing \
@ -39,20 +39,41 @@ import sys
import datasets
import transformers
from transformers import set_seed
from transformers import (
set_seed,
AutoModelForVision2Seq,
AutoProcessor,
LlavaForConditionalGeneration,
)
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import ScriptArguments, SFTConfig
from open_r1.utils import get_dataset, get_model, get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format
from open_r1.configs import ScriptArguments, SFTConfig
from open_r1.utils import get_dataset, get_model, get_tokenizer, get_processor
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from PIL import Image
from transformers import Qwen2VLProcessor
from typing import Any
from scripts.agents.function_parser import parse_function_call
from scripts.agents.smolvlm2_collator import create_vlm_collate_fn
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
def main(script_args, training_args, model_args):
# Force single GPU mode if requested
# if hasattr(script_args, 'single_gpu') and script_args.single_gpu:
# logger.info("Single GPU mode requested - setting CUDA_VISIBLE_DEVICES=0")
# # Disable distributed training
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# training_args.local_rank = -1
# training_args.ddp_backend = None
set_seed(training_args.seed)
###############
@ -85,15 +106,42 @@ def main(script_args, training_args, model_args):
init_wandb_training(training_args)
######################################
# Load dataset, tokenizer, and model #
# Load dataset, processor/tokenizer, and model #
######################################
dataset = get_dataset(script_args)
tokenizer = get_tokenizer(model_args, training_args)
model = get_model(model_args, training_args)
if tokenizer.chat_template is None:
logger.info("No chat template provided, defaulting to ChatML.")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
if training_args.vision_model:
logger.info("Setting up vision-language model training")
# Set VLM-specific training arguments (following TRL reference)
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
training_args.ddp_find_unused_parameters = True
# Load processor and model for VLM
processor = get_processor(model_args, training_args, script_args)
model = get_model(
model_args, training_args
) # This should return AutoModelForVision2Seq
data_collator = create_vlm_collate_fn(processor, training_args, script_args)
processing_class = processor.tokenizer
model_tags = ["open-r1", "vision-language", "vlm"]
else:
logger.info("Setting up text-only model training")
# Load tokenizer and model for text-only
tokenizer = get_tokenizer(model_args, training_args)
model = get_model(model_args, training_args)
if tokenizer.chat_template is None:
logger.info("No chat template provided, defaulting to ChatML.")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
data_collator = None # Use default
processing_class = tokenizer
model_tags = ["open-r1"]
############################
# Initialize the SFT Trainer
@ -101,9 +149,14 @@ def main(script_args, training_args, model_args):
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
processing_class=tokenizer,
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
processing_class=processing_class,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
@ -128,16 +181,17 @@ def main(script_args, training_args, model_args):
# Save model and create model card
##################################
logger.info("*** Save model ***")
# Align the model's generation config with the tokenizer's eos token
# to avoid unbounded generation in the transformers `pipeline()` function
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
try:
processor.save_pretrained(training_args.output_dir)
except Exception as e:
logger.error(f"Error saving processor: {e}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
"tags": model_tags,
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
@ -160,7 +214,10 @@ def main(script_args, training_args, model_args):
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
trainer.push_to_hub(**kwargs, token=os.getenv("HF_TOKEN"))
# Also push processor for VLM models
if training_args.vision_model and trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)
if __name__ == "__main__":

View file

@ -1,6 +1,6 @@
from .data import get_dataset
from .import_utils import is_e2b_available, is_morph_available
from .model_utils import get_model, get_tokenizer
from .model_utils import get_model, get_tokenizer, get_processor
__all__ = ["get_tokenizer", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]
__all__ = ["get_tokenizer", "get_processor", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]

View file

@ -1,12 +1,20 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedTokenizer,
AutoProcessor,
AutoModelForImageTextToText,
)
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
from ..configs import GRPOConfig, SFTConfig
from ..configs import GRPOConfig, SFTConfig, ScriptArguments
def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer:
def get_tokenizer(
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@ -20,10 +28,42 @@ def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
return tokenizer
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
"""Get the model"""
def get_processor(
model_args: ModelConfig,
training_args: SFTConfig | GRPOConfig,
script_args: ScriptArguments,
) -> AutoProcessor:
"""Get the processor for VLM models."""
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
# Set the image processor resize size
if script_args.image_resize is not None and "resolution_max_side" in script_args.image_resize:
processor.image_processor.size = {
"longest_edge": script_args.image_resize["resolution_max_side"]
}
if hasattr(processor, "tokenizer"):
processor.tokenizer.truncation_side = "right"
processor.tokenizer.padding_side = "right"
if training_args.chat_template is not None:
processor.chat_template = training_args.chat_template
return processor
def get_model(
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
) -> AutoModelForCausalLM | AutoModelForImageTextToText:
"""Get the model - supports both text-only and vision-language models"""
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
@ -35,8 +75,19 @@ def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) ->
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
# Check if this is a VLM model using the explicit flag
if hasattr(training_args, "vision_model") and training_args.vision_model:
# Load as vision-language model
model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
else:
# Load as text-only model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
return model

View file

View file

@ -1,219 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from datasets import load_dataset
from e2b_code_interpreter.models import Execution, ExecutionError
from open_r1.rewards import code_reward, ioi_code_reward
from open_r1.utils.routed_morph import RoutedMorphSandbox
from open_r1.utils.routed_sandbox import RoutedSandbox
class TestCodeRewards(unittest.TestCase):
def test_python_code_reward(self):
# requires E2B, see the README.md file
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
NUM_SAMPLES = 20
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
reward_kwargs = {"verification_info": [sample["verification_info"] for sample in samples]}
rewards = code_reward(test_completions, **reward_kwargs)
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_e2b_router(self):
# run router locally: python scripts/e2b_router.py
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
NUM_SAMPLES = 128
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
reward_kwargs = {"verification_info": [sample["verification_info"] for sample in samples]}
rewards = code_reward(test_completions, e2b_router_url="0.0.0.0:8000", **reward_kwargs)
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_e2b_router_parallel(self):
# run router locally: python scripts/e2b_router.py
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
BATCH_SIZE = 32
NUM_SAMPLES = 256
def batch_code_reward(examples):
test_completions = [[{"content": solution}] for solution in examples["gold_standard_solution"]]
reward_kwargs = {
"verification_info": [verification_info for verification_info in examples["verification_info"]]
}
rewards = code_reward(test_completions, e2b_router_url="0.0.0.0:8000", **reward_kwargs)
assert rewards == [1.0] * BATCH_SIZE
return examples
code_dataset = code_dataset["train"].select(range(NUM_SAMPLES))
code_dataset = code_dataset.map(
batch_code_reward,
batched=True,
batch_size=BATCH_SIZE,
num_proc=4,
load_from_cache_file=False,
)
def test_ioi_code_reward(self):
# This slow test case requires spinning up a bunch (I tested with ~64) of piston workers, see docs here
# slurm/piston/README.md
code_dataset = load_dataset("open-r1/ioi-reward-test-dataset")
NUM_SAMPLES = 16
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": f"```cpp\n{sample['sample_solution']}```"}] for sample in samples]
keys = [key for key in samples[0] if key not in ["prompt", "completion"]]
reward_kwargs = {key: [example[key] for example in samples] for key in keys}
rewards = ioi_code_reward(test_completions, **reward_kwargs)
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_e2b_router_run_code_success(self):
# run router locally: python scripts/e2b_router.py
routed_sandbox = RoutedSandbox(router_url="localhost:8000")
scripts = [
"print('hello from integration test')",
"result = 2 + 2\nprint(result)",
]
results = routed_sandbox.run_code(scripts)
assert len(results) == 2
for result in results:
assert isinstance(result, Execution)
# assert result.exit_code == 0
assert result.error is None
assert "hello" in result.logs["stdout"][0] or "4" in result.logs["stdout"][0]
def test_e2b_router_run_code_with_error(self):
# run router locally: python scripts/e2b_router.py
routed_sandbox = RoutedSandbox(router_url="localhost:8000")
scripts = ["print('this is fine')", "print('unterminated string"]
results = routed_sandbox.run_code(scripts)
assert len(results) == 2
# First one should be okay
# assert results[0].exit_code == 0 # Execution object has no attribute 'exit_code'
assert results[0].error is None
assert "this is fine" in results[0].logs["stdout"][0]
# Second one should have a syntax error
# assert results[1].exit_code != 0 # Execution object has no attribute 'exit_code'
assert results[1].error is not None
assert isinstance(results[1].error, ExecutionError)
assert "SyntaxError" in results[1].error.name
def test_python_code_reward_morph(self):
# requires MorphCloud, see the README.md file
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
NUM_SAMPLES = 20
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
reward_kwargs = {
"verification_info": [sample["verification_info"] for sample in samples],
"provider_type": "morph",
}
rewards = code_reward(test_completions, **reward_kwargs)
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_morph_router(self):
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
NUM_SAMPLES = 32
samples = code_dataset["train"].select(range(NUM_SAMPLES))
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
reward_kwargs = {
"verification_info": [sample["verification_info"] for sample in samples],
"provider_type": "morph",
"morph_router_url": "0.0.0.0:8001",
}
rewards = code_reward(test_completions, **reward_kwargs)
print(rewards)
assert rewards == [1.0] * NUM_SAMPLES
def test_morph_router_parallel(self):
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
BATCH_SIZE = 32
NUM_SAMPLES = 256
def batch_code_reward(examples):
test_completions = [[{"content": solution}] for solution in examples["gold_standard_solution"]]
reward_kwargs = {
"verification_info": [verification_info for verification_info in examples["verification_info"]],
"provider_type": "morph",
"morph_router_url": "0.0.0.0:8001",
}
rewards = code_reward(test_completions, **reward_kwargs)
assert rewards == [1.0] * BATCH_SIZE
return examples
code_dataset = code_dataset["train"].select(range(NUM_SAMPLES))
code_dataset = code_dataset.map(
batch_code_reward,
batched=True,
batch_size=BATCH_SIZE,
num_proc=4,
load_from_cache_file=False,
)
def test_morph_router_run_code_success(self):
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
routed_sandbox = RoutedMorphSandbox(router_url="localhost:8001")
scripts = [
"print('hello from morph integration test')",
"result = 2 + 2\nprint(result)",
]
results = routed_sandbox.run_code(scripts)
assert len(results) == 2
for result in results:
assert result.exception_str is None
assert "hello" in result.text or "4" in result.text
def test_morph_router_run_code_with_error(self):
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
routed_sandbox = RoutedMorphSandbox(router_url="localhost:8001")
scripts = ["print('this is fine with morph')", "print('unterminated string"]
results = routed_sandbox.run_code(scripts)
assert len(results) == 2
# First one should be okay
assert results[0].exception_str is None
assert "this is fine with morph" in results[0].text
# Second one should have a syntax error
assert "SyntaxError" in results[1].text
if __name__ == "__main__":
unittest.main()

View file

@ -1,568 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from dotenv import load_dotenv
from open_r1.configs import GRPOScriptArguments
from open_r1.rewards import (
accuracy_reward,
format_reward,
get_code_format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
get_reward_funcs,
get_soft_overlong_punishment,
len_reward,
reasoning_steps_reward,
tag_count_reward,
)
load_dotenv()
class TestGetRewardFuncs(unittest.TestCase):
def test_get_reward_funcs(self):
"""Test get_reward_funcs with various reward functions."""
reward_names = [
"accuracy",
"format",
"reasoning_steps",
"cosine",
"repetition_penalty",
"length",
"tag_count",
"code",
"ioi_code",
"code_format",
"binary_code",
]
reward_func_names = [
"accuracy_reward",
"format_reward",
"reasoning_steps_reward",
"cosine_scaled_reward",
"repetition_penalty_reward",
"len_reward",
"tag_count_reward",
"code_reward",
"ioi_code_reward",
"code_format_reward",
"binary_code_reward",
]
args = GRPOScriptArguments(
dataset_name="dummy",
reward_funcs=reward_names,
)
reward_funcs = get_reward_funcs(args)
self.assertEqual(len(reward_funcs), 11)
for func_name, func in zip(reward_func_names, reward_funcs):
self.assertEqual(func_name, func.__name__)
class TestRewards(unittest.TestCase):
def test_accuracy_reward_correct_answer(self):
"""Test accuracy_reward with a correct answer."""
completion = [[{"content": r"\boxed{\frac{63}{400}}"}]]
solution = [r"\frac{63}{400}"]
rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 1.0)
def test_accuracy_reward_wrong_answer(self):
"""Test accuracy_reward with an incorrect answer."""
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
solution = [r"\frac{63}{400}"]
rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 0.0)
def test_accuracy_reward_wrong_answer_no_latex(self):
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
completion = [[{"content": r"\boxed{3}"}]]
solution = ["6"]
rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 0.0)
def test_format_reward_correct(self):
"""Test format_reward with correct format."""
completion = [[{"content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer\n</answer>"}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)
def test_format_reward_incorrect(self):
"""Test format_reward with incorrect format."""
incorrect_formats = [
"<think>Only thinking</think>",
"<answer>Only answer</answer>",
"No tags at all",
"<think>Missing closing</think><answer>Missing closing",
"<think>Wrong order</answer><answer>Wrong order</think>",
]
for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 0.0)
def test_reasoning_steps_reward(self):
"""Test reasoning_steps_reward with various formats."""
test_cases = [
# Full credit cases (3 or more steps)
("Step 1: First step\nStep 2: Second step\nStep 3: Third step", 1.0),
("First, we do this.\nSecond, we do that.\nFinally, we conclude.", 1.0),
# Partial credit cases (less than 3 steps)
("Step 1: Only step", 1 / 3),
("First, we do this.\nFinally, we conclude.", 2 / 3),
# No credit case
("Just plain text without any clear steps", 0.0),
]
for content, expected_reward in test_cases:
completion = [[{"content": content}]]
rewards = reasoning_steps_reward(completion)
self.assertAlmostEqual(rewards[0], expected_reward)
def test_multiple_completions(self):
"""Test handling multiple completions at once."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}],
[{"content": r"\boxed{\frac{64}{400}}"}],
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = accuracy_reward(completions, solutions)
self.assertEqual(len(rewards), 2)
self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)
def test_cosine_scaled_reward(self):
"""Test cosine_scaled_reward with various cases."""
# Test parameters
test_params = {
"min_value_wrong": -1.0,
"max_value_wrong": -0.5,
"min_value_correct": 0.5,
"max_value_correct": 1.0,
"max_len": 100,
}
test_cases = [
# Correct answers with different lengths
(
r"\boxed{\frac{63}{400}}",
r"\frac{63}{400}",
20,
0.943,
), # Short correct answer
(
r"\boxed{\frac{63}{400}}",
r"\frac{63}{400}",
80,
0.547,
), # Long correct answer
# Wrong answers with different lengths
(
r"\boxed{\frac{64}{400}}",
r"\frac{63}{400}",
20,
-0.942,
), # Short wrong answer
(
r"\boxed{\frac{64}{400}}",
r"\frac{63}{400}",
80,
-0.547,
), # Long wrong answer
]
for content, solution, content_len, expected_reward in test_cases:
# Pad content to desired length
padded_content = content + " " * (content_len - len(content))
completion = [[{"content": padded_content}]]
rewards = get_cosine_scaled_reward(**test_params)(completion, [solution])
self.assertAlmostEqual(rewards[0], expected_reward, places=2)
def test_format_reward_specific_multiline(self):
"""Test format_reward with a specific multiline input."""
inputs = "<think>\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n</think>\n<answer>\n9\n</answer>"
completion = [[{"content": inputs}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)
def test_same_length_responses(self):
"""Test len_reward when all responses have the same length."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}],
[{"content": r"\boxed{\frac{64}{400}}"}],
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = len_reward(completions, solutions)
self.assertEqual(rewards, [0.0, 0.0])
def test_different_lengths_correct_answers(self):
"""Test len_reward with different length correct answers."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}], # shorter
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = len_reward(completions, solutions)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward
self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward
def test_different_lengths_incorrect_answers(self):
"""Test len_reward with different length incorrect answers."""
completions = [
[{"content": r"\boxed{\frac{64}{400}}"}], # shorter
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = len_reward(completions, solutions)
self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards
self.assertLessEqual(rewards[1], 0.0)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less
def test_mixed_correctness(self):
"""Test len_reward with mix of correct and incorrect answers of different lengths."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer
[{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer
]
solutions = [r"\frac{63}{400}"] * 4
rewards = len_reward(completions, solutions)
# Shortest correct answer should get positive reward
self.assertGreater(rewards[0], 0.0)
# Longer correct answer might get negative reward:
self.assertGreater(rewards[2], rewards[1])
self.assertGreaterEqual(rewards[1], rewards[3])
# Incorrect answers should get non-positive rewards
self.assertLessEqual(rewards[2], 0.0)
self.assertLessEqual(rewards[3], 0.0)
# Shorter answers should get better rewards within their correctness category
self.assertGreater(rewards[0], rewards[1]) # correct answers
self.assertGreater(rewards[2], rewards[3]) # incorrect answers
def test_unparseable_solution(self):
"""Test len_reward with unparseable solution."""
completions = [
[{"content": r"\boxed{answer}"}],
[{"content": r"\boxed{answer} " + "x" * 10}],
]
solutions = ["unparseable_latex", "unparseable_latex"]
rewards = len_reward(completions, solutions)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward
self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward
class TestRepetitionPenaltyReward(unittest.TestCase):
def test_positive_max_penalty_raises_value_error(self):
with self.assertRaises(ValueError):
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0)
with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"):
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5)
def test_no_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this is a test sentence"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_full_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this this this this this"}]]
rewards = reward_fn(completions)
# (1 - 1/4) * -1 = -0.75
self.assertEqual(rewards, [-0.75])
def test_partial_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this is a this is a test"}]]
rewards = reward_fn(completions)
# Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total
# (1 - 4/6) * -1 = -1/3 = -0.3333...
self.assertAlmostEqual(rewards[0], -1 / 3)
def test_multiple_completions(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
completions = [
[{"content": "this is a test"}],
[{"content": "test test test test"}],
]
rewards = reward_fn(completions)
# Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0
# Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25
self.assertAlmostEqual(rewards[0], 0.0)
self.assertAlmostEqual(rewards[1], -0.25)
def test_empty_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": ""}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_different_ngram_size(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0)
completions = [[{"content": "this is a this is a test"}]]
rewards = reward_fn(completions)
self.assertAlmostEqual(rewards[0], -0.4)
def test_mixed_case(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [
[{"content": "This is A Test"}],
[{"content": "this IS a test"}],
]
rewards = reward_fn(completions)
# both completions should produce the same reward, because the text gets lowercased
self.assertAlmostEqual(rewards[0], rewards[1])
def test_one_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "word"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_two_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "two words"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_three_word_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "three different words"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_three_word_repetition_completion(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "word word word word"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.5])
def test_four_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "one two one two"}]]
rewards = reward_fn(completions)
# ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1.
self.assertEqual(rewards, [0.0])
def test_five_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
completions = [[{"content": "A B C A B"}]]
rewards = reward_fn(completions)
# (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0
self.assertEqual(rewards, [0.0])
def test_six_word_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C A B C"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.25])
def test_long_completion_with_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C A B C E F G A B C A B C"}]]
rewards = reward_fn(completions)
self.assertAlmostEqual(rewards[0], -0.3846, places=4)
def test_long_completion_without_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
completions = [[{"content": "A B C D E F G H I J K L"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
def test_tag_count_rewards_all_correct(self):
"""Test tag_count_reward with correct tags."""
completion = [[{"content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer\n</answer>"}]]
rewards = tag_count_reward(completion)
self.assertEqual(rewards[0], 1.0)
def test_tag_count_rewards_missing_think_begin(self):
"""Test tag_count_reward with missing <think> tag."""
completion = [[{"content": "Some reasoning\n</think>\n<answer>\nThe answer\n</answer>"}]]
rewards = tag_count_reward(completion)
self.assertEqual(rewards[0], 0.75)
def test_tag_count_rewards_missing_think_end(self):
"""Test tag_count_reward with missing </think> tag."""
completion = [[{"content": "<think>\nSome reasoning\n<answer>\nThe answer\n</answer>"}]]
rewards = tag_count_reward(completion)
self.assertEqual(rewards[0], 0.75)
def test_tag_count_rewards_missing_answer_begin(self):
"""Test tag_count_reward with missing <answer> tag."""
completion = [[{"content": "<think>\nSome reasoning\n</think>\nThe answer\n</answer>"}]]
rewards = tag_count_reward(completion)
self.assertEqual(rewards[0], 0.75)
def test_tag_count_rewards_missing_answer_end(self):
"""Test tag_count_reward with missing </answer> tag."""
completion = [[{"content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer"}]]
rewards = tag_count_reward(completion)
self.assertEqual(rewards[0], 0.75)
def test_tag_count_rewards_missing_all_tags(self):
"""Test tag_count_reward with missing all tags."""
completion = [[{"content": "Some reasoning\nThe answer"}]]
rewards = tag_count_reward(completion)
self.assertEqual(rewards[0], 0.0)
def test_full_repetition_with_language(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="en")
completions = [[{"content": "that that that that that"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.75])
# begin test for zh language
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="zh")
completions = [[{"content": "这个这个这个这个这个"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.75])
def test_soft_overlong_punishment_short_completion(self):
"""Test soft overlong punishment reward function with a short completion."""
# length 50, with max=100 and soft cache=20, reward should be 0.
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 50] # 50 <= 80
rewards = reward_fn(completion_ids=completion_ids)
self.assertEqual(rewards, [0])
def test_soft_overlong_punishment_long_completion(self):
"""Test soft overlong punishment reward function with a longer than max completion."""
# 110 > 100, reward should be -1.
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 110]
rewards = reward_fn(completion_ids)
self.assertEqual(rewards, [-1])
def test_soft_overlong_punishment_intermediate_completion(self):
"""Test soft overlong punishment reward function for intermediate length completion."""
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 90] # 90 is between 80 and 100
rewards = reward_fn(completion_ids)
self.assertAlmostEqual(rewards[0], -0.5, places=4)
class TestCodeFormat(unittest.TestCase):
def test_correct_python_format(self):
"""Test code format reward with correct Python format."""
completion = [
[
{
"content": "<think>\nLet's solve this\nStep 1: First step\n</think>\n<answer>\n```python\ndef hello():\n print('world')\n```\n</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)
def test_incorrect_formats(self):
"""Test code format reward with various incorrect formats."""
incorrect_formats = [
# Missing think/answer tags
"```python\ndef hello():\n print('world')\n```",
# Missing code block
"<think>Some thinking</think><answer>Just plain text</answer>",
# Wrong language
"<think>Analysis</think><answer>```javascript\nconsole.log('hello');\n```</answer>",
# Missing language identifier
"<think>Analysis</think><answer>```\ndef hello(): pass\n```</answer>",
# Wrong order of tags
"<answer>```python\ndef hello(): pass\n```</answer><think>Analysis</think>",
]
reward_fn = get_code_format_reward(language="python")
for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 0.0)
def test_multiple_code_blocks(self):
"""Test format reward with multiple code blocks in think and answer sections."""
completion = [
[
{
"content": "<think>\nHere's an example:\n```python\nx = 1\n```\nNow the solution:\n</think>\n<answer>\n```python\ndef solution():\n return 42\n```\n</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)
def test_different_languages(self):
"""Test code format reward with different programming languages."""
completion = [
[
{
"content": "<think>\nAnalysis\n</think>\n<answer>\n```javascript\nconsole.log('hello');\n```\n</answer>"
}
]
]
# Test with JavaScript
js_reward_fn = get_code_format_reward(language="javascript")
rewards = js_reward_fn(completion)
self.assertEqual(rewards[0], 1.0)
# Same completion should fail for Python
py_reward_fn = get_code_format_reward(language="python")
rewards = py_reward_fn(completion)
self.assertEqual(rewards[0], 0.0)
def test_multiline_code(self):
"""Test format reward with complex multiline code blocks."""
completion = [
[
{
"content": "<think>\nHere's the analysis\n</think>\n<answer>\n```python\nclass Solution:\n def __init__(self):\n self.value = 42\n \n def get_value(self):\n return self.value\n```\n</answer>"
}
]
]
reward_fn = get_code_format_reward(language="python")
rewards = reward_fn(completion)
self.assertEqual(rewards[0], 1.0)
if __name__ == "__main__":
unittest.main()

View file

@ -1,129 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from dataclasses import asdict
from datasets import DatasetDict, load_dataset
from open_r1.configs import DatasetConfig, DatasetMixtureConfig, ScriptArguments
from open_r1.utils.data import get_dataset
class TestGetDataset(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dataset_name = "trl-internal-testing/zen"
cls.dataset_config = "conversational_preference"
cls.ref_dataset = load_dataset(cls.dataset_name, cls.dataset_config)
def test_dataset_and_config_name(self):
args = ScriptArguments(dataset_name=self.dataset_name, dataset_config=self.dataset_config)
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertEqual(len(dataset["train"]), len(self.ref_dataset["train"]))
def test_unweighted_mixture(self):
"""Mix train and test splits of the same dataset."""
dataset_configs = [
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="train", columns=None, weight=None),
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="test", columns=None, weight=None),
]
dataset_mixture = DatasetMixtureConfig(
datasets=dataset_configs,
)
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertEqual(len(dataset["train"]), len(self.ref_dataset["train"]) + len(self.ref_dataset["test"]))
def test_weighted_mixture(self):
"""Test loading a dataset mixture with weights."""
dataset_configs = [
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="train", columns=None, weight=0.25),
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="test", columns=None, weight=0.5),
]
dataset_mixture = DatasetMixtureConfig(
datasets=dataset_configs,
)
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertEqual(
len(dataset["train"]), len(self.ref_dataset["train"]) // 4 + len(self.ref_dataset["test"]) // 2
)
def test_mixture_and_test_split(self):
"""Test loading a dataset mixture with test split."""
dataset_configs = [
DatasetConfig(
id=self.dataset_name, config=self.dataset_config, split="train[:10]", columns=None, weight=None
),
]
dataset_mixture = DatasetMixtureConfig(datasets=dataset_configs, test_split_size=0.2)
args = ScriptArguments(dataset_name=None, dataset_mixture=asdict(dataset_mixture))
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertIn("test", dataset)
self.assertEqual(len(dataset["train"]), 8)
self.assertEqual(len(dataset["test"]), 2)
def test_mixture_column_selection(self):
"""Test loading a dataset mixture with column selection."""
dataset_configs = [
DatasetConfig(
id=self.dataset_name,
config=self.dataset_config,
split="train",
columns=["prompt", "chosen"],
weight=None,
),
]
dataset_mixture = DatasetMixtureConfig(
datasets=dataset_configs,
)
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
dataset = get_dataset(args)
self.assertIsInstance(dataset, DatasetDict)
self.assertIn("train", dataset)
self.assertIn("prompt", dataset["train"].column_names)
self.assertIn("chosen", dataset["train"].column_names)
def test_mixture_with_mismatched_columns(self):
dataset_configs = [
DatasetConfig(
id=self.dataset_name, config=self.dataset_config, split="train", columns=["prompt"], weight=None
),
DatasetConfig(
id=self.dataset_name, config=self.dataset_config, split="train", columns=["chosen"], weight=None
),
]
dataset_mixture = DatasetMixtureConfig(
datasets=dataset_configs,
)
with self.assertRaises(ValueError) as context:
_ = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
self.assertIn("Column names must be consistent", str(context.exception))
def test_no_dataset_name_or_mixture(self):
with self.assertRaises(ValueError) as context:
_ = ScriptArguments(dataset_name=None, dataset_mixture=None)
self.assertIn("Either `dataset_name` or `dataset_mixture` must be provided", str(context.exception))
if __name__ == "__main__":
unittest.main()

5
token_stat Normal file
View file

@ -0,0 +1,5 @@
max length: 3678
total user token: 52272359
total system token: 326634875
total answer token: 32224010
total token: 843473151