mirror of
https://github.com/albertan017/LLM4Decompile.git
synced 2026-06-17 01:55:50 +00:00
212 lines
No EOL
6.5 KiB
Python
212 lines
No EOL
6.5 KiB
Python
import re
|
||
import json
|
||
import argparse
|
||
from multiprocessing import Pool, cpu_count
|
||
from tqdm import tqdm # ✅ 添加进度条模块
|
||
import random
|
||
|
||
import subprocess
|
||
|
||
def good_func(func):
|
||
func = '{'.join(func.split('{')[1:])
|
||
func_sp = func.split('\n')
|
||
total = 0
|
||
for line in func_sp:
|
||
if len(line.strip())>=3:
|
||
total+=1
|
||
if total>3 and total<300:
|
||
return True
|
||
return False
|
||
def strip_empty(code):
|
||
return "\n".join(line for line in code.splitlines() if line.strip())
|
||
def format_with_clang(func: str, style: str = "Google") -> str:
|
||
# Build the command
|
||
if not func:
|
||
return None
|
||
cmd = ["clang-format", f"--style={style}"]
|
||
try:
|
||
proc = subprocess.run(
|
||
cmd,
|
||
input=func,
|
||
text=True,
|
||
capture_output=True,
|
||
check=True,
|
||
timeout=0.5
|
||
)
|
||
return proc.stdout
|
||
except Exception as e:
|
||
# print(f"clang-format failed:{e}")
|
||
# print(func)
|
||
# print('-------------------------')
|
||
return None
|
||
|
||
|
||
# ----------------------------
|
||
# 1. 十六进制转十进制
|
||
# ----------------------------
|
||
def hex_to_dec(text):
|
||
pattern = re.compile(r'\b(0x[0-9a-fA-F]+)([uUlL]{1,3})?\b')
|
||
def convert(match):
|
||
hex_part = match.group(1)
|
||
suffix = match.group(2) or ""
|
||
dec_value = str(int(hex_part, 16))
|
||
return dec_value + suffix
|
||
return pattern.sub(convert, text)
|
||
|
||
|
||
# ----------------------------
|
||
# 2. 删除特定关键字
|
||
# ----------------------------
|
||
def remove_keywords(text):
|
||
patterns = [
|
||
r'\b__fastcall\b',
|
||
r'\b__cdecl\b',
|
||
r'\b__ptr32\b',
|
||
r'\b__noreturn\s+noreturn\b'
|
||
]
|
||
combined_pattern = re.compile('|'.join(patterns))
|
||
return combined_pattern.sub('', text)
|
||
|
||
|
||
# ----------------------------
|
||
# 3. 替换 typedef 类型为原始类型
|
||
# ----------------------------
|
||
typedef_map = {
|
||
"cpu_set_t": "int",
|
||
"nl_item": "int",
|
||
"__time_t": "int",
|
||
"__mode_t": "unsigned short",
|
||
"__off64_t": "long long",
|
||
"__blksize_t": "long",
|
||
"__ino_t": "unsigned long",
|
||
"__blkcnt_t": "unsigned long long",
|
||
"__syscall_slong_t": "long",
|
||
"__ssize_t": "long int",
|
||
"wchar_t": "unsigned short int",
|
||
"wctype_t": "unsigned short int",
|
||
"__int64": "long long",
|
||
"__int32": "int",
|
||
"__int16": "short",
|
||
"__int8": "char",
|
||
"_QWORD": "uint64_t",
|
||
"_OWORD": "long double",
|
||
"_DWORD": "uint32_t",
|
||
"size_t": "unsigned int",
|
||
"_BYTE": "uint8_t",
|
||
"_TBYTE": "uint16_t",
|
||
"_BOOL8": "uint8_t",
|
||
"gcc_va_list": "va_list",
|
||
"_WORD": "unsigned short",
|
||
"_BOOL4": "int",
|
||
"__va_list_tag": "va_list",
|
||
"_IO_FILE": "FILE",
|
||
"DIR": "int",
|
||
"__fsword_t": "long",
|
||
"__kernel_ulong_t": "int",
|
||
"cc_t": "int",
|
||
"speed_t": "int",
|
||
"fd_set": "int",
|
||
"__suseconds_t": "int",
|
||
"_UNKNOWN": "void",
|
||
"__sighandler_t": "void (*)(int)",
|
||
"__compar_fn_t": "int (*)(const void *, const void *)",
|
||
}
|
||
|
||
def replace_typedefs(text):
|
||
for alias, original in typedef_map.items():
|
||
pattern = re.compile(rf'\b{re.escape(alias)}\b')
|
||
text = pattern.sub(original, text)
|
||
return text
|
||
|
||
|
||
# ----------------------------
|
||
# 4. 删除注释
|
||
# ----------------------------
|
||
def remove_comments(text):
|
||
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
|
||
text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE)
|
||
return text
|
||
|
||
|
||
# ----------------------------
|
||
# 5. 单条伪代码处理
|
||
# ----------------------------
|
||
def process_code(code_str):
|
||
code_str = remove_comments(code_str)
|
||
code_str = hex_to_dec(code_str)
|
||
code_str = remove_keywords(code_str)
|
||
code_str = replace_typedefs(code_str)
|
||
return code_str
|
||
|
||
|
||
# 包装 process_code,使其接受一个 dict 并处理字段
|
||
def process_entry(entry, key_name='pseudo'):
|
||
# result = {}
|
||
|
||
# # 原始字段保留
|
||
# result['ida_pseudo'] = entry.get('ida_pseudo', '')
|
||
# result['ida_strip_pseudo'] = entry.get('ida_strip_pseudo', '')
|
||
|
||
# # 分别处理两个字段
|
||
# result['ida_pseudo_result'] = process_code(result['ida_pseudo'])
|
||
# result['ida_strip_pseudo_result'] = process_code(result['ida_strip_pseudo'])
|
||
|
||
result = process_code(entry.get(key_name, ''))
|
||
if not result.strip():
|
||
return ''
|
||
formatted = format_with_clang(result)
|
||
if formatted is None:
|
||
return None
|
||
cleaned = strip_empty(formatted)
|
||
|
||
return cleaned
|
||
|
||
# 主函数
|
||
def normalize_code_list_parallel(input_json, output_json, key_name='pseudo', num_workers=None, remove=1):
|
||
with open(input_json, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
if not isinstance(data, list):
|
||
raise ValueError("输入 JSON 应为对象数组")
|
||
|
||
num_workers = num_workers or cpu_count()
|
||
print(f"[+] 开始处理 {len(data)} 条记录,使用 {num_workers} 个进程")
|
||
|
||
from functools import partial
|
||
process_entry_key = partial(process_entry, key_name=key_name)
|
||
|
||
with Pool(processes=num_workers) as pool:
|
||
result = list(tqdm(pool.imap(process_entry_key, data), total=len(data), desc="Processing"))
|
||
|
||
data_good = []
|
||
for record, norm in zip(data, result):
|
||
if norm:
|
||
if not good_func(norm):
|
||
continue
|
||
record[f"{key_name}_norm"] = norm
|
||
data_good.append(record)
|
||
elif norm is None:
|
||
if not remove:
|
||
record[f"{key_name}_norm"] = record[f"{key_name}"]
|
||
data_good.append(record)
|
||
|
||
with open(output_json, 'w', encoding='utf-8') as f:
|
||
json.dump(data_good, f, indent=2, ensure_ascii=False)
|
||
|
||
print(f"[✓] 完成处理:{input_json}:{len(data)} → {output_json}:{len(data_good)}")
|
||
|
||
|
||
|
||
# ----------------------------
|
||
# 7. 命令行入口
|
||
# ----------------------------
|
||
if __name__ == '__main__':
|
||
parser = argparse.ArgumentParser(description="并行处理 IDA 伪代码字符串列表")
|
||
parser.add_argument('--input_json', default="exebench_format_top1p.json", help='输入 JSON 文件路径(每项为字符串)')
|
||
parser.add_argument('--output_json', default="exebench_format_pseudo_top1p.json", help='输出 JSON 文件路径')
|
||
parser.add_argument('--key_name', default="pseudo", help='输出 JSON 文件路径')
|
||
parser.add_argument('--workers', type=int, default=32, help='进程数默认使用8核心')
|
||
parser.add_argument('--remove', type=int, default=1, help='remove fail cases')
|
||
args = parser.parse_args()
|
||
|
||
normalize_code_list_parallel(args.input_json, args.output_json, args.key_name, args.workers, args.remove) |