LLM4Decompile/sk2decompile/Preprocess/normalize_src_basedonpseudo.py
2025-10-08 18:15:57 +08:00

212 lines
No EOL
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import json
import argparse
from multiprocessing import Pool, cpu_count
from tqdm import tqdm # ✅ 添加进度条模块
import random
import subprocess
def good_func(func):
func = '{'.join(func.split('{')[1:])
func_sp = func.split('\n')
total = 0
for line in func_sp:
if len(line.strip())>=3:
total+=1
if total>3 and total<300:
return True
return False
def strip_empty(code):
return "\n".join(line for line in code.splitlines() if line.strip())
def format_with_clang(func: str, style: str = "Google") -> str:
# Build the command
if not func:
return None
cmd = ["clang-format", f"--style={style}"]
try:
proc = subprocess.run(
cmd,
input=func,
text=True,
capture_output=True,
check=True,
timeout=0.5
)
return proc.stdout
except Exception as e:
# print(f"clang-format failed:{e}")
# print(func)
# print('-------------------------')
return None
# ----------------------------
# 1. 十六进制转十进制
# ----------------------------
def hex_to_dec(text):
pattern = re.compile(r'\b(0x[0-9a-fA-F]+)([uUlL]{1,3})?\b')
def convert(match):
hex_part = match.group(1)
suffix = match.group(2) or ""
dec_value = str(int(hex_part, 16))
return dec_value + suffix
return pattern.sub(convert, text)
# ----------------------------
# 2. 删除特定关键字
# ----------------------------
def remove_keywords(text):
patterns = [
r'\b__fastcall\b',
r'\b__cdecl\b',
r'\b__ptr32\b',
r'\b__noreturn\s+noreturn\b'
]
combined_pattern = re.compile('|'.join(patterns))
return combined_pattern.sub('', text)
# ----------------------------
# 3. 替换 typedef 类型为原始类型
# ----------------------------
typedef_map = {
"cpu_set_t": "int",
"nl_item": "int",
"__time_t": "int",
"__mode_t": "unsigned short",
"__off64_t": "long long",
"__blksize_t": "long",
"__ino_t": "unsigned long",
"__blkcnt_t": "unsigned long long",
"__syscall_slong_t": "long",
"__ssize_t": "long int",
"wchar_t": "unsigned short int",
"wctype_t": "unsigned short int",
"__int64": "long long",
"__int32": "int",
"__int16": "short",
"__int8": "char",
"_QWORD": "uint64_t",
"_OWORD": "long double",
"_DWORD": "uint32_t",
"size_t": "unsigned int",
"_BYTE": "uint8_t",
"_TBYTE": "uint16_t",
"_BOOL8": "uint8_t",
"gcc_va_list": "va_list",
"_WORD": "unsigned short",
"_BOOL4": "int",
"__va_list_tag": "va_list",
"_IO_FILE": "FILE",
"DIR": "int",
"__fsword_t": "long",
"__kernel_ulong_t": "int",
"cc_t": "int",
"speed_t": "int",
"fd_set": "int",
"__suseconds_t": "int",
"_UNKNOWN": "void",
"__sighandler_t": "void (*)(int)",
"__compar_fn_t": "int (*)(const void *, const void *)",
}
def replace_typedefs(text):
for alias, original in typedef_map.items():
pattern = re.compile(rf'\b{re.escape(alias)}\b')
text = pattern.sub(original, text)
return text
# ----------------------------
# 4. 删除注释
# ----------------------------
def remove_comments(text):
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE)
return text
# ----------------------------
# 5. 单条伪代码处理
# ----------------------------
def process_code(code_str):
code_str = remove_comments(code_str)
code_str = hex_to_dec(code_str)
code_str = remove_keywords(code_str)
code_str = replace_typedefs(code_str)
return code_str
# 包装 process_code使其接受一个 dict 并处理字段
def process_entry(entry, key_name='pseudo'):
# result = {}
# # 原始字段保留
# result['ida_pseudo'] = entry.get('ida_pseudo', '')
# result['ida_strip_pseudo'] = entry.get('ida_strip_pseudo', '')
# # 分别处理两个字段
# result['ida_pseudo_result'] = process_code(result['ida_pseudo'])
# result['ida_strip_pseudo_result'] = process_code(result['ida_strip_pseudo'])
result = process_code(entry.get(key_name, ''))
if not result.strip():
return ''
formatted = format_with_clang(result)
if formatted is None:
return None
cleaned = strip_empty(formatted)
return cleaned
# 主函数
def normalize_code_list_parallel(input_json, output_json, key_name='pseudo', num_workers=None, remove=1):
with open(input_json, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("输入 JSON 应为对象数组")
num_workers = num_workers or cpu_count()
print(f"[+] 开始处理 {len(data)} 条记录,使用 {num_workers} 个进程")
from functools import partial
process_entry_key = partial(process_entry, key_name=key_name)
with Pool(processes=num_workers) as pool:
result = list(tqdm(pool.imap(process_entry_key, data), total=len(data), desc="Processing"))
data_good = []
for record, norm in zip(data, result):
if norm:
if not good_func(norm):
continue
record[f"{key_name}_norm"] = norm
data_good.append(record)
elif norm is None:
if not remove:
record[f"{key_name}_norm"] = record[f"{key_name}"]
data_good.append(record)
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(data_good, f, indent=2, ensure_ascii=False)
print(f"[✓] 完成处理:{input_json}:{len(data)}{output_json}:{len(data_good)}")
# ----------------------------
# 7. 命令行入口
# ----------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="并行处理 IDA 伪代码字符串列表")
parser.add_argument('--input_json', default="exebench_format_top1p.json", help='输入 JSON 文件路径(每项为字符串)')
parser.add_argument('--output_json', default="exebench_format_pseudo_top1p.json", help='输出 JSON 文件路径')
parser.add_argument('--key_name', default="pseudo", help='输出 JSON 文件路径')
parser.add_argument('--workers', type=int, default=32, help='进程数默认使用8核心')
parser.add_argument('--remove', type=int, default=1, help='remove fail cases')
args = parser.parse_args()
normalize_code_list_parallel(args.input_json, args.output_json, args.key_name, args.workers, args.remove)