diff --git a/mlx/mlx_engine.py b/mlx/mlx_engine.py index cd67b24..5755e52 100644 --- a/mlx/mlx_engine.py +++ b/mlx/mlx_engine.py @@ -12,6 +12,7 @@ import argparse import json +import re import sys import os import time @@ -26,6 +27,8 @@ "35b": "mlx-community/Qwen3.5-35B-A3B-4bit", } +STOP_STRINGS = ["\x3c|endoftext|\x3e", "\x3c|im_end|\x3e", "\x3c|im_start|\x3e"] + # Global state model = None tokenizer = None @@ -52,11 +55,22 @@ def load_model(model_key="9b"): return model, tokenizer -def generate(messages, max_tokens=2000, temperature=0.7, stream=False): - """Generate a response from the model.""" +def clean_response(text): + """Strip special tokens and thinking tags from response.""" + for stop in STOP_STRINGS: + if stop in text: + text = text[:text.index(stop)] + + if "\x3cthink\x3e" in text: + text = re.sub(r'\x3cthink\x3e.*?\x3c/think\x3e', '', text, flags=re.DOTALL) + + return text.strip() + + +def generate(messages, max_tokens=2000, temperature=0.7): + """Generate a complete response (non-streaming).""" from mlx_lm import generate as mlx_generate - # Format messages into prompt prompt = format_chat(messages) t0 = time.time() @@ -66,18 +80,7 @@ def generate(messages, max_tokens=2000, temperature=0.7, stream=False): ) elapsed = time.time() - t0 - # Strip special tokens - for stop in ["<|endoftext|>", "<|im_end|>", "<|im_start|>"]: - if stop in response: - response = response[:response.index(stop)] - - # Strip thinking tags — extract content after - import re - if "" in response: - # Remove everything between and - response = re.sub(r'.*?', '', response, flags=re.DOTALL) - response = response.strip() - + response = clean_response(response) tokens = len(tokenizer.encode(response)) if response else 0 speed = tokens / elapsed if elapsed > 0 else 0 @@ -89,21 +92,82 @@ def generate(messages, max_tokens=2000, temperature=0.7, stream=False): } +def generate_stream(messages, max_tokens=2000, temperature=0.7): + """Stream tokens one at a time using mlx_lm.stream_generate.""" + from mlx_lm import stream_generate + + prompt = format_chat(messages) + + # Track state for filtering thinking tags + in_think = False + think_done = False + buffer = "" + + for resp in stream_generate( + model, tokenizer, prompt=prompt, + max_tokens=max_tokens, + ): + text = resp.text + token_id = resp.token + finish = resp.finish_reason + + # Skip empty text + if not text: + if finish: + yield "", finish, resp + continue + + # Filter out thinking content + if not think_done: + buffer += text + # Check if we've seen the end of thinking + if "\x3c/think\x3e" in buffer: + # Extract content after + after = buffer.split("\x3c/think\x3e", 1)[-1] + think_done = True + buffer = "" + if after.strip(): + yield after, finish, resp + continue + # Still in thinking region, don't yield + if "\x3cthink\x3e" in buffer or in_think: + in_think = True + continue + # No thinking tags seen, yield normally + think_done = True + yield buffer, finish, resp + buffer = "" + continue + + # Check for stop tokens + skip = False + for stop in STOP_STRINGS: + if stop in text: + text = text[:text.index(stop)] + skip = True + + if text: + yield text, finish, resp + + if skip or finish: + yield "", finish or "stop", resp + return + + def format_chat(messages): """Format chat messages into a prompt string.""" - # Use Qwen chat template parts = [] for msg in messages: role = msg["role"] content = msg["content"] if role == "system": - parts.append(f"<|im_start|>system\n{content}<|im_end|>") + parts.append(f"\x3c|im_start|\x3esystem\n{content}\x3c|im_end|\x3e") elif role == "user": - parts.append(f"<|im_start|>user\n{content}<|im_end|>") + parts.append(f"\x3c|im_start|\x3euser\n{content}\x3c|im_end|\x3e") elif role == "assistant": - parts.append(f"<|im_start|>assistant\n{content}<|im_end|>") + parts.append(f"\x3c|im_start|\x3eassistant\n{content}\x3c|im_end|\x3e") # Add empty thinking block to skip reasoning mode - parts.append("<|im_start|>assistant\n\n\n\n\n") + parts.append("\x3c|im_start|\x3eassistant\n\x3cthink\x3e\n\n\x3c/think\x3e\n\n") return "\n".join(parts) @@ -190,7 +254,7 @@ def do_GET(self): self._send_json({"status": "ok", "model": model_name}) elif path == "/props": self._send_json({ - "model_alias": f"Qwen3.5-{model_name}-MLX", + "model_alias": f"Qwen3.5-{model_name.upper()}-MLX", "model_path": MODELS.get(model_name, ""), }) elif path == "/v1/context/list": @@ -250,7 +314,15 @@ def _handle_chat(self): messages = body.get("messages", []) max_tokens = body.get("max_tokens", 2000) temperature = body.get("temperature", 0.7) + stream = body.get("stream", False) + + if stream: + self._handle_chat_stream(messages, max_tokens, temperature) + else: + self._handle_chat_normal(messages, max_tokens, temperature) + def _handle_chat_normal(self, messages, max_tokens, temperature): + """Non-streaming: return complete response as single JSON.""" try: result = generate(messages, max_tokens, temperature) @@ -276,6 +348,48 @@ def _handle_chat(self): except Exception as e: self._send_json({"error": {"message": str(e)}}, status=500) + def _handle_chat_stream(self, messages, max_tokens, temperature): + """Streaming: return SSE (Server-Sent Events) format for agent.py.""" + try: + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.send_header("Connection", "keep-alive") + self.send_header("Access-Control-Allow-Origin", "*") + self.end_headers() + + for text, finish, resp in generate_stream(messages, max_tokens, temperature): + chunk = { + "choices": [{ + "delta": {}, + "finish_reason": finish, + }] + } + + if text: + chunk["choices"][0]["delta"]["content"] = text + + if finish: + chunk["choices"][0]["delta"] = {} + chunk["choices"][0]["finish_reason"] = finish + + line = f"data: {json.dumps(chunk)}\n\n" + self.wfile.write(line.encode()) + self.wfile.flush() + + # Send [DONE] marker + self.wfile.write(b"data: [DONE]\n\n") + self.wfile.flush() + + except Exception as e: + try: + error_chunk = f"data: {json.dumps({'error': str(e)})}\n\n" + self.wfile.write(error_chunk.encode()) + self.wfile.write(b"data: [DONE]\n\n") + self.wfile.flush() + except Exception: + pass + def _send_json(self, data, status=200): self.send_response(status) self.send_header("Content-Type", "application/json") @@ -298,7 +412,7 @@ def main(): parser.add_argument("--load-context", help="Load KV cache before serving") args = parser.parse_args() - print(f"\n 🍎 mac code MLX engine") + print(f"\n \U0001f34e mac code MLX engine") print(f" Model: {MODELS[args.model]}") print(f" Port: {args.port}") print() @@ -319,6 +433,7 @@ def main(): # Start server print(f" Server: http://localhost:{args.port}") print(f" KV cache: persistent context enabled") + print(f" Streaming: enabled") print() server = HTTPServer(("127.0.0.1", args.port), APIHandler)