Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion cyberai/agents/exploit/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

from cyberai.core.base_agent import BaseAgent, Tool
from cyberai.core.prompts import EXPLOIT_PROMPT
from cyberai.core.llm_client import (
format_assistant_tool_turn,
format_tool_results,
)
from cyberai.integrations.oob_payloads import get_all_payloads
from cyberai.integrations.phantom_grid import PhantomGridClient

Expand Down Expand Up @@ -37,13 +41,33 @@ def _register_tools(self) -> None:
name="analyze_vector",
description="Analyze CVSS attack vector for exploitability",
func=analyze_attack_vector,
input_schema={
"type": "object",
"properties": {
"cve_id": {
"type": "string",
"description": "CVE id from the ranked list",
}
},
"required": ["cve_id"],
},
)
)
self.register_tool(
Tool(
name="build_chain",
description="Build multi-step exploit chain from CVEs",
func=build_exploit_chain,
input_schema={
"type": "object",
"properties": {
"target": {
"type": "string",
"description": "Target host or IP",
}
},
"required": ["target"],
},
)
)

Expand All @@ -69,7 +93,12 @@ def run(self, target: str, context: Optional[Dict[str, Any]] = None) -> Dict[str
attack_paths.extend(build_attack_paths(cve, vector_analysis))

attack_paths.sort(key=lambda p: p.success_probability, reverse=True)
chain = self.call_tool("build_chain", cves=ranked_cves[:3], target=target)
if getattr(self.config, "use_native_tools", False) and self.llm is not None:
chain = self._native_chain_build(target, ranked_cves) or self.call_tool(
"build_chain", cves=ranked_cves[:3], target=target
)
else:
chain = self.call_tool("build_chain", cves=ranked_cves[:3], target=target)

analysis = self._ai_analysis(target, ranked_cves, attack_paths, chain)

Expand All @@ -91,6 +120,69 @@ def run(self, target: str, context: Optional[Dict[str, Any]] = None) -> Dict[str
self._log("Exploit analysis complete", {"paths_found": len(attack_paths)})
return result

def _native_chain_build(self, target: str, ranked_cves: List[Dict]) -> Optional[Dict[str, Any]]:
"""Flag-gated: LLM chooses/orders tool calls to build the chain.

Returns build_chain output if the model invoked it, else None so the
caller falls back to the deterministic path. No-op without an LLM.
"""
if self.llm is None:
return None
provider = self.llm.config.provider
system = (
"You are an exploit-chain planner. Use analyze_vector(cve_id) to "
"assess individual CVEs, then build_chain(target) to assemble the "
"final chain. Call build_chain exactly once when ready."
)
cve_ids = [c.get("cve_id") for c in ranked_cves[:5]]
messages: List[Dict[str, Any]] = [
{
"role": "user",
"content": (
f"Target: {target}. Ranked CVE ids: {json.dumps(cve_ids)}. "
"Plan and build the exploit chain."
),
}
]
tools = list(self.tools.values())
max_iter = getattr(self.config, "max_agent_iterations", 10)
chain_result: Optional[Dict[str, Any]] = None
for _ in range(max_iter):
resp = self.llm.call_tools(
messages, system=system, tools=tools, agent_name=self.AGENT_NAME
)
if not resp.tool_calls:
break
messages.append(format_assistant_tool_turn(provider, resp))
results = []
for tc in resp.tool_calls:
output = self._exec_native_tool(tc, ranked_cves, target)
if tc.name == "build_chain":
chain_result = output
out_str = output if isinstance(output, str) else json.dumps(output)
results.append((tc, out_str))
messages.extend(format_tool_results(provider, results))
if chain_result is not None:
break
return chain_result

def _exec_native_tool(self, tc: Any, ranked_cves: List[Dict], target: str) -> Any:
"""Adapt model-supplied identifiers to real functions.

The model passes cve_id / target; actual CVE data is resolved here from
ranked_cves rather than trusting the model to echo full dicts.
"""
if tc.name == "analyze_vector":
cid = tc.arguments.get("cve_id")
cve = next((c for c in ranked_cves if c.get("cve_id") == cid), None)
if cve is None:
return {"error": f"unknown cve_id: {cid}"}
return analyze_attack_vector(cve)
if tc.name == "build_chain":
tgt = tc.arguments.get("target") or target
return build_exploit_chain(ranked_cves[:3], tgt)
return {"error": f"unknown tool: {tc.name}"}

def _ai_analysis(
self,
target: str,
Expand Down
3 changes: 3 additions & 0 deletions cyberai/core/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class Tool:
func: Callable
params: Dict[str, str] = field(default_factory=dict)
parameters: Optional[Dict[str, str]] = None
# Explicit JSON Schema for native LLM tool calling. params expresses only
# flat string args; set this for typed/nested/list args (KI: build_chain).
input_schema: Optional[Dict[str, Any]] = None

def __post_init__(self) -> None:
# KI-6: agents pass parameters=...; mirror it into params.
Expand Down
207 changes: 206 additions & 1 deletion cyberai/core/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import List, Dict, Optional, Any
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Dict, Optional, Any

from .config import LLMConfig
from .cost_tracker import CostTracker, BudgetExceeded
import httpx

if TYPE_CHECKING:
from .base_agent import Tool


class LLMClient:
"""
Expand Down Expand Up @@ -127,6 +133,93 @@ def _call_ollama(self, messages: List[Dict], system: Optional[str]) -> str:
response.raise_for_status()
return response.json()["message"]["content"]

# ── native tool calling (sync) ────────────────────────────────────

def call_tools(
self,
messages: List[Dict],
system: Optional[str] = None,
tools: Optional[List["Tool"]] = None,
agent_name: str = "unknown",
cacheable_system: bool = False,
) -> "LLMResponse":
"""One tool-enabled round-trip. Ollama tool calling is unsupported."""
tools = tools or []
if self.config.provider == "openai":
return self._call_tools_openai(messages, system, tools, agent_name)
elif self.config.provider == "anthropic":
return self._call_tools_anthropic(messages, system, tools, agent_name, cacheable_system)
else:
raise ValueError(f"Tool calling unsupported for provider: {self.config.provider}")

def _call_tools_openai(self, messages, system, tools, agent_name="unknown"):
import openai

client = openai.OpenAI(api_key=self.config.api_key)
full_messages = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
response = client.chat.completions.create(
model=self.config.model,
messages=full_messages,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
tools=_tools_to_openai_format(tools),
)
self._record_usage(
agent_name,
getattr(response, "model", self.config.model),
getattr(response.usage, "prompt_tokens", 0),
getattr(response.usage, "completion_tokens", 0),
)
choice = response.choices[0]
msg = choice.message
calls = []
for tc in getattr(msg, "tool_calls", None) or []:
try:
args = json.loads(tc.function.arguments or "{}")
except json.JSONDecodeError:
args = {}
calls.append(ToolCall(id=tc.id, name=tc.function.name, arguments=args))
return LLMResponse(text=msg.content, tool_calls=calls, stop_reason=choice.finish_reason)

def _call_tools_anthropic(
self, messages, system, tools, agent_name="unknown", cacheable_system=False
):
import anthropic

client = anthropic.Anthropic(api_key=self.config.api_key)
kwargs: Dict[str, Any] = dict(
model=self.config.model,
max_tokens=self.config.max_tokens,
messages=messages,
tools=_tools_to_anthropic_format(tools),
)
if system:
kwargs["system"] = _wrap_cacheable(system) if cacheable_system else system
response = client.messages.create(**kwargs)
self._record_usage(
agent_name,
getattr(response, "model", self.config.model),
getattr(response.usage, "input_tokens", 0),
getattr(response.usage, "output_tokens", 0),
cache_creation_tokens=getattr(response.usage, "cache_creation_input_tokens", 0) or 0,
cache_read_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0,
)
text_parts = []
calls = []
for block in response.content:
if block.type == "text":
text_parts.append(block.text)
elif block.type == "tool_use":
calls.append(ToolCall(id=block.id, name=block.name, arguments=dict(block.input)))
return LLMResponse(
text="".join(text_parts) or None,
tool_calls=calls,
stop_reason=getattr(response, "stop_reason", None),
)

# ── async API ─────────────────────────────────────────────────────

async def acall(
Expand Down Expand Up @@ -226,3 +319,115 @@ def _wrap_cacheable(system_text: str) -> list[dict]:
"cache_control": {"type": "ephemeral"},
}
]


# ── native tool calling: response types + spec converters ─────────────


@dataclass
class ToolCall:
"""A single tool invocation requested by the model."""

id: str
name: str
arguments: Dict[str, Any] = field(default_factory=dict)


@dataclass
class LLMResponse:
"""Tool-enabled call result: free text and/or requested tool calls."""

text: Optional[str] = None
tool_calls: List[ToolCall] = field(default_factory=list)
stop_reason: Optional[str] = None


def _params_to_schema(tool: "Tool") -> Dict[str, Any]:
"""JSON Schema for a Tool's arguments.

Prefers tool.input_schema when set; otherwise derives an all-string
object schema from tool.params (arg name -> description). params cannot
express non-string/nested types — set input_schema for those.
"""
explicit = getattr(tool, "input_schema", None)
if explicit:
return explicit
props = {
name: {"type": "string", "description": desc}
for name, desc in (getattr(tool, "params", None) or {}).items()
}
return {"type": "object", "properties": props, "required": list(props)}


def _tools_to_openai_format(tools: List["Tool"]) -> List[Dict[str, Any]]:
"""Convert Tools to the OpenAI chat.completions `tools` shape."""
return [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": _params_to_schema(t),
},
}
for t in tools
]


def _tools_to_anthropic_format(tools: List["Tool"]) -> List[Dict[str, Any]]:
"""Convert Tools to the Anthropic messages `tools` shape.

Anthropic uses a flat {name, description, input_schema} per tool,
where input_schema is the JSON Schema for the arguments object.
"""
return [
{
"name": t.name,
"description": t.description,
"input_schema": _params_to_schema(t),
}
for t in tools
]


# ── native tool calling: provider-aware message threading ─────────────


def format_assistant_tool_turn(provider: str, response: "LLMResponse") -> Dict[str, Any]:
"""Rebuild the assistant turn that requested tool calls, for re-sending."""
if provider == "anthropic":
content: List[Dict[str, Any]] = []
if response.text:
content.append({"type": "text", "text": response.text})
for tc in response.tool_calls:
content.append(
{"type": "tool_use", "id": tc.id, "name": tc.name, "input": tc.arguments}
)
return {"role": "assistant", "content": content}
return {
"role": "assistant",
"content": response.text,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {"name": tc.name, "arguments": json.dumps(tc.arguments)},
}
for tc in response.tool_calls
],
}


def format_tool_results(provider: str, results: List[tuple]) -> List[Dict[str, Any]]:
"""results: list of (ToolCall, output_str) -> provider-shaped messages."""
if provider == "anthropic":
return [
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": tc.id, "content": out}
for tc, out in results
],
}
]
return [{"role": "tool", "tool_call_id": tc.id, "content": out} for tc, out in results]
Loading
Loading