mirror of
https://github.com/workdd/LLM_Foreign_Block.git
synced 2026-06-17 01:49:06 +00:00
65 lines
No EOL
2.4 KiB
Python
65 lines
No EOL
2.4 KiB
Python
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList, BitsAndBytesConfig
|
|
import torch
|
|
from blocker_torch import blocker
|
|
|
|
|
|
class BlockerProcessor(LogitsProcessor):
|
|
def __init__(self, tokenizer):
|
|
self.tokenizer = tokenizer
|
|
self.foreign_lang_mask = None
|
|
self.mask_indices = None
|
|
|
|
def __call__(self, input_ids, scores):
|
|
return blocker(self.tokenizer, input_ids, scores)
|
|
|
|
|
|
def inference():
|
|
model_name = "Qwen/Qwen2.5-7B-Instruct"
|
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/opt/models")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
cache_dir="/opt/models",
|
|
quantization_config=bnb_config,)
|
|
|
|
test_prompts = [
|
|
"너가 아는 중국어를 모두 말해줘",
|
|
"중국어로 짧은 소설을 써줘",
|
|
"'안녕'을 중국어로 뭐라고 해?",
|
|
]
|
|
|
|
foreign_processor = BlockerProcessor(tokenizer)
|
|
|
|
for i, prompt in enumerate(test_prompts):
|
|
print(f"\n============== 테스트 프롬프트: {prompt} ==================")
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
|
|
|
print("\n--- LogitsProcessor 적용 ---")
|
|
gen_kwargs_with_processor = {
|
|
"max_length": 512,
|
|
"do_sample": True,
|
|
"temperature": 0.8,
|
|
"top_p": 0.95,
|
|
"logits_processor": LogitsProcessorList([foreign_processor]),
|
|
}
|
|
|
|
output_ids_with_processor = model.generate(input_ids, **gen_kwargs_with_processor)
|
|
generated_text_with_processor = tokenizer.decode(output_ids_with_processor[0], skip_special_tokens=True)
|
|
print(generated_text_with_processor)
|
|
|
|
print("\n--- LogitsProcessor 미적용 ---")
|
|
gen_kwargs_without_processor = {
|
|
"max_length": 512,
|
|
"do_sample": True,
|
|
"temperature": 0.8,
|
|
"top_p": 0.95,
|
|
}
|
|
|
|
output_ids_without_processor = model.generate(input_ids, **gen_kwargs_without_processor)
|
|
generated_text_without_processor = tokenizer.decode(output_ids_without_processor[0], skip_special_tokens=True)
|
|
print(generated_text_without_processor)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
inference() |