diff --git a/cyberai/agents/exploit/agent.py b/cyberai/agents/exploit/agent.py index 69a47df..9be6f52 100644 --- a/cyberai/agents/exploit/agent.py +++ b/cyberai/agents/exploit/agent.py @@ -10,6 +10,10 @@ from cyberai.core.base_agent import BaseAgent, Tool from cyberai.core.prompts import EXPLOIT_PROMPT +from cyberai.core.llm_client import ( + format_assistant_tool_turn, + format_tool_results, +) from cyberai.integrations.oob_payloads import get_all_payloads from cyberai.integrations.phantom_grid import PhantomGridClient @@ -37,6 +41,16 @@ def _register_tools(self) -> None: name="analyze_vector", description="Analyze CVSS attack vector for exploitability", func=analyze_attack_vector, + input_schema={ + "type": "object", + "properties": { + "cve_id": { + "type": "string", + "description": "CVE id from the ranked list", + } + }, + "required": ["cve_id"], + }, ) ) self.register_tool( @@ -44,6 +58,16 @@ def _register_tools(self) -> None: name="build_chain", description="Build multi-step exploit chain from CVEs", func=build_exploit_chain, + input_schema={ + "type": "object", + "properties": { + "target": { + "type": "string", + "description": "Target host or IP", + } + }, + "required": ["target"], + }, ) ) @@ -69,7 +93,12 @@ def run(self, target: str, context: Optional[Dict[str, Any]] = None) -> Dict[str attack_paths.extend(build_attack_paths(cve, vector_analysis)) attack_paths.sort(key=lambda p: p.success_probability, reverse=True) - chain = self.call_tool("build_chain", cves=ranked_cves[:3], target=target) + if getattr(self.config, "use_native_tools", False) and self.llm is not None: + chain = self._native_chain_build(target, ranked_cves) or self.call_tool( + "build_chain", cves=ranked_cves[:3], target=target + ) + else: + chain = self.call_tool("build_chain", cves=ranked_cves[:3], target=target) analysis = self._ai_analysis(target, ranked_cves, attack_paths, chain) @@ -91,6 +120,69 @@ def run(self, target: str, context: Optional[Dict[str, Any]] = None) -> Dict[str self._log("Exploit analysis complete", {"paths_found": len(attack_paths)}) return result + def _native_chain_build(self, target: str, ranked_cves: List[Dict]) -> Optional[Dict[str, Any]]: + """Flag-gated: LLM chooses/orders tool calls to build the chain. + + Returns build_chain output if the model invoked it, else None so the + caller falls back to the deterministic path. No-op without an LLM. + """ + if self.llm is None: + return None + provider = self.llm.config.provider + system = ( + "You are an exploit-chain planner. Use analyze_vector(cve_id) to " + "assess individual CVEs, then build_chain(target) to assemble the " + "final chain. Call build_chain exactly once when ready." + ) + cve_ids = [c.get("cve_id") for c in ranked_cves[:5]] + messages: List[Dict[str, Any]] = [ + { + "role": "user", + "content": ( + f"Target: {target}. Ranked CVE ids: {json.dumps(cve_ids)}. " + "Plan and build the exploit chain." + ), + } + ] + tools = list(self.tools.values()) + max_iter = getattr(self.config, "max_agent_iterations", 10) + chain_result: Optional[Dict[str, Any]] = None + for _ in range(max_iter): + resp = self.llm.call_tools( + messages, system=system, tools=tools, agent_name=self.AGENT_NAME + ) + if not resp.tool_calls: + break + messages.append(format_assistant_tool_turn(provider, resp)) + results = [] + for tc in resp.tool_calls: + output = self._exec_native_tool(tc, ranked_cves, target) + if tc.name == "build_chain": + chain_result = output + out_str = output if isinstance(output, str) else json.dumps(output) + results.append((tc, out_str)) + messages.extend(format_tool_results(provider, results)) + if chain_result is not None: + break + return chain_result + + def _exec_native_tool(self, tc: Any, ranked_cves: List[Dict], target: str) -> Any: + """Adapt model-supplied identifiers to real functions. + + The model passes cve_id / target; actual CVE data is resolved here from + ranked_cves rather than trusting the model to echo full dicts. + """ + if tc.name == "analyze_vector": + cid = tc.arguments.get("cve_id") + cve = next((c for c in ranked_cves if c.get("cve_id") == cid), None) + if cve is None: + return {"error": f"unknown cve_id: {cid}"} + return analyze_attack_vector(cve) + if tc.name == "build_chain": + tgt = tc.arguments.get("target") or target + return build_exploit_chain(ranked_cves[:3], tgt) + return {"error": f"unknown tool: {tc.name}"} + def _ai_analysis( self, target: str, diff --git a/cyberai/core/base_agent.py b/cyberai/core/base_agent.py index 0a6cb3b..9427f87 100644 --- a/cyberai/core/base_agent.py +++ b/cyberai/core/base_agent.py @@ -47,6 +47,9 @@ class Tool: func: Callable params: Dict[str, str] = field(default_factory=dict) parameters: Optional[Dict[str, str]] = None + # Explicit JSON Schema for native LLM tool calling. params expresses only + # flat string args; set this for typed/nested/list args (KI: build_chain). + input_schema: Optional[Dict[str, Any]] = None def __post_init__(self) -> None: # KI-6: agents pass parameters=...; mirror it into params. diff --git a/cyberai/core/llm_client.py b/cyberai/core/llm_client.py index f0774fa..6ab8f37 100644 --- a/cyberai/core/llm_client.py +++ b/cyberai/core/llm_client.py @@ -1,8 +1,14 @@ -from typing import List, Dict, Optional, Any +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Dict, Optional, Any + from .config import LLMConfig from .cost_tracker import CostTracker, BudgetExceeded import httpx +if TYPE_CHECKING: + from .base_agent import Tool + class LLMClient: """ @@ -127,6 +133,93 @@ def _call_ollama(self, messages: List[Dict], system: Optional[str]) -> str: response.raise_for_status() return response.json()["message"]["content"] + # ── native tool calling (sync) ──────────────────────────────────── + + def call_tools( + self, + messages: List[Dict], + system: Optional[str] = None, + tools: Optional[List["Tool"]] = None, + agent_name: str = "unknown", + cacheable_system: bool = False, + ) -> "LLMResponse": + """One tool-enabled round-trip. Ollama tool calling is unsupported.""" + tools = tools or [] + if self.config.provider == "openai": + return self._call_tools_openai(messages, system, tools, agent_name) + elif self.config.provider == "anthropic": + return self._call_tools_anthropic(messages, system, tools, agent_name, cacheable_system) + else: + raise ValueError(f"Tool calling unsupported for provider: {self.config.provider}") + + def _call_tools_openai(self, messages, system, tools, agent_name="unknown"): + import openai + + client = openai.OpenAI(api_key=self.config.api_key) + full_messages = [] + if system: + full_messages.append({"role": "system", "content": system}) + full_messages.extend(messages) + response = client.chat.completions.create( + model=self.config.model, + messages=full_messages, + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, + tools=_tools_to_openai_format(tools), + ) + self._record_usage( + agent_name, + getattr(response, "model", self.config.model), + getattr(response.usage, "prompt_tokens", 0), + getattr(response.usage, "completion_tokens", 0), + ) + choice = response.choices[0] + msg = choice.message + calls = [] + for tc in getattr(msg, "tool_calls", None) or []: + try: + args = json.loads(tc.function.arguments or "{}") + except json.JSONDecodeError: + args = {} + calls.append(ToolCall(id=tc.id, name=tc.function.name, arguments=args)) + return LLMResponse(text=msg.content, tool_calls=calls, stop_reason=choice.finish_reason) + + def _call_tools_anthropic( + self, messages, system, tools, agent_name="unknown", cacheable_system=False + ): + import anthropic + + client = anthropic.Anthropic(api_key=self.config.api_key) + kwargs: Dict[str, Any] = dict( + model=self.config.model, + max_tokens=self.config.max_tokens, + messages=messages, + tools=_tools_to_anthropic_format(tools), + ) + if system: + kwargs["system"] = _wrap_cacheable(system) if cacheable_system else system + response = client.messages.create(**kwargs) + self._record_usage( + agent_name, + getattr(response, "model", self.config.model), + getattr(response.usage, "input_tokens", 0), + getattr(response.usage, "output_tokens", 0), + cache_creation_tokens=getattr(response.usage, "cache_creation_input_tokens", 0) or 0, + cache_read_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0, + ) + text_parts = [] + calls = [] + for block in response.content: + if block.type == "text": + text_parts.append(block.text) + elif block.type == "tool_use": + calls.append(ToolCall(id=block.id, name=block.name, arguments=dict(block.input))) + return LLMResponse( + text="".join(text_parts) or None, + tool_calls=calls, + stop_reason=getattr(response, "stop_reason", None), + ) + # ── async API ───────────────────────────────────────────────────── async def acall( @@ -226,3 +319,115 @@ def _wrap_cacheable(system_text: str) -> list[dict]: "cache_control": {"type": "ephemeral"}, } ] + + +# ── native tool calling: response types + spec converters ───────────── + + +@dataclass +class ToolCall: + """A single tool invocation requested by the model.""" + + id: str + name: str + arguments: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class LLMResponse: + """Tool-enabled call result: free text and/or requested tool calls.""" + + text: Optional[str] = None + tool_calls: List[ToolCall] = field(default_factory=list) + stop_reason: Optional[str] = None + + +def _params_to_schema(tool: "Tool") -> Dict[str, Any]: + """JSON Schema for a Tool's arguments. + + Prefers tool.input_schema when set; otherwise derives an all-string + object schema from tool.params (arg name -> description). params cannot + express non-string/nested types — set input_schema for those. + """ + explicit = getattr(tool, "input_schema", None) + if explicit: + return explicit + props = { + name: {"type": "string", "description": desc} + for name, desc in (getattr(tool, "params", None) or {}).items() + } + return {"type": "object", "properties": props, "required": list(props)} + + +def _tools_to_openai_format(tools: List["Tool"]) -> List[Dict[str, Any]]: + """Convert Tools to the OpenAI chat.completions `tools` shape.""" + return [ + { + "type": "function", + "function": { + "name": t.name, + "description": t.description, + "parameters": _params_to_schema(t), + }, + } + for t in tools + ] + + +def _tools_to_anthropic_format(tools: List["Tool"]) -> List[Dict[str, Any]]: + """Convert Tools to the Anthropic messages `tools` shape. + + Anthropic uses a flat {name, description, input_schema} per tool, + where input_schema is the JSON Schema for the arguments object. + """ + return [ + { + "name": t.name, + "description": t.description, + "input_schema": _params_to_schema(t), + } + for t in tools + ] + + +# ── native tool calling: provider-aware message threading ───────────── + + +def format_assistant_tool_turn(provider: str, response: "LLMResponse") -> Dict[str, Any]: + """Rebuild the assistant turn that requested tool calls, for re-sending.""" + if provider == "anthropic": + content: List[Dict[str, Any]] = [] + if response.text: + content.append({"type": "text", "text": response.text}) + for tc in response.tool_calls: + content.append( + {"type": "tool_use", "id": tc.id, "name": tc.name, "input": tc.arguments} + ) + return {"role": "assistant", "content": content} + return { + "role": "assistant", + "content": response.text, + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}, + } + for tc in response.tool_calls + ], + } + + +def format_tool_results(provider: str, results: List[tuple]) -> List[Dict[str, Any]]: + """results: list of (ToolCall, output_str) -> provider-shaped messages.""" + if provider == "anthropic": + return [ + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": tc.id, "content": out} + for tc, out in results + ], + } + ] + return [{"role": "tool", "tool_call_id": tc.id, "content": out} for tc, out in results] diff --git a/tests/unit/test_tool_calling.py b/tests/unit/test_tool_calling.py new file mode 100644 index 0000000..00646f8 --- /dev/null +++ b/tests/unit/test_tool_calling.py @@ -0,0 +1,198 @@ +"""Day 19 — native LLM tool calling: spec converters, executor, loop.""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pytest + +from cyberai.core.base_agent import Tool +from cyberai.core.llm_client import ( + LLMResponse, + ToolCall, + _params_to_schema, + _tools_to_anthropic_format, + _tools_to_openai_format, + format_assistant_tool_turn, + format_tool_results, +) + + +# ── fixtures ────────────────────────────────────────────────────────── + + +def _noop(**kwargs: Any) -> Dict[str, Any]: + return {"ok": True, **kwargs} + + +@pytest.fixture +def flat_tool() -> Tool: + return Tool( + name="ping", + description="ping a host", + func=_noop, + params={"host": "target host"}, + ) + + +@pytest.fixture +def schema_tool() -> Tool: + return Tool( + name="build_chain", + description="build a chain", + func=_noop, + input_schema={ + "type": "object", + "properties": {"target": {"type": "string"}}, + "required": ["target"], + }, + ) + + +# ── _params_to_schema ───────────────────────────────────────────────── + + +def test_params_to_schema_from_flat_params(flat_tool): + schema = _params_to_schema(flat_tool) + assert schema["type"] == "object" + assert schema["properties"]["host"]["type"] == "string" + assert schema["required"] == ["host"] + + +def test_params_to_schema_prefers_explicit(schema_tool): + schema = _params_to_schema(schema_tool) + assert schema["properties"]["target"]["type"] == "string" + # explicit schema is returned verbatim — no synthetic string props + assert set(schema["properties"]) == {"target"} + + +# ── spec converters ─────────────────────────────────────────────────── + + +def test_openai_format_shape(flat_tool): + spec = _tools_to_openai_format([flat_tool])[0] + assert spec["type"] == "function" + assert spec["function"]["name"] == "ping" + assert spec["function"]["parameters"]["properties"]["host"] + + +def test_anthropic_format_shape(flat_tool): + spec = _tools_to_anthropic_format([flat_tool])[0] + assert spec["name"] == "ping" + assert "input_schema" in spec + assert spec["input_schema"]["properties"]["host"] + + +# ── provider-aware threading ────────────────────────────────────────── + + +def test_format_assistant_turn_anthropic(): + resp = LLMResponse( + text="thinking", + tool_calls=[ToolCall(id="t1", name="ping", arguments={"host": "x"})], + ) + turn = format_assistant_tool_turn("anthropic", resp) + assert turn["role"] == "assistant" + types = [b["type"] for b in turn["content"]] + assert "tool_use" in types and "text" in types + + +def test_format_assistant_turn_openai(): + resp = LLMResponse( + text=None, + tool_calls=[ToolCall(id="t1", name="ping", arguments={"host": "x"})], + ) + turn = format_assistant_tool_turn("openai", resp) + assert turn["tool_calls"][0]["function"]["name"] == "ping" + # OpenAI arguments must be a JSON string, not a dict + assert json.loads(turn["tool_calls"][0]["function"]["arguments"]) == {"host": "x"} + + +def test_format_tool_results_anthropic(): + tc = ToolCall(id="t1", name="ping", arguments={}) + msgs = format_tool_results("anthropic", [(tc, "result-str")]) + block = msgs[0]["content"][0] + assert block["type"] == "tool_result" + assert block["tool_use_id"] == "t1" + + +def test_format_tool_results_openai(): + tc = ToolCall(id="t1", name="ping", arguments={}) + msgs = format_tool_results("openai", [(tc, "result-str")]) + assert msgs[0]["role"] == "tool" + assert msgs[0]["tool_call_id"] == "t1" + + +# ── ExploitAgent native loop (mocked LLM) ───────────────────────────── + + +RANKED = [ + { + "cve_id": "CVE-TEST", + "cvss": 9.8, + "cvss_vector": "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", + } +] + + +def _make_agent(): + from cyberai.agents.exploit.agent import ExploitAgent + + agent = ExploitAgent.__new__(ExploitAgent) + agent.AGENT_NAME = "exploit" + agent.tools = {} + agent.llm = MagicMock() + agent.llm.config.provider = "anthropic" + agent.config = MagicMock() + agent.config.max_agent_iterations = 5 + agent._register_tools() + return agent + + +def test_exec_native_tool_resolves_cve_id(): + agent = _make_agent() + tc = ToolCall(id="t1", name="analyze_vector", arguments={"cve_id": "CVE-TEST"}) + out = agent._exec_native_tool(tc, RANKED, "127.0.0.1") + assert out["remotely_exploitable"] is True + + +def test_exec_native_tool_unknown_cve_id(): + agent = _make_agent() + tc = ToolCall(id="t1", name="analyze_vector", arguments={"cve_id": "NOPE"}) + out = agent._exec_native_tool(tc, RANKED, "127.0.0.1") + assert "error" in out + + +def test_exec_native_tool_build_chain(): + agent = _make_agent() + tc = ToolCall(id="t2", name="build_chain", arguments={"target": "10.0.0.1"}) + out = agent._exec_native_tool(tc, RANKED, "127.0.0.1") + assert out["chain_length"] == 1 + + +def test_native_chain_build_full_loop(): + agent = _make_agent() + # 1st round: model asks for analyze_vector; 2nd: build_chain; then stop. + agent.llm.call_tools.side_effect = [ + LLMResponse( + text=None, + tool_calls=[ToolCall(id="a1", name="analyze_vector", arguments={"cve_id": "CVE-TEST"})], + ), + LLMResponse( + text=None, + tool_calls=[ToolCall(id="c1", name="build_chain", arguments={"target": "10.0.0.1"})], + ), + ] + chain = agent._native_chain_build("10.0.0.1", RANKED) + assert chain is not None + assert chain["chain_length"] == 1 + # loop stopped right after build_chain — exactly 2 LLM round-trips + assert agent.llm.call_tools.call_count == 2 + + +def test_native_chain_build_no_tool_calls_returns_none(): + agent = _make_agent() + agent.llm.call_tools.return_value = LLMResponse(text="no tools", tool_calls=[]) + assert agent._native_chain_build("10.0.0.1", RANKED) is None