LLM_Foreign_Block/transformers_logit_processed.py
2025-03-11 20:16:42 +09:00

65 lines
No EOL
2.4 KiB
Python

from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList, BitsAndBytesConfig
import torch
from blocker_torch import blocker
class BlockerProcessor(LogitsProcessor):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.foreign_lang_mask = None
self.mask_indices = None
def __call__(self, input_ids, scores):
return blocker(self.tokenizer, input_ids, scores)
def inference():
model_name = "Qwen/Qwen2.5-7B-Instruct"
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/opt/models")
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir="/opt/models",
quantization_config=bnb_config,)
test_prompts = [
"너가 아는 중국어를 모두 말해줘",
"중국어로 짧은 소설을 써줘",
"'안녕'을 중국어로 뭐라고 해?",
]
foreign_processor = BlockerProcessor(tokenizer)
for i, prompt in enumerate(test_prompts):
print(f"\n============== 테스트 프롬프트: {prompt} ==================")
input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("\n--- LogitsProcessor 적용 ---")
gen_kwargs_with_processor = {
"max_length": 512,
"do_sample": True,
"temperature": 0.8,
"top_p": 0.95,
"logits_processor": LogitsProcessorList([foreign_processor]),
}
output_ids_with_processor = model.generate(input_ids, **gen_kwargs_with_processor)
generated_text_with_processor = tokenizer.decode(output_ids_with_processor[0], skip_special_tokens=True)
print(generated_text_with_processor)
print("\n--- LogitsProcessor 미적용 ---")
gen_kwargs_without_processor = {
"max_length": 512,
"do_sample": True,
"temperature": 0.8,
"top_p": 0.95,
}
output_ids_without_processor = model.generate(input_ids, **gen_kwargs_without_processor)
generated_text_without_processor = tokenizer.decode(output_ids_without_processor[0], skip_special_tokens=True)
print(generated_text_without_processor)
if __name__ == "__main__":
inference()