mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-06-17 02:27:23 +00:00
remove rendering plugin
This commit is contained in:
parent
21ff8599c1
commit
28126bcc8f
5 changed files with 3 additions and 236 deletions
|
|
@ -22,8 +22,8 @@ sibling modules:
|
|||
- ``markers`` -- per-model assistant role markers (explicit whitelist)
|
||||
- ``collation`` -- batch padding/truncation/MM alignment (consumed by the batch generators)
|
||||
|
||||
Per-template steps can be customized by registering an override via ``RenderingPlugin``
|
||||
(see ``plugins/model_plugins/rendering.py``) and constructing ``Renderer(processor, name=...)``.
|
||||
To support a new model, add its assistant-role markers to ``markers._ASSISTANT_MARKERS``; the
|
||||
built-in ``_render_messages`` / ``_parse_message`` then handle it via the model's own chat template.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -179,9 +179,8 @@ def _parse_message(generated_text: str) -> Message:
|
|||
|
||||
|
||||
class Renderer:
|
||||
def __init__(self, processor: Processor, config=None, name: str | None = None):
|
||||
def __init__(self, processor: Processor, config=None):
|
||||
self.processor = processor
|
||||
self.name = name
|
||||
|
||||
# Resolve the assistant role markers from the explicit per-model whitelist (no probing),
|
||||
# then encode them with this model's tokenizer to get the token-id forms used for labeling.
|
||||
|
|
@ -194,18 +193,6 @@ class Renderer:
|
|||
if not self._assistant_start_ids or not self._assistant_end_ids:
|
||||
raise ValueError(f"Empty assistant marker ids for model_type {model_type!r}.")
|
||||
|
||||
def _override(self, method_name: str):
|
||||
"""Return a registered plugin override for ``method_name``, or ``None``.
|
||||
|
||||
Imported lazily to avoid a core->plugins import cycle at module load.
|
||||
"""
|
||||
if self.name is None:
|
||||
return None
|
||||
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.name).get(method_name)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[Message],
|
||||
|
|
@ -224,16 +211,6 @@ class Renderer:
|
|||
Returns:
|
||||
ModelInput with input_ids, attention_mask, labels, and loss_weights.
|
||||
"""
|
||||
override = self._override("render_messages")
|
||||
if override is not None:
|
||||
return override(
|
||||
self.processor,
|
||||
messages,
|
||||
tools=tools,
|
||||
is_generate=is_generate,
|
||||
enable_thinking=enable_thinking,
|
||||
)
|
||||
|
||||
return _render_messages(
|
||||
self.processor,
|
||||
messages,
|
||||
|
|
@ -253,10 +230,6 @@ class Renderer:
|
|||
Returns:
|
||||
Parsed Message with typed content blocks.
|
||||
"""
|
||||
override = self._override("parse_message")
|
||||
if override is not None:
|
||||
return override(generated_text)
|
||||
|
||||
return _parse_message(generated_text)
|
||||
|
||||
def get_dummy_media_fragment(self, modality: str) -> dict:
|
||||
|
|
|
|||
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import Message
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RenderingPlugin(BasePlugin):
|
||||
"""Override hook for the built-in :class:`~llamafactory.v1.core.rendering.Renderer`.
|
||||
|
||||
The default rendering path (``render_messages`` / ``parse_message``) lives in the
|
||||
core ``Renderer`` and is used as-is when nothing is registered here. To customize a
|
||||
step for a given template, register a replacement in source code::
|
||||
|
||||
@RenderingPlugin("my_template").register("render_messages")
|
||||
def render_my_template(processor, messages, tools=None, *, is_generate=False, enable_thinking=False):
|
||||
...
|
||||
return ModelInput(...)
|
||||
|
||||
and construct the renderer with that name (``Renderer(processor, name="my_template")``).
|
||||
Methods left unregistered for a name fall back to the built-in default, so a template
|
||||
may override only ``parse_message`` and still use the default ``render_messages``.
|
||||
"""
|
||||
|
||||
_attempted_template_imports: set[str] = set()
|
||||
|
||||
def _ensure_template_imported(self) -> None:
|
||||
if self.name is None or self.name in self._attempted_template_imports:
|
||||
return
|
||||
|
||||
full_module_name = f"{__package__}.templates.{self.name}"
|
||||
self._attempted_template_imports.add(self.name)
|
||||
try:
|
||||
importlib.import_module(full_module_name)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
|
||||
|
||||
def __getitem__(self, method_name: str):
|
||||
self._ensure_template_imported()
|
||||
return super().__getitem__(method_name)
|
||||
|
||||
def get(self, method_name: str) -> Callable | None:
|
||||
"""Return the registered override for ``method_name``, or ``None`` if there is none.
|
||||
|
||||
Unlike ``__getitem__`` this never raises, so the caller can cleanly fall back to
|
||||
the built-in default when no custom implementation is registered.
|
||||
"""
|
||||
self._ensure_template_imported()
|
||||
if self.name is None:
|
||||
return None
|
||||
return self._registry[self.name].get(method_name)
|
||||
|
||||
def render_messages(self, *args, **kwargs):
|
||||
"""Render messages using a template-specific renderer."""
|
||||
return self["render_messages"](*args, **kwargs)
|
||||
|
||||
def parse_message(self, generated_text: str) -> Message:
|
||||
"""Parse generated text using a model-specific parser."""
|
||||
return self["parse_message"](generated_text)
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ....utils.types import Message
|
||||
from ..rendering import RenderingPlugin
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3").register("parse_message")
|
||||
def parse_qwen3_message(generated_text: str) -> Message:
|
||||
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text: The generated text in the Qwen3 template format.
|
||||
|
||||
Returns:
|
||||
The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "think":
|
||||
content.append({"type": "reasoning", "value": tag_value})
|
||||
elif tag_type == "tool_call":
|
||||
json.loads(tag_value)
|
||||
content.append({"type": "tool_call", "value": tag_value})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ....utils.types import Message
|
||||
from ..rendering import RenderingPlugin
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
||||
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text: The generated text in the Qwen3 nothink template format.
|
||||
|
||||
Returns:
|
||||
The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value})
|
||||
elif tag_type == "tool_call":
|
||||
json.loads(tag_value)
|
||||
content.append({"type": "tool_call", "value": tag_value})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
Loading…
Add table
Add a link
Reference in a new issue