mirror of
https://github.com/albertan017/LLM4Decompile.git
synced 2026-06-17 01:55:50 +00:00
Add files via upload
inference
This commit is contained in:
parent
cff0c7c3fc
commit
b82f0e511e
3 changed files with 487 additions and 0 deletions
96
sk2decompile/evaluation/llm_server.py
Normal file
96
sk2decompile/evaluation/llm_server.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
from vllm import LLM, SamplingParams
|
||||
from argparse import ArgumentParser
|
||||
import os
|
||||
import json
|
||||
from transformers import AutoTokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
inputs = []
|
||||
def parse_args() -> ArgumentParser:
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str)
|
||||
parser.add_argument("--gpus", type=int, default=1)
|
||||
parser.add_argument("--max_num_seqs", type=int, default=1)
|
||||
parser.add_argument("--gpu_memory_utilization", type=float, default=0.95)
|
||||
parser.add_argument("--temperature", type=float, default=0)
|
||||
parser.add_argument("--max_total_tokens", type=int, default=8192)
|
||||
parser.add_argument("--max_new_tokens", type=int, default=512)
|
||||
parser.add_argument("--stop_sequences", type=str, default=None)
|
||||
parser.add_argument("--testset_path", type=str)
|
||||
parser.add_argument("--output_path", type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
# def llm_inference(inputs, args):
|
||||
# llm = LLM(
|
||||
# model=args.model_path,
|
||||
# tensor_parallel_size=args.gpus,
|
||||
# max_model_len=args.max_total_tokens,
|
||||
# gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
# )
|
||||
|
||||
# sampling_params = SamplingParams(
|
||||
# temperature=args.temperature,
|
||||
# max_tokens=args.max_new_tokens,
|
||||
# stop=args.stop_sequences,
|
||||
# )
|
||||
|
||||
# gen_results = llm.generate(inputs, sampling_params)
|
||||
# gen_results = [[output.outputs[0].text] for output in gen_results]
|
||||
|
||||
# return gen_results
|
||||
|
||||
|
||||
def llm_inference(inputs,
|
||||
model_path,
|
||||
gpus=1,
|
||||
max_total_tokens=8192,
|
||||
gpu_memory_utilization=0.95,
|
||||
temperature=0,
|
||||
max_new_tokens=512,
|
||||
stop_sequences=None):
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=gpus,
|
||||
max_model_len=max_total_tokens,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
max_tokens=max_new_tokens,
|
||||
stop=stop_sequences,
|
||||
)
|
||||
|
||||
gen_results = llm.generate(inputs, sampling_params)
|
||||
gen_results = [[output.outputs[0].text] for output in gen_results]
|
||||
|
||||
return gen_results
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
with open(args.testset_path, "r") as f:
|
||||
samples = json.load(f)
|
||||
before = "# This is the assembly code:\n"
|
||||
after = "\n# What is the source code?\n"
|
||||
for sample in samples:
|
||||
prompt = before + sample["input_asm_prompt"].strip() + after
|
||||
inputs.append(prompt)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
if args.stop_sequences is None:
|
||||
args.stop_sequences = [tokenizer.eos_token]
|
||||
gen_results = llm_inference(inputs, args.model_path,
|
||||
args.gpus,
|
||||
args.max_total_tokens,
|
||||
args.gpu_memory_utilization,
|
||||
args.temperature,
|
||||
args.max_new_tokens,
|
||||
args.stop_sequences)
|
||||
|
||||
if not os.path.exists(args.output_path):
|
||||
os.mkdir(args.output_path)
|
||||
idx = 0
|
||||
for gen_result in gen_results:
|
||||
with open(args.output_path + '/' + str(idx) + '.c', 'w') as f:
|
||||
f.write(gen_result[0])
|
||||
idx += 1
|
||||
212
sk2decompile/evaluation/normalize_pseudo.py
Normal file
212
sk2decompile/evaluation/normalize_pseudo.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
import re
|
||||
import json
|
||||
import argparse
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from tqdm import tqdm # ✅ 添加进度条模块
|
||||
import random
|
||||
|
||||
import subprocess
|
||||
|
||||
def good_func(func):
|
||||
func = '{'.join(func.split('{')[1:])
|
||||
func_sp = func.split('\n')
|
||||
total = 0
|
||||
for line in func_sp:
|
||||
if len(line.strip())>=3:
|
||||
total+=1
|
||||
if total>3 and total<300:
|
||||
return True
|
||||
return False
|
||||
def strip_empty(code):
|
||||
return "\n".join(line for line in code.splitlines() if line.strip())
|
||||
def format_with_clang(func: str, style: str = "Google") -> str:
|
||||
# Build the command
|
||||
if not func:
|
||||
return None
|
||||
cmd = ["clang-format", f"--style={style}"]
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
input=func,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
timeout=0.5
|
||||
)
|
||||
return proc.stdout
|
||||
except Exception as e:
|
||||
# print(f"clang-format failed:{e}")
|
||||
# print(func)
|
||||
# print('-------------------------')
|
||||
return None
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 1. 十六进制转十进制
|
||||
# ----------------------------
|
||||
def hex_to_dec(text):
|
||||
pattern = re.compile(r'\b(0x[0-9a-fA-F]+)([uUlL]{1,3})?\b')
|
||||
def convert(match):
|
||||
hex_part = match.group(1)
|
||||
suffix = match.group(2) or ""
|
||||
dec_value = str(int(hex_part, 16))
|
||||
return dec_value + suffix
|
||||
return pattern.sub(convert, text)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 2. 删除特定关键字
|
||||
# ----------------------------
|
||||
def remove_keywords(text):
|
||||
patterns = [
|
||||
r'\b__fastcall\b',
|
||||
r'\b__cdecl\b',
|
||||
r'\b__ptr32\b',
|
||||
r'\b__noreturn\s+noreturn\b'
|
||||
]
|
||||
combined_pattern = re.compile('|'.join(patterns))
|
||||
return combined_pattern.sub('', text)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 3. 替换 typedef 类型为原始类型
|
||||
# ----------------------------
|
||||
typedef_map = {
|
||||
"cpu_set_t": "int",
|
||||
"nl_item": "int",
|
||||
"__time_t": "int",
|
||||
"__mode_t": "unsigned short",
|
||||
"__off64_t": "long long",
|
||||
"__blksize_t": "long",
|
||||
"__ino_t": "unsigned long",
|
||||
"__blkcnt_t": "unsigned long long",
|
||||
"__syscall_slong_t": "long",
|
||||
"__ssize_t": "long int",
|
||||
"wchar_t": "unsigned short int",
|
||||
"wctype_t": "unsigned short int",
|
||||
"__int64": "long long",
|
||||
"__int32": "int",
|
||||
"__int16": "short",
|
||||
"__int8": "char",
|
||||
"_QWORD": "uint64_t",
|
||||
"_OWORD": "long double",
|
||||
"_DWORD": "uint32_t",
|
||||
"size_t": "unsigned int",
|
||||
"_BYTE": "uint8_t",
|
||||
"_TBYTE": "uint16_t",
|
||||
"_BOOL8": "uint8_t",
|
||||
"gcc_va_list": "va_list",
|
||||
"_WORD": "unsigned short",
|
||||
"_BOOL4": "int",
|
||||
"__va_list_tag": "va_list",
|
||||
"_IO_FILE": "FILE",
|
||||
"DIR": "int",
|
||||
"__fsword_t": "long",
|
||||
"__kernel_ulong_t": "int",
|
||||
"cc_t": "int",
|
||||
"speed_t": "int",
|
||||
"fd_set": "int",
|
||||
"__suseconds_t": "int",
|
||||
"_UNKNOWN": "void",
|
||||
"__sighandler_t": "void (*)(int)",
|
||||
"__compar_fn_t": "int (*)(const void *, const void *)",
|
||||
}
|
||||
|
||||
def replace_typedefs(text):
|
||||
for alias, original in typedef_map.items():
|
||||
pattern = re.compile(rf'\b{re.escape(alias)}\b')
|
||||
text = pattern.sub(original, text)
|
||||
return text
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 4. 删除注释
|
||||
# ----------------------------
|
||||
def remove_comments(text):
|
||||
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE)
|
||||
return text
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 5. 单条伪代码处理
|
||||
# ----------------------------
|
||||
def process_code(code_str):
|
||||
code_str = remove_comments(code_str)
|
||||
code_str = hex_to_dec(code_str)
|
||||
code_str = remove_keywords(code_str)
|
||||
code_str = replace_typedefs(code_str)
|
||||
return code_str
|
||||
|
||||
|
||||
# 包装 process_code,使其接受一个 dict 并处理字段
|
||||
def process_entry(entry, key_name='pseudo'):
|
||||
# result = {}
|
||||
|
||||
# # 原始字段保留
|
||||
# result['ida_pseudo'] = entry.get('ida_pseudo', '')
|
||||
# result['ida_strip_pseudo'] = entry.get('ida_strip_pseudo', '')
|
||||
|
||||
# # 分别处理两个字段
|
||||
# result['ida_pseudo_result'] = process_code(result['ida_pseudo'])
|
||||
# result['ida_strip_pseudo_result'] = process_code(result['ida_strip_pseudo'])
|
||||
|
||||
result = process_code(entry.get(key_name, ''))
|
||||
if not result.strip():
|
||||
return ''
|
||||
formatted = format_with_clang(result)
|
||||
if formatted is None:
|
||||
return None
|
||||
cleaned = strip_empty(formatted)
|
||||
|
||||
return cleaned
|
||||
|
||||
# 主函数
|
||||
def normalize_code_list_parallel(input_json, output_json, key_name='pseudo', num_workers=None, remove=1):
|
||||
with open(input_json, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("输入 JSON 应为对象数组")
|
||||
|
||||
num_workers = num_workers or cpu_count()
|
||||
print(f"[+] 开始处理 {len(data)} 条记录,使用 {num_workers} 个进程")
|
||||
|
||||
from functools import partial
|
||||
process_entry_key = partial(process_entry, key_name=key_name)
|
||||
|
||||
with Pool(processes=num_workers) as pool:
|
||||
result = list(tqdm(pool.imap(process_entry_key, data), total=len(data), desc="Processing"))
|
||||
|
||||
data_good = []
|
||||
for record, norm in zip(data, result):
|
||||
if norm:
|
||||
if not good_func(norm):
|
||||
continue
|
||||
record[f"{key_name}_norm"] = norm
|
||||
data_good.append(record)
|
||||
elif norm is None:
|
||||
if not remove:
|
||||
record[f"{key_name}_norm"] = record[f"{key_name}"]
|
||||
data_good.append(record)
|
||||
|
||||
with open(output_json, 'w', encoding='utf-8') as f:
|
||||
json.dump(data_good, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"[✓] 完成处理:{input_json}:{len(data)} → {output_json}:{len(data_good)}")
|
||||
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 7. 命令行入口
|
||||
# ----------------------------
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="并行处理 IDA 伪代码字符串列表")
|
||||
parser.add_argument('--input_json', default="exebench_format_top1p.json", help='输入 JSON 文件路径(每项为字符串)')
|
||||
parser.add_argument('--output_json', default="exebench_format_pseudo_top1p.json", help='输出 JSON 文件路径')
|
||||
parser.add_argument('--key_name', default="pseudo", help='输出 JSON 文件路径')
|
||||
parser.add_argument('--workers', type=int, default=32, help='进程数默认使用8核心')
|
||||
parser.add_argument('--remove', type=int, default=1, help='remove fail cases')
|
||||
args = parser.parse_args()
|
||||
|
||||
normalize_code_list_parallel(args.input_json, args.output_json, args.key_name, args.workers, args.remove)
|
||||
179
sk2decompile/evaluation/sk2decompile_inf.py
Normal file
179
sk2decompile/evaluation/sk2decompile_inf.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
from llm_server import llm_inference
|
||||
from transformers import AutoTokenizer
|
||||
import json
|
||||
import argparse
|
||||
import shutil
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
opts = ["O0", "O1", "O2", "O3"]
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument("--model_path",type=str,default="LLM4Binary/sk2decompile-struct-6.7b")
|
||||
arg_parser.add_argument("--dataset_path",type=str,default='reverse_sample.json')
|
||||
arg_parser.add_argument("--decompiler",type=str,default='ida_pseudo_norm')
|
||||
arg_parser.add_argument("--gpus", type=int, default=1)
|
||||
arg_parser.add_argument("--max_num_seqs", type=int, default=1)
|
||||
arg_parser.add_argument("--gpu_memory_utilization", type=float, default=0.8)
|
||||
arg_parser.add_argument("--temperature", type=float, default=0)
|
||||
arg_parser.add_argument("--max_total_tokens", type=int, default=32768)
|
||||
arg_parser.add_argument("--max_new_tokens", type=int, default=4096)
|
||||
arg_parser.add_argument("--stop_sequences", type=str, default=None)
|
||||
arg_parser.add_argument("--recover_model_path", type=str, default='LLM4Binary/sk2decompile-ident-6.7', help="Path to the model to recover from, if any.")
|
||||
arg_parser.add_argument("--output_path", type=str, default='./result/sk2decompile')
|
||||
arg_parser.add_argument("--only_save", type=int, default=0)
|
||||
arg_parser.add_argument("--strip", type=int, default=1)
|
||||
arg_parser.add_argument("--language", type=str, default='c')
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
before = "# This is the assembly code:\n"
|
||||
after = "\n# What is the source code?\n"
|
||||
|
||||
if args.dataset_path.endswith('.json'):
|
||||
with open(args.dataset_path, "r") as f:
|
||||
print("===========")
|
||||
print(f"Loading dataset from {args.dataset_path}")
|
||||
print("===========")
|
||||
samples = json.load(f)
|
||||
elif args.dataset_path.endswith('.jsonl'):
|
||||
samples = []
|
||||
with open(args.dataset_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
samples.append(json.loads(line))
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
if args.stop_sequences is None:
|
||||
args.stop_sequences = [tokenizer.eos_token]
|
||||
|
||||
inputs = []
|
||||
infos = []
|
||||
for sample in samples:
|
||||
prompt = before + sample[args.decompiler].strip() + after
|
||||
sample['prompt_model1'] = prompt
|
||||
inputs.append(prompt)
|
||||
infos.append({
|
||||
"opt": sample["opt"],
|
||||
"language": sample["language"],
|
||||
"index": sample["index"],
|
||||
"func_name": sample["func_name"]
|
||||
})
|
||||
|
||||
|
||||
print("Starting first model inference...")
|
||||
gen_results = llm_inference(inputs, args.model_path,
|
||||
args.gpus,
|
||||
args.max_total_tokens,
|
||||
args.gpu_memory_utilization,
|
||||
args.temperature,
|
||||
args.max_new_tokens,
|
||||
args.stop_sequences)
|
||||
gen_results = [gen_result[0] for gen_result in gen_results]
|
||||
|
||||
for idx in range(len(gen_results)):
|
||||
samples[idx]['gen_result_model1'] = gen_results[idx]
|
||||
|
||||
inputs_recovery = []
|
||||
before_recovery = "# This is the normalized code:\n"
|
||||
after_recovery = "\n# What is the source code?\n"
|
||||
|
||||
for idx, sample in enumerate(gen_results):
|
||||
prompt_recovery = before_recovery + sample.strip() + after_recovery
|
||||
samples[idx]['prompt_model2'] = prompt_recovery
|
||||
inputs_recovery.append(prompt_recovery)
|
||||
|
||||
print("Starting recovery model inference...")
|
||||
gen_results_recovery = llm_inference(inputs_recovery, args.recover_model_path,
|
||||
args.gpus,
|
||||
args.max_total_tokens,
|
||||
args.gpu_memory_utilization,
|
||||
args.temperature,
|
||||
args.max_new_tokens,
|
||||
args.stop_sequences)
|
||||
gen_results_recovery = [gen_result[0] for gen_result in gen_results_recovery]
|
||||
|
||||
|
||||
for idx in range(len(gen_results_recovery)):
|
||||
samples[idx]['gen_result_model2'] = gen_results_recovery[idx]
|
||||
|
||||
if args.output_path:
|
||||
if os.path.exists(args.output_path):
|
||||
shutil.rmtree(args.output_path)
|
||||
for opt in opts:
|
||||
os.makedirs(os.path.join(args.output_path, opt))
|
||||
|
||||
if args.strip:
|
||||
print("Processing function name stripping...")
|
||||
for idx in range(len(gen_results_recovery)):
|
||||
one = gen_results_recovery[idx]
|
||||
func_name_in_gen = one.split('(')[0].split(' ')[-1].strip()
|
||||
if func_name_in_gen.strip() and func_name_in_gen[0:2] == '**':
|
||||
func_name_in_gen = func_name_in_gen[2:]
|
||||
elif func_name_in_gen.strip() and func_name_in_gen[0] == '*':
|
||||
func_name_in_gen = func_name_in_gen[1:]
|
||||
|
||||
original_func_name = samples[idx]["func_name"]
|
||||
gen_results_recovery[idx] = one.replace(func_name_in_gen, original_func_name)
|
||||
samples[idx]["gen_result_model2_stripped"] = gen_results_recovery[idx]
|
||||
|
||||
print("Saving inference results and logs...")
|
||||
for idx_sample, final_result in enumerate(gen_results_recovery):
|
||||
opt = infos[idx_sample]['opt']
|
||||
language = infos[idx_sample]['language']
|
||||
original_index = samples[idx_sample]['index']
|
||||
|
||||
save_path = os.path.join(args.output_path, opt, f"{original_index}_{opt}.{language}")
|
||||
with open(save_path, "w") as f:
|
||||
f.write(final_result)
|
||||
|
||||
log_path = save_path + ".log"
|
||||
log_data = {
|
||||
"index": original_index,
|
||||
"opt": opt,
|
||||
"language": language,
|
||||
"func_name": samples[idx_sample]["func_name"],
|
||||
"decompiler": args.decompiler,
|
||||
"input_asm": samples[idx_sample][args.decompiler].strip(),
|
||||
"prompt_model1": samples[idx_sample]['prompt_model1'],
|
||||
"gen_result_model1": samples[idx_sample]['gen_result_model1'],
|
||||
"prompt_model2": samples[idx_sample]['prompt_model2'],
|
||||
"gen_result_model2": samples[idx_sample]['gen_result_model2'],
|
||||
"final_result": final_result,
|
||||
"stripped": args.strip
|
||||
}
|
||||
|
||||
if args.strip and "gen_result_model2_stripped" in samples[idx_sample]:
|
||||
log_data["gen_result_model2_stripped"] = samples[idx_sample]["gen_result_model2_stripped"]
|
||||
|
||||
with open(log_path, "w") as f:
|
||||
json.dump(log_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
json_path = os.path.join(args.output_path, 'inference_results.jsonl')
|
||||
with open(json_path, 'w') as f:
|
||||
for sample in samples:
|
||||
f.write(json.dumps(sample) + '\n')
|
||||
|
||||
stats_path = os.path.join(args.output_path, 'inference_stats.txt')
|
||||
with open(stats_path, 'w') as f:
|
||||
f.write(f"Total samples processed: {len(samples)}\n")
|
||||
f.write(f"Model path: {args.model_path}\n")
|
||||
f.write(f"Recovery model path: {args.recover_model_path}\n")
|
||||
f.write(f"Dataset path: {args.dataset_path}\n")
|
||||
f.write(f"Language: {args.language}\n")
|
||||
f.write(f"Decompiler: {args.decompiler}\n")
|
||||
f.write(f"Strip function names: {bool(args.strip)}\n")
|
||||
|
||||
opt_counts = {"O0": 0, "O1": 0, "O2": 0, "O3": 0}
|
||||
for sample in samples:
|
||||
opt_counts[sample['opt']] += 1
|
||||
|
||||
f.write("\nSamples per optimization level:\n")
|
||||
for opt, count in opt_counts.items():
|
||||
f.write(f" {opt}: {count}\n")
|
||||
|
||||
print(f"Inference completed! Results saved to {args.output_path}")
|
||||
print(f"Total {len(samples)} samples processed.")
|
||||
Loading…
Add table
Add a link
Reference in a new issue