Add files via upload

inference
This commit is contained in:
albertan017 2025-10-16 23:29:55 +08:00 committed by GitHub
commit b82f0e511e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 487 additions and 0 deletions

View file

@ -0,0 +1,96 @@
from vllm import LLM, SamplingParams
from argparse import ArgumentParser
import os
import json
from transformers import AutoTokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "true"
inputs = []
def parse_args() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("--model_path", type=str)
parser.add_argument("--gpus", type=int, default=1)
parser.add_argument("--max_num_seqs", type=int, default=1)
parser.add_argument("--gpu_memory_utilization", type=float, default=0.95)
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--max_total_tokens", type=int, default=8192)
parser.add_argument("--max_new_tokens", type=int, default=512)
parser.add_argument("--stop_sequences", type=str, default=None)
parser.add_argument("--testset_path", type=str)
parser.add_argument("--output_path", type=str, default=None)
return parser.parse_args()
# def llm_inference(inputs, args):
# llm = LLM(
# model=args.model_path,
# tensor_parallel_size=args.gpus,
# max_model_len=args.max_total_tokens,
# gpu_memory_utilization=args.gpu_memory_utilization,
# )
# sampling_params = SamplingParams(
# temperature=args.temperature,
# max_tokens=args.max_new_tokens,
# stop=args.stop_sequences,
# )
# gen_results = llm.generate(inputs, sampling_params)
# gen_results = [[output.outputs[0].text] for output in gen_results]
# return gen_results
def llm_inference(inputs,
model_path,
gpus=1,
max_total_tokens=8192,
gpu_memory_utilization=0.95,
temperature=0,
max_new_tokens=512,
stop_sequences=None):
llm = LLM(
model=model_path,
tensor_parallel_size=gpus,
max_model_len=max_total_tokens,
gpu_memory_utilization=gpu_memory_utilization,
)
sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_new_tokens,
stop=stop_sequences,
)
gen_results = llm.generate(inputs, sampling_params)
gen_results = [[output.outputs[0].text] for output in gen_results]
return gen_results
if __name__ == "__main__":
args = parse_args()
with open(args.testset_path, "r") as f:
samples = json.load(f)
before = "# This is the assembly code:\n"
after = "\n# What is the source code?\n"
for sample in samples:
prompt = before + sample["input_asm_prompt"].strip() + after
inputs.append(prompt)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if args.stop_sequences is None:
args.stop_sequences = [tokenizer.eos_token]
gen_results = llm_inference(inputs, args.model_path,
args.gpus,
args.max_total_tokens,
args.gpu_memory_utilization,
args.temperature,
args.max_new_tokens,
args.stop_sequences)
if not os.path.exists(args.output_path):
os.mkdir(args.output_path)
idx = 0
for gen_result in gen_results:
with open(args.output_path + '/' + str(idx) + '.c', 'w') as f:
f.write(gen_result[0])
idx += 1

View file

@ -0,0 +1,212 @@
import re
import json
import argparse
from multiprocessing import Pool, cpu_count
from tqdm import tqdm # ✅ 添加进度条模块
import random
import subprocess
def good_func(func):
func = '{'.join(func.split('{')[1:])
func_sp = func.split('\n')
total = 0
for line in func_sp:
if len(line.strip())>=3:
total+=1
if total>3 and total<300:
return True
return False
def strip_empty(code):
return "\n".join(line for line in code.splitlines() if line.strip())
def format_with_clang(func: str, style: str = "Google") -> str:
# Build the command
if not func:
return None
cmd = ["clang-format", f"--style={style}"]
try:
proc = subprocess.run(
cmd,
input=func,
text=True,
capture_output=True,
check=True,
timeout=0.5
)
return proc.stdout
except Exception as e:
# print(f"clang-format failed:{e}")
# print(func)
# print('-------------------------')
return None
# ----------------------------
# 1. 十六进制转十进制
# ----------------------------
def hex_to_dec(text):
pattern = re.compile(r'\b(0x[0-9a-fA-F]+)([uUlL]{1,3})?\b')
def convert(match):
hex_part = match.group(1)
suffix = match.group(2) or ""
dec_value = str(int(hex_part, 16))
return dec_value + suffix
return pattern.sub(convert, text)
# ----------------------------
# 2. 删除特定关键字
# ----------------------------
def remove_keywords(text):
patterns = [
r'\b__fastcall\b',
r'\b__cdecl\b',
r'\b__ptr32\b',
r'\b__noreturn\s+noreturn\b'
]
combined_pattern = re.compile('|'.join(patterns))
return combined_pattern.sub('', text)
# ----------------------------
# 3. 替换 typedef 类型为原始类型
# ----------------------------
typedef_map = {
"cpu_set_t": "int",
"nl_item": "int",
"__time_t": "int",
"__mode_t": "unsigned short",
"__off64_t": "long long",
"__blksize_t": "long",
"__ino_t": "unsigned long",
"__blkcnt_t": "unsigned long long",
"__syscall_slong_t": "long",
"__ssize_t": "long int",
"wchar_t": "unsigned short int",
"wctype_t": "unsigned short int",
"__int64": "long long",
"__int32": "int",
"__int16": "short",
"__int8": "char",
"_QWORD": "uint64_t",
"_OWORD": "long double",
"_DWORD": "uint32_t",
"size_t": "unsigned int",
"_BYTE": "uint8_t",
"_TBYTE": "uint16_t",
"_BOOL8": "uint8_t",
"gcc_va_list": "va_list",
"_WORD": "unsigned short",
"_BOOL4": "int",
"__va_list_tag": "va_list",
"_IO_FILE": "FILE",
"DIR": "int",
"__fsword_t": "long",
"__kernel_ulong_t": "int",
"cc_t": "int",
"speed_t": "int",
"fd_set": "int",
"__suseconds_t": "int",
"_UNKNOWN": "void",
"__sighandler_t": "void (*)(int)",
"__compar_fn_t": "int (*)(const void *, const void *)",
}
def replace_typedefs(text):
for alias, original in typedef_map.items():
pattern = re.compile(rf'\b{re.escape(alias)}\b')
text = pattern.sub(original, text)
return text
# ----------------------------
# 4. 删除注释
# ----------------------------
def remove_comments(text):
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE)
return text
# ----------------------------
# 5. 单条伪代码处理
# ----------------------------
def process_code(code_str):
code_str = remove_comments(code_str)
code_str = hex_to_dec(code_str)
code_str = remove_keywords(code_str)
code_str = replace_typedefs(code_str)
return code_str
# 包装 process_code使其接受一个 dict 并处理字段
def process_entry(entry, key_name='pseudo'):
# result = {}
# # 原始字段保留
# result['ida_pseudo'] = entry.get('ida_pseudo', '')
# result['ida_strip_pseudo'] = entry.get('ida_strip_pseudo', '')
# # 分别处理两个字段
# result['ida_pseudo_result'] = process_code(result['ida_pseudo'])
# result['ida_strip_pseudo_result'] = process_code(result['ida_strip_pseudo'])
result = process_code(entry.get(key_name, ''))
if not result.strip():
return ''
formatted = format_with_clang(result)
if formatted is None:
return None
cleaned = strip_empty(formatted)
return cleaned
# 主函数
def normalize_code_list_parallel(input_json, output_json, key_name='pseudo', num_workers=None, remove=1):
with open(input_json, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("输入 JSON 应为对象数组")
num_workers = num_workers or cpu_count()
print(f"[+] 开始处理 {len(data)} 条记录,使用 {num_workers} 个进程")
from functools import partial
process_entry_key = partial(process_entry, key_name=key_name)
with Pool(processes=num_workers) as pool:
result = list(tqdm(pool.imap(process_entry_key, data), total=len(data), desc="Processing"))
data_good = []
for record, norm in zip(data, result):
if norm:
if not good_func(norm):
continue
record[f"{key_name}_norm"] = norm
data_good.append(record)
elif norm is None:
if not remove:
record[f"{key_name}_norm"] = record[f"{key_name}"]
data_good.append(record)
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(data_good, f, indent=2, ensure_ascii=False)
print(f"[✓] 完成处理:{input_json}:{len(data)}{output_json}:{len(data_good)}")
# ----------------------------
# 7. 命令行入口
# ----------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="并行处理 IDA 伪代码字符串列表")
parser.add_argument('--input_json', default="exebench_format_top1p.json", help='输入 JSON 文件路径(每项为字符串)')
parser.add_argument('--output_json', default="exebench_format_pseudo_top1p.json", help='输出 JSON 文件路径')
parser.add_argument('--key_name', default="pseudo", help='输出 JSON 文件路径')
parser.add_argument('--workers', type=int, default=32, help='进程数默认使用8核心')
parser.add_argument('--remove', type=int, default=1, help='remove fail cases')
args = parser.parse_args()
normalize_code_list_parallel(args.input_json, args.output_json, args.key_name, args.workers, args.remove)

View file

@ -0,0 +1,179 @@
from llm_server import llm_inference
from transformers import AutoTokenizer
import json
import argparse
import shutil
import os
from tqdm import tqdm
opts = ["O0", "O1", "O2", "O3"]
current_dir = os.path.dirname(os.path.abspath(__file__))
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--model_path",type=str,default="LLM4Binary/sk2decompile-struct-6.7b")
arg_parser.add_argument("--dataset_path",type=str,default='reverse_sample.json')
arg_parser.add_argument("--decompiler",type=str,default='ida_pseudo_norm')
arg_parser.add_argument("--gpus", type=int, default=1)
arg_parser.add_argument("--max_num_seqs", type=int, default=1)
arg_parser.add_argument("--gpu_memory_utilization", type=float, default=0.8)
arg_parser.add_argument("--temperature", type=float, default=0)
arg_parser.add_argument("--max_total_tokens", type=int, default=32768)
arg_parser.add_argument("--max_new_tokens", type=int, default=4096)
arg_parser.add_argument("--stop_sequences", type=str, default=None)
arg_parser.add_argument("--recover_model_path", type=str, default='LLM4Binary/sk2decompile-ident-6.7', help="Path to the model to recover from, if any.")
arg_parser.add_argument("--output_path", type=str, default='./result/sk2decompile')
arg_parser.add_argument("--only_save", type=int, default=0)
arg_parser.add_argument("--strip", type=int, default=1)
arg_parser.add_argument("--language", type=str, default='c')
args = arg_parser.parse_args()
before = "# This is the assembly code:\n"
after = "\n# What is the source code?\n"
if args.dataset_path.endswith('.json'):
with open(args.dataset_path, "r") as f:
print("===========")
print(f"Loading dataset from {args.dataset_path}")
print("===========")
samples = json.load(f)
elif args.dataset_path.endswith('.jsonl'):
samples = []
with open(args.dataset_path, "r") as f:
for line in f:
line = line.strip()
if line:
samples.append(json.loads(line))
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if args.stop_sequences is None:
args.stop_sequences = [tokenizer.eos_token]
inputs = []
infos = []
for sample in samples:
prompt = before + sample[args.decompiler].strip() + after
sample['prompt_model1'] = prompt
inputs.append(prompt)
infos.append({
"opt": sample["opt"],
"language": sample["language"],
"index": sample["index"],
"func_name": sample["func_name"]
})
print("Starting first model inference...")
gen_results = llm_inference(inputs, args.model_path,
args.gpus,
args.max_total_tokens,
args.gpu_memory_utilization,
args.temperature,
args.max_new_tokens,
args.stop_sequences)
gen_results = [gen_result[0] for gen_result in gen_results]
for idx in range(len(gen_results)):
samples[idx]['gen_result_model1'] = gen_results[idx]
inputs_recovery = []
before_recovery = "# This is the normalized code:\n"
after_recovery = "\n# What is the source code?\n"
for idx, sample in enumerate(gen_results):
prompt_recovery = before_recovery + sample.strip() + after_recovery
samples[idx]['prompt_model2'] = prompt_recovery
inputs_recovery.append(prompt_recovery)
print("Starting recovery model inference...")
gen_results_recovery = llm_inference(inputs_recovery, args.recover_model_path,
args.gpus,
args.max_total_tokens,
args.gpu_memory_utilization,
args.temperature,
args.max_new_tokens,
args.stop_sequences)
gen_results_recovery = [gen_result[0] for gen_result in gen_results_recovery]
for idx in range(len(gen_results_recovery)):
samples[idx]['gen_result_model2'] = gen_results_recovery[idx]
if args.output_path:
if os.path.exists(args.output_path):
shutil.rmtree(args.output_path)
for opt in opts:
os.makedirs(os.path.join(args.output_path, opt))
if args.strip:
print("Processing function name stripping...")
for idx in range(len(gen_results_recovery)):
one = gen_results_recovery[idx]
func_name_in_gen = one.split('(')[0].split(' ')[-1].strip()
if func_name_in_gen.strip() and func_name_in_gen[0:2] == '**':
func_name_in_gen = func_name_in_gen[2:]
elif func_name_in_gen.strip() and func_name_in_gen[0] == '*':
func_name_in_gen = func_name_in_gen[1:]
original_func_name = samples[idx]["func_name"]
gen_results_recovery[idx] = one.replace(func_name_in_gen, original_func_name)
samples[idx]["gen_result_model2_stripped"] = gen_results_recovery[idx]
print("Saving inference results and logs...")
for idx_sample, final_result in enumerate(gen_results_recovery):
opt = infos[idx_sample]['opt']
language = infos[idx_sample]['language']
original_index = samples[idx_sample]['index']
save_path = os.path.join(args.output_path, opt, f"{original_index}_{opt}.{language}")
with open(save_path, "w") as f:
f.write(final_result)
log_path = save_path + ".log"
log_data = {
"index": original_index,
"opt": opt,
"language": language,
"func_name": samples[idx_sample]["func_name"],
"decompiler": args.decompiler,
"input_asm": samples[idx_sample][args.decompiler].strip(),
"prompt_model1": samples[idx_sample]['prompt_model1'],
"gen_result_model1": samples[idx_sample]['gen_result_model1'],
"prompt_model2": samples[idx_sample]['prompt_model2'],
"gen_result_model2": samples[idx_sample]['gen_result_model2'],
"final_result": final_result,
"stripped": args.strip
}
if args.strip and "gen_result_model2_stripped" in samples[idx_sample]:
log_data["gen_result_model2_stripped"] = samples[idx_sample]["gen_result_model2_stripped"]
with open(log_path, "w") as f:
json.dump(log_data, f, indent=2, ensure_ascii=False)
json_path = os.path.join(args.output_path, 'inference_results.jsonl')
with open(json_path, 'w') as f:
for sample in samples:
f.write(json.dumps(sample) + '\n')
stats_path = os.path.join(args.output_path, 'inference_stats.txt')
with open(stats_path, 'w') as f:
f.write(f"Total samples processed: {len(samples)}\n")
f.write(f"Model path: {args.model_path}\n")
f.write(f"Recovery model path: {args.recover_model_path}\n")
f.write(f"Dataset path: {args.dataset_path}\n")
f.write(f"Language: {args.language}\n")
f.write(f"Decompiler: {args.decompiler}\n")
f.write(f"Strip function names: {bool(args.strip)}\n")
opt_counts = {"O0": 0, "O1": 0, "O2": 0, "O3": 0}
for sample in samples:
opt_counts[sample['opt']] += 1
f.write("\nSamples per optimization level:\n")
for opt, count in opt_counts.items():
f.write(f" {opt}: {count}\n")
print(f"Inference completed! Results saved to {args.output_path}")
print(f"Total {len(samples)} samples processed.")