mirror of
https://github.com/workdd/LLM_Foreign_Block.git
synced 2026-06-17 01:49:06 +00:00
rm blocker numpy update vllm to use blocker_torch.py
This commit is contained in:
parent
7dc1cdcb07
commit
0c8e413f9b
5 changed files with 43 additions and 117 deletions
|
|
@ -14,13 +14,12 @@ LLM 모델의 외국어 토큰 생성을 차단하는 코드 구현
|
|||
4. 결과 생성: 차단된 토큰을 제외한 나머지 토큰으로 텍스트 생성
|
||||
|
||||
## 이슈 업데이트
|
||||
- `blocker_numpy.py`는 현재 성능이 좋지 않아 torch 버전의 blocker를 사용하는 것을 추천드립니다.
|
||||
- 따라서, 기존 vllm 버전은 blocker_numpy.py를 사용하고 있는데, blocker_torch_vllm.py를 사용하시면 됩니다.
|
||||
- 기존 `blocker_numpy.py`는 추론 성능을 떨어뜨리는 이슈가 있어 삭제했습니다.
|
||||
|
||||
## 파일 구조 및 설명
|
||||
|
||||
- `blocker_numpy.py`, `blocker_torch.py`
|
||||
- NumPy 혹은, Torch Tensor 기반 외국어 토큰 차단 구현
|
||||
- `blocker_torch.py`
|
||||
- Torch Tensor 기반 외국어 토큰 차단 구현
|
||||
- 중국어, 일본어, 러시아어에 해당하는 유니코드 범위의 토큰을 식별하고 차단
|
||||
```
|
||||
chinese_ranges = [
|
||||
|
|
|
|||
|
|
@ -1,44 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
foreign_lang_mask = None
|
||||
mask_indices = None
|
||||
|
||||
|
||||
def blocker(tokenizer, input_ids, logits):
|
||||
"""중국어, 일본어, 러시아어 토큰을 차단하는 함수"""
|
||||
global foreign_lang_mask, mask_indices
|
||||
|
||||
if foreign_lang_mask is None:
|
||||
# 어휘집의 모든 토큰 ID 생성
|
||||
vocab_size = logits.shape[-1]
|
||||
token_ids = np.arange(vocab_size)
|
||||
|
||||
# vLLM에서는 batch_decode 대신 tokenizer.decode를 사용
|
||||
decoded_tokens = [tokenizer.decode([id]) for id in token_ids]
|
||||
|
||||
# 마스킹할 문자 범위 정의
|
||||
chinese_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
]
|
||||
|
||||
japanese_ranges = [(0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)] # Hiragana # Katakana # Katakana Phonetic Extensions
|
||||
|
||||
russian_ranges = [(0x0400, 0x04FF), (0x0500, 0x052F)] # Cyrillic # Cyrillic Supplement
|
||||
|
||||
all_ranges = chinese_ranges + japanese_ranges + russian_ranges
|
||||
|
||||
# 해당 언어 문자 범위에 해당하는 토큰을 마스킹
|
||||
foreign_lang_mask = np.array([any(any(start <= ord(c) <= end for start, end in all_ranges) for c in token if c) for token in decoded_tokens])
|
||||
|
||||
# 차단할 인덱스 저장
|
||||
mask_indices = np.where(foreign_lang_mask)[0]
|
||||
|
||||
# NumPy 배열로 처리
|
||||
logits_max_idx = min(logits.shape[0], np.max(mask_indices) + 1 if len(mask_indices) > 0 else 0)
|
||||
valid_indices = mask_indices[mask_indices < logits_max_idx]
|
||||
logits[valid_indices] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
|
@ -31,6 +31,9 @@ def blocker(tokenizer, input_ids, logits):
|
|||
logits.device
|
||||
)
|
||||
|
||||
# 해당 토큰에 대한 로짓을 -inf로 설정
|
||||
logits[:, foreign_lang_mask] = -float("inf")
|
||||
# 해당 토큰에 대한 로짓을 -inf로 설정 (차원 확인)
|
||||
if logits.dim() == 1:
|
||||
logits[foreign_lang_mask] = -float("inf")
|
||||
else:
|
||||
logits[:, foreign_lang_mask] = -float("inf")
|
||||
return logits
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
|
||||
foreign_lang_mask = None
|
||||
|
||||
def blocker(tokenizer, input_ids, logits):
|
||||
global foreign_lang_mask
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# logits: shape (batch_size, vocab_size)
|
||||
# -> numpy일 경우 강제 변환
|
||||
if isinstance(logits, np.ndarray):
|
||||
logits = torch.from_numpy(logits).to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
|
||||
if foreign_lang_mask is None:
|
||||
token_ids = list(range(vocab_size))
|
||||
decoded_tokens = [tokenizer.decode([i]) for i in token_ids]
|
||||
|
||||
# 마스킹할 문자 범위 정의
|
||||
def is_foreign(token):
|
||||
ranges = [
|
||||
(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0x20000, 0x2A6DF), (0xF900, 0xFAFF), # Chinese
|
||||
(0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF), # Japanese
|
||||
(0x0400, 0x04FF), (0x0500, 0x052F), # Russian
|
||||
]
|
||||
return any(any(start <= ord(c) <= end for start, end in ranges) for c in token if c)
|
||||
|
||||
mask_list = [is_foreign(token) for token in decoded_tokens]
|
||||
foreign_lang_mask = torch.tensor(mask_list, dtype=torch.bool, device=logits.device)
|
||||
|
||||
logits[foreign_lang_mask] = float("-inf")
|
||||
|
||||
print("logit 처리 시간:", time.time() - start_time)
|
||||
return logits
|
||||
|
|
@ -1,12 +1,24 @@
|
|||
import numpy as np
|
||||
from vllm import LLM, SamplingParams
|
||||
from blocker_numpy import blocker
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.outputs import RequestOutput
|
||||
from blocker_torch import blocker
|
||||
|
||||
|
||||
class BlockerProcessor:
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
self.foreign_lang_mask = None
|
||||
self.mask_indices = None
|
||||
|
||||
def __call__(self, input_ids, logits):
|
||||
return blocker(self.tokenizer, input_ids, logits)
|
||||
|
||||
|
||||
def inference():
|
||||
model_name = "Qwen/Qwen2.5-7B-Instruct-AWQ"
|
||||
llm = LLM(model=model_name)
|
||||
model_name = "Qwen/Qwen2.5-14B-Instruct"
|
||||
|
||||
llm = LLM(model=model_name, download_dir="/opt/models")
|
||||
tokenizer = llm.get_tokenizer()
|
||||
|
||||
test_prompts = [
|
||||
|
|
@ -15,39 +27,33 @@ def inference():
|
|||
"'안녕'을 중국어로 뭐라고 해?",
|
||||
]
|
||||
|
||||
def logits_processor_wrapper(input_ids, logits):
|
||||
return blocker(tokenizer, input_ids, logits)
|
||||
foreign_processor = BlockerProcessor(tokenizer)
|
||||
|
||||
# LogitsProcessor를 적용한 샘플링 파라미터
|
||||
sampling_params_with_processor = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=512,
|
||||
logits_processors=[logits_processor_wrapper]
|
||||
)
|
||||
sampling_params_with_processor = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512, logits_processors=[foreign_processor])
|
||||
|
||||
# LogitsProcessor를 적용하지 않은 샘플링 파라미터
|
||||
sampling_params_without_processor = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=512
|
||||
)
|
||||
sampling_params_without_processor = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512)
|
||||
|
||||
# 추론 실행 (LogitsProcessor 적용)
|
||||
outputs_with_processor = llm.generate(test_prompts, sampling_params_with_processor)
|
||||
|
||||
# 추론 실행 (LogitsProcessor 미적용)
|
||||
outputs_without_processor = llm.generate(test_prompts, sampling_params_without_processor)
|
||||
|
||||
# 결과 출력
|
||||
for i, (output_with, output_without) in enumerate(zip(outputs_with_processor, outputs_without_processor)):
|
||||
prompt = output_with.prompt
|
||||
generated_text_with = output_with.outputs[0].text
|
||||
generated_text_without = output_without.outputs[0].text
|
||||
|
||||
for i, prompt in enumerate(test_prompts):
|
||||
print(f"\n============== 테스트 프롬프트: {prompt} ==================")
|
||||
|
||||
print("\n--- LogitsProcessor 적용 ---")
|
||||
outputs_with_processor = llm.generate([prompt], sampling_params_with_processor)
|
||||
for output in outputs_with_processor:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"생성된 텍스트: {generated_text}")
|
||||
|
||||
print(generated_text_with)
|
||||
|
||||
print("\n--- LogitsProcessor 미적용 ---")
|
||||
outputs_without_processor = llm.generate([prompt], sampling_params_without_processor)
|
||||
for output in outputs_without_processor:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"생성된 텍스트: {generated_text}")
|
||||
print(generated_text_without)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference()
|
||||
inference()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue