Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions doris_mcp_server/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import asyncio
import logging
import re
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass
Expand All @@ -36,6 +37,22 @@
from .logger import get_logger


_SQL_COMMENT_RE = re.compile(r"/\*.*?\*/|--[^\n]*", re.DOTALL)


def get_first_sql_keyword(sql: str) -> str:
"""Return the first SQL keyword (uppercase), ignoring leading comments/whitespace.

Strips `--` line comments and `/* */` block comments before extracting
the first token. A leading comment must not change how a statement is
classified (e.g. `-- note\\nSELECT 1` is still a SELECT).
"""
if not sql:
return ""
stripped = _SQL_COMMENT_RE.sub("", sql).strip()
if not stripped:
return ""
return stripped.split(None, 1)[0].upper()


@dataclass
Expand Down Expand Up @@ -95,15 +112,9 @@ async def execute(self, sql: str, params: tuple | None = None, auth_context=None
async with self.connection.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql, params)

# Check if it's a query statement (statement that returns result set)
# FIX for Issue #62 Bug 5: Added WITH support for Common Table Expressions (CTE)
sql_upper = sql.strip().upper()
if (sql_upper.startswith("SELECT") or
sql_upper.startswith("SHOW") or
sql_upper.startswith("DESCRIBE") or
sql_upper.startswith("DESC") or
sql_upper.startswith("EXPLAIN") or
sql_upper.startswith("WITH")): # FIX: Support CTE queries
# cursor.description is set by the DB driver for any statement that returns rows,
# avoiding a brittle hardcoded keyword list (e.g. missing WITH/CTE, comments before keywords).
if cursor.description:
data = await cursor.fetchall()
row_count = len(data)
else:
Expand Down
9 changes: 6 additions & 3 deletions doris_mcp_server/utils/query_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import sqlparse

from .db import DorisConnectionManager, QueryResult
from .db import DorisConnectionManager, QueryResult, get_first_sql_keyword
from .logger import get_logger
from .sql_security_utils import get_auth_context

Expand Down Expand Up @@ -685,8 +685,11 @@ async def execute_sql_for_mcp(
else:
self.logger.warning("Security configuration not found, proceeding without validation")

# Add LIMIT if not present and it's a SELECT query
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
# Add LIMIT if not present and it's a SELECT query.
# get_first_sql_keyword skips leading comments so `-- note\nSELECT ...`
# still gets the LIMIT cap (sql.startswith would silently bypass it).
sql_upper = sql.upper()
if get_first_sql_keyword(sql) == "SELECT" and "LIMIT" not in sql_upper:
if sql.endswith(";"):
sql = sql[:-1]
sql = f"{sql} LIMIT {limit}"
Expand Down
161 changes: 159 additions & 2 deletions test/utils/test_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest

from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
from doris_mcp_server.utils.db import (
DorisConnection,
DorisSessionCache,
get_first_sql_keyword,
)


@pytest.fixture
Expand Down Expand Up @@ -76,3 +80,156 @@ def test_clear_cache(self, session_cache):
connection_manager.release_connection.assert_any_call("query", mock_conn1)
connection_manager.release_connection.assert_any_call("system", mock_conn2)
assert connection_manager.release_connection.call_count == 2


class TestGetFirstSqlKeyword:
"""Unit tests for get_first_sql_keyword.

Used by query_executor.py:689 to detect SELECT before cursor.execute
(where cursor.description is not yet available), so the auto-injected
LIMIT {max_rows} cap also works when the SQL is comment-prefixed.
"""

def test_plain_select(self):
assert get_first_sql_keyword("SELECT 1") == "SELECT"

def test_leading_whitespace(self):
assert get_first_sql_keyword(" \n\t SELECT 1") == "SELECT"

def test_lowercase(self):
assert get_first_sql_keyword("select 1") == "SELECT"

def test_line_comment_then_select(self):
sql = "-- a leading note\nSELECT 1"
assert get_first_sql_keyword(sql) == "SELECT"

def test_block_comment_then_select(self):
sql = "/* note */ SELECT 1"
assert get_first_sql_keyword(sql) == "SELECT"

def test_multiline_block_comment_then_select(self):
sql = "/*\n multi\n line\n*/\nSELECT 1"
assert get_first_sql_keyword(sql) == "SELECT"

def test_mixed_whitespace_and_comments(self):
sql = " -- one\n /* two */ \n SELECT 1"
assert get_first_sql_keyword(sql) == "SELECT"

def test_comment_then_with_cte(self):
sql = "-- note\nWITH x AS (SELECT 1) SELECT * FROM x"
assert get_first_sql_keyword(sql) == "WITH"

def test_non_select_unaffected(self):
assert get_first_sql_keyword("INSERT INTO t VALUES (1)") == "INSERT"
assert get_first_sql_keyword("-- c\nINSERT INTO t VALUES (1)") == "INSERT"

def test_empty_and_only_comments(self):
assert get_first_sql_keyword("") == ""
assert get_first_sql_keyword(" ") == ""
assert get_first_sql_keyword("-- only a comment") == ""
assert get_first_sql_keyword("/* only */") == ""


def _make_doris_connection(cursor_description, fetchall_rows, rowcount=0):
"""Build a DorisConnection whose underlying cursor returns the given values.

The driver-level cursor is fully mocked: only `description`, `fetchall()`
and `rowcount` matter for the result-set-detection branch we want to test.
"""
cursor = MagicMock()
cursor.execute = AsyncMock(return_value=None)
cursor.fetchall = AsyncMock(return_value=fetchall_rows)
cursor.description = cursor_description
cursor.rowcount = rowcount

cursor_ctx = MagicMock()
cursor_ctx.__aenter__ = AsyncMock(return_value=cursor)
cursor_ctx.__aexit__ = AsyncMock(return_value=None)

raw_connection = MagicMock()
raw_connection.cursor = MagicMock(return_value=cursor_ctx)

return DorisConnection(connection=raw_connection, session_id="test")


class TestExecuteResultSetDetection:
"""Behavior contract for DorisConnection.execute().

These tests pin the user-facing contract: any statement the driver
reports as producing a result set must have its rows returned, and any
statement that does not produce a result set must report rowcount.

Guards against regression of:
- Issue #62 Bug 5 (CTE / WITH returning empty data)
- The leading-comment bug (SELECT prefixed by `--` or `/* */` returning
empty data while row_count was non-zero)
- Future "missing keyword in the whitelist" bugs of the same class

The tests deliberately do not assert anything about how the SQL text is
parsed — they only assert that when `cursor.description` is populated,
rows are fetched, regardless of the SQL phrasing.
"""

@pytest.mark.parametrize(
"sql",
[
"SELECT 1",
" SELECT 1",
"-- leading line comment\nSELECT 1",
"/* leading block comment */ SELECT 1",
"/*\n multi\n line\n*/\nSELECT 1",
" -- one\n /* two */ \n SELECT 1",
"(SELECT 1)",
"WITH t AS (SELECT 1) SELECT * FROM t",
"-- comment\nWITH t AS (SELECT 1) SELECT * FROM t",
"SHOW TABLES",
"DESC some_table",
"EXPLAIN SELECT 1",
],
ids=[
"plain_select",
"leading_whitespace",
"line_comment_then_select",
"block_comment_then_select",
"multiline_block_comment",
"mixed_whitespace_and_comments",
"parenthesized_select",
"with_cte",
"comment_then_with_cte",
"show",
"desc",
"explain",
],
)
async def test_returns_rows_when_driver_reports_result_set(self, sql):
rows = [{"col": 1}]
conn = _make_doris_connection(
cursor_description=[("col", None, None, None, None, None, None)],
fetchall_rows=rows,
)

result = await conn.execute(sql)

assert result.data == rows
assert result.row_count == len(rows)

@pytest.mark.parametrize(
"sql, affected",
[
("INSERT INTO t VALUES (1)", 1),
("UPDATE t SET x = 1", 5),
("DELETE FROM t WHERE x = 1", 3),
("CREATE TABLE t (x INT)", 0),
],
)
async def test_no_fetch_when_driver_reports_no_result_set(self, sql, affected):
conn = _make_doris_connection(
cursor_description=None,
fetchall_rows=[],
rowcount=affected,
)

result = await conn.execute(sql)

assert result.data == []
assert result.row_count == affected