mirror of
https://github.com/huggingface/open-r1.git
synced 2026-06-24 01:54:06 +00:00
Compare commits
16 commits
main
...
gui-traini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31a99af5bb | ||
|
|
76c56724ee | ||
|
|
b5a27167f1 | ||
|
|
dca3a06ada | ||
|
|
803c468507 | ||
|
|
afbd97b1ec | ||
|
|
55c49d66c3 | ||
|
|
4c89c85fff | ||
|
|
2ef6b50ccd | ||
|
|
7852ddefc8 | ||
|
|
648a523325 | ||
|
|
342f8f7856 | ||
|
|
02819cf0ab | ||
|
|
4a55c49641 | ||
|
|
dfcaecc92c | ||
|
|
c13574e28a |
28 changed files with 3136 additions and 969 deletions
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
|
|
@ -16,9 +16,9 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Python environment
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.10.10
|
||||
- name: Install dependencies
|
||||
|
|
|
|||
|
|
@ -299,7 +299,6 @@ Make sure your dataset contains a `verification_info` column with the following
|
|||
}
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
For example, to train a smol model on Python problems, start the vLLM server:
|
||||
|
||||
|
|
|
|||
61
inference/Qwen2.5-VL-3B-instruct.py
Normal file
61
inference/Qwen2.5-VL-3B-instruct.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from datasets import load_dataset
|
||||
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"smolagents/Qwen2.5-VL-3B-Instruct-Agentic", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
|
||||
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
# "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# attn_implementation="flash_attention_2",
|
||||
# device_map="auto",
|
||||
# )
|
||||
|
||||
# default processer
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
|
||||
# The default range for the number of visual tokens per image in the model is 4-16384.
|
||||
# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
|
||||
# min_pixels = 256*28*28
|
||||
# max_pixels = 1280*28*28
|
||||
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
dataset = load_dataset("smolagents/aguvis-stage-2", "mind2web", split="train")
|
||||
|
||||
for example in dataset:
|
||||
messages = [
|
||||
{"role": "system", "content": example["system"]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image", "image": example["image"]},
|
||||
{"type": "text", "text": example["user"]}
|
||||
]},
|
||||
]
|
||||
break
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=4096)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
142
recipes/Qwen2.5-VL-3B-Instruct/sft/config_gui.yaml
Normal file
142
recipes/Qwen2.5-VL-3B-Instruct/sft/config_gui.yaml
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
|
||||
vision_model: true
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: smolagents/aguvis-stage-2
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 4096
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_grad_norm: 0.2
|
||||
warmup_ratio: 0.03
|
||||
learning_rate: 2.0e-05
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
# Image resize arguments
|
||||
image_resize:
|
||||
factor: 28
|
||||
min_pixels: 200704
|
||||
max_pixels: 1003520
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
bf16: true
|
||||
do_eval: true
|
||||
eval_strategy: 'steps'
|
||||
eval_steps: 100
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: A-Mahla/Qwen2.5-VL-3B-Instruct-Agentic-GUI
|
||||
hub_strategy: end
|
||||
push_to_hub: true
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
output_dir: /fsx/amir_mahla/smolagents-Qwen2.5-VL-3B-Instruct-Agentic
|
||||
overwrite_output_dir: true
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_project: smolagents
|
||||
save_strategy: "epoch"
|
||||
save_steps: 1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
|
||||
dataset_mixture:
|
||||
datasets: # List of datasets to include in the mixture
|
||||
- id: smolagents/aguvis-stage-2 # Hub dataset ID
|
||||
config: mind2web # Name of the dataset config
|
||||
split: train # Split to use from the dataset
|
||||
columns: # Columns to keep
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: guiact-web-single
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: guiact-web-multi
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: miniwob
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: coat
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: android_control
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: gui-odyssey
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: amex
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: aitw
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
seed: 42 # Seed for shuffling the combined dataset
|
||||
test_split_size: 0.01
|
||||
116
recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_1152.yaml
Normal file
116
recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_1152.yaml
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
vision_model: true
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: smolagents/aguvis-stage-2
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 4096
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_grad_norm: 0.2
|
||||
warmup_ratio: 0.03
|
||||
learning_rate: 2.0e-05
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_eval_batch_size: 2
|
||||
per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
image_resize:
|
||||
resolution_max_side: 1152
|
||||
to_pixel_coordinates: true
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'steps'
|
||||
eval_steps: 100
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI
|
||||
hub_model_revision: main
|
||||
hub_strategy: end
|
||||
push_to_hub: false
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-1152-pixel-coordinates
|
||||
overwrite_output_dir: true
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_project: smolagents
|
||||
save_strategy: steps
|
||||
save_steps: 800
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
|
||||
dataset_mixture:
|
||||
datasets: # List of datasets to include in the mixture
|
||||
- id: smolagents/aguvis-stage-1 # Hub dataset ID
|
||||
config: guienv # Name of the dataset config
|
||||
split: train # Split to use from the dataset
|
||||
columns: # Columns to keep
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: omniact
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ricoig16k
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ricosca
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: seeclick
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ui_refexp
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: webui350k
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: widget_captioning
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
seed: 42 # Seed for shuffling the combined dataset
|
||||
test_split_size: 0.007
|
||||
116
recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_384.yaml
Normal file
116
recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_384.yaml
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
vision_model: true
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: smolagents/aguvis-stage-2
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 4096
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_grad_norm: 0.2
|
||||
warmup_ratio: 0.03
|
||||
learning_rate: 2.0e-05
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_eval_batch_size: 2
|
||||
per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
image_resize:
|
||||
resolution_max_side: 384
|
||||
to_pixel_coordinates: true
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'steps'
|
||||
eval_steps: 100
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI
|
||||
hub_model_revision: main
|
||||
hub_strategy: end
|
||||
push_to_hub: false
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-384-pixel-coordinates
|
||||
overwrite_output_dir: true
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_project: smolagents
|
||||
save_strategy: steps
|
||||
save_steps: 800
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
|
||||
dataset_mixture:
|
||||
datasets: # List of datasets to include in the mixture
|
||||
- id: smolagents/aguvis-stage-1 # Hub dataset ID
|
||||
config: guienv # Name of the dataset config
|
||||
split: train # Split to use from the dataset
|
||||
columns: # Columns to keep
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: omniact
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ricoig16k
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ricosca
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: seeclick
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ui_refexp
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: webui350k
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: widget_captioning
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
seed: 42 # Seed for shuffling the combined dataset
|
||||
test_split_size: 0.007
|
||||
116
recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_764.yaml
Normal file
116
recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_1_764.yaml
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
vision_model: true
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: smolagents/aguvis-stage-2
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 4096
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_grad_norm: 0.2
|
||||
warmup_ratio: 0.03
|
||||
learning_rate: 2.0e-05
|
||||
gradient_accumulation_steps: 32
|
||||
per_device_eval_batch_size: 2
|
||||
per_device_train_batch_size: 2 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
image_resize:
|
||||
resolution_max_side: 764
|
||||
to_pixel_coordinates: true
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'steps'
|
||||
eval_steps: 100
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI
|
||||
hub_model_revision: main
|
||||
hub_strategy: end
|
||||
push_to_hub: false
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-764-pixel-coordinates
|
||||
overwrite_output_dir: true
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_project: smolagents
|
||||
save_strategy: steps
|
||||
save_steps: 800
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
|
||||
dataset_mixture:
|
||||
datasets: # List of datasets to include in the mixture
|
||||
- id: smolagents/aguvis-stage-1 # Hub dataset ID
|
||||
config: guienv # Name of the dataset config
|
||||
split: train # Split to use from the dataset
|
||||
columns: # Columns to keep
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: omniact
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ricoig16k
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ricosca
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: seeclick
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: ui_refexp
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: webui350k
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-1
|
||||
config: widget_captioning
|
||||
split: train
|
||||
columns:
|
||||
- images
|
||||
- texts
|
||||
weight: 1.
|
||||
seed: 42 # Seed for shuffling the combined dataset
|
||||
test_split_size: 0.007
|
||||
137
recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_2.yaml
Normal file
137
recipes/SmolVLM2-2.2B-Instruct/sft/config_gui_phase_2.yaml
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
# Model arguments
|
||||
# You can download the model and manually change the rope to 300k/500k and max_position_embeddings to 32768
|
||||
model_name_or_path: HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
vision_model: true
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: sdpa
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: smolagents/aguvis-stage-2
|
||||
dataset_num_proc: 48
|
||||
|
||||
#SFT hyperparam
|
||||
max_length: 4096
|
||||
optim: adamw_torch
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_grad_norm: 0.2
|
||||
warmup_ratio: 0.03
|
||||
learning_rate: 2.0e-05
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 4 # Change this depending on the context length of the model to keep a 500M GBS.
|
||||
|
||||
# SFT trainer config
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
bf16: true
|
||||
do_eval: true
|
||||
eval_strategy: 'steps'
|
||||
eval_steps: 100
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI
|
||||
hub_model_revision: max-resolution-1152-without-system
|
||||
hub_strategy: end
|
||||
push_to_hub: true
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
output_dir: /fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-max-resolution-1152-without-system
|
||||
overwrite_output_dir: true
|
||||
report_to:
|
||||
- wandb
|
||||
wandb_project: smolagents
|
||||
save_strategy: "epoch"
|
||||
save_steps: 1
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
|
||||
dataset_mixture:
|
||||
datasets: # List of datasets to include in the mixture
|
||||
- id: smolagents/aguvis-stage-2 # Hub dataset ID
|
||||
config: mind2web # Name of the dataset config
|
||||
split: train # Split to use from the dataset
|
||||
columns: # Columns to keep
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: guiact-web-single
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: guiact-web-multi
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: miniwob
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: coat
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: android_control
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: gui-odyssey
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: amex
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
- id: smolagents/aguvis-stage-2
|
||||
config: aitw
|
||||
split: train
|
||||
columns:
|
||||
- system
|
||||
- user
|
||||
- assistant
|
||||
- image
|
||||
weight: 1.
|
||||
seed: 42 # Seed for shuffling the combined dataset
|
||||
test_split_size: 0.01
|
||||
194
scripts/agents/action_conversion.py
Normal file
194
scripts/agents/action_conversion.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
from function_parser import FunctionCall
|
||||
from copy import deepcopy
|
||||
|
||||
# from aguvis aguvis action space to custom action space:
|
||||
|
||||
# mobile.home() -> navigate_home()
|
||||
# mobile.open_app(app_name='drupe') -> open_app(app_name: str) -> str:
|
||||
# mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518]) -> swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])
|
||||
# mobile.back() -> navigate_back()
|
||||
# mobile.long_press(x=0.799, y=0.911) -> long_press(x, y)
|
||||
# mobile.terminate(status='success') -> final_answer(answer: str)
|
||||
|
||||
# answer('text') -> final_answer('text') OK
|
||||
# mobile.wait(seconds=3) -> wait(seconds=3) OK
|
||||
# pyautogui.hscroll(page=-0.1)
|
||||
# ?
|
||||
# pyautogui.scroll(page=-0.1) or pyautogui.scroll(0.13) OK
|
||||
# -> negative: scroll(direction: Literal["up", "down"] = "up", amount: int = abs(page * 10))
|
||||
# -> positive: scroll(direction: Literal["up", "down"] = "down", amount: int = abs(page * 10))
|
||||
# pyautogui.click(x=0.8102, y=0.9463) -> click(x: int, y: int) OK
|
||||
# pyautogui.doubleClick() -> double_click() OK
|
||||
# pyautogui.hotkey(keys=['ctrl', 'c']) -> press(keys: str | list) OK
|
||||
# pyautogui.press(keys='enter') or pyautogui.press(keys=['enter']) -> press(keys: str | list) OK
|
||||
# pyautogui.moveTo(x=0.04, y=0.405) -> move_mouse(x: int, y: int) OK
|
||||
# pyautogui.write(message='bread buns') -> type(text: str) OK
|
||||
# pyautogui.dragTo(x=0.8102, y=0.9463) -> drag(x1, y1, x2, y2) OK but to recheck formatage in official dataset
|
||||
|
||||
|
||||
def convert_to_pixel_coordinates(action: FunctionCall, resolution: tuple[int, int]) -> None:
|
||||
if "arg_0" in action.parameters:
|
||||
if isinstance(action.parameters["arg_0"], (list, tuple)):
|
||||
action.parameters["from_coord"] = (int(action.parameters["arg_0"][0] * resolution[0]), int(action.parameters["arg_0"][1] * resolution[1]))
|
||||
else:
|
||||
action.parameters["x"] = int(action.parameters["arg_0"] * resolution[0])
|
||||
del action.parameters["arg_0"]
|
||||
if "arg_1" in action.parameters:
|
||||
if isinstance(action.parameters["arg_1"], (list, tuple)):
|
||||
action.parameters["to_coord"] = (int(action.parameters["arg_1"][0] * resolution[0]), int(action.parameters["arg_1"][1] * resolution[1]))
|
||||
else:
|
||||
action.parameters["y"] = int(action.parameters["arg_1"] * resolution[1])
|
||||
del action.parameters["arg_1"]
|
||||
|
||||
def change_argument_name(action: FunctionCall) -> None:
|
||||
if "arg_0" in action.parameters:
|
||||
if isinstance(action.parameters["arg_0"], (list, tuple)):
|
||||
action.parameters["from_coord"] = (float(action.parameters["arg_0"][0]), float(action.parameters["arg_0"][1]))
|
||||
else:
|
||||
action.parameters["x"] = float(action.parameters["arg_0"])
|
||||
del action.parameters["arg_0"]
|
||||
if "arg_1" in action.parameters:
|
||||
if isinstance(action.parameters["arg_1"], (list, tuple)):
|
||||
action.parameters["to_coord"] = (float(action.parameters["arg_1"][0]), float(action.parameters["arg_1"][1]))
|
||||
else:
|
||||
action.parameters["y"] = float(action.parameters["arg_1"])
|
||||
del action.parameters["arg_1"]
|
||||
|
||||
|
||||
def rename_parameters(action: FunctionCall) -> None:
|
||||
"""
|
||||
Reorder FunctionCall parameters to use arg_0, arg_1, arg_2, etc. as keys.
|
||||
Preserves the order of the original parameters.
|
||||
|
||||
Args:
|
||||
action: FunctionCall object to reorder parameters for
|
||||
|
||||
"""
|
||||
if not action.parameters:
|
||||
return
|
||||
|
||||
for i, (key, value) in enumerate(deepcopy(action.parameters).items()):
|
||||
tmp = value
|
||||
del action.parameters[key]
|
||||
action.parameters[f"arg_{i}"] = tmp
|
||||
|
||||
|
||||
|
||||
def action_conversion(
|
||||
actions: list[FunctionCall], resolution: tuple[int, int]
|
||||
) -> list[FunctionCall]:
|
||||
for i, action in enumerate(actions):
|
||||
rename_parameters(action)
|
||||
# MOBILE ACTIONS
|
||||
if action.function_name == "mobile.home":
|
||||
actions[i].function_name = "navigate_home"
|
||||
|
||||
elif action.function_name == "mobile.open_app":
|
||||
actions[i].function_name = "open_app"
|
||||
|
||||
elif action.function_name == "mobile.swipe":
|
||||
actions[i].function_name = "swipe"
|
||||
change_argument_name(actions[i])
|
||||
|
||||
elif action.function_name == "mobile.back":
|
||||
actions[i].function_name = "navigate_back"
|
||||
|
||||
elif action.function_name == "mobile.long_press":
|
||||
actions[i].function_name = "long_press"
|
||||
change_argument_name(actions[i])
|
||||
|
||||
elif action.function_name in ["mobile.terminate", "answer"]:
|
||||
actions[i].function_name = "final_answer"
|
||||
|
||||
elif action.function_name == "mobile.wait":
|
||||
actions[i].function_name = "wait"
|
||||
if "arg_0" in actions[i].parameters:
|
||||
actions[i].parameters["seconds"] = int(actions[i].parameters["arg_0"])
|
||||
del actions[i].parameters["arg_0"]
|
||||
|
||||
# OS ACTION
|
||||
elif action.function_name == "pyautogui.click":
|
||||
actions[i].function_name = "click"
|
||||
change_argument_name(actions[i])
|
||||
|
||||
elif action.function_name == "pyautogui.doubleClick":
|
||||
actions[i].function_name = "double_click"
|
||||
change_argument_name(actions[i])
|
||||
|
||||
elif action.function_name == "pyautogui.rightClick":
|
||||
actions[i].function_name = "right_click"
|
||||
change_argument_name(actions[i])
|
||||
|
||||
elif action.function_name in ["pyautogui.hotkey", "pyautogui.press"]:
|
||||
actions[i].function_name = "press"
|
||||
if "arg_0" in actions[i].parameters:
|
||||
actions[i].parameters["keys"] = actions[i].parameters["arg_0"]
|
||||
del actions[i].parameters["arg_0"]
|
||||
|
||||
elif action.function_name == "pyautogui.moveTo":
|
||||
actions[i].function_name = "move_mouse"
|
||||
change_argument_name(actions[i])
|
||||
|
||||
elif action.function_name == "pyautogui.write":
|
||||
actions[i].function_name = "type"
|
||||
|
||||
elif action.function_name in ["pyautogui.scroll", "pyautogui.hscroll"]:
|
||||
arg_value = actions[i].parameters["arg_0"]
|
||||
if arg_value < 0:
|
||||
if action.function_name == "pyautogui.hscroll":
|
||||
actions[i].parameters["direction"] = "left"
|
||||
else:
|
||||
actions[i].parameters["direction"] = "up"
|
||||
else:
|
||||
if action.function_name == "pyautogui.hscroll":
|
||||
actions[i].parameters["direction"] = "right"
|
||||
else:
|
||||
actions[i].parameters["direction"] = "down"
|
||||
del actions[i].parameters["arg_0"]
|
||||
actions[i].function_name = "scroll"
|
||||
actions[i].parameters["amount"] = int(abs(arg_value * 100))
|
||||
|
||||
elif action.function_name == "pyautogui.dragTo":
|
||||
actions[i].function_name = "drag"
|
||||
change_argument_name(actions[i])
|
||||
|
||||
else:
|
||||
ValueError("Error FonctionCall Formatting")
|
||||
|
||||
actions[i].original_string = actions[i].to_string()
|
||||
|
||||
return actions
|
||||
|
||||
if __name__ == "__main__":
|
||||
from function_parser import FunctionCall
|
||||
|
||||
# Example actions for all function types
|
||||
actions = [
|
||||
# MOBILE ACTIONS
|
||||
FunctionCall("mobile.home", {}, "mobile.home()"),
|
||||
FunctionCall("mobile.open_app", {"app_name": "drupe"}, "mobile.open_app(app_name='drupe')"),
|
||||
FunctionCall("mobile.swipe", {"from_coord": [0.581, 0.898], "to_coord": [0.601, 0.518]}, "mobile.swipe(from_coord=[0.581,0.898],to_coord=[0.601,0.518])"),
|
||||
FunctionCall("mobile.back", {}, "mobile.back()"),
|
||||
FunctionCall("mobile.long_press", {"x": 0.799, "y": 0.911}, "mobile.long_press(x=0.799, y=0.911)"),
|
||||
FunctionCall("mobile.terminate", {"status": "success"}, "mobile.terminate(status='success')"),
|
||||
FunctionCall("answer", {"arg_0": "text"}, "answer('text')"),
|
||||
FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)"),
|
||||
# OS ACTIONS
|
||||
FunctionCall("pyautogui.hscroll", {"page": -0.1}, "pyautogui.hscroll(page=-0.1)"),
|
||||
FunctionCall("pyautogui.scroll", {"page": 0.13}, "pyautogui.scroll(page=0.13)"),
|
||||
FunctionCall("pyautogui.click", {"x": 0.8102, "y": 0.9463}, "pyautogui.click(x=0.8102, y=0.9463)"),
|
||||
FunctionCall("pyautogui.doubleClick", {}, "pyautogui.doubleClick()"),
|
||||
FunctionCall("pyautogui.hotkey", {"keys": ["ctrl", "c"]}, "pyautogui.hotkey(keys=['ctrl','c'])"),
|
||||
FunctionCall("pyautogui.press", {"keys": "enter"}, "pyautogui.press(keys='enter')"),
|
||||
FunctionCall("pyautogui.moveTo", {"x": 0.04, "y": 0.405}, "pyautogui.moveTo(x=0.04, y=0.405)"),
|
||||
FunctionCall("pyautogui.write", {"message": "bread buns"}, "pyautogui.write(message='bread buns')"),
|
||||
FunctionCall("pyautogui.dragTo", {"from_coord": [0.87, 0.423], "to_coord": [0.8102, 0.9463]}, "pyautogui.dragTo(from_coord=[0.87, 0.423], to_coord=[0.8102, 0.9463])"),
|
||||
]
|
||||
resolution = (1080, 1920)
|
||||
print("Before conversion:")
|
||||
for action in actions:
|
||||
print(action)
|
||||
print("\nAfter conversion:")
|
||||
converted = action_conversion(actions, resolution)
|
||||
for action in converted:
|
||||
print(action)
|
||||
231
scripts/agents/config.py
Normal file
231
scripts/agents/config.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
|
||||
# aguvis json file with mobile action space
|
||||
MOBILE_FILE = [
|
||||
"android_control.json",
|
||||
"gui-odyssey-l1.json",
|
||||
"aitw-l3.json",
|
||||
"coat.jsonamex-l2.json",
|
||||
"amex-l1.json",
|
||||
"amex-l3.json",
|
||||
"gui-odyssey-l3.json",
|
||||
"aitw-l1.json",
|
||||
"aitw-l2.json",
|
||||
"gui-odyssey-l2.json",
|
||||
]
|
||||
|
||||
# Processing: guienv
|
||||
# Max conversations by image: 5 conversations
|
||||
# Duplicates: 0
|
||||
# Duplicate images in guienv.json. difference: 257578
|
||||
# len(images_path): 327972
|
||||
# len(images_set_path): 70394
|
||||
# user/assistant by image: 3.6590902633747193
|
||||
#
|
||||
# Processing: omniact
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in omniact.json
|
||||
# len(images_path): 6720
|
||||
# len(images_set_path): 6720
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Processing: ricoig16k
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in ricoig16k.json
|
||||
# len(images_path): 16133
|
||||
# len(images_set_path): 16133
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Processing: ricosca
|
||||
# Max conversations by image: 20 conversations
|
||||
# Duplicates: 0
|
||||
# Duplicate images in ricosca.json. difference: 155066
|
||||
# len(images_path): 173212
|
||||
# len(images_set_path): 18146
|
||||
# user/assistant by image: 8.54546456519343
|
||||
#
|
||||
# Processing: seeclick
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in seeclick.json
|
||||
# len(images_path): 271121
|
||||
# len(images_set_path): 271121
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Processing: webui350k
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in webui350k.json
|
||||
# len(images_path): 57389
|
||||
# len(images_set_path): 57389
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Processing: ui_refexp
|
||||
# Max conversations by image: 15 conversations
|
||||
# Duplicates: 32
|
||||
# Duplicate images in ui_refexp.json. difference: 10978
|
||||
# len(images_path): 15624
|
||||
# len(images_set_path): 4646
|
||||
# user/assistant by image: 2.3628928110202323
|
||||
#
|
||||
# Processing: widget_captioning
|
||||
# Max conversations by image: 161 conversations
|
||||
# Duplicates: 4877
|
||||
# Duplicate images in widget_captioning.json. difference: 87017
|
||||
# len(images_path): 101426
|
||||
# len(images_set_path): 14409
|
||||
# user/assistant by image: 6.039072801721146
|
||||
#
|
||||
# total_samples = 458958
|
||||
|
||||
config_dict_stage_1 = [
|
||||
{
|
||||
"json_path": "guienv.json",
|
||||
"images_folder": "guienvs/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "omniact.json",
|
||||
"images_folder": "omniact/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "ricoig16k.json",
|
||||
"images_folder": "ricoig16k/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "ricosca.json",
|
||||
"images_folder": "ricosca/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "seeclick.json",
|
||||
"images_folder": "seeclick/seeclick_web_imgs/",
|
||||
},
|
||||
{
|
||||
"json_path": "webui350k.json",
|
||||
"images_folder": "webui350k/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "ui_refexp.json",
|
||||
"images_folder": "ui_refexp/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "widget_captioning.json",
|
||||
"images_folder": "widget_captioning/images/",
|
||||
},
|
||||
|
||||
]
|
||||
|
||||
|
||||
# Processing: mind2web-l3
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in mind2web-l3.json
|
||||
# len(images_path): 7591
|
||||
# len(images_set_path): 7591
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Processing: guiact-web-single
|
||||
# Max conversations by image: 12 conversations
|
||||
# Duplicates: 0
|
||||
# Duplicate images in guiact-web-single.json. difference: 54134
|
||||
# len(images_path): 67396
|
||||
# len(images_set_path): 13262
|
||||
# user/assistant by image: 4.081888101342181
|
||||
#
|
||||
# Processing: guiact-web-multi-l3
|
||||
# Max conversations by image: 2 conversations
|
||||
# Duplicates: 0
|
||||
# Duplicate images in guiact-web-multi-l3.json. difference: 24
|
||||
# len(images_path): 16704
|
||||
# len(images_set_path): 16680
|
||||
# user/assistant by image: 0.0014388489208633094
|
||||
#
|
||||
# Processing: miniwob-l3
|
||||
# Max conversations by image: 6 conversations
|
||||
# Duplicates: 0
|
||||
# Duplicate images in miniwob-l3.json. difference: 161
|
||||
# len(images_path): 9826
|
||||
# len(images_set_path): 9665
|
||||
# user/assistant by image: 0.016658044490429385
|
||||
#
|
||||
# Processing: coat
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in coat.json
|
||||
# len(images_path): 11921
|
||||
# len(images_set_path): 11921
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Processing: android_control
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in android_control.json
|
||||
# len(images_path): 74714
|
||||
# len(images_set_path): 74714
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Processing: gui-odyssey-l3
|
||||
# Max conversations by image: 2 conversations
|
||||
# Duplicates: 0
|
||||
# Duplicate images in gui-odyssey-l3.json. difference: 24
|
||||
# len(images_path): 118282
|
||||
# len(images_set_path): 118258
|
||||
# user/assistant by image: 0.0002029461008980365
|
||||
#
|
||||
# Processing: amex-l3
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in amex-l3.json
|
||||
# len(images_path): 38469
|
||||
# len(images_set_path): 38469
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Processing: aitw-l3
|
||||
# Max conversations by image: 0 conversations
|
||||
# Duplicates: 0
|
||||
# No duplicate images in aitw-l3.json
|
||||
# len(images_path): 18992
|
||||
# len(images_set_path): 18992
|
||||
# user/assistant by image: 0.0
|
||||
#
|
||||
# Total samples: 309552
|
||||
|
||||
|
||||
config_dict_stage_2 = [
|
||||
{
|
||||
"json_path": "mind2web-l3.json",
|
||||
"images_folder": "mind2web/",
|
||||
},
|
||||
{
|
||||
"json_path": "guiact-web-single.json",
|
||||
"images_folder": "guiact-web-single/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "guiact-web-multi-l3.json",
|
||||
"images_folder": "guiact-web-multi-v2/images",
|
||||
},
|
||||
{
|
||||
"json_path": "miniwob-l3.json",
|
||||
"images_folder": "images",
|
||||
},
|
||||
{
|
||||
"json_path": "coat.json",
|
||||
"images_folder": "coat/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "android_control.json",
|
||||
"images_folder": "android_control/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "gui-odyssey-l3.json",
|
||||
"images_folder": "gui-odyssey/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "amex-l3.json",
|
||||
"images_folder": "amex/images/",
|
||||
},
|
||||
{
|
||||
"json_path": "aitw-l3.json",
|
||||
"images_folder": "aitw-v1/images/",
|
||||
},
|
||||
]
|
||||
547
scripts/agents/function_parser.py
Normal file
547
scripts/agents/function_parser.py
Normal file
|
|
@ -0,0 +1,547 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Function parser for extracting function names, parameter names, and values from string function calls.
|
||||
Supports both mobile and pyautogui function patterns.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Any, Union
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
|
||||
@dataclass
|
||||
class FunctionCall:
|
||||
"""Represents a parsed function call with its parameters."""
|
||||
function_name: str
|
||||
parameters: Dict[str, Any]
|
||||
original_string: str
|
||||
description: str = ""
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""
|
||||
Reconstruct the function call string from the parsed data.
|
||||
|
||||
Returns:
|
||||
String representation of the function call
|
||||
|
||||
Examples:
|
||||
>>> call = FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)")
|
||||
>>> call.to_string()
|
||||
"mobile.wait(seconds=3)"
|
||||
|
||||
>>> call = FunctionCall("function", {"arg_0": 1, "arg_1": 2, "x": 0.5}, "function(1, 2, x=0.5)")
|
||||
>>> call.to_string()
|
||||
"function(1, 2, x=0.5)"
|
||||
"""
|
||||
if not self.parameters:
|
||||
return f"{self.function_name}()"
|
||||
|
||||
# Separate positional and named arguments
|
||||
positional_args = []
|
||||
named_args = []
|
||||
|
||||
for name, value in self.parameters.items():
|
||||
if name.startswith("arg_"):
|
||||
# Positional argument
|
||||
positional_args.append((int(name.split("_")[1]), value))
|
||||
else:
|
||||
# kwargs
|
||||
named_args.append((name, value))
|
||||
|
||||
# Sort positional arguments by index
|
||||
positional_args.sort(key=lambda x: x[0])
|
||||
|
||||
# Build parameter string
|
||||
param_parts = []
|
||||
|
||||
# Add positional arguments
|
||||
for _, value in positional_args:
|
||||
param_parts.append(self._value_to_string(value))
|
||||
|
||||
# Add named arguments
|
||||
for name, value in named_args:
|
||||
param_parts.append(f"{name}={self._value_to_string(value)}")
|
||||
|
||||
return f"{self.function_name}({', '.join(param_parts)})"
|
||||
|
||||
def _value_to_string(self, value: Any) -> str:
|
||||
"""
|
||||
Convert a value to its string representation for function calls.
|
||||
|
||||
Args:
|
||||
value: The value to convert
|
||||
|
||||
Returns:
|
||||
String representation of the value
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
# Quote strings
|
||||
return f"'{value}'"
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Convert lists/tuples to string representation
|
||||
items = [self._value_to_string(item) for item in value]
|
||||
return f"[{', '.join(items)}]"
|
||||
elif isinstance(value, dict):
|
||||
# Convert dictionaries to string representation
|
||||
items = [f"'{k}': {self._value_to_string(v)}" for k, v in value.items()]
|
||||
return f"{{{', '.join(items)}}}"
|
||||
elif isinstance(value, bool):
|
||||
# Convert booleans to lowercase
|
||||
return str(value).lower()
|
||||
elif value is None:
|
||||
return "None"
|
||||
else:
|
||||
# Numbers and other types
|
||||
return str(value)
|
||||
|
||||
|
||||
def parse_function_call(function_string: str, pattern_to_match: list[str] = []) -> List[FunctionCall]:
|
||||
"""
|
||||
Parse a function call string and extract all function calls found.
|
||||
|
||||
Args:
|
||||
function_string: String representation of function calls
|
||||
|
||||
Returns:
|
||||
List of FunctionCall objects with parsed information
|
||||
|
||||
Examples:
|
||||
>>> parse_function_call("mobile.wait(seconds=3)")
|
||||
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
|
||||
|
||||
>>> parse_function_call("mobile. wait(seconds=3)")
|
||||
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
|
||||
|
||||
>>> parse_function_call("mobile.wait(seconds=3) mobile.home()")
|
||||
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...), FunctionCall(function_name='home', parameters={}, ...)]
|
||||
"""
|
||||
# Remove any leading/trailing whitespace
|
||||
function_string = function_string.strip()
|
||||
|
||||
# Pattern to match function calls with parameters
|
||||
# Matches: function_name(param1=value1, param2=value2, ...)
|
||||
# Can have any characters before the function call, extracts just the function name
|
||||
pattern = r'.*?([a-zA-Z_][a-zA-Z0-9_.]*)\(([^)]*)\)'
|
||||
|
||||
matches = re.findall(pattern, function_string)
|
||||
if not matches:
|
||||
# No valid function calls found in: {function_string}
|
||||
return []
|
||||
|
||||
results = []
|
||||
for match in matches:
|
||||
function_name = match[0]
|
||||
params_string = match[1]
|
||||
|
||||
if pattern_to_match and all(pattern not in function_name for pattern in pattern_to_match):
|
||||
continue
|
||||
|
||||
# Parse parameters
|
||||
parameters = parse_parameters(params_string)
|
||||
|
||||
# Create the original string for this specific function call
|
||||
original_string = f"{function_name}({params_string})"
|
||||
|
||||
results.append(FunctionCall(
|
||||
function_name=function_name,
|
||||
parameters=parameters,
|
||||
original_string=original_string
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def parse_parameters(params_string: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse parameter string and extract parameter names and values.
|
||||
|
||||
Args:
|
||||
params_string: String containing parameters (e.g., "x=0.5, y=0.6, text='hello'")
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter names to their values
|
||||
|
||||
Examples:
|
||||
>>> parse_parameters("x=0.5, y=0.6")
|
||||
{'x': 0.5, 'y': 0.6}
|
||||
|
||||
>>> parse_parameters("app_name='drupe'")
|
||||
{'app_name': 'drupe'}
|
||||
|
||||
>>> parse_parameters("'text'")
|
||||
{'arg_0': 'text'}
|
||||
|
||||
>>> parse_parameters("1, 3, 4")
|
||||
{'arg_0': 1, 'arg_1': 3, 'arg_2': 4}
|
||||
|
||||
>>> parse_parameters("arg1, arg2, x=0.5")
|
||||
{'arg_0': 'arg1', 'arg_1': 'arg2', 'x': 0.5}
|
||||
"""
|
||||
if not params_string.strip():
|
||||
return {}
|
||||
|
||||
parameters = OrderedDict()
|
||||
|
||||
# Split by commas, but be careful with commas inside quotes or brackets
|
||||
param_parts = split_parameters(params_string)
|
||||
|
||||
positional_index = 0
|
||||
|
||||
for part in param_parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# Parse individual parameter
|
||||
name, value = parse_single_parameter(part)
|
||||
|
||||
# For positional arguments, use index-based naming
|
||||
if name.startswith("arg_"):
|
||||
name = f"arg_{positional_index}"
|
||||
positional_index += 1
|
||||
|
||||
parameters[name] = value
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def split_parameters(params_string: str) -> List[str]:
|
||||
"""
|
||||
Split parameter string by commas, respecting quotes and brackets.
|
||||
|
||||
Args:
|
||||
params_string: String containing parameters
|
||||
|
||||
Returns:
|
||||
List of individual parameter strings
|
||||
"""
|
||||
parts = []
|
||||
current_part = ""
|
||||
paren_count = 0
|
||||
bracket_count = 0
|
||||
brace_count = 0
|
||||
in_quotes = False
|
||||
quote_char = None
|
||||
|
||||
for char in params_string:
|
||||
if char in ['"', "'"] and (not in_quotes or char == quote_char):
|
||||
if not in_quotes:
|
||||
in_quotes = True
|
||||
quote_char = char
|
||||
else:
|
||||
in_quotes = False
|
||||
quote_char = None
|
||||
elif not in_quotes:
|
||||
if char == '(':
|
||||
paren_count += 1
|
||||
elif char == ')':
|
||||
paren_count -= 1
|
||||
elif char == '[':
|
||||
bracket_count += 1
|
||||
elif char == ']':
|
||||
bracket_count -= 1
|
||||
elif char == '{':
|
||||
brace_count += 1
|
||||
elif char == '}':
|
||||
brace_count -= 1
|
||||
elif char == ',' and paren_count == 0 and bracket_count == 0 and brace_count == 0:
|
||||
parts.append(current_part.strip())
|
||||
current_part = ""
|
||||
continue
|
||||
|
||||
current_part += char
|
||||
|
||||
if current_part.strip():
|
||||
parts.append(current_part.strip())
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def parse_single_parameter(param_string: str) -> Tuple[str, Any]:
|
||||
"""
|
||||
Parse a single parameter string into name and value.
|
||||
|
||||
Args:
|
||||
param_string: String like "x=0.5" or "app_name='drupe'" or just "value"
|
||||
|
||||
Returns:
|
||||
Tuple of (parameter_name, parameter_value)
|
||||
|
||||
Examples:
|
||||
>>> parse_single_parameter("x=0.5")
|
||||
('x', 0.5)
|
||||
|
||||
>>> parse_single_parameter("app_name='drupe'")
|
||||
('app_name', 'drupe')
|
||||
|
||||
>>> parse_single_parameter("'text'")
|
||||
('arg_0', 'text')
|
||||
|
||||
>>> parse_single_parameter("123")
|
||||
('arg_0', 123)
|
||||
|
||||
>>> parse_single_parameter("3")
|
||||
('arg_0', 3)
|
||||
"""
|
||||
# Pattern to match parameter name and value
|
||||
pattern = r'^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$'
|
||||
|
||||
match = re.match(pattern, param_string)
|
||||
if match:
|
||||
# Named parameter
|
||||
param_name = match.group(1)
|
||||
param_value_str = match.group(2).strip()
|
||||
param_value = parse_value(param_value_str)
|
||||
return param_name, param_value
|
||||
else:
|
||||
# Positional parameter - treat as unnamed argument
|
||||
param_value = parse_value(param_string)
|
||||
return "arg_0", param_value
|
||||
|
||||
|
||||
def parse_value(value_string: str) -> Any:
|
||||
"""
|
||||
Parse a value string into appropriate Python type.
|
||||
|
||||
Args:
|
||||
value_string: String representation of a value
|
||||
|
||||
Returns:
|
||||
Parsed value (int, float, str, list, etc.)
|
||||
|
||||
Examples:
|
||||
>>> parse_value("3")
|
||||
3
|
||||
|
||||
>>> parse_value("3.14")
|
||||
3.14
|
||||
|
||||
>>> parse_value("'hello'")
|
||||
'hello'
|
||||
|
||||
>>> parse_value("[0.581, 0.898]")
|
||||
[0.581, 0.898]
|
||||
"""
|
||||
value_string = value_string.strip()
|
||||
|
||||
# String values (quoted)
|
||||
if (value_string.startswith("'") and value_string.endswith("'")) or \
|
||||
(value_string.startswith('"') and value_string.endswith('"')):
|
||||
return value_string[1:-1]
|
||||
|
||||
# List values
|
||||
if value_string.startswith('[') and value_string.endswith(']'):
|
||||
return parse_list(value_string)
|
||||
|
||||
# Dictionary values
|
||||
if value_string.startswith('{') and value_string.endswith('}'):
|
||||
return parse_dict(value_string)
|
||||
|
||||
# Boolean values
|
||||
if value_string.lower() in ['true', 'false']:
|
||||
return value_string.lower() == 'true'
|
||||
|
||||
# None value
|
||||
if value_string.lower() == 'none':
|
||||
return None
|
||||
|
||||
# Numeric values
|
||||
try:
|
||||
# Try integer first
|
||||
if '.' not in value_string:
|
||||
return int(value_string)
|
||||
else:
|
||||
return float(value_string)
|
||||
except ValueError:
|
||||
# If it's not a number, return as string (remove quotes if present)
|
||||
if value_string.startswith("'") and value_string.endswith("'"):
|
||||
return value_string[1:-1]
|
||||
elif value_string.startswith('"') and value_string.endswith('"'):
|
||||
return value_string[1:-1]
|
||||
else:
|
||||
return value_string
|
||||
|
||||
|
||||
def parse_list(list_string: str) -> List[Any]:
|
||||
"""
|
||||
Parse a list string into a Python list.
|
||||
|
||||
Args:
|
||||
list_string: String like "[0.581, 0.898]"
|
||||
|
||||
Returns:
|
||||
List of parsed values
|
||||
|
||||
Examples:
|
||||
>>> parse_list("[0.581, 0.898]")
|
||||
[0.581, 0.898]
|
||||
"""
|
||||
# Remove outer brackets
|
||||
content = list_string[1:-1].strip()
|
||||
if not content:
|
||||
return []
|
||||
|
||||
# Split by commas, respecting nested structures
|
||||
parts = split_parameters(content)
|
||||
|
||||
return [parse_value(part.strip()) for part in parts]
|
||||
|
||||
|
||||
def parse_dict(dict_string: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse a dictionary string into a Python dict.
|
||||
|
||||
Args:
|
||||
dict_string: String like "{'key': 'value'}"
|
||||
|
||||
Returns:
|
||||
Dictionary of parsed key-value pairs
|
||||
"""
|
||||
# Remove outer braces
|
||||
content = dict_string[1:-1].strip()
|
||||
if not content:
|
||||
return {}
|
||||
|
||||
# Split by commas, respecting nested structures
|
||||
parts = split_parameters(content)
|
||||
|
||||
result = {}
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if ':' in part:
|
||||
key_str, value_str = part.split(':', 1)
|
||||
key = parse_value(key_str.strip())
|
||||
value = parse_value(value_str.strip())
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_multiple_functions(function_strings: List[str]) -> List[FunctionCall]:
|
||||
"""
|
||||
Parse multiple function call strings.
|
||||
|
||||
Args:
|
||||
function_strings: List of function call strings
|
||||
|
||||
Returns:
|
||||
List of FunctionCall objects
|
||||
"""
|
||||
results = []
|
||||
for func_str in function_strings:
|
||||
try:
|
||||
result_list = parse_function_call(func_str)
|
||||
results.extend(result_list)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not parse function call '{func_str}': {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def extract_function_calls_from_text(text: str) -> List[FunctionCall]:
|
||||
"""
|
||||
Extract and parse function calls from a text block.
|
||||
|
||||
Args:
|
||||
text: Text containing function calls
|
||||
|
||||
Returns:
|
||||
List of FunctionCall objects
|
||||
"""
|
||||
# Pattern to find function calls in text
|
||||
# Matches: function_name(param1=value1, param2=value2)
|
||||
pattern = r'[a-zA-Z_][a-zA-Z0-9_.]*\([^)]*\)'
|
||||
|
||||
matches = re.findall(pattern, text)
|
||||
return parse_multiple_functions(matches)
|
||||
|
||||
|
||||
# Example usage and testing
|
||||
if __name__ == "__main__":
|
||||
test_cases = [
|
||||
"mobile.home()",
|
||||
"mobile.open_app(app_name='drupe')",
|
||||
"mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
|
||||
"mobile.back()",
|
||||
"mobile.long_press(x=0.799, y=0.911)",
|
||||
"mobile.terminate(status='success')",
|
||||
"answer('text')",
|
||||
"pyautogui.hscroll(page=-0.1)",
|
||||
"pyautogui.scroll(page=-0.1)",
|
||||
"pyautogui.scroll(0.13)",
|
||||
"pyautogui.click(x=0.8102, y=0.9463)",
|
||||
"pyautogui.hotkey(keys=['ctrl', 'c'])",
|
||||
"pyautogui.doubleClick()",
|
||||
"pyautogui.press(keys='enter')",
|
||||
"pyautogui.press(keys=['enter'])",
|
||||
"pyautogui.moveTo(x=0.04, y=0.405)",
|
||||
"pyautogui.write(message='bread buns')",
|
||||
"pyautogui.dragTo(x=0.8102, y=0.9463)",
|
||||
"mobile.wait(seconds=3)\nmobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
|
||||
# Additional test cases for multiple positional arguments
|
||||
"function(arg1, arg2, arg3)",
|
||||
"function('hello', 123, x=0.5)",
|
||||
"function(arg1, arg2, named_param='value')",
|
||||
"function(1, 2, 3, 4, 5)",
|
||||
"function('a', 'b', 'c', x=1, y=2)",
|
||||
]
|
||||
|
||||
print("Testing function parser:")
|
||||
print("=" * 50)
|
||||
|
||||
for test_case in test_cases:
|
||||
try:
|
||||
results = parse_function_call(test_case)
|
||||
print(f"✓ {test_case}")
|
||||
for result in results:
|
||||
print(f" Function: {result.function_name}")
|
||||
print(f" Parameters: {result.parameters}")
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"✗ {test_case}")
|
||||
print(f" Error: {e}")
|
||||
print()
|
||||
|
||||
# Test extracting from text
|
||||
print("Testing text extraction:")
|
||||
print("=" * 50)
|
||||
|
||||
sample_text = """
|
||||
mobile.wait(seconds=3)
|
||||
mobile.open_app(app_name='drupe')
|
||||
pyautogui.click(x=0.8102, y=0.9463)
|
||||
pyautogui.write(message='bread buns')
|
||||
"""
|
||||
|
||||
extracted = extract_function_calls_from_text(sample_text)
|
||||
for func_call in extracted:
|
||||
print(f"Found: {func_call.function_name} with params: {func_call.parameters}")
|
||||
|
||||
# Test reconstruction
|
||||
print("\nTesting function call reconstruction:")
|
||||
print("=" * 50)
|
||||
|
||||
reconstruction_tests = [
|
||||
"mobile.wait(seconds=3)",
|
||||
"mobile.home()",
|
||||
"mobile.open_app(app_name='drupe')",
|
||||
"mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
|
||||
"answer('text')",
|
||||
"pyautogui.scroll(0.13)",
|
||||
"pyautogui.click(x=0.8102, y=0.9463)",
|
||||
"pyautogui.hotkey(keys=['ctrl', 'c'])",
|
||||
"function(1, 2, 3)",
|
||||
"function('hello', 123, x=0.5, y=0.8)",
|
||||
"function([1, 3], 'arg2', named_param='value')",
|
||||
]
|
||||
|
||||
for test_case in reconstruction_tests:
|
||||
parsed_list = parse_function_call(test_case)
|
||||
for parsed in parsed_list:
|
||||
reconstructed = parsed.to_string()
|
||||
print(f"Original: {test_case}")
|
||||
print(f"Reconstructed: {reconstructed}")
|
||||
print(f"Match: {test_case == reconstructed}")
|
||||
assert test_case == reconstructed
|
||||
print()
|
||||
519
scripts/agents/get_aguvis_data.py
Normal file
519
scripts/agents/get_aguvis_data.py
Normal file
|
|
@ -0,0 +1,519 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to download, process, and upload the aguvis-stage2 dataset.
|
||||
Downloads from huggingface.co/datasets/xlangai/aguvis-stage2 and uploads to smolagents/aguvis-stage-2
|
||||
"""
|
||||
|
||||
import re
|
||||
import gc
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Generator, Callable, Literal
|
||||
from tqdm import tqdm
|
||||
from datasets import Dataset, load_dataset, concatenate_datasets
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import HfApi, login, snapshot_download
|
||||
from collections import defaultdict
|
||||
from PIL import Image
|
||||
import tarfile
|
||||
from itertools import islice
|
||||
import multiprocessing as mp
|
||||
from multiprocessing import Pool, Manager
|
||||
from prompts import OS_SYSTEM_PROMPT, MOBILE_SYSTEM_PROMPT
|
||||
from models import ConversationDataList, ConversationData, ChatMessage, DataRow
|
||||
from function_parser import parse_function_call
|
||||
from action_conversion import action_conversion
|
||||
from pydantic import BaseModel
|
||||
from config import config_dict_stage_1, config_dict_stage_2, MOBILE_FILE
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
|
||||
def authenticate_huggingface():
|
||||
"""Authenticate with HuggingFace Hub using token."""
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
if hf_token:
|
||||
print("Authenticating with HuggingFace Hub using token...")
|
||||
login(token=hf_token)
|
||||
else:
|
||||
raise ValueError("HF_TOKEN environment variable not set.")
|
||||
|
||||
|
||||
def discover_dataset_config(dataset_path: str, config_dict: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Discover dataset configuration by scanning the data directory."""
|
||||
dataset_dir = Path(dataset_path)
|
||||
train_dir = dataset_dir
|
||||
|
||||
if not train_dir.exists():
|
||||
raise FileNotFoundError(f"Train directory not found: {train_dir}")
|
||||
|
||||
configs = []
|
||||
processed_splits = set()
|
||||
|
||||
# Find all JSON files in the train directory
|
||||
for config in config_dict:
|
||||
subset_name = (
|
||||
config["json_path"]
|
||||
.replace(".json", "")
|
||||
.replace("-l1", "")
|
||||
.replace("-l2", "")
|
||||
.replace("-l3", "")
|
||||
)
|
||||
|
||||
# Skip if we already processed this split
|
||||
if subset_name in processed_splits:
|
||||
continue
|
||||
|
||||
config["subset_name"] = subset_name
|
||||
configs.append(config)
|
||||
processed_splits.add(subset_name)
|
||||
print(
|
||||
f"Discovered config: {config['subset_name']} -> {config['images_folder']}"
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def download_dataset(
|
||||
repo_id: str = "xlangai/aguvis-stage2", local_dir: str = "./aguvis_raw"
|
||||
) -> str:
|
||||
"""Download the dataset using snapshot_download."""
|
||||
print(f"Downloading dataset from {repo_id}...")
|
||||
local_path = snapshot_download(
|
||||
repo_id=repo_id, local_dir=local_dir, repo_type="dataset"
|
||||
)
|
||||
print(f"Dataset downloaded to: {local_path}")
|
||||
return local_path
|
||||
|
||||
|
||||
def extract_zip_files(dataset_path: str):
|
||||
"""Extract all zip files found in the dataset directory, but only if not already extracted."""
|
||||
print("Extracting zip files...")
|
||||
dataset_dir = Path(dataset_path)
|
||||
|
||||
for zip_file in dataset_dir.rglob("*.zip"):
|
||||
extract_dir = zip_file.parent / zip_file.stem
|
||||
if extract_dir.exists() and any(extract_dir.iterdir()):
|
||||
print(
|
||||
f"Skipping extraction for {zip_file} (already extracted at {extract_dir})"
|
||||
)
|
||||
continue
|
||||
|
||||
print(f"Extracting: {zip_file}")
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_dir)
|
||||
print(f"Extracted to: {extract_dir}")
|
||||
|
||||
|
||||
def extract_tar_parts_grouped(dataset_path: str):
|
||||
"""
|
||||
Finds all .tar.gz.part_* groups, merges them, and extracts them into directories
|
||||
named after their common prefix.
|
||||
"""
|
||||
dataset_dir = Path(dataset_path)
|
||||
part_files = list(dataset_dir.glob("*.tar.gz.part_*"))
|
||||
|
||||
if not part_files:
|
||||
print("No split .tar.gz.part_* files found.")
|
||||
return
|
||||
|
||||
# Group part files by prefix
|
||||
groups = defaultdict(list)
|
||||
for part in part_files:
|
||||
prefix = part.name.split(".tar.gz.part_")[0]
|
||||
groups[prefix].append(part)
|
||||
|
||||
for prefix, parts in groups.items():
|
||||
parts = sorted(parts) # Ensure correct order
|
||||
merged_tar_path = dataset_dir / f"{prefix}.tar.gz"
|
||||
extract_dir = dataset_dir / prefix
|
||||
|
||||
if extract_dir.exists() and any(extract_dir.iterdir()):
|
||||
print(
|
||||
f"Skipping extraction for '{prefix}' (already extracted at {extract_dir})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Merge parts
|
||||
CHUNK_SIZE = 1024 * 1024
|
||||
print(f"Merging parts for '{prefix}'...")
|
||||
with open(merged_tar_path, "wb") as outfile:
|
||||
for part in parts:
|
||||
print(f" Adding: {part.name}")
|
||||
with open(part, "rb") as infile:
|
||||
while chunk := infile.read(CHUNK_SIZE):
|
||||
outfile.write(chunk)
|
||||
|
||||
print(f"Merged to: {merged_tar_path}")
|
||||
|
||||
# Extract
|
||||
print(f"Extracting to: {extract_dir}")
|
||||
with tarfile.open(merged_tar_path, "r:gz") as tar:
|
||||
tar.extractall(path=extract_dir)
|
||||
print(f"Done extracting '{prefix}'\n")
|
||||
|
||||
|
||||
def check_subset_exists(repo_id: str, subset_name: str) -> bool:
|
||||
"""Check if a subset already exists in the remote dataset."""
|
||||
try:
|
||||
# Try to get dataset info with specific subset
|
||||
from datasets import get_dataset_config_names
|
||||
|
||||
config_names = get_dataset_config_names(repo_id)
|
||||
return subset_name in config_names
|
||||
except Exception as e:
|
||||
print(f"Could not check if subset exists: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_image_from_folder(images_folder: Path, img_path: str) -> Image.Image:
|
||||
"""Load images from the specified folder."""
|
||||
full_path = images_folder / img_path
|
||||
img = Image.open(full_path)
|
||||
new_img = img.copy()
|
||||
img.close()
|
||||
return new_img
|
||||
|
||||
|
||||
def convert_to_code_agent_format(messages: list[ChatMessage], json_path: str, reasoning: bool):
|
||||
for i, message in enumerate(messages):
|
||||
content = message.content
|
||||
|
||||
if message.role == "system":
|
||||
if json_path in MOBILE_FILE:
|
||||
content = MOBILE_SYSTEM_PROMPT
|
||||
else:
|
||||
content = OS_SYSTEM_PROMPT
|
||||
|
||||
if message.role == "user":
|
||||
content = content.replace("<image>\n", "").replace("<image>", "")
|
||||
|
||||
elif message.role == "assistant":
|
||||
content = (
|
||||
content.replace("Action: ", "")
|
||||
.replace("Observation: ", "")
|
||||
.replace("Thought: ", "")
|
||||
)
|
||||
if reasoning and i == len(messages) - 1:
|
||||
content = (
|
||||
"<code>\n" + content.strip() + "\n</code>"
|
||||
)
|
||||
elif reasoning:
|
||||
# TODO: Check if there is always only 2 assistants
|
||||
content = (
|
||||
"<think>\n"
|
||||
+ content.strip()
|
||||
+ "\n</think>\n"
|
||||
)
|
||||
else:
|
||||
content = content.strip()
|
||||
|
||||
messages[i].content = content
|
||||
|
||||
# Fuse subsequent messages have the same role, merge it
|
||||
if i > 0 and messages[i].role == messages[i - 1].role:
|
||||
# Need to fuse both messages
|
||||
if reasoning:
|
||||
messages[i - 1].content += messages[i].content
|
||||
else:
|
||||
messages[i - 1].content += "\n" + messages[i].content
|
||||
messages.pop(i)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def convert_to_chat_format(
|
||||
data: ConversationData, json_path: str, reasoning: bool
|
||||
) -> list[ChatMessage]:
|
||||
"""Convert data item to chat template format."""
|
||||
# This is a placeholder - you'll need to adapt this based on the actual data structure
|
||||
# The exact conversion depends on how the original data is structured
|
||||
chat_messages = data.to_chat_messages()
|
||||
# mobile = json_path in open("mobile_files.txt", "r").read()
|
||||
# os = json_path in open("os_files.txt", "r").read()
|
||||
# if not mobile and not os:
|
||||
# for message in chat_messages:
|
||||
# if mobile and os:
|
||||
# break
|
||||
# if message.role == "assistant":
|
||||
# if not mobile and "mobile" in message.content:
|
||||
# with open("mobile_files.txt", "a") as mobile_files:
|
||||
# mobile_files.write(json_path + "\n")
|
||||
# mobile = True
|
||||
# if not os and "pyautogui" in message.content:
|
||||
# with open("os_files.txt", "a") as os_files:
|
||||
# os_files.write(json_path + "\n")
|
||||
# os = True
|
||||
# Aguvis stage 1
|
||||
chat_messages = convert_to_code_agent_format(chat_messages, json_path, reasoning)
|
||||
return chat_messages
|
||||
|
||||
|
||||
def convert_to_new_action_space(
|
||||
messages: list[ChatMessage], resolution: tuple[int, int], code_format: bool = True
|
||||
) -> list[ChatMessage]:
|
||||
regex_match: re.Match | str | None = None
|
||||
index = -1
|
||||
regex = r"<code>\n(.*?)\n</code>"
|
||||
assistant_msg = [(i, message) for i, message in enumerate(messages) if message.role == "assistant"]
|
||||
if assistant_msg:
|
||||
for index, msg in assistant_msg:
|
||||
|
||||
if code_format:
|
||||
regex_match = re.search(regex, msg.content, re.DOTALL)
|
||||
else:
|
||||
regex_match = msg.content
|
||||
|
||||
if regex_match is not None:
|
||||
function_calls = parse_function_call(
|
||||
regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match,
|
||||
pattern_to_match=["pyautogui", "mobile", "terminate", "answer"],
|
||||
)
|
||||
|
||||
|
||||
if len(function_calls) > 0:
|
||||
|
||||
for i, function_call in enumerate(deepcopy(function_calls)):
|
||||
|
||||
if function_call.function_name == "pyautogui.dragTo" and not isinstance(list(function_calls[i].parameters.values())[0], (list, tuple)):
|
||||
x1, y1 = islice(function_calls[i-1].parameters.values(), 2)
|
||||
x2, y2 = islice(function_calls[i].parameters.values(), 2)
|
||||
function_calls[i].parameters = {"from_coord": (x1, y1), "to_coord": (x2, y2)}
|
||||
function_calls[i].original_string = function_calls[i].to_string()
|
||||
function_calls.pop(i-1)
|
||||
|
||||
function_calls = action_conversion(function_calls, resolution=resolution)
|
||||
|
||||
new_action_string = "\n".join(
|
||||
[function_call.to_string() for function_call in function_calls]
|
||||
)
|
||||
messages[index].content = messages[index].content.replace(
|
||||
regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match, new_action_string
|
||||
)
|
||||
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def process_subset(
|
||||
config: Dict[str, Any],
|
||||
dataset_path: str,
|
||||
) -> tuple[ConversationDataList, Path]:
|
||||
"""Process a single dataset subset."""
|
||||
subset_name = config["subset_name"]
|
||||
|
||||
print(f"Processing split: {subset_name}")
|
||||
|
||||
dataset_dir = Path(dataset_path)
|
||||
images_folder = dataset_dir / config["subset_name"] / config["images_folder"]
|
||||
|
||||
if not images_folder.exists():
|
||||
print(f"Images folder not found: {images_folder}")
|
||||
else:
|
||||
print(f"Images folder: {images_folder}")
|
||||
|
||||
json_config_path = dataset_dir / config["json_path"]
|
||||
with open(json_config_path, "r") as f:
|
||||
data = ConversationDataList.model_validate_json(f.read())
|
||||
# data = f.read()
|
||||
print(f"Added '{json_config_path}'")
|
||||
|
||||
return data, images_folder
|
||||
|
||||
|
||||
def row_generator(
|
||||
data: ConversationDataList, images_folder: Path, json_path: str, reasoning: bool
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
conversations: list[ConversationData] = data.root
|
||||
for item in tqdm(conversations):
|
||||
# Extract image paths from the data item
|
||||
try:
|
||||
# Load images
|
||||
image = load_image_from_folder(images_folder, item.image)
|
||||
chat_message = convert_to_chat_format(item, json_path, reasoning)
|
||||
chat_message = convert_to_new_action_space(chat_message, image.size, code_format=reasoning)
|
||||
if len(chat_message) == 0:
|
||||
continue
|
||||
|
||||
row = DataRow.from_chat_messages(chat_message, image, source=json_path.split("/")[-1].split(".")[0])
|
||||
yield row.model_dump(exclude_none=True)
|
||||
del image
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(f"Error processing item: {e}", item)
|
||||
continue
|
||||
|
||||
|
||||
class DatasetConfig(BaseModel):
|
||||
huggingface_repo_id: str
|
||||
local_path: str
|
||||
config_dict: List[Dict[str, Any]]
|
||||
smolagents_repo_id: str
|
||||
reasoning: bool
|
||||
|
||||
|
||||
def process_single_config(config: Dict[str, Any], dataset_path: str, smolagents_repo_id: str, reasoning: bool) -> bool:
|
||||
"""Process a single config in a separate process."""
|
||||
try:
|
||||
# Authenticate in this process
|
||||
authenticate_huggingface()
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f"Processing config: {config}")
|
||||
|
||||
# Check if the subset already exists in the remote dataset
|
||||
subset_name = config["subset_name"]
|
||||
# if check_subset_exists(smolagents_repo_id, subset_name):
|
||||
# print(
|
||||
# f"Subset '{subset_name}' already exists in {smolagents_repo_id}, skipping processing."
|
||||
# )
|
||||
# return True
|
||||
|
||||
json_path = config["json_path"]
|
||||
data, image_folder = process_subset(config, dataset_path)
|
||||
|
||||
# Collect all rows first
|
||||
rows = []
|
||||
datasets = []
|
||||
for row in row_generator(data, image_folder, json_path, reasoning):
|
||||
rows.append(row)
|
||||
if len(rows) > 20000:
|
||||
print("Creating batch dataset")
|
||||
dataset = Dataset.from_list(rows)
|
||||
datasets.append(dataset)
|
||||
rows = []
|
||||
gc.collect()
|
||||
|
||||
if len(rows) > 0:
|
||||
# Create dataset from collected data
|
||||
dataset = Dataset.from_list(rows)
|
||||
datasets.append(dataset)
|
||||
rows = []
|
||||
|
||||
dataset_to_push = concatenate_datasets(datasets)
|
||||
|
||||
# Push to hub
|
||||
dataset_to_push.push_to_hub(
|
||||
smolagents_repo_id,
|
||||
# config_name=subset_name, # This sets the subset name
|
||||
split="train", # This should be "train" not the subset name
|
||||
)
|
||||
|
||||
print(f"Processed and uploaded subset: {config['subset_name']}")
|
||||
|
||||
# Force garbage collection to manage memory
|
||||
gc.collect()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing config {config.get('subset_name', 'unknown')}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def make_dataset_from_original_data(dataset_config: DatasetConfig, max_processes: int | None = None):
|
||||
"""Main function to orchestrate the entire process."""
|
||||
load_dotenv(override=True)
|
||||
|
||||
print(f"Starting {dataset_config.smolagents_repo_id} dataset processing...")
|
||||
|
||||
# Step 0: Authenticate with HuggingFace Hub
|
||||
authenticate_huggingface()
|
||||
|
||||
dataset_path = download_dataset(
|
||||
dataset_config.huggingface_repo_id, dataset_config.local_path
|
||||
)
|
||||
|
||||
# extract_zip_files(dataset_path)
|
||||
# extract_tar_parts_grouped(dataset_path)
|
||||
|
||||
dataset_configs = discover_dataset_config(dataset_path, dataset_config.config_dict)
|
||||
converted_repo_id = dataset_config.smolagents_repo_id
|
||||
reasoning = dataset_config.reasoning
|
||||
|
||||
# Use multiprocessing to process configs in parallel
|
||||
available_cpus = mp.cpu_count()
|
||||
if max_processes is None:
|
||||
max_processes = available_cpus
|
||||
num_processes = min(max_processes, len(dataset_configs))
|
||||
print(f"Using {num_processes} processes (out of {available_cpus} available CPUs) to process {len(dataset_configs)} configs")
|
||||
|
||||
# Prepare arguments for multiprocessing
|
||||
process_args = [
|
||||
(config, dataset_path, converted_repo_id, reasoning)
|
||||
for config in dataset_configs if config["subset_name"] if config["subset_name"] in ["guiact-web-single"]
|
||||
]
|
||||
|
||||
# Process configs in parallel with progress tracking
|
||||
print(f"Starting parallel processing of {len(dataset_configs)} configs...")
|
||||
try:
|
||||
with Pool(processes=num_processes) as pool:
|
||||
results = []
|
||||
for i, result in enumerate(pool.starmap(process_single_config, process_args)):
|
||||
results.append(result)
|
||||
print(f"Completed {i+1}/{len(dataset_configs)} configs")
|
||||
except Exception as e:
|
||||
print(f"Multiprocessing failed: {e}")
|
||||
print("Falling back to sequential processing...")
|
||||
results = []
|
||||
for i, args in enumerate(process_args):
|
||||
result = process_single_config(*args)
|
||||
results.append(result)
|
||||
print(f"Completed {i+1}/{len(dataset_configs)} configs (sequential)")
|
||||
|
||||
# Check results
|
||||
successful = sum(results)
|
||||
total = len(dataset_configs)
|
||||
print(f"\nProcessing complete: {successful}/{total} configs processed successfully")
|
||||
|
||||
if successful < total:
|
||||
failed_count = total - successful
|
||||
print(f"Warning: {failed_count} configs failed to process. Check the logs above for details.")
|
||||
else:
|
||||
print("All configs processed successfully!")
|
||||
|
||||
# # Cleanup
|
||||
# print("\nCleaning up temporary files...")
|
||||
# # shutil.rmtree(dataset_path, ignore_errors=True)
|
||||
#
|
||||
# # api.upload_large_folder(folder_path=converted_folder, repo_id="smolagents/aguvis-stage-2", repo_type="dataset")
|
||||
#
|
||||
# shutil.rmtree(converted_folder, ignore_errors=True)
|
||||
#
|
||||
# print("All done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# dataset_config_1 = DatasetConfig(
|
||||
# huggingface_repo_id="xlangai/aguvis-stage1",
|
||||
# local_path="/fsx/amir_mahla/aguvis_raw_stage_1",
|
||||
# config_dict=config_dict_stage_1,
|
||||
# smolagents_repo_id="smolagents/aguvis-stage-1",
|
||||
# reasoning=False,
|
||||
# )
|
||||
# dataset_config_2 = DatasetConfig(
|
||||
# huggingface_repo_id="xlangai/aguvis-stage2",
|
||||
# local_path="/fsx/amir_mahla/aguvis_raw_stage_2",
|
||||
# config_dict=config_dict_stage_2,
|
||||
# smolagents_repo_id="smolagents/aguvis-stage-2",
|
||||
# reasoning=True,
|
||||
# )
|
||||
dataset_config_3 = DatasetConfig(
|
||||
huggingface_repo_id="xlangai/aguvis-stage2",
|
||||
local_path="/fsx/amir_mahla/aguvis_raw_stage_2",
|
||||
config_dict=config_dict_stage_2,
|
||||
smolagents_repo_id="smolagents/guiact-web-single",
|
||||
reasoning=True,
|
||||
)
|
||||
# You can specify max_processes to limit the number of parallel processes
|
||||
# make_dataset_from_original_data(dataset_config_1, max_processes=4)
|
||||
make_dataset_from_original_data(dataset_config_3, 1)
|
||||
154
scripts/agents/models.py
Normal file
154
scripts/agents/models.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
from typing import List, Optional, Literal
|
||||
from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
|
||||
from copy import deepcopy
|
||||
from PIL import Image
|
||||
from collections import OrderedDict
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: str
|
||||
|
||||
@staticmethod
|
||||
def from_conversation_list(data: list[dict[str, str]]) -> list["ChatMessage"]:
|
||||
messages = []
|
||||
system_added = False
|
||||
for item in data:
|
||||
if item["from"] == "system":
|
||||
if not system_added:
|
||||
role: Literal["user", "assistant", "system"] = "system"
|
||||
messages.append(ChatMessage(role=role, content=item["value"]))
|
||||
system_added = True
|
||||
elif item["from"] == "human":
|
||||
role = "user"
|
||||
messages.append(ChatMessage(role=role, content=item["value"]))
|
||||
else:
|
||||
role = "assistant"
|
||||
messages.append(ChatMessage(role=role, content=item["value"]))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class ConversationEntry(BaseModel):
|
||||
from_: Literal["system", "human", "gpt"] = Field(alias="from")
|
||||
value: str
|
||||
recipient: Optional[str] = None
|
||||
end_turn: Optional[bool] = None
|
||||
|
||||
def to_chat_message(self) -> ChatMessage:
|
||||
if self.from_ == "system":
|
||||
role: Literal["user", "assistant", "system"] = "system"
|
||||
elif self.from_ == "human":
|
||||
role = "user"
|
||||
else:
|
||||
role = "assistant"
|
||||
return ChatMessage(role=role, content=self.value)
|
||||
|
||||
class ConversationData(BaseModel):
|
||||
image: str
|
||||
conversations: List[ConversationEntry]
|
||||
recipient: Optional[str] = None
|
||||
end_turn: Optional[bool] = None
|
||||
|
||||
@field_validator("image", mode="before")
|
||||
def validate_image(cls, v):
|
||||
if isinstance(v, list):
|
||||
if len(v) == 1:
|
||||
return v[0]
|
||||
elif len(v) == 2:
|
||||
return v[1]
|
||||
else:
|
||||
raise ValueError("Expected 1 or 2 images, got multiple")
|
||||
return v
|
||||
|
||||
|
||||
def to_chat_messages(self) -> list[ChatMessage]:
|
||||
return [conversation.to_chat_message() for conversation in self.conversations]
|
||||
|
||||
class ConversationDataList(RootModel[List[ConversationData]]):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_conversation(self):
|
||||
new_conversations: dict[str, List[ConversationData]] = {}
|
||||
|
||||
# merge image duplicates
|
||||
for conversation in self.root:
|
||||
if conversation.image not in new_conversations:
|
||||
new_conversations[conversation.image] = [conversation]
|
||||
else:
|
||||
new_conversations[conversation.image].append(conversation)
|
||||
|
||||
# delete text duplicates
|
||||
duplicates = 0
|
||||
for data in new_conversations.values():
|
||||
if isinstance(data, list):
|
||||
index_to_pop = set()
|
||||
for i in range(len(data) - 1):
|
||||
for j in range(i + 1, len(data)):
|
||||
if [c1.model_dump() for c1 in data[i].conversations] == [c2.model_dump() for c2 in data[j].conversations]:
|
||||
if j not in index_to_pop:
|
||||
duplicates += 1
|
||||
index_to_pop.add(j)
|
||||
for index in sorted(index_to_pop, reverse=True):
|
||||
data.pop(index)
|
||||
|
||||
# delete text duplicates
|
||||
new_data = []
|
||||
for data in new_conversations.values():
|
||||
for i in range(len(data)):
|
||||
if i == 0:
|
||||
new_data.append(data[i])
|
||||
else:
|
||||
new_data[-1].conversations.extend(data[i].conversations)
|
||||
|
||||
|
||||
self.root = new_data
|
||||
|
||||
return self
|
||||
|
||||
class DataRow(BaseModel):
|
||||
images: list[Image.Image]
|
||||
texts: list[OrderedDict[str, str]]
|
||||
source: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def from_chat_messages(cls, messages: list[ChatMessage], image: Image.Image, source: str) -> "DataRow":
|
||||
|
||||
system, user, assistant = None, None, None
|
||||
have_system = any(message.role == "system" for message in messages)
|
||||
texts: list[OrderedDict[str, str]] = []
|
||||
images = [image]
|
||||
chat_messages: OrderedDict[str, str] = OrderedDict()
|
||||
for message in messages:
|
||||
if message.role == "system":
|
||||
system = message.content
|
||||
elif message.role == "user":
|
||||
user = message.content
|
||||
elif message.role == "assistant":
|
||||
assistant = message.content
|
||||
|
||||
if have_system and user is not None and assistant is not None and system is not None:
|
||||
chat_messages["system"] = system
|
||||
chat_messages["user"] = user
|
||||
chat_messages["assistant"] = assistant
|
||||
texts.append(chat_messages)
|
||||
chat_messages = OrderedDict()
|
||||
user, assistant = None, None
|
||||
|
||||
elif not have_system and user is not None and assistant is not None:
|
||||
chat_messages["user"] = user
|
||||
chat_messages["assistant"] = assistant
|
||||
texts.append(chat_messages)
|
||||
chat_messages = OrderedDict()
|
||||
user, assistant = None, None
|
||||
|
||||
return cls(images=images, texts=texts, source=source)
|
||||
|
||||
def to_model_dump(self) -> dict:
|
||||
return {
|
||||
"images": self.images,
|
||||
"texts": self.texts,
|
||||
"source": self.source,
|
||||
}
|
||||
145
scripts/agents/prompts.py
Normal file
145
scripts/agents/prompts.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
from typing import Literal
|
||||
|
||||
OS_ACTIONS = """
|
||||
def final_answer(answer: any) -> any:
|
||||
\"\"\"
|
||||
Provides a final answer to the given problem.
|
||||
Args:
|
||||
answer: The final answer to the problem
|
||||
\"\"\"
|
||||
|
||||
def move_mouse(self, x: float, y: float) -> str:
|
||||
\"\"\"
|
||||
Moves the mouse cursor to the specified coordinates
|
||||
Args:
|
||||
x: The x coordinate (horizontal position)
|
||||
y: The y coordinate (vertical position)
|
||||
\"\"\"
|
||||
|
||||
def click(x: Optional[float] = None, y: Optional[float] = None) -> str:
|
||||
\"\"\"
|
||||
Performs a left-click at the specified normalized coordinates
|
||||
Args:
|
||||
x: The x coordinate (horizontal position)
|
||||
y: The y coordinate (vertical position)
|
||||
\"\"\"
|
||||
|
||||
def double_click(x: Optional[float] = None, y: Optional[float] = None) -> str:
|
||||
\"\"\"
|
||||
Performs a double-click at the specified normalized coordinates
|
||||
Args:
|
||||
x: The x coordinate (horizontal position)
|
||||
y: The y coordinate (vertical position)
|
||||
\"\"\"
|
||||
|
||||
def type(text: str) -> str:
|
||||
\"\"\"
|
||||
Types the specified text at the current cursor position.
|
||||
Args:
|
||||
text: The text to type
|
||||
\"\"\"
|
||||
|
||||
def press(keys: str | list[str]) -> str:
|
||||
\"\"\"
|
||||
Presses a keyboard key
|
||||
Args:
|
||||
keys: The key or list of keys to press (e.g. "enter", "space", "backspace", "ctrl", etc.).
|
||||
\"\"\"
|
||||
|
||||
def navigate_back() -> str:
|
||||
\"\"\"
|
||||
Goes back to the previous page in the browser. If using this tool doesn't work, just click the button directly.
|
||||
\"\"\"
|
||||
|
||||
def drag(from_coord: list[float], to_coord: list[float]) -> str:
|
||||
\"\"\"
|
||||
Clicks [x1, y1], drags mouse to [x2, y2], then release click.
|
||||
Args:
|
||||
x1: origin x coordinate
|
||||
y1: origin y coordinate
|
||||
x2: end x coordinate
|
||||
y2: end y coordinate
|
||||
\"\"\"
|
||||
|
||||
def scroll(direction: Literal["up", "down"] = "down", amount: int = 1) -> str:
|
||||
\"\"\"
|
||||
Moves the mouse to selected coordinates, then uses the scroll button: this could scroll the page or zoom, depending on the app. DO NOT use scroll to move through linux desktop menus.
|
||||
Args:
|
||||
x: The x coordinate (horizontal position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates
|
||||
y: The y coordinate (vertical position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates
|
||||
direction: The direction to scroll ("up" or "down"), defaults to "down". For zoom, "up" zooms in, "down" zooms out.
|
||||
amount: The amount to scroll. A good amount is 1 or 2.
|
||||
\"\"\"
|
||||
|
||||
def wait(seconds: float) -> str:
|
||||
\"\"\"
|
||||
Waits for the specified number of seconds. Very useful in case the prior order is still executing (for example starting very heavy applications like browsers or office apps)
|
||||
Args:
|
||||
seconds: Number of seconds to wait, generally 2 is enough.
|
||||
\"\"\"
|
||||
"""
|
||||
|
||||
MOBILE_ACTIONS = """
|
||||
def navigate_back() -> str:
|
||||
\"\"\"
|
||||
Return to home page
|
||||
\"\"\"
|
||||
|
||||
def open_app(app_name: str) -> str:
|
||||
\"\"\"
|
||||
Launches the specified application.
|
||||
Args:
|
||||
app_name: the name of the application to launch
|
||||
\"\"\"
|
||||
|
||||
def swipe(from_coord: list[str], to_coord: list[str]) -> str:
|
||||
\"\"\"
|
||||
swipe from 'from_coord' to 'to_coord'
|
||||
Args:
|
||||
from_coord: origin coordinates
|
||||
to_coord: end coordinates
|
||||
\"\"\"
|
||||
|
||||
def long_press(x: int, y: int) -> str:
|
||||
\"\"\"
|
||||
Performs a long-press at the specified coordinates
|
||||
Args:
|
||||
x: The x coordinate (horizontal position)
|
||||
y: The y coordinate (vertical position)
|
||||
\"\"\"
|
||||
"""
|
||||
|
||||
OS_SYSTEM_PROMPT = f"""You are a helpful GUI agent. You’ll be given a task and a screenshot of the screen. Complete the task using Python function calls.
|
||||
|
||||
For each step:
|
||||
• First, <think></think> to express the thought process guiding your next action and the reasoning behind it.
|
||||
• Then, use <code></code> to perform the action. it will be executed in a stateful environment.
|
||||
|
||||
The following functions are exposed to the Python interpreter:
|
||||
<code>
|
||||
{OS_ACTIONS}
|
||||
</code>
|
||||
|
||||
The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||
"""
|
||||
|
||||
MOBILE_SYSTEM_PROMPT = f"""You are a helpful GUI agent. You’ll be given a task and a screenshot of the screen. Complete the task using Python function calls.
|
||||
|
||||
For each step:
|
||||
• First, <think></think> to express the thought process guiding your next action and the reasoning behind it.
|
||||
• Then, use <code></code> to perform the action. it will be executed in a stateful environment.
|
||||
|
||||
The following functions are exposed to the Python interpreter:
|
||||
<code>
|
||||
|
||||
# OS ACTIONS
|
||||
|
||||
{OS_ACTIONS}
|
||||
|
||||
# MOBILE ACTIONS
|
||||
|
||||
{MOBILE_ACTIONS}
|
||||
</code>
|
||||
|
||||
The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||
"""
|
||||
153
scripts/agents/qwenvl_collator.py
Normal file
153
scripts/agents/qwenvl_collator.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
from PIL import Image
|
||||
from scripts.agents.function_parser import parse_function_call
|
||||
|
||||
from qwen_vl_utils import smart_resize
|
||||
|
||||
def resize_images_in_messages(batch_messages, script_args) -> list[Image.Image]:
|
||||
|
||||
min_pixels = script_args.image_resize["min_pixels"]
|
||||
max_pixels = script_args.image_resize["max_pixels"]
|
||||
factor = script_args.image_resize["factor"]
|
||||
|
||||
all_image_inputs = []
|
||||
for messages in batch_messages:
|
||||
|
||||
old_image = messages[1]["content"][0]["image"]
|
||||
resized_height, resized_width = smart_resize(
|
||||
old_image.height,
|
||||
old_image.width,
|
||||
factor=factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
new_image = old_image.resize((resized_width, resized_height))
|
||||
messages[1]["content"][0]["image"] = new_image
|
||||
|
||||
function_calls = parse_function_call(messages[2]["content"])
|
||||
old_function_call_strings = [
|
||||
function_call.to_string() for function_call in function_calls
|
||||
]
|
||||
for function_call, old_function_call_string in zip(function_calls, old_function_call_strings):
|
||||
if function_call.function_name in [
|
||||
"click",
|
||||
"long_press",
|
||||
"double_click",
|
||||
"move_mouse",
|
||||
]:
|
||||
function_call.parameters["arg_0"] = (
|
||||
int(function_call.parameters["arg_0"]
|
||||
/ old_image.width
|
||||
* new_image.width)
|
||||
)
|
||||
function_call.parameters["arg_1"] = (
|
||||
int(function_call.parameters["arg_1"]
|
||||
/ old_image.height
|
||||
* new_image.height)
|
||||
)
|
||||
elif function_call.function_name in ["swipe", "drag"]:
|
||||
function_call.parameters["arg_0"] = (
|
||||
int(function_call.parameters["arg_0"][0]
|
||||
/ old_image.width
|
||||
* new_image.width),
|
||||
int(function_call.parameters["arg_0"][1]
|
||||
/ old_image.height
|
||||
* new_image.height)
|
||||
)
|
||||
function_call.parameters["arg_1"] = (
|
||||
int(function_call.parameters["arg_1"][0]
|
||||
/ old_image.width
|
||||
* new_image.width),
|
||||
int(function_call.parameters["arg_1"][1]
|
||||
/ old_image.height
|
||||
* new_image.height)
|
||||
)
|
||||
messages[2]["content"] = messages[2]["content"].replace(old_function_call_string, function_call.to_string())
|
||||
|
||||
all_image_inputs.append([new_image])
|
||||
return all_image_inputs
|
||||
|
||||
def create_vlm_collate_fn(processor, script_args):
|
||||
"""Optimized collate function for VLM training that masks system prompt tokens."""
|
||||
|
||||
def collate_fn(examples: list[dict[str, str | Image.Image]]):
|
||||
batch_messages = []
|
||||
system_prompts = []
|
||||
user_prompts = []
|
||||
for example in examples:
|
||||
system = example["system"]
|
||||
user = example["user"]
|
||||
assistant = example["assistant"]
|
||||
image = example["image"]
|
||||
|
||||
system_prompts.append(system)
|
||||
user_prompts.append(user)
|
||||
batch_messages.append(
|
||||
[
|
||||
{"role": "system", "content": system},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image},
|
||||
{"type": "text", "text": user},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": assistant},
|
||||
]
|
||||
)
|
||||
|
||||
all_image_inputs = []
|
||||
if script_args.image_resize is not None:
|
||||
all_image_inputs = resize_images_in_messages(batch_messages, script_args)
|
||||
|
||||
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
for messages in batch_messages
|
||||
]
|
||||
|
||||
batch = processor(
|
||||
text=texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
max_length=4096,
|
||||
)
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
labels = input_ids.clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
if hasattr(processor, "image_token"):
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
)
|
||||
if image_token_id is not None:
|
||||
labels[labels == image_token_id] = -100
|
||||
else:
|
||||
raise ValueError("Processor does not have image_token")
|
||||
|
||||
system_encodings = processor.tokenizer(
|
||||
system_prompts, add_special_tokens=False, padding=False
|
||||
)["input_ids"]
|
||||
|
||||
user_encodings = processor.tokenizer(
|
||||
user_prompts, add_special_tokens=False, padding=False
|
||||
)["input_ids"]
|
||||
|
||||
for encodings in [system_encodings, user_encodings]:
|
||||
for i, system_ids in enumerate(encodings):
|
||||
if input_ids[i, : len(system_ids)].tolist() == system_ids:
|
||||
labels[i, : len(system_ids)] = -100
|
||||
else:
|
||||
seq = input_ids[i].tolist()
|
||||
for j in range(len(seq) - len(system_ids) + 1):
|
||||
if seq[j : j + len(system_ids)] == system_ids:
|
||||
labels[i, j : j + len(system_ids)] = -100
|
||||
break # early exit
|
||||
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
return collate_fn
|
||||
224
scripts/agents/smolvlm2_collator.py
Normal file
224
scripts/agents/smolvlm2_collator.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
from PIL import Image
|
||||
from scripts.agents.function_parser import parse_function_call
|
||||
import numpy as np
|
||||
from transformers.models.smolvlm.image_processing_smolvlm import (
|
||||
get_resize_output_image_size,
|
||||
)
|
||||
from transformers.image_utils import ChannelDimension
|
||||
|
||||
|
||||
def transform_messages(
|
||||
batch_messages,
|
||||
image_resize: dict[str, int | bool],
|
||||
) -> list[list[Image.Image]]:
|
||||
|
||||
resolution_max_side = image_resize["resolution_max_side"] if "resolution_max_side" in image_resize else None
|
||||
to_pixel_coordinates = image_resize["to_pixel_coordinates"] if "to_pixel_coordinates" in image_resize else False
|
||||
|
||||
if not to_pixel_coordinates and resolution_max_side is None:
|
||||
return batch_messages
|
||||
|
||||
all_image_inputs: list[list[Image.Image]] = []
|
||||
for messages in batch_messages:
|
||||
new_image = None
|
||||
for i in range(len(messages)):
|
||||
if "image" in messages[i]["content"][0]:
|
||||
old_image = messages[i]["content"][0]["image"]
|
||||
|
||||
if resolution_max_side is not None:
|
||||
resized_height, resized_width = get_resize_output_image_size(
|
||||
np.array(old_image),
|
||||
resolution_max_side=resolution_max_side,
|
||||
input_data_format=ChannelDimension.LAST,
|
||||
)
|
||||
new_image = old_image.resize((resized_width, resized_height))
|
||||
else:
|
||||
resized_height, resized_width = old_image.height, old_image.width
|
||||
new_image = old_image
|
||||
|
||||
messages[i]["content"][0]["image"] = new_image
|
||||
all_image_inputs.append([new_image])
|
||||
|
||||
if messages[i]["role"] == "assistant" and to_pixel_coordinates:
|
||||
assert new_image is not None, "new_image is None"
|
||||
|
||||
function_calls = parse_function_call(messages[i]["content"][0]["text"])
|
||||
old_function_call_strings = [
|
||||
function_call.to_string() for function_call in function_calls
|
||||
]
|
||||
for function_call, old_function_call_string in zip(
|
||||
function_calls, old_function_call_strings
|
||||
):
|
||||
if function_call.function_name in [
|
||||
"click",
|
||||
"long_press",
|
||||
"double_click",
|
||||
"move_mouse",
|
||||
]:
|
||||
function_call.parameters["x"] = int(
|
||||
function_call.parameters["x"] * new_image.width
|
||||
)
|
||||
function_call.parameters["y"] = int(
|
||||
function_call.parameters["y"] * new_image.height
|
||||
)
|
||||
elif function_call.function_name in ["swipe", "drag"]:
|
||||
function_call.parameters["from_coord"] = (
|
||||
int(function_call.parameters["from_coord"][0] * new_image.width),
|
||||
int(
|
||||
function_call.parameters["from_coord"][1] * new_image.height
|
||||
),
|
||||
)
|
||||
function_call.parameters["to_coord"] = (
|
||||
int(function_call.parameters["to_coord"][0] * new_image.width),
|
||||
int(
|
||||
function_call.parameters["to_coord"][1] * new_image.height
|
||||
),
|
||||
)
|
||||
messages[i]["content"][0]["text"] = messages[i]["content"][0][
|
||||
"text"
|
||||
].replace(old_function_call_string, function_call.to_string())
|
||||
|
||||
return all_image_inputs
|
||||
|
||||
|
||||
def create_vlm_collate_fn(processor, training_args, script_args):
|
||||
"""Optimized collate function for VLM training that masks system prompt tokens."""
|
||||
|
||||
def collate_fn(examples: list[dict[str, list | str | Image.Image]]):
|
||||
batch_messages: list[list[dict[str, list | str | Image.Image]]] = []
|
||||
assistant_messages: list[list[str]] = []
|
||||
all_image_inputs: list[list[Image.Image]] = []
|
||||
for example in examples:
|
||||
images: list[Image.Image] = example["images"]
|
||||
is_first_user = True
|
||||
sample: list[dict[str, list | str | Image.Image]] = []
|
||||
assistant: list[str] = []
|
||||
for text in example["texts"]:
|
||||
if "system" in text.keys():
|
||||
sample.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": text["system"]}],
|
||||
}
|
||||
)
|
||||
|
||||
if is_first_user:
|
||||
sample.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": images[0]},
|
||||
{"type": "text", "text": text["user"]},
|
||||
],
|
||||
}
|
||||
)
|
||||
is_first_user = False
|
||||
else:
|
||||
sample.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text["user"]},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
sample.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "\n" + text["assistant"]}],
|
||||
}
|
||||
)
|
||||
assistant.append(text["assistant"] + "<end_of_utterance>")
|
||||
|
||||
batch_messages.append(sample)
|
||||
assistant_messages.append(assistant)
|
||||
all_image_inputs.append(images)
|
||||
|
||||
if script_args.image_resize is not None and "to_pixel_coordinates" in script_args.image_resize and script_args.image_resize["to_pixel_coordinates"]:
|
||||
all_image_inputs = transform_messages(
|
||||
batch_messages,
|
||||
image_resize=script_args.image_resize,
|
||||
)
|
||||
|
||||
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
for messages in batch_messages
|
||||
]
|
||||
|
||||
batch = processor(
|
||||
text=texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
max_length=training_args.max_length,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
labels = input_ids.clone()
|
||||
|
||||
assistant_encodings = [
|
||||
processor.tokenizer(
|
||||
assistant_message, add_special_tokens=False, padding=False
|
||||
)["input_ids"]
|
||||
for assistant_message in assistant_messages
|
||||
]
|
||||
|
||||
# Mask out all except the assistant messages
|
||||
for i, assistant_ids_list in enumerate(assistant_encodings):
|
||||
seq = input_ids[i].tolist()
|
||||
assistant_positions: list[int] = []
|
||||
for ids in assistant_ids_list:
|
||||
start_pos = 0
|
||||
while start_pos < len(seq) - len(ids) + 1:
|
||||
found = False
|
||||
for j in range(start_pos, len(seq) - len(ids) + 1):
|
||||
if seq[j : j + len(ids)] == ids:
|
||||
assistant_positions.extend(range(j, j + len(ids)))
|
||||
start_pos = j + len(ids)
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
break
|
||||
|
||||
for pos in range(len(seq)):
|
||||
if pos not in assistant_positions:
|
||||
labels[i, pos] = -100
|
||||
|
||||
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
return collate_fn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoProcessor
|
||||
from datasets import load_dataset
|
||||
|
||||
class ScriptArguments:
|
||||
image_resize = None
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
||||
)
|
||||
processor.image_processor.size = {"longest_edge": 384}
|
||||
collate_fn = create_vlm_collate_fn(processor, script_args=ScriptArguments)
|
||||
max_length = []
|
||||
for dataset_name in ['ricosca']:
|
||||
dataset_max_length = 0
|
||||
data = load_dataset("smolagents/aguvis-stage-1", dataset_name, split="train")
|
||||
print("processing", dataset_name)
|
||||
for example in data:
|
||||
batch = collate_fn([example])
|
||||
dataset_max_length = max(dataset_max_length, batch["input_ids"].shape[1])
|
||||
print("dataset_max_length", dataset_name, dataset_max_length)
|
||||
max_length.append(dataset_max_length)
|
||||
|
||||
print(max_length)
|
||||
print("max_length", max(max_length))
|
||||
open("max_length_384_phase_1.txt", "a").write(str(max(max_length)))
|
||||
54
scripts/agents/smolvlm_inference.py
Normal file
54
scripts/agents/smolvlm_inference.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import torch
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
from transformers.models.smolvlm.image_processing_smolvlm import SmolVLMImageProcessor
|
||||
|
||||
|
||||
class TransformersModel:
|
||||
def __init__(self, model_id: str, to_device: str = "cuda"):
|
||||
self.model_id = model_id
|
||||
self.processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
|
||||
self.processor.image_processor.size = {"longest_edge": 3 * 384}
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(to_device)
|
||||
|
||||
def generate(self, messages: list[dict], **kwargs):
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(self.model.device, dtype=torch.bfloat16)
|
||||
generated_ids = self.model.generate(**inputs, **kwargs)
|
||||
return self.processor.batch_decode(
|
||||
generated_ids[:, len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
|
||||
model = TransformersModel(
|
||||
model_id="/fsx/amir_mahla/smolagents-SmolVLM2-2.2B-Instruct-Agentic-GUI-phase-1-max-size-1152/checkpoint-800",
|
||||
to_device="cuda:0",
|
||||
)
|
||||
|
||||
image = Image.open("/admin/home/amir_mahla/screensuite/examples/sample_image.png")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": image,
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the screenshot, and the instruction, output a click that completes the instruction or targets the given element (always target the center of the element).\n\nOutput the click position as follows:\n\n<think>(thought process)</think><code>click(x, y)</code>\nWith x the number of pixels from the left edge and y the number of pixels from the top edge.\n\nNow write the click needed to complete the instruction:\nInstruction: view more information about bomber\n",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
print(model.generate(messages, max_new_tokens=128))
|
||||
2
setup.py
2
setup.py
|
|
@ -74,6 +74,8 @@ _deps = [
|
|||
"async-lru>=2.0.5",
|
||||
"aiofiles>=24.1.0",
|
||||
"pandas>=2.2.3",
|
||||
"qwen-vl-utils>=0.1.0",
|
||||
"setuptools>=80.9.0",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ START_TIME=$(date +%s)
|
|||
echo "START TIME: $(date)"
|
||||
|
||||
# Refresh Weka on h4 cache
|
||||
echo "Refreshing Weka filesystem..."
|
||||
find -L /fsx/h4/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
|
||||
# echo "Refreshing Weka filesystem..."
|
||||
# find -L /fsx/${USER}/ -type f | xargs -d '\n' -r -n512 -P64 weka fs tier fetch
|
||||
|
||||
# Default values
|
||||
MODEL=""
|
||||
|
|
|
|||
|
|
@ -68,19 +68,51 @@ class ScriptArguments(trl.ScriptArguments):
|
|||
|
||||
# Override the dataset_name to make it optional
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."}
|
||||
default=None,
|
||||
metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."},
|
||||
)
|
||||
dataset_mixture: Optional[dict[str, Any]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."},
|
||||
metadata={
|
||||
"help": "Configuration for creating dataset mixtures with advanced options like shuffling."
|
||||
},
|
||||
)
|
||||
single_gpu: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Force training on single GPU only, disabling distributed training."
|
||||
},
|
||||
)
|
||||
|
||||
image_resize: Optional[dict[str, int]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Resize the image to the given minimum and maximum pixels."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.dataset_mixture is None:
|
||||
raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")
|
||||
raise ValueError(
|
||||
"Either `dataset_name` or `dataset_mixture` must be provided"
|
||||
)
|
||||
|
||||
if self.image_resize is not None:
|
||||
if (
|
||||
not isinstance(self.image_resize, dict)
|
||||
# or "min_pixels" not in self.image_resize
|
||||
# or "max_pixels" not in self.image_resize
|
||||
# or "factor" not in self.image_resize
|
||||
or "to_pixel_coordinates" not in self.image_resize
|
||||
or "resolution_max_side" not in self.image_resize
|
||||
):
|
||||
raise ValueError(
|
||||
f"image_resize must be a dictionary with a 'min_pixels', 'max_pixels' and 'factor' key. {self.image_resize}"
|
||||
)
|
||||
|
||||
if self.dataset_mixture is not None:
|
||||
if not isinstance(self.dataset_mixture, dict) or "datasets" not in self.dataset_mixture:
|
||||
if (
|
||||
not isinstance(self.dataset_mixture, dict)
|
||||
or "datasets" not in self.dataset_mixture
|
||||
):
|
||||
raise ValueError(
|
||||
"dataset_mixture must be a dictionary with a 'datasets' key. "
|
||||
"Expected format: {'datasets': [...], 'seed': int}"
|
||||
|
|
@ -110,7 +142,11 @@ class ScriptArguments(trl.ScriptArguments):
|
|||
)
|
||||
|
||||
# Check that column names are consistent across all dataset configs
|
||||
columns_sets = [set(dataset.columns) for dataset in datasets_list if dataset.columns is not None]
|
||||
columns_sets = [
|
||||
set(dataset.columns)
|
||||
for dataset in datasets_list
|
||||
if dataset.columns is not None
|
||||
]
|
||||
if columns_sets:
|
||||
first_columns = columns_sets[0]
|
||||
if not all(columns == first_columns for columns in columns_sets):
|
||||
|
|
@ -135,13 +171,21 @@ class GRPOConfig(trl.GRPOConfig):
|
|||
default_factory=lambda: [],
|
||||
metadata={"help": "The callbacks to run during training."},
|
||||
)
|
||||
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
|
||||
chat_template: Optional[str] = field(
|
||||
default=None, metadata={"help": "The chat template to use."}
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main", metadata={"help": "The Hub model branch to push the model to."}
|
||||
)
|
||||
num_completions_to_print: int = field(default=0, metadata={"help": "Number of completions to print."})
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
num_completions_to_print: int = field(
|
||||
default=0, metadata={"help": "Number of completions to print."}
|
||||
)
|
||||
overwrite_hub_revision: bool = field(
|
||||
default=False, metadata={"help": "Whether to overwrite the Hub revision."}
|
||||
)
|
||||
push_to_hub_revision: bool = field(
|
||||
default=False, metadata={"help": "Whether to push to a Hub revision/branch."}
|
||||
)
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The optional system prompt to use."},
|
||||
|
|
@ -149,7 +193,9 @@ class GRPOConfig(trl.GRPOConfig):
|
|||
wandb_log_unique_prompts: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.")
|
||||
"help": (
|
||||
"Whether to log the unique prompts to wandb. This will create a new run for each unique prompt."
|
||||
)
|
||||
},
|
||||
)
|
||||
wandb_entity: Optional[str] = field(
|
||||
|
|
@ -180,17 +226,27 @@ class SFTConfig(trl.SFTConfig):
|
|||
default_factory=lambda: [],
|
||||
metadata={"help": "The callbacks to run during training."},
|
||||
)
|
||||
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
|
||||
chat_template: Optional[str] = field(
|
||||
default=None, metadata={"help": "The chat template to use."}
|
||||
)
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The optional system prompt to use for benchmarking."},
|
||||
)
|
||||
vision_model: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether this is a vision-language model training."},
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The Hub model branch to push the model to."},
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
overwrite_hub_revision: bool = field(
|
||||
default=False, metadata={"help": "Whether to overwrite the Hub revision."}
|
||||
)
|
||||
push_to_hub_revision: bool = field(
|
||||
default=False, metadata={"help": "Whether to push to a Hub revision/branch."}
|
||||
)
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
|
|
@ -263,7 +319,9 @@ class GRPOScriptArguments(ScriptArguments):
|
|||
)
|
||||
repetition_max_penalty: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
|
||||
metadata={
|
||||
"help": "Maximum (negative) penalty for for repetition penalty reward"
|
||||
},
|
||||
)
|
||||
code_language: str = field(
|
||||
default="python",
|
||||
|
|
@ -281,7 +339,9 @@ class GRPOScriptArguments(ScriptArguments):
|
|||
)
|
||||
code_eval_scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = field(
|
||||
default="weighted_sum",
|
||||
metadata={"help": "use fraction of passed test cases as reward. If false, use 0/1 scoring."},
|
||||
metadata={
|
||||
"help": "use fraction of passed test cases as reward. If false, use 0/1 scoring."
|
||||
},
|
||||
)
|
||||
parallel_code_exec_per_proc: int = field(
|
||||
default=2,
|
||||
|
|
|
|||
|
|
@ -13,18 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Supervised fine-tuning script for decoder language models.
|
||||
Supervised fine-tuning script for decoder language models and vision-language models.
|
||||
|
||||
Usage:
|
||||
|
||||
# One 1 node of 8 x H100s
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--model_name_or_path open-r1/Qwen2.5-Math-7B-RoPE-300k \
|
||||
--dataset_name open-r1/Mixture-of-Thoughts \
|
||||
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
|
||||
--dataset_name smolagents/gaia-traces \
|
||||
--num_train_epochs 1 \
|
||||
--dataset_config all \
|
||||
--eos_token '<|im_end|>' \
|
||||
--learning_rate 4.0e-5 \
|
||||
--num_train_epochs 5 \
|
||||
--max_seq_length 32768 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_checkpointing \
|
||||
|
|
@ -39,20 +39,41 @@ import sys
|
|||
|
||||
import datasets
|
||||
import transformers
|
||||
from transformers import set_seed
|
||||
from transformers import (
|
||||
set_seed,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
LlavaForConditionalGeneration,
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.configs import ScriptArguments, SFTConfig
|
||||
from open_r1.utils import get_dataset, get_model, get_tokenizer
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format
|
||||
|
||||
from open_r1.configs import ScriptArguments, SFTConfig
|
||||
from open_r1.utils import get_dataset, get_model, get_tokenizer, get_processor
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from PIL import Image
|
||||
from transformers import Qwen2VLProcessor
|
||||
from typing import Any
|
||||
from scripts.agents.function_parser import parse_function_call
|
||||
from scripts.agents.smolvlm2_collator import create_vlm_collate_fn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Force single GPU mode if requested
|
||||
# if hasattr(script_args, 'single_gpu') and script_args.single_gpu:
|
||||
# logger.info("Single GPU mode requested - setting CUDA_VISIBLE_DEVICES=0")
|
||||
# # Disable distributed training
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
# training_args.local_rank = -1
|
||||
# training_args.ddp_backend = None
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
|
|
@ -85,15 +106,42 @@ def main(script_args, training_args, model_args):
|
|||
init_wandb_training(training_args)
|
||||
|
||||
######################################
|
||||
# Load dataset, tokenizer, and model #
|
||||
# Load dataset, processor/tokenizer, and model #
|
||||
######################################
|
||||
dataset = get_dataset(script_args)
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
model = get_model(model_args, training_args)
|
||||
|
||||
if tokenizer.chat_template is None:
|
||||
logger.info("No chat template provided, defaulting to ChatML.")
|
||||
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
|
||||
if training_args.vision_model:
|
||||
logger.info("Setting up vision-language model training")
|
||||
|
||||
# Set VLM-specific training arguments (following TRL reference)
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
training_args.ddp_find_unused_parameters = True
|
||||
|
||||
# Load processor and model for VLM
|
||||
processor = get_processor(model_args, training_args, script_args)
|
||||
model = get_model(
|
||||
model_args, training_args
|
||||
) # This should return AutoModelForVision2Seq
|
||||
data_collator = create_vlm_collate_fn(processor, training_args, script_args)
|
||||
processing_class = processor.tokenizer
|
||||
model_tags = ["open-r1", "vision-language", "vlm"]
|
||||
|
||||
else:
|
||||
logger.info("Setting up text-only model training")
|
||||
|
||||
# Load tokenizer and model for text-only
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
model = get_model(model_args, training_args)
|
||||
|
||||
if tokenizer.chat_template is None:
|
||||
logger.info("No chat template provided, defaulting to ChatML.")
|
||||
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
|
||||
|
||||
data_collator = None # Use default
|
||||
processing_class = tokenizer
|
||||
model_tags = ["open-r1"]
|
||||
|
||||
############################
|
||||
# Initialize the SFT Trainer
|
||||
|
|
@ -101,9 +149,14 @@ def main(script_args, training_args, model_args):
|
|||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
|
||||
processing_class=tokenizer,
|
||||
eval_dataset=(
|
||||
dataset[script_args.dataset_test_split]
|
||||
if training_args.eval_strategy != "no"
|
||||
else None
|
||||
),
|
||||
processing_class=processing_class,
|
||||
peft_config=get_peft_config(model_args),
|
||||
callbacks=get_callbacks(training_args, model_args),
|
||||
)
|
||||
|
|
@ -128,16 +181,17 @@ def main(script_args, training_args, model_args):
|
|||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
# Align the model's generation config with the tokenizer's eos token
|
||||
# to avoid unbounded generation in the transformers `pipeline()` function
|
||||
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
try:
|
||||
processor.save_pretrained(training_args.output_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving processor: {e}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"dataset_name": script_args.dataset_name,
|
||||
"tags": ["open-r1"],
|
||||
"tags": model_tags,
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
|
@ -160,7 +214,10 @@ def main(script_args, training_args, model_args):
|
|||
#############
|
||||
if training_args.push_to_hub:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
trainer.push_to_hub(**kwargs, token=os.getenv("HF_TOKEN"))
|
||||
# Also push processor for VLM models
|
||||
if training_args.vision_model and trainer.accelerator.is_main_process:
|
||||
processor.push_to_hub(training_args.hub_model_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from .data import get_dataset
|
||||
from .import_utils import is_e2b_available, is_morph_available
|
||||
from .model_utils import get_model, get_tokenizer
|
||||
from .model_utils import get_model, get_tokenizer, get_processor
|
||||
|
||||
|
||||
__all__ = ["get_tokenizer", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]
|
||||
__all__ = ["get_tokenizer", "get_processor", "is_e2b_available", "is_morph_available", "get_model", "get_dataset"]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,20 @@
|
|||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
AutoProcessor,
|
||||
AutoModelForImageTextToText,
|
||||
)
|
||||
|
||||
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
|
||||
|
||||
from ..configs import GRPOConfig, SFTConfig
|
||||
from ..configs import GRPOConfig, SFTConfig, ScriptArguments
|
||||
|
||||
|
||||
def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer:
|
||||
def get_tokenizer(
|
||||
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
|
||||
) -> PreTrainedTokenizer:
|
||||
"""Get the tokenizer for the model."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
|
|
@ -20,10 +28,42 @@ def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
|
|||
return tokenizer
|
||||
|
||||
|
||||
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
|
||||
"""Get the model"""
|
||||
def get_processor(
|
||||
model_args: ModelConfig,
|
||||
training_args: SFTConfig | GRPOConfig,
|
||||
script_args: ScriptArguments,
|
||||
) -> AutoProcessor:
|
||||
"""Get the processor for VLM models."""
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
|
||||
# Set the image processor resize size
|
||||
if script_args.image_resize is not None and "resolution_max_side" in script_args.image_resize:
|
||||
processor.image_processor.size = {
|
||||
"longest_edge": script_args.image_resize["resolution_max_side"]
|
||||
}
|
||||
if hasattr(processor, "tokenizer"):
|
||||
processor.tokenizer.truncation_side = "right"
|
||||
processor.tokenizer.padding_side = "right"
|
||||
|
||||
if training_args.chat_template is not None:
|
||||
processor.chat_template = training_args.chat_template
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def get_model(
|
||||
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig
|
||||
) -> AutoModelForCausalLM | AutoModelForImageTextToText:
|
||||
"""Get the model - supports both text-only and vision-language models"""
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
|
|
@ -35,8 +75,19 @@ def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) ->
|
|||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Check if this is a VLM model using the explicit flag
|
||||
if hasattr(training_args, "vision_model") and training_args.vision_model:
|
||||
# Load as vision-language model
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
# Load as text-only model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -1,219 +0,0 @@
|
|||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from e2b_code_interpreter.models import Execution, ExecutionError
|
||||
from open_r1.rewards import code_reward, ioi_code_reward
|
||||
from open_r1.utils.routed_morph import RoutedMorphSandbox
|
||||
from open_r1.utils.routed_sandbox import RoutedSandbox
|
||||
|
||||
|
||||
class TestCodeRewards(unittest.TestCase):
|
||||
def test_python_code_reward(self):
|
||||
# requires E2B, see the README.md file
|
||||
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
|
||||
NUM_SAMPLES = 20
|
||||
samples = code_dataset["train"].select(range(NUM_SAMPLES))
|
||||
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
|
||||
reward_kwargs = {"verification_info": [sample["verification_info"] for sample in samples]}
|
||||
rewards = code_reward(test_completions, **reward_kwargs)
|
||||
print(rewards)
|
||||
assert rewards == [1.0] * NUM_SAMPLES
|
||||
|
||||
def test_e2b_router(self):
|
||||
# run router locally: python scripts/e2b_router.py
|
||||
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
|
||||
NUM_SAMPLES = 128
|
||||
samples = code_dataset["train"].select(range(NUM_SAMPLES))
|
||||
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
|
||||
reward_kwargs = {"verification_info": [sample["verification_info"] for sample in samples]}
|
||||
rewards = code_reward(test_completions, e2b_router_url="0.0.0.0:8000", **reward_kwargs)
|
||||
print(rewards)
|
||||
assert rewards == [1.0] * NUM_SAMPLES
|
||||
|
||||
def test_e2b_router_parallel(self):
|
||||
# run router locally: python scripts/e2b_router.py
|
||||
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_SAMPLES = 256
|
||||
|
||||
def batch_code_reward(examples):
|
||||
test_completions = [[{"content": solution}] for solution in examples["gold_standard_solution"]]
|
||||
reward_kwargs = {
|
||||
"verification_info": [verification_info for verification_info in examples["verification_info"]]
|
||||
}
|
||||
rewards = code_reward(test_completions, e2b_router_url="0.0.0.0:8000", **reward_kwargs)
|
||||
assert rewards == [1.0] * BATCH_SIZE
|
||||
return examples
|
||||
|
||||
code_dataset = code_dataset["train"].select(range(NUM_SAMPLES))
|
||||
code_dataset = code_dataset.map(
|
||||
batch_code_reward,
|
||||
batched=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
num_proc=4,
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
def test_ioi_code_reward(self):
|
||||
# This slow test case requires spinning up a bunch (I tested with ~64) of piston workers, see docs here
|
||||
# slurm/piston/README.md
|
||||
code_dataset = load_dataset("open-r1/ioi-reward-test-dataset")
|
||||
NUM_SAMPLES = 16
|
||||
samples = code_dataset["train"].select(range(NUM_SAMPLES))
|
||||
test_completions = [[{"content": f"```cpp\n{sample['sample_solution']}```"}] for sample in samples]
|
||||
keys = [key for key in samples[0] if key not in ["prompt", "completion"]]
|
||||
reward_kwargs = {key: [example[key] for example in samples] for key in keys}
|
||||
rewards = ioi_code_reward(test_completions, **reward_kwargs)
|
||||
print(rewards)
|
||||
assert rewards == [1.0] * NUM_SAMPLES
|
||||
|
||||
def test_e2b_router_run_code_success(self):
|
||||
# run router locally: python scripts/e2b_router.py
|
||||
routed_sandbox = RoutedSandbox(router_url="localhost:8000")
|
||||
scripts = [
|
||||
"print('hello from integration test')",
|
||||
"result = 2 + 2\nprint(result)",
|
||||
]
|
||||
|
||||
results = routed_sandbox.run_code(scripts)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
for result in results:
|
||||
assert isinstance(result, Execution)
|
||||
# assert result.exit_code == 0
|
||||
assert result.error is None
|
||||
assert "hello" in result.logs["stdout"][0] or "4" in result.logs["stdout"][0]
|
||||
|
||||
def test_e2b_router_run_code_with_error(self):
|
||||
# run router locally: python scripts/e2b_router.py
|
||||
|
||||
routed_sandbox = RoutedSandbox(router_url="localhost:8000")
|
||||
scripts = ["print('this is fine')", "print('unterminated string"]
|
||||
|
||||
results = routed_sandbox.run_code(scripts)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
# First one should be okay
|
||||
# assert results[0].exit_code == 0 # Execution object has no attribute 'exit_code'
|
||||
assert results[0].error is None
|
||||
assert "this is fine" in results[0].logs["stdout"][0]
|
||||
|
||||
# Second one should have a syntax error
|
||||
|
||||
# assert results[1].exit_code != 0 # Execution object has no attribute 'exit_code'
|
||||
assert results[1].error is not None
|
||||
assert isinstance(results[1].error, ExecutionError)
|
||||
assert "SyntaxError" in results[1].error.name
|
||||
|
||||
def test_python_code_reward_morph(self):
|
||||
# requires MorphCloud, see the README.md file
|
||||
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
|
||||
NUM_SAMPLES = 20
|
||||
samples = code_dataset["train"].select(range(NUM_SAMPLES))
|
||||
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
|
||||
reward_kwargs = {
|
||||
"verification_info": [sample["verification_info"] for sample in samples],
|
||||
"provider_type": "morph",
|
||||
}
|
||||
rewards = code_reward(test_completions, **reward_kwargs)
|
||||
print(rewards)
|
||||
assert rewards == [1.0] * NUM_SAMPLES
|
||||
|
||||
def test_morph_router(self):
|
||||
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
|
||||
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
|
||||
NUM_SAMPLES = 32
|
||||
samples = code_dataset["train"].select(range(NUM_SAMPLES))
|
||||
test_completions = [[{"content": sample["gold_standard_solution"]}] for sample in samples]
|
||||
reward_kwargs = {
|
||||
"verification_info": [sample["verification_info"] for sample in samples],
|
||||
"provider_type": "morph",
|
||||
"morph_router_url": "0.0.0.0:8001",
|
||||
}
|
||||
rewards = code_reward(test_completions, **reward_kwargs)
|
||||
print(rewards)
|
||||
assert rewards == [1.0] * NUM_SAMPLES
|
||||
|
||||
def test_morph_router_parallel(self):
|
||||
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
|
||||
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated-tested-shuffled")
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_SAMPLES = 256
|
||||
|
||||
def batch_code_reward(examples):
|
||||
test_completions = [[{"content": solution}] for solution in examples["gold_standard_solution"]]
|
||||
reward_kwargs = {
|
||||
"verification_info": [verification_info for verification_info in examples["verification_info"]],
|
||||
"provider_type": "morph",
|
||||
"morph_router_url": "0.0.0.0:8001",
|
||||
}
|
||||
rewards = code_reward(test_completions, **reward_kwargs)
|
||||
assert rewards == [1.0] * BATCH_SIZE
|
||||
return examples
|
||||
|
||||
code_dataset = code_dataset["train"].select(range(NUM_SAMPLES))
|
||||
code_dataset = code_dataset.map(
|
||||
batch_code_reward,
|
||||
batched=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
num_proc=4,
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
def test_morph_router_run_code_success(self):
|
||||
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
|
||||
|
||||
routed_sandbox = RoutedMorphSandbox(router_url="localhost:8001")
|
||||
scripts = [
|
||||
"print('hello from morph integration test')",
|
||||
"result = 2 + 2\nprint(result)",
|
||||
]
|
||||
|
||||
results = routed_sandbox.run_code(scripts)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
for result in results:
|
||||
assert result.exception_str is None
|
||||
assert "hello" in result.text or "4" in result.text
|
||||
|
||||
def test_morph_router_run_code_with_error(self):
|
||||
# run router locally: python scripts/morph_router.py --port 8001 --max_num_sandboxes 20
|
||||
|
||||
routed_sandbox = RoutedMorphSandbox(router_url="localhost:8001")
|
||||
scripts = ["print('this is fine with morph')", "print('unterminated string"]
|
||||
|
||||
results = routed_sandbox.run_code(scripts)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
# First one should be okay
|
||||
assert results[0].exception_str is None
|
||||
assert "this is fine with morph" in results[0].text
|
||||
|
||||
# Second one should have a syntax error
|
||||
assert "SyntaxError" in results[1].text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,568 +0,0 @@
|
|||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from open_r1.configs import GRPOScriptArguments
|
||||
from open_r1.rewards import (
|
||||
accuracy_reward,
|
||||
format_reward,
|
||||
get_code_format_reward,
|
||||
get_cosine_scaled_reward,
|
||||
get_repetition_penalty_reward,
|
||||
get_reward_funcs,
|
||||
get_soft_overlong_punishment,
|
||||
len_reward,
|
||||
reasoning_steps_reward,
|
||||
tag_count_reward,
|
||||
)
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class TestGetRewardFuncs(unittest.TestCase):
|
||||
def test_get_reward_funcs(self):
|
||||
"""Test get_reward_funcs with various reward functions."""
|
||||
reward_names = [
|
||||
"accuracy",
|
||||
"format",
|
||||
"reasoning_steps",
|
||||
"cosine",
|
||||
"repetition_penalty",
|
||||
"length",
|
||||
"tag_count",
|
||||
"code",
|
||||
"ioi_code",
|
||||
"code_format",
|
||||
"binary_code",
|
||||
]
|
||||
reward_func_names = [
|
||||
"accuracy_reward",
|
||||
"format_reward",
|
||||
"reasoning_steps_reward",
|
||||
"cosine_scaled_reward",
|
||||
"repetition_penalty_reward",
|
||||
"len_reward",
|
||||
"tag_count_reward",
|
||||
"code_reward",
|
||||
"ioi_code_reward",
|
||||
"code_format_reward",
|
||||
"binary_code_reward",
|
||||
]
|
||||
|
||||
args = GRPOScriptArguments(
|
||||
dataset_name="dummy",
|
||||
reward_funcs=reward_names,
|
||||
)
|
||||
|
||||
reward_funcs = get_reward_funcs(args)
|
||||
self.assertEqual(len(reward_funcs), 11)
|
||||
for func_name, func in zip(reward_func_names, reward_funcs):
|
||||
self.assertEqual(func_name, func.__name__)
|
||||
|
||||
|
||||
class TestRewards(unittest.TestCase):
|
||||
def test_accuracy_reward_correct_answer(self):
|
||||
"""Test accuracy_reward with a correct answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{63}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_accuracy_reward_wrong_answer(self):
|
||||
"""Test accuracy_reward with an incorrect answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
def test_accuracy_reward_wrong_answer_no_latex(self):
|
||||
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
|
||||
completion = [[{"content": r"\boxed{3}"}]]
|
||||
solution = ["6"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
def test_format_reward_correct(self):
|
||||
"""Test format_reward with correct format."""
|
||||
completion = [[{"content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer\n</answer>"}]]
|
||||
rewards = format_reward(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_format_reward_incorrect(self):
|
||||
"""Test format_reward with incorrect format."""
|
||||
incorrect_formats = [
|
||||
"<think>Only thinking</think>",
|
||||
"<answer>Only answer</answer>",
|
||||
"No tags at all",
|
||||
"<think>Missing closing</think><answer>Missing closing",
|
||||
"<think>Wrong order</answer><answer>Wrong order</think>",
|
||||
]
|
||||
|
||||
for fmt in incorrect_formats:
|
||||
completion = [[{"content": fmt}]]
|
||||
rewards = format_reward(completion)
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
def test_reasoning_steps_reward(self):
|
||||
"""Test reasoning_steps_reward with various formats."""
|
||||
test_cases = [
|
||||
# Full credit cases (3 or more steps)
|
||||
("Step 1: First step\nStep 2: Second step\nStep 3: Third step", 1.0),
|
||||
("First, we do this.\nSecond, we do that.\nFinally, we conclude.", 1.0),
|
||||
# Partial credit cases (less than 3 steps)
|
||||
("Step 1: Only step", 1 / 3),
|
||||
("First, we do this.\nFinally, we conclude.", 2 / 3),
|
||||
# No credit case
|
||||
("Just plain text without any clear steps", 0.0),
|
||||
]
|
||||
|
||||
for content, expected_reward in test_cases:
|
||||
completion = [[{"content": content}]]
|
||||
rewards = reasoning_steps_reward(completion)
|
||||
self.assertAlmostEqual(rewards[0], expected_reward)
|
||||
|
||||
def test_multiple_completions(self):
|
||||
"""Test handling multiple completions at once."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{63}{400}}"}],
|
||||
[{"content": r"\boxed{\frac{64}{400}}"}],
|
||||
]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = accuracy_reward(completions, solutions)
|
||||
self.assertEqual(len(rewards), 2)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
self.assertEqual(rewards[1], 0.0)
|
||||
|
||||
def test_cosine_scaled_reward(self):
|
||||
"""Test cosine_scaled_reward with various cases."""
|
||||
# Test parameters
|
||||
test_params = {
|
||||
"min_value_wrong": -1.0,
|
||||
"max_value_wrong": -0.5,
|
||||
"min_value_correct": 0.5,
|
||||
"max_value_correct": 1.0,
|
||||
"max_len": 100,
|
||||
}
|
||||
|
||||
test_cases = [
|
||||
# Correct answers with different lengths
|
||||
(
|
||||
r"\boxed{\frac{63}{400}}",
|
||||
r"\frac{63}{400}",
|
||||
20,
|
||||
0.943,
|
||||
), # Short correct answer
|
||||
(
|
||||
r"\boxed{\frac{63}{400}}",
|
||||
r"\frac{63}{400}",
|
||||
80,
|
||||
0.547,
|
||||
), # Long correct answer
|
||||
# Wrong answers with different lengths
|
||||
(
|
||||
r"\boxed{\frac{64}{400}}",
|
||||
r"\frac{63}{400}",
|
||||
20,
|
||||
-0.942,
|
||||
), # Short wrong answer
|
||||
(
|
||||
r"\boxed{\frac{64}{400}}",
|
||||
r"\frac{63}{400}",
|
||||
80,
|
||||
-0.547,
|
||||
), # Long wrong answer
|
||||
]
|
||||
|
||||
for content, solution, content_len, expected_reward in test_cases:
|
||||
# Pad content to desired length
|
||||
padded_content = content + " " * (content_len - len(content))
|
||||
completion = [[{"content": padded_content}]]
|
||||
|
||||
rewards = get_cosine_scaled_reward(**test_params)(completion, [solution])
|
||||
self.assertAlmostEqual(rewards[0], expected_reward, places=2)
|
||||
|
||||
def test_format_reward_specific_multiline(self):
|
||||
"""Test format_reward with a specific multiline input."""
|
||||
inputs = "<think>\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n</think>\n<answer>\n9\n</answer>"
|
||||
completion = [[{"content": inputs}]]
|
||||
rewards = format_reward(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_same_length_responses(self):
|
||||
"""Test len_reward when all responses have the same length."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{63}{400}}"}],
|
||||
[{"content": r"\boxed{\frac{64}{400}}"}],
|
||||
]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertEqual(rewards, [0.0, 0.0])
|
||||
|
||||
def test_different_lengths_correct_answers(self):
|
||||
"""Test len_reward with different length correct answers."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{63}{400}}"}], # shorter
|
||||
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward
|
||||
self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward
|
||||
|
||||
def test_different_lengths_incorrect_answers(self):
|
||||
"""Test len_reward with different length incorrect answers."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{64}{400}}"}], # shorter
|
||||
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards
|
||||
self.assertLessEqual(rewards[1], 0.0)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less
|
||||
|
||||
def test_mixed_correctness(self):
|
||||
"""Test len_reward with mix of correct and incorrect answers of different lengths."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter
|
||||
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer
|
||||
[{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter
|
||||
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}"] * 4
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
|
||||
# Shortest correct answer should get positive reward
|
||||
self.assertGreater(rewards[0], 0.0)
|
||||
|
||||
# Longer correct answer might get negative reward:
|
||||
self.assertGreater(rewards[2], rewards[1])
|
||||
self.assertGreaterEqual(rewards[1], rewards[3])
|
||||
|
||||
# Incorrect answers should get non-positive rewards
|
||||
self.assertLessEqual(rewards[2], 0.0)
|
||||
self.assertLessEqual(rewards[3], 0.0)
|
||||
|
||||
# Shorter answers should get better rewards within their correctness category
|
||||
self.assertGreater(rewards[0], rewards[1]) # correct answers
|
||||
self.assertGreater(rewards[2], rewards[3]) # incorrect answers
|
||||
|
||||
def test_unparseable_solution(self):
|
||||
"""Test len_reward with unparseable solution."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{answer}"}],
|
||||
[{"content": r"\boxed{answer} " + "x" * 10}],
|
||||
]
|
||||
solutions = ["unparseable_latex", "unparseable_latex"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward
|
||||
self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward
|
||||
|
||||
|
||||
class TestRepetitionPenaltyReward(unittest.TestCase):
|
||||
def test_positive_max_penalty_raises_value_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0)
|
||||
with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"):
|
||||
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5)
|
||||
|
||||
def test_no_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [[{"content": "this is a test sentence"}]]
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_full_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [[{"content": "this this this this this"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# (1 - 1/4) * -1 = -0.75
|
||||
self.assertEqual(rewards, [-0.75])
|
||||
|
||||
def test_partial_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [[{"content": "this is a this is a test"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total
|
||||
# (1 - 4/6) * -1 = -1/3 = -0.3333...
|
||||
self.assertAlmostEqual(rewards[0], -1 / 3)
|
||||
|
||||
def test_multiple_completions(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
|
||||
completions = [
|
||||
[{"content": "this is a test"}],
|
||||
[{"content": "test test test test"}],
|
||||
]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0
|
||||
# Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25
|
||||
self.assertAlmostEqual(rewards[0], 0.0)
|
||||
self.assertAlmostEqual(rewards[1], -0.25)
|
||||
|
||||
def test_empty_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [[{"content": ""}]]
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_different_ngram_size(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0)
|
||||
completions = [[{"content": "this is a this is a test"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertAlmostEqual(rewards[0], -0.4)
|
||||
|
||||
def test_mixed_case(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [
|
||||
[{"content": "This is A Test"}],
|
||||
[{"content": "this IS a test"}],
|
||||
]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# both completions should produce the same reward, because the text gets lowercased
|
||||
self.assertAlmostEqual(rewards[0], rewards[1])
|
||||
|
||||
def test_one_word_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "word"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_two_word_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "two words"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_three_word_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "three different words"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_three_word_repetition_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "word word word word"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [-0.5])
|
||||
|
||||
def test_four_word_completion_with_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "one two one two"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1.
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_five_word_completion_with_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
|
||||
completions = [[{"content": "A B C A B"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_six_word_completion_with_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "A B C A B C"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [-0.25])
|
||||
|
||||
def test_long_completion_with_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "A B C A B C E F G A B C A B C"}]]
|
||||
rewards = reward_fn(completions)
|
||||
self.assertAlmostEqual(rewards[0], -0.3846, places=4)
|
||||
|
||||
def test_long_completion_without_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "A B C D E F G H I J K L"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_tag_count_rewards_all_correct(self):
|
||||
"""Test tag_count_reward with correct tags."""
|
||||
completion = [[{"content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer\n</answer>"}]]
|
||||
rewards = tag_count_reward(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_tag_count_rewards_missing_think_begin(self):
|
||||
"""Test tag_count_reward with missing <think> tag."""
|
||||
completion = [[{"content": "Some reasoning\n</think>\n<answer>\nThe answer\n</answer>"}]]
|
||||
rewards = tag_count_reward(completion)
|
||||
self.assertEqual(rewards[0], 0.75)
|
||||
|
||||
def test_tag_count_rewards_missing_think_end(self):
|
||||
"""Test tag_count_reward with missing </think> tag."""
|
||||
completion = [[{"content": "<think>\nSome reasoning\n<answer>\nThe answer\n</answer>"}]]
|
||||
rewards = tag_count_reward(completion)
|
||||
self.assertEqual(rewards[0], 0.75)
|
||||
|
||||
def test_tag_count_rewards_missing_answer_begin(self):
|
||||
"""Test tag_count_reward with missing <answer> tag."""
|
||||
completion = [[{"content": "<think>\nSome reasoning\n</think>\nThe answer\n</answer>"}]]
|
||||
rewards = tag_count_reward(completion)
|
||||
self.assertEqual(rewards[0], 0.75)
|
||||
|
||||
def test_tag_count_rewards_missing_answer_end(self):
|
||||
"""Test tag_count_reward with missing </answer> tag."""
|
||||
completion = [[{"content": "<think>\nSome reasoning\n</think>\n<answer>\nThe answer"}]]
|
||||
rewards = tag_count_reward(completion)
|
||||
self.assertEqual(rewards[0], 0.75)
|
||||
|
||||
def test_tag_count_rewards_missing_all_tags(self):
|
||||
"""Test tag_count_reward with missing all tags."""
|
||||
completion = [[{"content": "Some reasoning\nThe answer"}]]
|
||||
rewards = tag_count_reward(completion)
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
def test_full_repetition_with_language(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="en")
|
||||
completions = [[{"content": "that that that that that"}]]
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [-0.75])
|
||||
# begin test for zh language
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="zh")
|
||||
completions = [[{"content": "这个这个这个这个这个"}]]
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [-0.75])
|
||||
|
||||
def test_soft_overlong_punishment_short_completion(self):
|
||||
"""Test soft overlong punishment reward function with a short completion."""
|
||||
# length 50, with max=100 and soft cache=20, reward should be 0.
|
||||
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
|
||||
completion_ids = [[1] * 50] # 50 <= 80
|
||||
rewards = reward_fn(completion_ids=completion_ids)
|
||||
self.assertEqual(rewards, [0])
|
||||
|
||||
def test_soft_overlong_punishment_long_completion(self):
|
||||
"""Test soft overlong punishment reward function with a longer than max completion."""
|
||||
# 110 > 100, reward should be -1.
|
||||
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
|
||||
completion_ids = [[1] * 110]
|
||||
rewards = reward_fn(completion_ids)
|
||||
self.assertEqual(rewards, [-1])
|
||||
|
||||
def test_soft_overlong_punishment_intermediate_completion(self):
|
||||
"""Test soft overlong punishment reward function for intermediate length completion."""
|
||||
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
|
||||
completion_ids = [[1] * 90] # 90 is between 80 and 100
|
||||
rewards = reward_fn(completion_ids)
|
||||
self.assertAlmostEqual(rewards[0], -0.5, places=4)
|
||||
|
||||
|
||||
class TestCodeFormat(unittest.TestCase):
|
||||
def test_correct_python_format(self):
|
||||
"""Test code format reward with correct Python format."""
|
||||
completion = [
|
||||
[
|
||||
{
|
||||
"content": "<think>\nLet's solve this\nStep 1: First step\n</think>\n<answer>\n```python\ndef hello():\n print('world')\n```\n</answer>"
|
||||
}
|
||||
]
|
||||
]
|
||||
reward_fn = get_code_format_reward(language="python")
|
||||
rewards = reward_fn(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_incorrect_formats(self):
|
||||
"""Test code format reward with various incorrect formats."""
|
||||
incorrect_formats = [
|
||||
# Missing think/answer tags
|
||||
"```python\ndef hello():\n print('world')\n```",
|
||||
# Missing code block
|
||||
"<think>Some thinking</think><answer>Just plain text</answer>",
|
||||
# Wrong language
|
||||
"<think>Analysis</think><answer>```javascript\nconsole.log('hello');\n```</answer>",
|
||||
# Missing language identifier
|
||||
"<think>Analysis</think><answer>```\ndef hello(): pass\n```</answer>",
|
||||
# Wrong order of tags
|
||||
"<answer>```python\ndef hello(): pass\n```</answer><think>Analysis</think>",
|
||||
]
|
||||
|
||||
reward_fn = get_code_format_reward(language="python")
|
||||
for fmt in incorrect_formats:
|
||||
completion = [[{"content": fmt}]]
|
||||
rewards = reward_fn(completion)
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
def test_multiple_code_blocks(self):
|
||||
"""Test format reward with multiple code blocks in think and answer sections."""
|
||||
completion = [
|
||||
[
|
||||
{
|
||||
"content": "<think>\nHere's an example:\n```python\nx = 1\n```\nNow the solution:\n</think>\n<answer>\n```python\ndef solution():\n return 42\n```\n</answer>"
|
||||
}
|
||||
]
|
||||
]
|
||||
reward_fn = get_code_format_reward(language="python")
|
||||
rewards = reward_fn(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_different_languages(self):
|
||||
"""Test code format reward with different programming languages."""
|
||||
completion = [
|
||||
[
|
||||
{
|
||||
"content": "<think>\nAnalysis\n</think>\n<answer>\n```javascript\nconsole.log('hello');\n```\n</answer>"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
# Test with JavaScript
|
||||
js_reward_fn = get_code_format_reward(language="javascript")
|
||||
rewards = js_reward_fn(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
# Same completion should fail for Python
|
||||
py_reward_fn = get_code_format_reward(language="python")
|
||||
rewards = py_reward_fn(completion)
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
def test_multiline_code(self):
|
||||
"""Test format reward with complex multiline code blocks."""
|
||||
completion = [
|
||||
[
|
||||
{
|
||||
"content": "<think>\nHere's the analysis\n</think>\n<answer>\n```python\nclass Solution:\n def __init__(self):\n self.value = 42\n \n def get_value(self):\n return self.value\n```\n</answer>"
|
||||
}
|
||||
]
|
||||
]
|
||||
reward_fn = get_code_format_reward(language="python")
|
||||
rewards = reward_fn(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from dataclasses import asdict
|
||||
|
||||
from datasets import DatasetDict, load_dataset
|
||||
|
||||
from open_r1.configs import DatasetConfig, DatasetMixtureConfig, ScriptArguments
|
||||
from open_r1.utils.data import get_dataset
|
||||
|
||||
|
||||
class TestGetDataset(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.dataset_name = "trl-internal-testing/zen"
|
||||
cls.dataset_config = "conversational_preference"
|
||||
cls.ref_dataset = load_dataset(cls.dataset_name, cls.dataset_config)
|
||||
|
||||
def test_dataset_and_config_name(self):
|
||||
args = ScriptArguments(dataset_name=self.dataset_name, dataset_config=self.dataset_config)
|
||||
dataset = get_dataset(args)
|
||||
self.assertIsInstance(dataset, DatasetDict)
|
||||
self.assertIn("train", dataset)
|
||||
self.assertEqual(len(dataset["train"]), len(self.ref_dataset["train"]))
|
||||
|
||||
def test_unweighted_mixture(self):
|
||||
"""Mix train and test splits of the same dataset."""
|
||||
dataset_configs = [
|
||||
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="train", columns=None, weight=None),
|
||||
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="test", columns=None, weight=None),
|
||||
]
|
||||
dataset_mixture = DatasetMixtureConfig(
|
||||
datasets=dataset_configs,
|
||||
)
|
||||
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
|
||||
dataset = get_dataset(args)
|
||||
self.assertIsInstance(dataset, DatasetDict)
|
||||
self.assertIn("train", dataset)
|
||||
self.assertEqual(len(dataset["train"]), len(self.ref_dataset["train"]) + len(self.ref_dataset["test"]))
|
||||
|
||||
def test_weighted_mixture(self):
|
||||
"""Test loading a dataset mixture with weights."""
|
||||
dataset_configs = [
|
||||
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="train", columns=None, weight=0.25),
|
||||
DatasetConfig(id=self.dataset_name, config=self.dataset_config, split="test", columns=None, weight=0.5),
|
||||
]
|
||||
dataset_mixture = DatasetMixtureConfig(
|
||||
datasets=dataset_configs,
|
||||
)
|
||||
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
|
||||
dataset = get_dataset(args)
|
||||
self.assertIsInstance(dataset, DatasetDict)
|
||||
self.assertIn("train", dataset)
|
||||
self.assertEqual(
|
||||
len(dataset["train"]), len(self.ref_dataset["train"]) // 4 + len(self.ref_dataset["test"]) // 2
|
||||
)
|
||||
|
||||
def test_mixture_and_test_split(self):
|
||||
"""Test loading a dataset mixture with test split."""
|
||||
dataset_configs = [
|
||||
DatasetConfig(
|
||||
id=self.dataset_name, config=self.dataset_config, split="train[:10]", columns=None, weight=None
|
||||
),
|
||||
]
|
||||
dataset_mixture = DatasetMixtureConfig(datasets=dataset_configs, test_split_size=0.2)
|
||||
args = ScriptArguments(dataset_name=None, dataset_mixture=asdict(dataset_mixture))
|
||||
dataset = get_dataset(args)
|
||||
self.assertIsInstance(dataset, DatasetDict)
|
||||
self.assertIn("train", dataset)
|
||||
self.assertIn("test", dataset)
|
||||
self.assertEqual(len(dataset["train"]), 8)
|
||||
self.assertEqual(len(dataset["test"]), 2)
|
||||
|
||||
def test_mixture_column_selection(self):
|
||||
"""Test loading a dataset mixture with column selection."""
|
||||
dataset_configs = [
|
||||
DatasetConfig(
|
||||
id=self.dataset_name,
|
||||
config=self.dataset_config,
|
||||
split="train",
|
||||
columns=["prompt", "chosen"],
|
||||
weight=None,
|
||||
),
|
||||
]
|
||||
dataset_mixture = DatasetMixtureConfig(
|
||||
datasets=dataset_configs,
|
||||
)
|
||||
args = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
|
||||
dataset = get_dataset(args)
|
||||
self.assertIsInstance(dataset, DatasetDict)
|
||||
self.assertIn("train", dataset)
|
||||
self.assertIn("prompt", dataset["train"].column_names)
|
||||
self.assertIn("chosen", dataset["train"].column_names)
|
||||
|
||||
def test_mixture_with_mismatched_columns(self):
|
||||
dataset_configs = [
|
||||
DatasetConfig(
|
||||
id=self.dataset_name, config=self.dataset_config, split="train", columns=["prompt"], weight=None
|
||||
),
|
||||
DatasetConfig(
|
||||
id=self.dataset_name, config=self.dataset_config, split="train", columns=["chosen"], weight=None
|
||||
),
|
||||
]
|
||||
dataset_mixture = DatasetMixtureConfig(
|
||||
datasets=dataset_configs,
|
||||
)
|
||||
with self.assertRaises(ValueError) as context:
|
||||
_ = ScriptArguments(dataset_mixture=asdict(dataset_mixture))
|
||||
self.assertIn("Column names must be consistent", str(context.exception))
|
||||
|
||||
def test_no_dataset_name_or_mixture(self):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
_ = ScriptArguments(dataset_name=None, dataset_mixture=None)
|
||||
self.assertIn("Either `dataset_name` or `dataset_mixture` must be provided", str(context.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
5
token_stat
Normal file
5
token_stat
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
max length: 3678
|
||||
total user token: 52272359
|
||||
total system token: 326634875
|
||||
total answer token: 32224010
|
||||
total token: 843473151
|
||||
Loading…
Add table
Add a link
Reference in a new issue