mirror of
https://github.com/huggingface/open-r1.git
synced 2026-06-24 01:54:06 +00:00
Compare commits
37 commits
main
...
agent-trac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2fbac03f07 | ||
|
|
3b021de966 | ||
|
|
8cc3983a7a | ||
|
|
c1cea15129 | ||
|
|
5d7205d888 | ||
|
|
91e4dc16dc | ||
|
|
ef3f888a99 | ||
|
|
9cdf0d92dc | ||
|
|
b6de9cbc59 | ||
|
|
e245aa0335 | ||
|
|
6a9db1b584 | ||
|
|
ce7d8bdc45 | ||
|
|
5ed2005cb1 | ||
|
|
28afbef24c | ||
|
|
7bcb96e699 | ||
|
|
2e7d1dad0f | ||
|
|
64ae55198f | ||
|
|
884c8e94f2 | ||
|
|
0adc082393 | ||
|
|
52ac4e2fc2 | ||
|
|
23c2128d20 | ||
|
|
f6f138bb61 | ||
|
|
b2996c1bae | ||
|
|
b402450b78 | ||
|
|
6df61613bb | ||
|
|
c8aa2c4c27 | ||
|
|
69d55f6226 | ||
|
|
319ae52c1d | ||
|
|
28bc464568 | ||
|
|
f35337e681 | ||
|
|
1a7becfeb6 | ||
|
|
7a1fb98c8b | ||
|
|
7d9fc6e483 | ||
|
|
38bfa931ae | ||
|
|
a6f5a15129 | ||
|
|
6d0963ebda | ||
|
|
e7df0369d0 |
6 changed files with 1638 additions and 0 deletions
63
agentic-traces/README.md
Normal file
63
agentic-traces/README.md
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
# Generate agent traces
|
||||
|
||||
## Step 1: Install (setup the environment)
|
||||
|
||||
```bash
|
||||
make install
|
||||
```
|
||||
|
||||
```bash
|
||||
source openr1/bin/activate
|
||||
uv pip install -e ".[smolagents,jupyter]"
|
||||
```
|
||||
|
||||
## Step 2: Start the R1 server
|
||||
|
||||
for the `serve_r1.slurm` file do not forget to add the router address
|
||||
|
||||
```bash
|
||||
sbatch slurm/serve_router.slurm
|
||||
sbatch slurm/serve_r1.slurm
|
||||
```
|
||||
|
||||
## Step 3: Generate traces
|
||||
|
||||
This takes ~3 days to complete.
|
||||
|
||||
```bash
|
||||
sbatch slurm/agentic_generation.slurm
|
||||
```
|
||||
|
||||
## Step 4: Process the traces and upload dataset to the hub
|
||||
|
||||
This is done in a jupyter notebook for ease of use during development.
|
||||
|
||||
Follow the instructions in eda.ipynb to process the traces into a training dataset.
|
||||
The notebook filters the failed generation traces then it upload the dataset to the hub for later use.
|
||||
|
||||
**TODO:**
|
||||
- filter the traces to keep traces that pass the test cases
|
||||
- filter by length of the generation, so traces that converge quickly are favoured.
|
||||
|
||||
**Remarks:**
|
||||
Right now, the `generate_agent_traces.py` file seems to be buggy, it does not generate a single correct trace.By correct, I mean a trace that passes the test cases.
|
||||
|
||||
The dataset can be found at https://huggingface.co/datasets/baptistecolle/codeforces-agentic-generations
|
||||
|
||||
## Step 5: Train on the traces and upload the model to the hub
|
||||
|
||||
```bash
|
||||
sbatch --nodes=1 --time=8:00:00 slurm/train.slurm Qwen2.5-1.5B-Instruct sft demo_agentic_trace zero3 '--per_device_train_batch_size=1 --num_train_epochs=5'
|
||||
```
|
||||
|
||||
The trainedmodel can be found at https://huggingface.co/baptistecolle/Qwen2.5-1.5B-Open-R1-Distill-Agentic-Trace
|
||||
|
||||
## Step 6: Test the model
|
||||
first need to fix the generate_agent_traces.py file before testing the model I believe (see: `generate_agent_traces.py` file is not working)
|
||||
**TODO:** create some custom metrics in lighteval for the agentic traces.
|
||||
|
||||
# TODOs:
|
||||
- **The `generate_agent_traces.py` file is not working**: most of the generation of the traces fails, and furthermore based on the eda (exploratory data analysis) none of the generated traces acutally pass the test cases, indeed almost all traces end with `Error:\\nReached max steps.` so none of the generated traces actually solve the test cases
|
||||
|
||||
# Current status
|
||||
- The pipeline is present, now we just need to debug it to increase performance.
|
||||
1201
agentic-traces/eda.ipynb
Normal file
1201
agentic-traces/eda.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,44 @@
|
|||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: baptistecolle/codeforces-agentic-generations
|
||||
dataset_num_proc: 48
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'no'
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: Qwen2.5-1.5B-Open-R1-Distill-Agentic-Trace
|
||||
hub_strategy: every_save
|
||||
learning_rate: 5.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
packing: true
|
||||
max_length: 16384
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-1.5B-Open-R1-Distill-Agentic-Trace
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 16
|
||||
per_device_train_batch_size: 16
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "steps"
|
||||
save_steps: 100
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
use_liger: true
|
||||
warmup_ratio: 0.05
|
||||
302
scripts/generate_agent_traces.py
Normal file
302
scripts/generate_agent_traces.py
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
import argparse
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import Set, Any, List
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import requests.adapters
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from smolagents import CodeAgent, Tool
|
||||
from smolagents.models import get_clean_message_list
|
||||
|
||||
file_lock = Lock()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")
|
||||
|
||||
print("Launching generation")
|
||||
class ModifiedFinalAnswerTool(Tool):
|
||||
name = "final_answer"
|
||||
description = "Provides a final answer to the given problem."
|
||||
inputs = {'answer_function': {'type': 'any', 'description': 'The final function that solves the problem'}}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, answer_function: Any) -> str:
|
||||
source_code = inspect.getsource(answer_function)
|
||||
print("USING MODIFIED FINAL ANSWER TOOL, got source code:\n", source_code)
|
||||
return source_code
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_initialized = False
|
||||
|
||||
class ChatMessage:
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
def generate_completion_from_messages(session, messages, args, stop_sequences) -> str:
|
||||
retry_budget = 10
|
||||
while retry_budget > 0:
|
||||
try:
|
||||
formatted_chat = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
print("Input token count:", len(tokenizer.encode(formatted_chat)))
|
||||
# Add a small random delay to prevent overwhelming the API
|
||||
time.sleep(random.uniform(0.0, 0.1))
|
||||
response = session.post(
|
||||
f"http://{args.api_addr}/v1/chat/completions",
|
||||
json={
|
||||
"model": "default",
|
||||
"messages": messages,
|
||||
"max_tokens": args.max_tokens,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"stop": stop_sequences,
|
||||
},
|
||||
headers={"Authorization": "Bearer EMPTY"},
|
||||
timeout=2*60*60
|
||||
)
|
||||
|
||||
# Check status code and log error content if needed
|
||||
if response.status_code >= 400:
|
||||
print(f"HTTP Error {response.status_code}: {response.reason}")
|
||||
print(f"Response content: {response.text}")
|
||||
traceback.print_exc()
|
||||
retry_budget -= 1
|
||||
time.sleep(20)
|
||||
continue
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
output = response.json()["choices"][0]["message"]["content"]
|
||||
return output
|
||||
except ValueError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
print(f"Response content: {response.text}")
|
||||
traceback.print_exc()
|
||||
retry_budget -= 1
|
||||
time.sleep(20)
|
||||
continue
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"API request error (will retry): {e}")
|
||||
traceback.print_exc()
|
||||
retry_budget -= 1
|
||||
time.sleep(20)
|
||||
|
||||
raise Exception("Failed to get a valid response after multiple retries")
|
||||
|
||||
def get_agent_run(session, task, args):
|
||||
def model(messages, stop_sequences = None):
|
||||
cleaned_messages = get_clean_message_list(messages, {"system": "user", "tool-call": "assistant", "tool-response": "user"}, flatten_messages_as_text=True)
|
||||
result = generate_completion_from_messages(session, cleaned_messages, args, stop_sequences)
|
||||
return ChatMessage(content=result)
|
||||
|
||||
agent = CodeAgent(
|
||||
model=model,
|
||||
tools=[ModifiedFinalAnswerTool()],
|
||||
additional_authorized_imports=["sympy", "numpy", "math"],
|
||||
max_steps=10,
|
||||
verbosity_level=2
|
||||
)
|
||||
|
||||
try:
|
||||
output = agent.run(task)
|
||||
return agent.write_memory_to_messages(), output
|
||||
except Exception as e:
|
||||
print(f"Error when generating agentic trace: {e}")
|
||||
return None
|
||||
|
||||
def process_example(example, session, args, output_file, pbar=None):
|
||||
prompt = f"""Here is a task to solve using a function:
|
||||
{example[args.prompt_column]}
|
||||
|
||||
Now write a function that solves the problem, test it and return it using final_answer(your_function).
|
||||
The function should take the inputs described in the task above, using them in this way: the function will be passed the 'lines' described in the task as different arguments.
|
||||
For instance:
|
||||
- if the task says 'the first line is a number, the second line is a list of numbers', your function should take two arguments like this: def your_function(n, numbers).
|
||||
- if the task says 'the first line will contain a number n, the n lines after that will be strings', your function should take flexible arguments like this: def your_function(n, *n_lines).
|
||||
Make sure to properly extract the inputs from the string arguments.
|
||||
ALWAYS RUN THE FUNCTION IN A CODE SNIPPET WITH TEST CASES BEFORE RETURNING IT.
|
||||
"""
|
||||
try:
|
||||
agent_outputs, agent_memories = [], []
|
||||
for _ in range(args.num_generations):
|
||||
agent_output, agent_memory = get_agent_run(session, prompt, args)
|
||||
agent_outputs.append(agent_output)
|
||||
agent_memories.append(agent_memory)
|
||||
|
||||
if any(agent_output is None for agent_output in agent_outputs):
|
||||
print("Error processing example")
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return None
|
||||
|
||||
finish_reasons = []
|
||||
api_metadata = []
|
||||
|
||||
for agent_run in agent_output:
|
||||
finish_reasons.append(None)
|
||||
api_metadata.append(None)
|
||||
|
||||
# Convert agent_run to a serializable format
|
||||
serializable_generations = []
|
||||
for generation in agent_memories:
|
||||
if generation is not None:
|
||||
# Convert to a simple list of dictionaries if it's not already
|
||||
if isinstance(generation, list):
|
||||
serializable_generations.append([
|
||||
{k: v for k, v in msg.items() if isinstance(v, (str, int, float, bool, type(None), list, dict))}
|
||||
for msg in generation if isinstance(msg, dict)
|
||||
])
|
||||
else:
|
||||
# Handle other formats or provide a placeholder
|
||||
serializable_generations.append(str(generation))
|
||||
else:
|
||||
serializable_generations.append(None)
|
||||
|
||||
# Combine original dataset fields with generations
|
||||
result = {
|
||||
**example, # Preserve all original dataset fields
|
||||
"generations": serializable_generations,
|
||||
"final_outputs": agent_outputs,
|
||||
"finish_reasons": finish_reasons,
|
||||
"api_metadata": api_metadata,
|
||||
}
|
||||
|
||||
# Write to file with lock
|
||||
with file_lock:
|
||||
with open(output_file, mode="a") as f:
|
||||
try:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
f.flush()
|
||||
except TypeError as e:
|
||||
print(f"JSON serialization error: {e}")
|
||||
# Fallback: store with minimal information
|
||||
fallback_result = {
|
||||
**{k: v for k, v in example.items() if isinstance(v, (str, int, float, bool, type(None), list, dict))},
|
||||
"error": f"Failed to serialize full result: {e}"
|
||||
}
|
||||
f.write(json.dumps(fallback_result) + "\n")
|
||||
f.flush()
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error processing example: {e}")
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return None
|
||||
|
||||
def load_processed_uuids(output_file, uuid_column):
|
||||
processed_uuids = set()
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, mode="r") as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return processed_uuids
|
||||
|
||||
def process_example_wrapper(args_tuple):
|
||||
example, session, args, output_file, pbar = args_tuple
|
||||
return process_example(example, session, args, output_file, pbar)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("--dataset-name", type=str, required=True)
|
||||
parser.add_argument("--output-file", type=str, required=True)
|
||||
parser.add_argument("--prompt-column", type=str, required=True)
|
||||
parser.add_argument("--uuid-column", type=str, required=True)
|
||||
parser.add_argument("--api-addr", type=str, default="localhost:39876")
|
||||
parser.add_argument("--num-generations", type=int, default=5)
|
||||
parser.add_argument("--temperature", type=float, default=0.6)
|
||||
parser.add_argument("--top-p", type=float, default=0.95)
|
||||
parser.add_argument("--max-tokens", type=int, default=8096)
|
||||
parser.add_argument("--max-concurrent", type=int, default=1000)
|
||||
args = parser.parse_args()
|
||||
|
||||
subset = ""
|
||||
# subset = "[:10]"
|
||||
seed = 42
|
||||
|
||||
dataset = load_dataset(
|
||||
"open-r1/codeforces-test-cases",
|
||||
split=f"train{subset}",
|
||||
).shuffle(seed=seed)
|
||||
dataset = dataset.filter(lambda x: x["full_test_set"])
|
||||
|
||||
processed_uuids = load_processed_uuids(args.output_file, args.uuid_column)
|
||||
if processed_uuids:
|
||||
print(f"Found {len(processed_uuids)} already processed examples, resuming from there...")
|
||||
|
||||
# Ensure the output directory exists
|
||||
output_path = Path(args.output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create the file if it doesn't exist
|
||||
if not output_path.exists():
|
||||
with open(args.output_file, mode="w") as f:
|
||||
f.write("")
|
||||
|
||||
# print(f"Processing using {args.max_concurrent} workers")
|
||||
# print(f"Using ip {args.api_addr}")
|
||||
|
||||
|
||||
# Create a session that will be shared among threads
|
||||
session = requests.Session()
|
||||
adapter = requests.adapters.HTTPAdapter(
|
||||
pool_connections=args.max_concurrent,
|
||||
pool_maxsize=args.max_concurrent,
|
||||
max_retries=3
|
||||
)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
|
||||
# Filter out already processed examples
|
||||
examples_to_process = []
|
||||
for example in dataset:
|
||||
uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()
|
||||
if uuid not in processed_uuids:
|
||||
examples_to_process.append(example)
|
||||
|
||||
print(f"Processing {len(examples_to_process)} examples with {args.max_concurrent} workers")
|
||||
|
||||
pbar = tqdm(
|
||||
total=len(examples_to_process),
|
||||
desc="Generating responses",
|
||||
unit="row",
|
||||
mininterval=2,
|
||||
smoothing=0.0001,
|
||||
)
|
||||
|
||||
# Prepare arguments for each example
|
||||
example_args = [(example, session, args, args.output_file, pbar) for example in examples_to_process]
|
||||
|
||||
# Use ThreadPoolExecutor to process examples concurrently
|
||||
with ThreadPoolExecutor(max_workers=args.max_concurrent) as executor:
|
||||
# Submit all tasks
|
||||
futures = [executor.submit(process_example_wrapper, arg) for arg in example_args]
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in futures:
|
||||
future.result() # This ensures exceptions are raised
|
||||
|
||||
pbar.close()
|
||||
print("All examples processed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
setup.py
5
setup.py
|
|
@ -70,6 +70,9 @@ _deps = [
|
|||
"trl==0.16.0",
|
||||
"vllm==0.7.2",
|
||||
"wandb>=0.19.1",
|
||||
"smolagents==1.12.0",
|
||||
"ipykernel",
|
||||
"ipywidgets",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
|
|
@ -92,6 +95,8 @@ extras["quality"] = deps_list("ruff", "isort", "flake8")
|
|||
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
|
||||
extras["eval"] = deps_list("lighteval", "math-verify")
|
||||
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["code"]
|
||||
extras["smolagents"] = deps_list("smolagents")
|
||||
extras["jupyter"] = deps_list("ipykernel", "ipywidgets")
|
||||
|
||||
# core dependencies shared across the whole project - keep this to a bare minimum :)
|
||||
install_requires = [
|
||||
|
|
|
|||
23
slurm/agentic_generation.slurm
Normal file
23
slurm/agentic_generation.slurm
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=agentic-r1
|
||||
#SBATCH --partition=hopper-cpu
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --cpus-per-task=64
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --output=./logs/%x_%j_%n.out
|
||||
#SBATCH --error=./logs/%x_%j_%n.err
|
||||
#SBATCH --time=7-00:00:00
|
||||
set -exuo pipefail
|
||||
|
||||
source ~/.bashrc
|
||||
source openr1/bin/activate
|
||||
|
||||
python scripts/generate_agent_traces.py \
|
||||
--output-file "data/codeforces_agentic_generations.jsonl" \
|
||||
--prompt-column "prompt" \
|
||||
--uuid-column "contestId" \
|
||||
--api-addr "10.53.86.164:39876" \
|
||||
--num-generations 5 \
|
||||
--max-tokens 8096 \
|
||||
--max-concurrent 64
|
||||
Loading…
Add table
Add a link
Reference in a new issue