Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 137 additions & 22 deletions mlx/mlx_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import argparse
import json
import re
import sys
import os
import time
Expand All @@ -26,6 +27,8 @@
"35b": "mlx-community/Qwen3.5-35B-A3B-4bit",
}

STOP_STRINGS = ["\x3c|endoftext|\x3e", "\x3c|im_end|\x3e", "\x3c|im_start|\x3e"]

# Global state
model = None
tokenizer = None
Expand All @@ -52,11 +55,22 @@ def load_model(model_key="9b"):
return model, tokenizer


def generate(messages, max_tokens=2000, temperature=0.7, stream=False):
"""Generate a response from the model."""
def clean_response(text):
"""Strip special tokens and thinking tags from response."""
for stop in STOP_STRINGS:
if stop in text:
text = text[:text.index(stop)]

if "\x3cthink\x3e" in text:
text = re.sub(r'\x3cthink\x3e.*?\x3c/think\x3e', '', text, flags=re.DOTALL)

return text.strip()


def generate(messages, max_tokens=2000, temperature=0.7):
"""Generate a complete response (non-streaming)."""
from mlx_lm import generate as mlx_generate

# Format messages into prompt
prompt = format_chat(messages)

t0 = time.time()
Expand All @@ -66,18 +80,7 @@ def generate(messages, max_tokens=2000, temperature=0.7, stream=False):
)
elapsed = time.time() - t0

# Strip special tokens
for stop in ["<|endoftext|>", "<|im_end|>", "<|im_start|>"]:
if stop in response:
response = response[:response.index(stop)]

# Strip thinking tags — extract content after </think>
import re
if "<think>" in response:
# Remove everything between <think> and </think>
response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
response = response.strip()

response = clean_response(response)
tokens = len(tokenizer.encode(response)) if response else 0
speed = tokens / elapsed if elapsed > 0 else 0

Expand All @@ -89,21 +92,82 @@ def generate(messages, max_tokens=2000, temperature=0.7, stream=False):
}


def generate_stream(messages, max_tokens=2000, temperature=0.7):
"""Stream tokens one at a time using mlx_lm.stream_generate."""
from mlx_lm import stream_generate

prompt = format_chat(messages)

# Track state for filtering thinking tags
in_think = False
think_done = False
buffer = ""

for resp in stream_generate(
model, tokenizer, prompt=prompt,
max_tokens=max_tokens,
):
text = resp.text
token_id = resp.token
finish = resp.finish_reason

# Skip empty text
if not text:
if finish:
yield "", finish, resp
continue

# Filter out thinking content
if not think_done:
buffer += text
# Check if we've seen the end of thinking
if "\x3c/think\x3e" in buffer:
# Extract content after </think>
after = buffer.split("\x3c/think\x3e", 1)[-1]
think_done = True
buffer = ""
if after.strip():
yield after, finish, resp
continue
# Still in thinking region, don't yield
if "\x3cthink\x3e" in buffer or in_think:
in_think = True
continue
# No thinking tags seen, yield normally
think_done = True
yield buffer, finish, resp
buffer = ""
continue

# Check for stop tokens
skip = False
for stop in STOP_STRINGS:
if stop in text:
text = text[:text.index(stop)]
skip = True

if text:
yield text, finish, resp

if skip or finish:
yield "", finish or "stop", resp
return


def format_chat(messages):
"""Format chat messages into a prompt string."""
# Use Qwen chat template
parts = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
parts.append(f"<|im_start|>system\n{content}<|im_end|>")
parts.append(f"\x3c|im_start|\x3esystem\n{content}\x3c|im_end|\x3e")
elif role == "user":
parts.append(f"<|im_start|>user\n{content}<|im_end|>")
parts.append(f"\x3c|im_start|\x3euser\n{content}\x3c|im_end|\x3e")
elif role == "assistant":
parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
parts.append(f"\x3c|im_start|\x3eassistant\n{content}\x3c|im_end|\x3e")
# Add empty thinking block to skip reasoning mode
parts.append("<|im_start|>assistant\n<think>\n\n</think>\n\n")
parts.append("\x3c|im_start|\x3eassistant\n\x3cthink\x3e\n\n\x3c/think\x3e\n\n")
return "\n".join(parts)


Expand Down Expand Up @@ -190,7 +254,7 @@ def do_GET(self):
self._send_json({"status": "ok", "model": model_name})
elif path == "/props":
self._send_json({
"model_alias": f"Qwen3.5-{model_name}-MLX",
"model_alias": f"Qwen3.5-{model_name.upper()}-MLX",
"model_path": MODELS.get(model_name, ""),
})
elif path == "/v1/context/list":
Expand Down Expand Up @@ -250,7 +314,15 @@ def _handle_chat(self):
messages = body.get("messages", [])
max_tokens = body.get("max_tokens", 2000)
temperature = body.get("temperature", 0.7)
stream = body.get("stream", False)

if stream:
self._handle_chat_stream(messages, max_tokens, temperature)
else:
self._handle_chat_normal(messages, max_tokens, temperature)

def _handle_chat_normal(self, messages, max_tokens, temperature):
"""Non-streaming: return complete response as single JSON."""
try:
result = generate(messages, max_tokens, temperature)

Expand All @@ -276,6 +348,48 @@ def _handle_chat(self):
except Exception as e:
self._send_json({"error": {"message": str(e)}}, status=500)

def _handle_chat_stream(self, messages, max_tokens, temperature):
"""Streaming: return SSE (Server-Sent Events) format for agent.py."""
try:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive")
self.send_header("Access-Control-Allow-Origin", "*")
self.end_headers()

for text, finish, resp in generate_stream(messages, max_tokens, temperature):
chunk = {
"choices": [{
"delta": {},
"finish_reason": finish,
}]
}

if text:
chunk["choices"][0]["delta"]["content"] = text

if finish:
chunk["choices"][0]["delta"] = {}
chunk["choices"][0]["finish_reason"] = finish

line = f"data: {json.dumps(chunk)}\n\n"
self.wfile.write(line.encode())
self.wfile.flush()

# Send [DONE] marker
self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush()

except Exception as e:
try:
error_chunk = f"data: {json.dumps({'error': str(e)})}\n\n"
self.wfile.write(error_chunk.encode())
self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush()
except Exception:
pass

def _send_json(self, data, status=200):
self.send_response(status)
self.send_header("Content-Type", "application/json")
Expand All @@ -298,7 +412,7 @@ def main():
parser.add_argument("--load-context", help="Load KV cache before serving")
args = parser.parse_args()

print(f"\n 🍎 mac code MLX engine")
print(f"\n \U0001f34e mac code MLX engine")
print(f" Model: {MODELS[args.model]}")
print(f" Port: {args.port}")
print()
Expand All @@ -319,6 +433,7 @@ def main():
# Start server
print(f" Server: http://localhost:{args.port}")
print(f" KV cache: persistent context enabled")
print(f" Streaming: enabled")
print()

server = HTTPServer(("127.0.0.1", args.port), APIHandler)
Expand Down