diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index 328d9a976..9f9341895 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -113,10 +113,6 @@ def _normalize_qwen3_dot_messages( return normalized_messages -def _chat_template_disables_thinking(base_model: str) -> bool: - return is_qwen3_dot_family_model(base_model) - - @dataclass class OpenAICompatibleTinkerServer: host: str | None = None @@ -556,19 +552,18 @@ async def prompt_tokens( ) -> list[int]: normalized_messages = _normalize_qwen3_dot_messages(base_model, messages) tokenizer = self._get_renderer(base_model).tokenizer - if _chat_template_disables_thinking(base_model): - encoding = tokenizer.apply_chat_template( - cast(Any, normalized_messages), - tools=cast(Any, tools), - add_generation_prompt=True, - enable_thinking=False, - ) - else: - encoding = tokenizer.apply_chat_template( - cast(Any, normalized_messages), - tools=cast(Any, tools), - add_generation_prompt=True, - ) + chat_template_kwargs = {} + if isinstance(tokenizer.chat_template, str): + if "enable_thinking" in tokenizer.chat_template: + chat_template_kwargs["enable_thinking"] = False + if "preserve_thinking" in tokenizer.chat_template: + chat_template_kwargs["preserve_thinking"] = True + encoding = tokenizer.apply_chat_template( + cast(Any, normalized_messages), + tools=cast(Any, tools), + add_generation_prompt=True, + **chat_template_kwargs, + ) if isinstance(encoding, BatchEncoding): return encoding.input_ids else: