rm blocker numpy update vllm to use blocker_torch.py

This commit is contained in:
workdd 2025-07-10 11:00:39 +09:00
commit 0c8e413f9b
5 changed files with 43 additions and 117 deletions

View file

@ -14,13 +14,12 @@ LLM 모델의 외국어 토큰 생성을 차단하는 코드 구현
4. 결과 생성: 차단된 토큰을 제외한 나머지 토큰으로 텍스트 생성
## 이슈 업데이트
- `blocker_numpy.py`는 현재 성능이 좋지 않아 torch 버전의 blocker를 사용하는 것을 추천드립니다.
- 따라서, 기존 vllm 버전은 blocker_numpy.py를 사용하고 있는데, blocker_torch_vllm.py를 사용하시면 됩니다.
- 기존 `blocker_numpy.py`는 추론 성능을 떨어뜨리는 이슈가 있어 삭제했습니다.
## 파일 구조 및 설명
- `blocker_numpy.py`, `blocker_torch.py`
- NumPy 혹은, Torch Tensor 기반 외국어 토큰 차단 구현
- `blocker_torch.py`
- Torch Tensor 기반 외국어 토큰 차단 구현
- 중국어, 일본어, 러시아어에 해당하는 유니코드 범위의 토큰을 식별하고 차단
```
chinese_ranges = [

View file

@ -1,44 +0,0 @@
import numpy as np
foreign_lang_mask = None
mask_indices = None
def blocker(tokenizer, input_ids, logits):
"""중국어, 일본어, 러시아어 토큰을 차단하는 함수"""
global foreign_lang_mask, mask_indices
if foreign_lang_mask is None:
# 어휘집의 모든 토큰 ID 생성
vocab_size = logits.shape[-1]
token_ids = np.arange(vocab_size)
# vLLM에서는 batch_decode 대신 tokenizer.decode를 사용
decoded_tokens = [tokenizer.decode([id]) for id in token_ids]
# 마스킹할 문자 범위 정의
chinese_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
]
japanese_ranges = [(0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)] # Hiragana # Katakana # Katakana Phonetic Extensions
russian_ranges = [(0x0400, 0x04FF), (0x0500, 0x052F)] # Cyrillic # Cyrillic Supplement
all_ranges = chinese_ranges + japanese_ranges + russian_ranges
# 해당 언어 문자 범위에 해당하는 토큰을 마스킹
foreign_lang_mask = np.array([any(any(start <= ord(c) <= end for start, end in all_ranges) for c in token if c) for token in decoded_tokens])
# 차단할 인덱스 저장
mask_indices = np.where(foreign_lang_mask)[0]
# NumPy 배열로 처리
logits_max_idx = min(logits.shape[0], np.max(mask_indices) + 1 if len(mask_indices) > 0 else 0)
valid_indices = mask_indices[mask_indices < logits_max_idx]
logits[valid_indices] = -float("inf")
return logits

View file

@ -31,6 +31,9 @@ def blocker(tokenizer, input_ids, logits):
logits.device
)
# 해당 토큰에 대한 로짓을 -inf로 설정
logits[:, foreign_lang_mask] = -float("inf")
# 해당 토큰에 대한 로짓을 -inf로 설정 (차원 확인)
if logits.dim() == 1:
logits[foreign_lang_mask] = -float("inf")
else:
logits[:, foreign_lang_mask] = -float("inf")
return logits

View file

@ -1,38 +0,0 @@
import numpy as np
import torch
import time
foreign_lang_mask = None
def blocker(tokenizer, input_ids, logits):
global foreign_lang_mask
start_time = time.time()
# logits: shape (batch_size, vocab_size)
# -> numpy일 경우 강제 변환
if isinstance(logits, np.ndarray):
logits = torch.from_numpy(logits).to("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = logits.shape[-1]
if foreign_lang_mask is None:
token_ids = list(range(vocab_size))
decoded_tokens = [tokenizer.decode([i]) for i in token_ids]
# 마스킹할 문자 범위 정의
def is_foreign(token):
ranges = [
(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0x20000, 0x2A6DF), (0xF900, 0xFAFF), # Chinese
(0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF), # Japanese
(0x0400, 0x04FF), (0x0500, 0x052F), # Russian
]
return any(any(start <= ord(c) <= end for start, end in ranges) for c in token if c)
mask_list = [is_foreign(token) for token in decoded_tokens]
foreign_lang_mask = torch.tensor(mask_list, dtype=torch.bool, device=logits.device)
logits[foreign_lang_mask] = float("-inf")
print("logit 처리 시간:", time.time() - start_time)
return logits

View file

@ -1,12 +1,24 @@
import numpy as np
from vllm import LLM, SamplingParams
from blocker_numpy import blocker
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.outputs import RequestOutput
from blocker_torch import blocker
class BlockerProcessor:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.foreign_lang_mask = None
self.mask_indices = None
def __call__(self, input_ids, logits):
return blocker(self.tokenizer, input_ids, logits)
def inference():
model_name = "Qwen/Qwen2.5-7B-Instruct-AWQ"
llm = LLM(model=model_name)
model_name = "Qwen/Qwen2.5-14B-Instruct"
llm = LLM(model=model_name, download_dir="/opt/models")
tokenizer = llm.get_tokenizer()
test_prompts = [
@ -15,39 +27,33 @@ def inference():
"'안녕'을 중국어로 뭐라고 해?",
]
def logits_processor_wrapper(input_ids, logits):
return blocker(tokenizer, input_ids, logits)
foreign_processor = BlockerProcessor(tokenizer)
# LogitsProcessor를 적용한 샘플링 파라미터
sampling_params_with_processor = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=512,
logits_processors=[logits_processor_wrapper]
)
sampling_params_with_processor = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512, logits_processors=[foreign_processor])
# LogitsProcessor를 적용하지 않은 샘플링 파라미터
sampling_params_without_processor = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=512
)
sampling_params_without_processor = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512)
# 추론 실행 (LogitsProcessor 적용)
outputs_with_processor = llm.generate(test_prompts, sampling_params_with_processor)
# 추론 실행 (LogitsProcessor 미적용)
outputs_without_processor = llm.generate(test_prompts, sampling_params_without_processor)
# 결과 출력
for i, (output_with, output_without) in enumerate(zip(outputs_with_processor, outputs_without_processor)):
prompt = output_with.prompt
generated_text_with = output_with.outputs[0].text
generated_text_without = output_without.outputs[0].text
for i, prompt in enumerate(test_prompts):
print(f"\n============== 테스트 프롬프트: {prompt} ==================")
print("\n--- LogitsProcessor 적용 ---")
outputs_with_processor = llm.generate([prompt], sampling_params_with_processor)
for output in outputs_with_processor:
generated_text = output.outputs[0].text
print(f"생성된 텍스트: {generated_text}")
print(generated_text_with)
print("\n--- LogitsProcessor 미적용 ---")
outputs_without_processor = llm.generate([prompt], sampling_params_without_processor)
for output in outputs_without_processor:
generated_text = output.outputs[0].text
print(f"생성된 텍스트: {generated_text}")
print(generated_text_without)
if __name__ == "__main__":
inference()
inference()