diff --git a/mlx/mlx_engine.py b/mlx/mlx_engine.py
index cd67b24..5755e52 100644
--- a/mlx/mlx_engine.py
+++ b/mlx/mlx_engine.py
@@ -12,6 +12,7 @@
import argparse
import json
+import re
import sys
import os
import time
@@ -26,6 +27,8 @@
"35b": "mlx-community/Qwen3.5-35B-A3B-4bit",
}
+STOP_STRINGS = ["\x3c|endoftext|\x3e", "\x3c|im_end|\x3e", "\x3c|im_start|\x3e"]
+
# Global state
model = None
tokenizer = None
@@ -52,11 +55,22 @@ def load_model(model_key="9b"):
return model, tokenizer
-def generate(messages, max_tokens=2000, temperature=0.7, stream=False):
- """Generate a response from the model."""
+def clean_response(text):
+ """Strip special tokens and thinking tags from response."""
+ for stop in STOP_STRINGS:
+ if stop in text:
+ text = text[:text.index(stop)]
+
+ if "\x3cthink\x3e" in text:
+ text = re.sub(r'\x3cthink\x3e.*?\x3c/think\x3e', '', text, flags=re.DOTALL)
+
+ return text.strip()
+
+
+def generate(messages, max_tokens=2000, temperature=0.7):
+ """Generate a complete response (non-streaming)."""
from mlx_lm import generate as mlx_generate
- # Format messages into prompt
prompt = format_chat(messages)
t0 = time.time()
@@ -66,18 +80,7 @@ def generate(messages, max_tokens=2000, temperature=0.7, stream=False):
)
elapsed = time.time() - t0
- # Strip special tokens
- for stop in ["<|endoftext|>", "<|im_end|>", "<|im_start|>"]:
- if stop in response:
- response = response[:response.index(stop)]
-
- # Strip thinking tags — extract content after
- import re
- if "" in response:
- # Remove everything between and
- response = re.sub(r'.*?', '', response, flags=re.DOTALL)
- response = response.strip()
-
+ response = clean_response(response)
tokens = len(tokenizer.encode(response)) if response else 0
speed = tokens / elapsed if elapsed > 0 else 0
@@ -89,21 +92,82 @@ def generate(messages, max_tokens=2000, temperature=0.7, stream=False):
}
+def generate_stream(messages, max_tokens=2000, temperature=0.7):
+ """Stream tokens one at a time using mlx_lm.stream_generate."""
+ from mlx_lm import stream_generate
+
+ prompt = format_chat(messages)
+
+ # Track state for filtering thinking tags
+ in_think = False
+ think_done = False
+ buffer = ""
+
+ for resp in stream_generate(
+ model, tokenizer, prompt=prompt,
+ max_tokens=max_tokens,
+ ):
+ text = resp.text
+ token_id = resp.token
+ finish = resp.finish_reason
+
+ # Skip empty text
+ if not text:
+ if finish:
+ yield "", finish, resp
+ continue
+
+ # Filter out thinking content
+ if not think_done:
+ buffer += text
+ # Check if we've seen the end of thinking
+ if "\x3c/think\x3e" in buffer:
+ # Extract content after
+ after = buffer.split("\x3c/think\x3e", 1)[-1]
+ think_done = True
+ buffer = ""
+ if after.strip():
+ yield after, finish, resp
+ continue
+ # Still in thinking region, don't yield
+ if "\x3cthink\x3e" in buffer or in_think:
+ in_think = True
+ continue
+ # No thinking tags seen, yield normally
+ think_done = True
+ yield buffer, finish, resp
+ buffer = ""
+ continue
+
+ # Check for stop tokens
+ skip = False
+ for stop in STOP_STRINGS:
+ if stop in text:
+ text = text[:text.index(stop)]
+ skip = True
+
+ if text:
+ yield text, finish, resp
+
+ if skip or finish:
+ yield "", finish or "stop", resp
+ return
+
+
def format_chat(messages):
"""Format chat messages into a prompt string."""
- # Use Qwen chat template
parts = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
- parts.append(f"<|im_start|>system\n{content}<|im_end|>")
+ parts.append(f"\x3c|im_start|\x3esystem\n{content}\x3c|im_end|\x3e")
elif role == "user":
- parts.append(f"<|im_start|>user\n{content}<|im_end|>")
+ parts.append(f"\x3c|im_start|\x3euser\n{content}\x3c|im_end|\x3e")
elif role == "assistant":
- parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
+ parts.append(f"\x3c|im_start|\x3eassistant\n{content}\x3c|im_end|\x3e")
# Add empty thinking block to skip reasoning mode
- parts.append("<|im_start|>assistant\n\n\n\n\n")
+ parts.append("\x3c|im_start|\x3eassistant\n\x3cthink\x3e\n\n\x3c/think\x3e\n\n")
return "\n".join(parts)
@@ -190,7 +254,7 @@ def do_GET(self):
self._send_json({"status": "ok", "model": model_name})
elif path == "/props":
self._send_json({
- "model_alias": f"Qwen3.5-{model_name}-MLX",
+ "model_alias": f"Qwen3.5-{model_name.upper()}-MLX",
"model_path": MODELS.get(model_name, ""),
})
elif path == "/v1/context/list":
@@ -250,7 +314,15 @@ def _handle_chat(self):
messages = body.get("messages", [])
max_tokens = body.get("max_tokens", 2000)
temperature = body.get("temperature", 0.7)
+ stream = body.get("stream", False)
+
+ if stream:
+ self._handle_chat_stream(messages, max_tokens, temperature)
+ else:
+ self._handle_chat_normal(messages, max_tokens, temperature)
+ def _handle_chat_normal(self, messages, max_tokens, temperature):
+ """Non-streaming: return complete response as single JSON."""
try:
result = generate(messages, max_tokens, temperature)
@@ -276,6 +348,48 @@ def _handle_chat(self):
except Exception as e:
self._send_json({"error": {"message": str(e)}}, status=500)
+ def _handle_chat_stream(self, messages, max_tokens, temperature):
+ """Streaming: return SSE (Server-Sent Events) format for agent.py."""
+ try:
+ self.send_response(200)
+ self.send_header("Content-Type", "text/event-stream")
+ self.send_header("Cache-Control", "no-cache")
+ self.send_header("Connection", "keep-alive")
+ self.send_header("Access-Control-Allow-Origin", "*")
+ self.end_headers()
+
+ for text, finish, resp in generate_stream(messages, max_tokens, temperature):
+ chunk = {
+ "choices": [{
+ "delta": {},
+ "finish_reason": finish,
+ }]
+ }
+
+ if text:
+ chunk["choices"][0]["delta"]["content"] = text
+
+ if finish:
+ chunk["choices"][0]["delta"] = {}
+ chunk["choices"][0]["finish_reason"] = finish
+
+ line = f"data: {json.dumps(chunk)}\n\n"
+ self.wfile.write(line.encode())
+ self.wfile.flush()
+
+ # Send [DONE] marker
+ self.wfile.write(b"data: [DONE]\n\n")
+ self.wfile.flush()
+
+ except Exception as e:
+ try:
+ error_chunk = f"data: {json.dumps({'error': str(e)})}\n\n"
+ self.wfile.write(error_chunk.encode())
+ self.wfile.write(b"data: [DONE]\n\n")
+ self.wfile.flush()
+ except Exception:
+ pass
+
def _send_json(self, data, status=200):
self.send_response(status)
self.send_header("Content-Type", "application/json")
@@ -298,7 +412,7 @@ def main():
parser.add_argument("--load-context", help="Load KV cache before serving")
args = parser.parse_args()
- print(f"\n 🍎 mac code MLX engine")
+ print(f"\n \U0001f34e mac code MLX engine")
print(f" Model: {MODELS[args.model]}")
print(f" Port: {args.port}")
print()
@@ -319,6 +433,7 @@ def main():
# Start server
print(f" Server: http://localhost:{args.port}")
print(f" KV cache: persistent context enabled")
+ print(f" Streaming: enabled")
print()
server = HTTPServer(("127.0.0.1", args.port), APIHandler)