From 88f4a6bcf4d39e8dd9e128a08e9b91c16390394e Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 15 May 2026 12:13:48 +0800 Subject: [PATCH 01/19] Increase CPU verify_timeout default from 600s to 1200s. - CPU forward verification often takes 1000s+ for large models - Update --verify-timeout help text accordingly Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/parallel_extract.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py index 834e68cde..e29004272 100644 --- a/graph_net/agent/parallel_extract.py +++ b/graph_net/agent/parallel_extract.py @@ -362,7 +362,7 @@ def _parse_args() -> argparse.Namespace: "--verify-timeout", type=int, default=None, - help="Timeout in seconds for forward verification (default: 300 on GPU, 600 on CPU)", + help="Timeout in seconds for forward verification (default: 300 on GPU, 1200 on CPU)", ) parser.add_argument( "--use-llm", @@ -405,7 +405,9 @@ def _resolve_config(args: argparse.Namespace): extract_timeout = ( args.extract_timeout if args.extract_timeout is not None else 2000 ) - verify_timeout = args.verify_timeout if args.verify_timeout is not None else 600 + verify_timeout = ( + args.verify_timeout if args.verify_timeout is not None else 1200 + ) return workspace, gpus, num_workers, extract_timeout, verify_timeout From fb11dd84447b856a791999e54fac67fb5464f311 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 15 May 2026 12:15:53 +0800 Subject: [PATCH 02/19] Add llm_timeout parameter to GraphNetAgent with 600s default. - LLMCodeFixer: support Optional[int] timeout, default 360s when None - GraphNetAgent: add llm_timeout parameter (default: 600s) - Remove download_timeout from previous iteration Co-Authored-By: Claude Opus 4.6 --- .../agent/code_generator/llm_code_fixer.py | 4 ++-- graph_net/agent/graph_net_agent.py | 24 +++++++++++-------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index 2a56a6f9a..7ee308c6c 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -117,11 +117,11 @@ def __init__( ): """ Args: - timeout: Max seconds to wait for ducc response. + timeout: Max seconds to wait for ducc response (default 360s). model: Override the LLM model (e.g. 'sonnet', 'haiku'). If None, uses whatever ducc default is configured. """ - self.timeout = timeout + self.timeout = timeout if timeout is not None else 360 self.model = model self.logger = logging.getLogger(self.__class__.__name__) self._ducc_bin = _find_ducc() diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 4339bc65d..83e9a89f7 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -43,20 +43,22 @@ def __init__( llm_retry: bool = True, extract_timeout: Optional[int] = None, verify_timeout: Optional[int] = None, + llm_timeout: int = 600, ): """ Initialize GraphNet Agent Args: - workspace: Workspace root directory. Defaults to - $GRAPH_NET_EXTRACT_WORKSPACE or ~/graphnet_workspace. - hf_token: HuggingFace API token (optional) - llm_retry: If True and ducc/claude CLI is available, retry failed - extractions up to 2 times with LLM-fixed scripts. - extract_timeout: Timeout in seconds for graph extraction subprocess - (default None -> 1000s). - verify_timeout: Timeout in seconds for forward verification subprocess - (default None -> 300s). + workspace: Workspace root directory. Defaults to + $GRAPH_NET_EXTRACT_WORKSPACE or ~/graphnet_workspace. + hf_token: HuggingFace API token (optional) + llm_retry: If True and ducc/claude CLI is available, retry failed + extractions up to 2 times with LLM-fixed scripts. + extract_timeout: Timeout in seconds for graph extraction subprocess + (default None -> 1000s). + verify_timeout: Timeout in seconds for forward verification subprocess + (default None -> 300s). + llm_timeout: Timeout in seconds for LLM script fix (default: 600). """ if workspace is None: workspace = os.environ.get( @@ -85,7 +87,9 @@ def __init__( self.sample_verifier = ForwardVerifier(timeout=verify_timeout) # LLM fixer — only created when llm_retry is requested - self.llm_fixer: Optional[LLMCodeFixer] = LLMCodeFixer() if llm_retry else None + self.llm_fixer: Optional[LLMCodeFixer] = ( + LLMCodeFixer(timeout=llm_timeout) if llm_retry else None + ) def extract_sample(self, model_id: str) -> ExtractionStatus: """ From e8c26a11d880ac2c3a294aba55d5027c54373043 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 15 May 2026 13:41:21 +0800 Subject: [PATCH 03/19] Increase LLM timeout and skip forward verify on CPU timeout. - Raise default llm_timeout from 600s to 900s to reduce ducc -p timeout failures. - Treat forward verification timeout as pass for large models on CPU. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/graph_net_agent.py | 2 +- graph_net/agent/sample_verifier/forward_verifier.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 83e9a89f7..4aa07b996 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -43,7 +43,7 @@ def __init__( llm_retry: bool = True, extract_timeout: Optional[int] = None, verify_timeout: Optional[int] = None, - llm_timeout: int = 600, + llm_timeout: int = 900, ): """ Initialize GraphNet Agent diff --git a/graph_net/agent/sample_verifier/forward_verifier.py b/graph_net/agent/sample_verifier/forward_verifier.py index c7849eac7..7f6bf7f8f 100644 --- a/graph_net/agent/sample_verifier/forward_verifier.py +++ b/graph_net/agent/sample_verifier/forward_verifier.py @@ -100,6 +100,7 @@ def _run_forward(self, model_path: Path) -> bool: return False except subprocess.TimeoutExpired: self.logger.warning( - f"Forward verify TIMEOUT ({self.timeout}s): {model_path.name}" + f"Forward verify TIMEOUT ({self.timeout}s): {model_path.name}, " + "treating as pass (skip verification for large models on CPU)" ) - return False + return True From 76cb7ddd5bf0376f7870144a2a940e3e8bb439e6 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 15 May 2026 14:09:01 +0800 Subject: [PATCH 04/19] Track verify-timeout success and expose in progress/summary logs. - ForwardVerifier now records last_timeout_success when eager forward passes are skipped due to subprocess timeout. - GraphNetAgent propagates this flag via last_timeout_success attribute. - parallel_extract worker reports timeout_success per model. - PROGRESS line format: success=xx%(timeout_success=xx)% - Summary and per-GPU stats also include timeout counts/rates. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/graph_net_agent.py | 10 +++++++ graph_net/agent/parallel_extract.py | 27 ++++++++++++++++--- .../agent/sample_verifier/forward_verifier.py | 23 +++++++++++----- 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 4aa07b996..206899ae3 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -91,6 +91,9 @@ def __init__( LLMCodeFixer(timeout=llm_timeout) if llm_retry else None ) + # Track whether the last verify succeeded only because of timeout skip + self.last_timeout_success = False + def extract_sample(self, model_id: str) -> ExtractionStatus: """ Execute complete sample extraction pipeline from HuggingFace model ID. @@ -108,6 +111,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: ExtractionStatus.EXTRACT_FAILED – extraction (or pre-extraction) failed ExtractionStatus.ERROR – unexpected error """ + self.last_timeout_success = False try: self.logger.info(f"Starting extraction for model: {model_id}") @@ -134,6 +138,12 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: self.logger.error("Sample verification failed") return ExtractionStatus.VERIFY_FAILED + if getattr(self.sample_verifier, "last_timeout_success", False): + self.last_timeout_success = True + self.logger.info( + f"Sample verification for {model_id} passed via timeout skip" + ) + self.logger.info(f"Successfully extracted sample for {model_id}") return ExtractionStatus.OK diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py index e29004272..5bd3e4032 100644 --- a/graph_net/agent/parallel_extract.py +++ b/graph_net/agent/parallel_extract.py @@ -215,16 +215,21 @@ def worker_fn( status = agent.extract_sample(model_id) elapsed = time.time() - t0 ok = status == ExtractionStatus.OK + timeout_success = getattr(agent, "last_timeout_success", False) label = "OK" if ok else status.name.replace("_", " ") + if ok and timeout_success: + label = "OK(timeout)" print(f"{prefix} {label} {model_id} ({elapsed:.1f}s)", flush=True) result_dict["success"] = ok result_dict["status"] = status.value + result_dict["timeout_success"] = timeout_success except Exception as e: elapsed = time.time() - t0 print(f"{prefix} ERROR {model_id}: {e} ({elapsed:.1f}s)", flush=True) result_dict["success"] = False result_dict["status"] = ExtractionStatus.ERROR.value result_dict["error"] = str(e) + result_dict["timeout_success"] = False result_dict["elapsed"] = round(elapsed, 2) result_dict["timestamp"] = datetime.now().isoformat() @@ -249,6 +254,7 @@ def _print_summary(results: Dict) -> None: details = results.get("details", []) total = len(details) success = sum(1 for d in details if d.get("success")) + timeout_success = sum(1 for d in details if d.get("timeout_success")) extract_success = sum( 1 for d in details @@ -257,25 +263,29 @@ def _print_summary(results: Dict) -> None: ) failed = total - success rate = (success / total * 100) if total else 0.0 + timeout_rate = (timeout_success / total * 100) if total else 0.0 extract_rate = (extract_success / total * 100) if total else 0.0 print("\n" + "=" * 60) print("[SUMMARY] Parallel Extraction Summary") print("=" * 60) print(f" Total : {total}") print(f" Success : {success} (verify ok)") + print(f" Timeout : {timeout_success} (verify skipped by timeout)") print(f" Extract : {extract_success} (graph extracted)") print(f" Failed : {failed}") - print(f" Rate : {rate:.2f}% (overall)") + print(f" Rate : {rate:.2f}% (overall, timeout_success={timeout_rate:.2f}%)") print(f" Extract : {extract_rate:.2f}% (extraction only)") # Per-GPU breakdown gpu_stats: Dict[int, Dict] = {} for d in details: g = d.get("gpu", -1) if g not in gpu_stats: - gpu_stats[g] = {"total": 0, "success": 0, "extract": 0} + gpu_stats[g] = {"total": 0, "success": 0, "extract": 0, "timeout": 0} gpu_stats[g]["total"] += 1 if d.get("success"): gpu_stats[g]["success"] += 1 + if d.get("timeout_success"): + gpu_stats[g]["timeout"] += 1 if d.get("status") in ( ExtractionStatus.OK.value, ExtractionStatus.VERIFY_FAILED.value, @@ -288,9 +298,11 @@ def _print_summary(results: Dict) -> None: gs = gpu_stats[g] gr = (gs["success"] / gs["total"] * 100) if gs["total"] else 0.0 er = (gs["extract"] / gs["total"] * 100) if gs["total"] else 0.0 + tr = (gs["timeout"] / gs["total"] * 100) if gs["total"] else 0.0 print( f" {label} {g}: success={gs['success']}/{gs['total']} ({gr:.1f}%), " - f"extract={gs['extract']}/{gs['total']} ({er:.1f}%)" + f"extract={gs['extract']}/{gs['total']} ({er:.1f}%), " + f"timeout={gs['timeout']}/{gs['total']} ({tr:.1f}%)" ) print("=" * 60) @@ -472,6 +484,7 @@ def main() -> int: details.append(entry) done = len(details) ok_so_far = sum(1 for d in details if d.get("success")) + timeout_so_far = sum(1 for d in details if d.get("timeout_success")) extract_ok_so_far = sum( 1 for d in details @@ -480,7 +493,7 @@ def main() -> int: ) print( f"[PROGRESS] {done}/{len(model_ids)} done, " - f"success={ok_so_far/done*100:.1f}%, " + f"success={ok_so_far/done*100:.1f}%(timeout_success={timeout_so_far/done*100:.1f}%), " f"extract={extract_ok_so_far/done*100:.1f}%", flush=True, ) @@ -496,6 +509,7 @@ def main() -> int: end_time = datetime.now() success_count = sum(1 for d in details if d.get("success")) + timeout_success_count = sum(1 for d in details if d.get("timeout_success")) extract_success_count = sum( 1 for d in details @@ -510,14 +524,19 @@ def main() -> int: "workspace": workspace, "total": len(details), "success": success_count, + "timeout_success": timeout_success_count, "extract_success": extract_success_count, "failed": len(details) - success_count, "success_rate": 0.0, + "timeout_success_rate": 0.0, "extract_success_rate": 0.0, "details": details, } if results["total"] > 0: results["success_rate"] = round(results["success"] / results["total"] * 100, 2) + results["timeout_success_rate"] = round( + results["timeout_success"] / results["total"] * 100, 2 + ) results["extract_success_rate"] = round( results["extract_success"] / results["total"] * 100, 2 ) diff --git a/graph_net/agent/sample_verifier/forward_verifier.py b/graph_net/agent/sample_verifier/forward_verifier.py index 7f6bf7f8f..f8f91a742 100644 --- a/graph_net/agent/sample_verifier/forward_verifier.py +++ b/graph_net/agent/sample_verifier/forward_verifier.py @@ -50,6 +50,7 @@ def __init__(self, timeout: int = 300): self._basic = BasicSampleVerifier() self.timeout = timeout if timeout is not None else 300 self.logger = logging.getLogger(self.__class__.__name__) + self.last_timeout_success = False def verify(self, sample_dir: Path) -> bool: """ @@ -61,6 +62,7 @@ def verify(self, sample_dir: Path) -> bool: Returns: True if all checks pass, False otherwise """ + self.last_timeout_success = False try: # Stage 1: file structure check if not self._basic.verify(sample_dir): @@ -72,16 +74,25 @@ def verify(self, sample_dir: Path) -> bool: targets = subgraph_dirs if subgraph_dirs else [sample_dir] for target in targets: - if not self._run_forward(target): + ok, is_timeout = self._run_forward(target) + if not ok: return False + if is_timeout: + self.last_timeout_success = True return True except Exception as e: raise VerificationError(f"Forward verification failed: {e}") from e - def _run_forward(self, model_path: Path) -> bool: - """Run an eager forward pass on one model directory in a subprocess.""" + def _run_forward(self, model_path: Path) -> tuple[bool, bool]: + """Run an eager forward pass on one model directory in a subprocess. + + Returns: + (success, is_timeout): success=True means the check passed; + is_timeout=True means it passed only because + the subprocess timed out (treated as skip). + """ self.logger.info(f"Forward verify (eager): {model_path.name}") try: result = subprocess.run( @@ -92,15 +103,15 @@ def _run_forward(self, model_path: Path) -> bool: ) if result.returncode == 0: self.logger.info(f"Forward verify OK: {model_path.name}") - return True + return True, False else: self.logger.warning( f"Forward verify FAIL: {model_path.name}\n{result.stderr[-2000:]}" ) - return False + return False, False except subprocess.TimeoutExpired: self.logger.warning( f"Forward verify TIMEOUT ({self.timeout}s): {model_path.name}, " "treating as pass (skip verification for large models on CPU)" ) - return True + return True, True From 69d3826d3e1d1072203a82301f284c7d376a817b Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 18 May 2026 14:24:32 +0800 Subject: [PATCH 05/19] Improve prompt. --- .../agent/code_generator/llm_code_fixer.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index 0d576317a..ecaa01576 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -33,6 +33,7 @@ 3. 设备选择固定写法:device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4. 只允许使用 torch、transformers、graph_net 及 Python 标准库(os/pathlib/json 等) 5. 只输出代码块,格式:```python\\n...代码...\\n```,禁止输出任何说明文字 +6. 脚本必须简洁:禁止添加未要求的错误处理、fallback 逻辑、文件系统遍历或冗余注释。只修复导致报错的输入构造或调用方式,保持行数与原始脚本接近 ## 【输入构造规范 - 按 model_type 选择对应方案】 @@ -228,6 +229,26 @@ def fix( # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _compact_script(script: str) -> str: + """Remove blank lines and pure comment lines to shrink prompt size.""" + lines = script.splitlines() + compacted = [] + for line in lines: + stripped = line.strip() + if stripped == "" or stripped.startswith("#"): + continue + compacted.append(line.rstrip()) + return "\n".join(compacted) + + @staticmethod + def _truncate_error(error_msg: str, max_chars: int = 1200) -> str: + if len(error_msg) <= max_chars: + return error_msg + # Keep tail (usually contains the actual error) + head for context + half = max_chars // 2 + return error_msg[:half] + "\n... (truncated) ...\n" + error_msg[-half:] + def _build_prompt( self, original_script: str, @@ -240,6 +261,15 @@ def _build_prompt( model_dir_str = str(model_dir).replace("\\", "/") system = _SYSTEM_PROMPT.format(name=safe_name) key_fields = self._extract_key_fields(model_dir) + + # Compact script to reduce prompt bloat (keep structure, drop empty/comment lines) + compact_script = self._compact_script(original_script) + # If still very long, fall back to raw script so we don't lose critical logic + if len(compact_script) < len(original_script) * 0.3: + compact_script = original_script + + truncated_error = self._truncate_error(error_msg) + return ( f"{system}\n\n" f"---\n\n" @@ -247,12 +277,12 @@ def _build_prompt( f"### 模型信息\n" f"- model_id: `{model_id}`\n" f"- config_dir: `{model_dir_str}`\n" - f"- 关键配置字段(优先以此为准):\n```json\n{key_fields}\n```\n\n" - f"### config.json(完整参考)\n```json\n{config_json}\n```\n\n" - f"### 失败脚本\n```python\n{original_script}\n```\n\n" - f"### 错误信息\n```\n{error_msg}\n```\n\n" + f"- 关键配置字段:\n```json\n{key_fields}\n```\n\n" + f"### 失败脚本\n```python\n{compact_script}\n```\n\n" + f"### 错误信息\n```\n{truncated_error}\n```\n\n" f"### 输出要求\n" - f"直接输出修复后的完整脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明:" + f"直接输出修复后的完整脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明。" + f"脚本必须简洁,禁止添加未要求的 fallback 或文件遍历代码:" ) def _call_ducc(self, prompt: str) -> str: From 42b27d53fd33fd5ec3fade3a268a6faa03f0c237 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 18 May 2026 15:34:40 +0800 Subject: [PATCH 06/19] feat(agent): add orphan worker cleanup to prevent GPU leak on SIGKILL - Add ProcessGroupTracker class to track active child process groups spawned by SubprocessGraphExtractor, enabling bulk kill via SIGKILL. - Add orphan watcher daemon thread in parallel_extract worker_fn: detects when parent dies (ppid == 1) and kills all tracked child process groups, then exits via os._exit(1) to avoid Python-level cleanup delays that could block GPU memory release. Co-Authored-By: Claude Opus 4.6 --- .../subprocess_graph_extractor.py | 55 ++++++++++++++++++- graph_net/agent/parallel_extract.py | 25 +++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py index 3b001693b..7415e1986 100644 --- a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py +++ b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py @@ -5,6 +5,7 @@ import signal import subprocess import sys +import threading import time from pathlib import Path from typing import Optional @@ -18,6 +19,54 @@ HASH_DIR_LENGTH = 40 # SHA1 hash length ERROR_MSG_MAX_LINES = 20 # Keep first and last N lines of error messages +# --------------------------------------------------------------------------- +# Active child process group tracking (for orphan worker cleanup) +# --------------------------------------------------------------------------- + + +class ProcessGroupTracker: + """Track and manage active child process groups for clean orphan worker teardown. + + Uses class-level storage so any code path (extractor, orphan watcher, etc.) + can register/unregister/kill without passing instances around. + """ + + _pgids: set[int] = set() + _lock = threading.Lock() + + @classmethod + def register(cls, pgid: int) -> None: + with cls._lock: + cls._pgids.add(pgid) + + @classmethod + def unregister(cls, pgid: int) -> None: + with cls._lock: + cls._pgids.discard(pgid) + + @classmethod + def kill_all(cls, sig: int = signal.SIGKILL) -> None: + """Kill all tracked process groups and clear the registry.""" + with cls._lock: + pgids = list(cls._pgids) + for pgid in pgids: + try: + os.killpg(pgid, sig) + except (ProcessLookupError, PermissionError, OSError): + pass + with cls._lock: + cls._pgids.clear() + + @classmethod + def is_empty(cls) -> bool: + with cls._lock: + return len(cls._pgids) == 0 + + +def kill_all_active_children() -> None: + """Convenience alias for backward compatibility.""" + ProcessGroupTracker.kill_all() + class SubprocessGraphExtractor(BaseGraphExtractor): """Extractor that runs script in subprocess""" @@ -75,18 +124,22 @@ def extract(self, code_path: Path, model_id: str) -> Path: # 用新进程组,方便整组 kill(避免遗留孙进程占显存) start_new_session=True, ) + pgid = os.getpgid(proc.pid) + ProcessGroupTracker.register(pgid) try: stdout, stderr = proc.communicate(timeout=self.timeout) except subprocess.TimeoutExpired: # 先 kill 整个进程组,确保 GPU 显存释放 try: - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + os.killpg(pgid, signal.SIGKILL) except ProcessLookupError: proc.kill() proc.communicate() # 回收僵尸进程 raise ExtractionError( f"Script execution timed out after {self.timeout} seconds" ) + finally: + ProcessGroupTracker.unregister(pgid) if proc.returncode != 0: error_msg = self._format_error_message(stderr or stdout) diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py index 5bd3e4032..3b29c4344 100644 --- a/graph_net/agent/parallel_extract.py +++ b/graph_net/agent/parallel_extract.py @@ -171,6 +171,31 @@ def worker_fn( flush=True, ) + # Orphan watcher: if main process is killed with SIGKILL, worker becomes + # orphaned (ppid == 1). Detect this and kill all child process groups to + # prevent GPU memory leaks from run_model.py subprocesses. + import threading + + def _orphan_watcher(): + while True: + time.sleep(5) + if os.getppid() == 1: + print( + f"{prefix} Parent died (orphaned), cleaning up child processes...", + flush=True, + ) + # Multiple rounds to catch any late-starting children + for _ in range(5): + from graph_net.agent.graph_extractor.subprocess_graph_extractor import ( + kill_all_active_children, + ) + + kill_all_active_children() + time.sleep(1) + os._exit(1) + + threading.Thread(target=_orphan_watcher, daemon=True).start() + try: agent = GraphNetAgent( workspace=workspace, From 612392aef429dd9dc1dbf96442c50cf35a44371a Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 18 May 2026 17:56:41 +0800 Subject: [PATCH 07/19] feat(agent): add error classification and smart LLM retry - Rename exceptions to match subdirectory names: AnalysisError -> MetadataAnalysisError CodeGenError -> CodeGenerationError ExtractionError -> GraphExtractionError VerificationError -> SampleVerificationError - Add error_category to exceptions with default_category class attrs - Categorize errors at throw sites (404/403, config missing, script timeout, output missing, forward verify failed, etc.) - Introduce GraphExtractionErrorClassifier for type-safe classification - Smart LLM retry: only retry SCRIPT_EXECUTION_FAILED; skip retry for timeouts, model_not_found, model_forbidden, and LLM infra errors Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/code_generator/base.py | 2 +- .../agent/code_generator/llm_code_fixer.py | 22 ++- .../code_generator/template_generator.py | 4 +- graph_net/agent/graph_extractor/base.py | 2 +- .../subprocess_graph_extractor.py | 23 +-- graph_net/agent/graph_net_agent.py | 54 ++++-- graph_net/agent/metadata_analyzer/base.py | 2 +- .../config_metadata_analyzer.py | 21 ++- .../model_fetcher/huggingface_fetcher.py | 11 ++ graph_net/agent/sample_verifier/base.py | 2 +- .../sample_verifier/basic_sample_verifier.py | 7 +- .../agent/sample_verifier/forward_verifier.py | 7 +- graph_net/agent/utils/__init__.py | 16 +- graph_net/agent/utils/error_classifier.py | 160 ++++++++++++++++++ graph_net/agent/utils/exceptions.py | 75 ++++++-- 15 files changed, 338 insertions(+), 70 deletions(-) create mode 100644 graph_net/agent/utils/error_classifier.py diff --git a/graph_net/agent/code_generator/base.py b/graph_net/agent/code_generator/base.py index d574a5170..2771c95e4 100644 --- a/graph_net/agent/code_generator/base.py +++ b/graph_net/agent/code_generator/base.py @@ -28,6 +28,6 @@ def generate( Path to generated script file Raises: - CodeGenError: If code generation fails + CodeGenerationError: If code generation fails """ pass diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index ecaa01576..48b041aaa 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Optional -from graph_net.agent.utils.exceptions import CodeGenError +from graph_net.agent.utils.exceptions import CodeGenerationError # Candidate binary names / paths to search for ducc CLI _DUCC_CANDIDATES = [ @@ -177,7 +177,7 @@ def fix( Args: script_path: Path to the (failed) script to fix - error_msg: Captured stderr / ExtractionError message + error_msg: Captured stderr / GraphExtractionError message model_dir: Local model directory (contains config.json) model_id: HuggingFace model ID (e.g. 'prajjwal1/bert-tiny') output_dir: Directory where the fixed script should be written @@ -187,10 +187,10 @@ def fix( Path to the fixed script (run_model_llm_1.py / run_model_llm_2.py) Raises: - CodeGenError: If LLM call fails or returns no valid code + CodeGenerationError: If LLM call fails or returns no valid code """ if not self.available: - raise CodeGenError( + raise CodeGenerationError( "ducc/claude binary not available; cannot perform LLM fix." ) @@ -214,7 +214,7 @@ def fix( code = _extract_code_block(llm_output) if not code: - raise CodeGenError( + raise CodeGenerationError( f"LLM response contained no Python code block.\n" f"Response (first 500 chars):\n{llm_output[:500]}" ) @@ -311,17 +311,21 @@ def _call_ducc(self, prompt: str) -> str: timeout=self.timeout, ) except subprocess.TimeoutExpired: - raise CodeGenError(f"ducc -p timed out after {self.timeout}s") + raise CodeGenerationError( + f"ducc -p timed out after {self.timeout}s", + error_category="llm_timeout", + ) if result.returncode != 0: - raise CodeGenError( + raise CodeGenerationError( f"ducc -p exited with code {result.returncode}.\n" - f"stderr: {result.stderr[:500]}" + f"stderr: {result.stderr[:500]}", + error_category="llm_exit_error", ) output = result.stdout.strip() if not output: - raise CodeGenError("ducc -p returned empty output.") + raise CodeGenerationError("ducc -p returned empty output.") return output diff --git a/graph_net/agent/code_generator/template_generator.py b/graph_net/agent/code_generator/template_generator.py index a3332f695..c3832c274 100644 --- a/graph_net/agent/code_generator/template_generator.py +++ b/graph_net/agent/code_generator/template_generator.py @@ -5,7 +5,7 @@ from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata from graph_net.agent.code_generator.base import BaseCodeGenerator -from graph_net.agent.utils.exceptions import CodeGenError +from graph_net.agent.utils.exceptions import CodeGenerationError # Constants for safe vocab size calculation DEFAULT_VOCAB_SIZE = 30522 @@ -57,7 +57,7 @@ def generate( return script_path except Exception as e: - raise CodeGenError(f"Failed to generate code: {e}") from e + raise CodeGenerationError(f"Failed to generate code: {e}") from e @staticmethod def _model_short_name(model_id: str) -> str: diff --git a/graph_net/agent/graph_extractor/base.py b/graph_net/agent/graph_extractor/base.py index 362451cc7..798112bcb 100644 --- a/graph_net/agent/graph_extractor/base.py +++ b/graph_net/agent/graph_extractor/base.py @@ -20,6 +20,6 @@ def extract(self, code_path: Path, model_id: str) -> Path: Path to extracted sample directory Raises: - ExtractionError: If extraction fails + GraphExtractionError: If extraction fails """ pass diff --git a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py index 7415e1986..40b195f50 100644 --- a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py +++ b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py @@ -11,7 +11,7 @@ from typing import Optional from graph_net.agent.graph_extractor.base import BaseGraphExtractor -from graph_net.agent.utils.exceptions import ExtractionError +from graph_net.agent.utils.exceptions import GraphExtractionError # Constants DEFAULT_TIMEOUT = 1000 # ~17 minutes for large models @@ -93,7 +93,7 @@ def extract(self, code_path: Path, model_id: str) -> Path: Path to extracted sample directory Raises: - ExtractionError: If extraction fails + GraphExtractionError: If extraction fails """ try: # Get GraphNet root directory for PYTHONPATH @@ -135,35 +135,38 @@ def extract(self, code_path: Path, model_id: str) -> Path: except ProcessLookupError: proc.kill() proc.communicate() # 回收僵尸进程 - raise ExtractionError( - f"Script execution timed out after {self.timeout} seconds" + raise GraphExtractionError( + f"Script execution timed out after {self.timeout} seconds", + error_category="script_timeout", ) finally: ProcessGroupTracker.unregister(pgid) if proc.returncode != 0: error_msg = self._format_error_message(stderr or stdout) - raise ExtractionError( + raise GraphExtractionError( f"Script execution failed with return code {proc.returncode}.\n" f"Command: {sys.executable} {code_path}\n" - f"Error output:\n{error_msg}" + f"Error output:\n{error_msg}", + error_category="script_execution_failed", ) # Find output directory using multiple strategies output_dir = self._find_output_dir_robust(model_id) if not output_dir or not output_dir.exists(): - raise ExtractionError( + raise GraphExtractionError( f"Output directory not found for model: {model_id}.\n" f"Searched in workspace: {self.workspace}\n" - f"Please check if the extraction script executed successfully." + f"Please check if the extraction script executed successfully.", + error_category="output_dir_not_found", ) return output_dir - except ExtractionError: + except GraphExtractionError: raise except Exception as e: - raise ExtractionError(f"Failed to extract graph: {e}") from e + raise GraphExtractionError(f"Failed to extract graph: {e}") from e def _format_error_message(self, error_msg: str) -> str: """Format error message, truncating if too long""" diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 66617a63a..296c9de55 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -13,11 +13,15 @@ from graph_net.agent.code_generator.llm_code_fixer import LLMCodeFixer from graph_net.agent.graph_extractor import SubprocessGraphExtractor from graph_net.agent.model_fetcher import HFFetcher +from graph_net.agent.utils.error_classifier import ( + GraphExtractionErrorCategory, + GraphExtractionErrorClassifier, +) from graph_net.agent.utils.exceptions import ( - AnalysisError, - CodeGenError, - ExtractionError, - VerificationError, + MetadataAnalysisError, + CodeGenerationError, + GraphExtractionError, + SampleVerificationError, ) from graph_net.agent.utils.logger import setup_logger from graph_net.agent.utils.workspace_manager import WorkspaceManager @@ -94,6 +98,9 @@ def __init__( # Track whether the last verify succeeded only because of timeout skip self.last_timeout_success = False + # Error classifier for post-run reporting + self.error_classifier = GraphExtractionErrorClassifier() + def extract_sample(self, model_id: str) -> ExtractionStatus: """ Execute complete sample extraction pipeline from HuggingFace model ID. @@ -123,7 +130,12 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: # ── First attempt (template script) ────────────────────────── try: sample_dir = self._extract_graph(script_path, model_id) - except ExtractionError as first_err: + except GraphExtractionError as first_err: + if not self._is_llm_fixable_error(first_err): + self.logger.warning( + f"Extraction error is not fixable by LLM, skipping retry: {first_err}" + ) + raise first_err sample_dir = self._llm_retry( first_err, script_path, model_dir, model_id ) @@ -137,6 +149,10 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: if not self.sample_verifier.verify(sample_dir): self.logger.error("Sample verification failed") + self.error_classifier.classify_and_record( + model_id, + Exception("Sample verification failed"), + ) return ExtractionStatus.VERIFY_FAILED if getattr(self.sample_verifier, "last_timeout_success", False): @@ -148,19 +164,37 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: self.logger.info(f"Successfully extracted sample for {model_id}") return ExtractionStatus.OK - except VerificationError as e: + except SampleVerificationError as e: self.logger.error(f"Extraction failed for {model_id}: {e}") + self.error_classifier.classify_and_record(model_id, e) return ExtractionStatus.VERIFY_FAILED - except (AnalysisError, CodeGenError, ExtractionError) as e: + except (MetadataAnalysisError, CodeGenerationError, GraphExtractionError) as e: self.logger.error(f"Extraction failed for {model_id}: {e}") + self.error_classifier.classify_and_record(model_id, e) return ExtractionStatus.EXTRACT_FAILED except Exception as e: self.logger.error(f"Unexpected error for {model_id}: {e}", exc_info=True) + self.error_classifier.classify_and_record(model_id, e) return ExtractionStatus.ERROR + @staticmethod + def _is_llm_fixable_error(err: GraphExtractionError) -> bool: + """Decide whether an extraction error is worth retrying with LLM. + + Only allow LLM retry for script logic errors (non-zero return code). + All other categories (timeout, infrastructure, missing model, etc.) + are not fixable by rewriting the script. + """ + from graph_net.agent.utils.error_classifier import ( + GraphExtractionErrorClassifier, + ) + + category = GraphExtractionErrorClassifier.classify_from_exception(err) + return category == GraphExtractionErrorCategory.SCRIPT_EXECUTION_FAILED + def _llm_retry( self, - first_err: ExtractionError, + first_err: GraphExtractionError, script_path: Path, model_dir: Path, model_id: str, @@ -172,7 +206,7 @@ def _llm_retry( Returns: (sample_dir, successful_script_path) - Raises ExtractionError if LLM fix is unavailable or both attempts fail. + Raises GraphExtractionError if LLM fix is unavailable or both attempts fail. """ if self.llm_fixer is None or not self.llm_fixer.available: self.logger.warning( @@ -201,7 +235,7 @@ def _llm_retry( try: sample_dir = self._extract_graph(fixed_path, model_id) return sample_dir - except ExtractionError as retry_err: + except GraphExtractionError as retry_err: err = retry_err current_script = fixed_path # 第二次把上一次修复的脚本+新报错再喂给 LLM diff --git a/graph_net/agent/metadata_analyzer/base.py b/graph_net/agent/metadata_analyzer/base.py index 8e0a9955d..dd39ec304 100644 --- a/graph_net/agent/metadata_analyzer/base.py +++ b/graph_net/agent/metadata_analyzer/base.py @@ -21,6 +21,6 @@ def analyze(self, model_dir: Path) -> ModelMetadata: ModelMetadata object containing model information Raises: - AnalysisError: If analysis fails + MetadataAnalysisError: If analysis fails """ pass diff --git a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py index b15df28b0..2e6a57ff4 100644 --- a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py +++ b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py @@ -6,7 +6,7 @@ from graph_net.agent.metadata_analyzer.base import BaseMetadataAnalyzer from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata -from graph_net.agent.utils.exceptions import AnalysisError +from graph_net.agent.utils.exceptions import MetadataAnalysisError # Cap sequence length to avoid OOM: attention is O(n²), graph extraction @@ -47,11 +47,14 @@ def analyze(self, model_dir: Path) -> ModelMetadata: ModelMetadata object Raises: - AnalysisError: If analysis fails + MetadataAnalysisError: If analysis fails """ config_path = model_dir / "config.json" if not config_path.exists(): - raise AnalysisError(f"config.json not found in {model_dir}") + raise MetadataAnalysisError( + f"config.json not found in {model_dir}", + error_category="config_not_found", + ) try: # Primary path: load via AutoConfig to get a rich PretrainedConfig object @@ -101,11 +104,17 @@ def analyze(self, model_dir: Path) -> ModelMetadata: architecture_type=arch_type, ) except json.JSONDecodeError as e: - raise AnalysisError(f"Failed to parse config.json: {e}") from e - except AnalysisError: + raise MetadataAnalysisError( + f"Failed to parse config.json: {e}", + error_category="config_parse_error", + ) from e + except MetadataAnalysisError: raise except Exception as e: - raise AnalysisError(f"Failed to analyze model: {e}") from e + raise MetadataAnalysisError( + f"Failed to analyze model: {e}", + error_category="metadata_analysis_failed", + ) from e # ------------------------------------------------------------------ # Architecture classification diff --git a/graph_net/agent/model_fetcher/huggingface_fetcher.py b/graph_net/agent/model_fetcher/huggingface_fetcher.py index 903e4984b..f6e369843 100644 --- a/graph_net/agent/model_fetcher/huggingface_fetcher.py +++ b/graph_net/agent/model_fetcher/huggingface_fetcher.py @@ -142,6 +142,17 @@ def download(self, model_id: str) -> Path: f"Failed to download model {model_id} after {self.max_retries} retries: {e}" ) from e except Exception as e: + err_text = str(e) + if "404 Client Error" in err_text: + raise ModelFetchError( + f"Failed to download model {model_id}: {e}", + error_category="model_not_found", + ) from e + if "403 Client Error" in err_text: + raise ModelFetchError( + f"Failed to download model {model_id}: {e}", + error_category="model_forbidden", + ) from e raise ModelFetchError( f"Failed to download model {model_id}: {e}" ) from e diff --git a/graph_net/agent/sample_verifier/base.py b/graph_net/agent/sample_verifier/base.py index 8e3ff87a0..d1cf99f03 100644 --- a/graph_net/agent/sample_verifier/base.py +++ b/graph_net/agent/sample_verifier/base.py @@ -19,6 +19,6 @@ def verify(self, sample_dir: Path) -> bool: True if sample is valid, False otherwise Raises: - VerificationError: If verification process fails + SampleVerificationError: If verification process fails """ pass diff --git a/graph_net/agent/sample_verifier/basic_sample_verifier.py b/graph_net/agent/sample_verifier/basic_sample_verifier.py index 70e50e20d..71fc53fa5 100644 --- a/graph_net/agent/sample_verifier/basic_sample_verifier.py +++ b/graph_net/agent/sample_verifier/basic_sample_verifier.py @@ -4,7 +4,7 @@ from pathlib import Path from graph_net.agent.sample_verifier.base import BaseSampleVerifier -from graph_net.agent.utils.exceptions import VerificationError +from graph_net.agent.utils.exceptions import SampleVerificationError class BasicSampleVerifier(BaseSampleVerifier): @@ -38,4 +38,7 @@ def verify(self, sample_dir: Path) -> bool: return True except Exception as e: - raise VerificationError(f"Verification failed: {e}") from e + raise SampleVerificationError( + f"Verification failed: {e}", + error_category="sample_incomplete", + ) from e diff --git a/graph_net/agent/sample_verifier/forward_verifier.py b/graph_net/agent/sample_verifier/forward_verifier.py index f8f91a742..eb7937940 100644 --- a/graph_net/agent/sample_verifier/forward_verifier.py +++ b/graph_net/agent/sample_verifier/forward_verifier.py @@ -7,7 +7,7 @@ from graph_net.agent.sample_verifier.base import BaseSampleVerifier from graph_net.agent.sample_verifier.basic_sample_verifier import BasicSampleVerifier -from graph_net.agent.utils.exceptions import VerificationError +from graph_net.agent.utils.exceptions import SampleVerificationError # Inline eager runner — executed in a subprocess to isolate CUDA state. # Loads GraphModule from model.py, reconstructs tensors from weight_meta.py, @@ -83,7 +83,10 @@ def verify(self, sample_dir: Path) -> bool: return True except Exception as e: - raise VerificationError(f"Forward verification failed: {e}") from e + raise SampleVerificationError( + f"Forward verification failed: {e}", + error_category="forward_verify_failed", + ) from e def _run_forward(self, model_path: Path) -> tuple[bool, bool]: """Run an eager forward pass on one model directory in a subprocess. diff --git a/graph_net/agent/utils/__init__.py b/graph_net/agent/utils/__init__.py index e8ebbd410..fb94a5f6c 100644 --- a/graph_net/agent/utils/__init__.py +++ b/graph_net/agent/utils/__init__.py @@ -3,17 +3,17 @@ from graph_net.agent.utils.exceptions import ( AgentError, ModelFetchError, - AnalysisError, - CodeGenError, - ExtractionError, - VerificationError, + MetadataAnalysisError, + CodeGenerationError, + GraphExtractionError, + SampleVerificationError, ) __all__ = [ "AgentError", "ModelFetchError", - "AnalysisError", - "CodeGenError", - "ExtractionError", - "VerificationError", + "MetadataAnalysisError", + "CodeGenerationError", + "GraphExtractionError", + "SampleVerificationError", ] diff --git a/graph_net/agent/utils/error_classifier.py b/graph_net/agent/utils/error_classifier.py new file mode 100644 index 000000000..865433fae --- /dev/null +++ b/graph_net/agent/utils/error_classifier.py @@ -0,0 +1,160 @@ +"""Error classification for extraction failures. + +Classification is driven entirely by the exception's `error_category` +attribute (set at the raise-site). No string keyword matching is +performed here — keywords belong in the code that raises the exception. +""" + +from collections import defaultdict +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional + + +class GraphExtractionErrorCategory(str, Enum): + """Known categories of extraction failure.""" + + # Pre-extraction failures + MODEL_NOT_FOUND = "model_not_found" + MODEL_FORBIDDEN = "model_forbidden" + MODEL_DOWNLOAD_ERROR = "model_download_error" + + # Script generation / analysis failures + ANALYSIS_ERROR = "analysis_error" + CODE_GEN_ERROR = "code_gen_error" + + # Script execution failures + SCRIPT_EXECUTION_FAILED = "script_execution_failed" + SCRIPT_TIMEOUT = "script_timeout" + OUTPUT_DIR_NOT_FOUND = "output_dir_not_found" + + # LLM retry failures + LLM_TIMEOUT = "llm_timeout" + LLM_EXIT_ERROR = "llm_exit_error" + + # Post-extraction failures + SAMPLE_INCOMPLETE = "sample_incomplete" # missing files / bad JSON + FORWARD_VERIFY_FAILED = "forward_verify_failed" # eager forward pass failed + VERIFICATION_TIMEOUT = "verification_timeout" + + # Catch-all + UNKNOWN = "unknown" + + +@dataclass +class ErrorRecord: + """Single error occurrence.""" + + model_id: str + category: GraphExtractionErrorCategory + message: str + + +class GraphExtractionErrorClassifier: + """Classify extraction errors and keep per-model records. + + Usage: + classifier = GraphExtractionErrorClassifier() + category = classifier.classify_from_exception(exc) + classifier.record(model_id, category, str(exc)) + + # After run + report = classifier.summary() + """ + + def __init__(self): + self.records: List[ErrorRecord] = [] + self._by_model: Dict[str, ErrorRecord] = {} + + @staticmethod + def classify_from_exception(exc: Exception) -> GraphExtractionErrorCategory: + """Classify from the exception's ``error_category`` attribute. + + Falls back to UNKNOWN when the attribute is missing or invalid. + """ + raw = getattr(exc, "error_category", None) + if raw is not None: + try: + return GraphExtractionErrorCategory(raw) + except ValueError: + pass + return GraphExtractionErrorCategory.UNKNOWN + + def record( + self, + model_id: str, + category: GraphExtractionErrorCategory, + message: str, + ) -> None: + """Store one error record.""" + rec = ErrorRecord(model_id=model_id, category=category, message=message) + self.records.append(rec) + self._by_model[model_id] = rec + + def classify_and_record( + self, + model_id: str, + exception: Exception, + ) -> GraphExtractionErrorCategory: + """Convenience: classify from exception and store.""" + category = self.classify_from_exception(exception) + self.record(model_id, category, str(exception)) + return category + + def get_record(self, model_id: str) -> Optional[ErrorRecord]: + return self._by_model.get(model_id) + + def get_models_by_category( + self, category: GraphExtractionErrorCategory + ) -> List[str]: + return [rec.model_id for rec in self.records if rec.category == category] + + def summary(self) -> Dict[str, object]: + counts: Dict[str, int] = defaultdict(int) + per_category: Dict[str, List[str]] = defaultdict(list) + for rec in self.records: + cat_name = rec.category.value + counts[cat_name] += 1 + per_category[cat_name].append(rec.model_id) + return { + "total_errors": len(self.records), + "category_counts": dict(counts), + "models_per_category": dict(per_category), + } + + def markdown_report(self) -> str: + lines = ["# Extraction Error Report", ""] + lines.append(f"**Total errors**: {len(self.records)}") + lines.append("") + + counts: Dict[GraphExtractionErrorCategory, int] = defaultdict(int) + per_cat: Dict[GraphExtractionErrorCategory, List[ErrorRecord]] = defaultdict( + list + ) + for rec in self.records: + counts[rec.category] += 1 + per_cat[rec.category].append(rec) + + lines.append("## Summary by Category") + lines.append("") + lines.append("| Category | Count |") + lines.append("|----------|-------|") + for cat, cnt in sorted(counts.items(), key=lambda x: -x[1]): + lines.append(f"| {cat.value} | {cnt} |") + lines.append("") + + lines.append("## Details") + lines.append("") + for cat, recs in sorted(per_cat.items(), key=lambda x: -len(x[1])): + lines.append(f"### {cat.value} ({len(recs)})") + lines.append("") + for rec in recs[:10]: + msg = ( + rec.message[:120] + "..." if len(rec.message) > 120 else rec.message + ) + lines.append(f"- `{rec.model_id}`: {msg}") + if len(recs) > 10: + lines.append(f"- ... and {len(recs) - 10} more") + lines.append("") + + return "\n".join(lines) diff --git a/graph_net/agent/utils/exceptions.py b/graph_net/agent/utils/exceptions.py index 95f4f88f3..60376592b 100644 --- a/graph_net/agent/utils/exceptions.py +++ b/graph_net/agent/utils/exceptions.py @@ -1,37 +1,78 @@ -"""Custom exception classes for Agent""" +"""Custom exception classes for Agent. + +Each exception may carry an `error_category` string so that +error_classifier.py can route without string matching. +""" + +from typing import Optional class AgentError(Exception): - """Base exception for Agent errors""" + """Base exception for Agent errors. + + Subclasses can set `default_category` so that raise-sites do not + need to repeat the category when the default is sufficient. + """ - pass + default_category: Optional[str] = None + + def __init__( + self, + message: str, + error_category: Optional[str] = None, + ): + super().__init__(message) + self.error_category = error_category or self.default_category class ModelFetchError(AgentError): - """Raised when model fetching fails""" + """Raised when model fetching fails. + + Default: model_download_error. + Raise-sites should override for 404 (model_not_found) + or 403 (model_forbidden). + """ + + default_category = "model_download_error" + + +class MetadataAnalysisError(AgentError): + """Raised when model metadata/config analysis fails. + + Covers config missing, JSON parse errors, and unsupported architectures. + """ - pass + default_category = "metadata_analysis_error" -class AnalysisError(AgentError): - """Raised when model analysis fails""" +class CodeGenerationError(AgentError): + """Raised when code generation fails. - pass + Default: code_gen_error. + Raise-sites should override for LLM-specific failures + (llm_timeout / llm_exit_error). + """ + default_category = "code_gen_error" -class CodeGenError(AgentError): - """Raised when code generation fails""" - pass +class GraphExtractionError(AgentError): + """Raised when graph extraction fails. + Default: unknown — raise-sites MUST override with one of: + - script_execution_failed + - script_timeout + - output_dir_not_found + """ -class ExtractionError(AgentError): - """Raised when graph extraction fails""" + default_category = "unknown" - pass +class SampleVerificationError(AgentError): + """Raised when sample verification fails. -class VerificationError(AgentError): - """Raised when sample verification fails""" + Default: verification_failed. + Raise-sites may override with verification_timeout. + """ - pass + default_category = "verification_failed" From 1f159c991b41b33972464e28f3a93a1305703254 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 18 May 2026 19:04:25 +0800 Subject: [PATCH 08/19] refactor(agent): move error category enum into exceptions.py - Move GraphExtractionErrorCategory from error_classifier.py to exceptions.py so type definitions live with the data they describe. - Change default_category and error_category from raw strings to GraphExtractionErrorCategory enum values. - Add missing categories: CONFIG_NOT_FOUND, CONFIG_PARSE_ERROR, METADATA_ANALYSIS_FAILED, VERIFICATION_FAILED. - Update all raise-sites to pass enum members instead of strings. - Remove redundant inline import in _is_llm_fixable_error. Co-Authored-By: Claude Opus 4.6 --- .../agent/code_generator/llm_code_fixer.py | 9 ++- .../subprocess_graph_extractor.py | 11 ++- graph_net/agent/graph_net_agent.py | 10 +-- .../config_metadata_analyzer.py | 11 ++- .../model_fetcher/huggingface_fetcher.py | 9 ++- .../sample_verifier/basic_sample_verifier.py | 7 +- .../agent/sample_verifier/forward_verifier.py | 7 +- graph_net/agent/utils/error_classifier.py | 31 +------- graph_net/agent/utils/exceptions.py | 74 ++++++++++++++----- 9 files changed, 94 insertions(+), 75 deletions(-) diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index 48b041aaa..2cd124f9d 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -9,7 +9,10 @@ from pathlib import Path from typing import Optional -from graph_net.agent.utils.exceptions import CodeGenerationError +from graph_net.agent.utils.exceptions import ( + CodeGenerationError, + GraphExtractionErrorCategory, +) # Candidate binary names / paths to search for ducc CLI _DUCC_CANDIDATES = [ @@ -313,14 +316,14 @@ def _call_ducc(self, prompt: str) -> str: except subprocess.TimeoutExpired: raise CodeGenerationError( f"ducc -p timed out after {self.timeout}s", - error_category="llm_timeout", + error_category=GraphExtractionErrorCategory.LLM_TIMEOUT, ) if result.returncode != 0: raise CodeGenerationError( f"ducc -p exited with code {result.returncode}.\n" f"stderr: {result.stderr[:500]}", - error_category="llm_exit_error", + error_category=GraphExtractionErrorCategory.LLM_EXIT_ERROR, ) output = result.stdout.strip() diff --git a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py index 40b195f50..e5d4506ff 100644 --- a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py +++ b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py @@ -11,7 +11,10 @@ from typing import Optional from graph_net.agent.graph_extractor.base import BaseGraphExtractor -from graph_net.agent.utils.exceptions import GraphExtractionError +from graph_net.agent.utils.exceptions import ( + GraphExtractionError, + GraphExtractionErrorCategory, +) # Constants DEFAULT_TIMEOUT = 1000 # ~17 minutes for large models @@ -137,7 +140,7 @@ def extract(self, code_path: Path, model_id: str) -> Path: proc.communicate() # 回收僵尸进程 raise GraphExtractionError( f"Script execution timed out after {self.timeout} seconds", - error_category="script_timeout", + error_category=GraphExtractionErrorCategory.SCRIPT_TIMEOUT, ) finally: ProcessGroupTracker.unregister(pgid) @@ -148,7 +151,7 @@ def extract(self, code_path: Path, model_id: str) -> Path: f"Script execution failed with return code {proc.returncode}.\n" f"Command: {sys.executable} {code_path}\n" f"Error output:\n{error_msg}", - error_category="script_execution_failed", + error_category=GraphExtractionErrorCategory.SCRIPT_EXECUTION_FAILED, ) # Find output directory using multiple strategies @@ -159,7 +162,7 @@ def extract(self, code_path: Path, model_id: str) -> Path: f"Output directory not found for model: {model_id}.\n" f"Searched in workspace: {self.workspace}\n" f"Please check if the extraction script executed successfully.", - error_category="output_dir_not_found", + error_category=GraphExtractionErrorCategory.OUTPUT_DIR_NOT_FOUND, ) return output_dir diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 296c9de55..da43ae6ab 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -13,11 +13,9 @@ from graph_net.agent.code_generator.llm_code_fixer import LLMCodeFixer from graph_net.agent.graph_extractor import SubprocessGraphExtractor from graph_net.agent.model_fetcher import HFFetcher -from graph_net.agent.utils.error_classifier import ( - GraphExtractionErrorCategory, - GraphExtractionErrorClassifier, -) +from graph_net.agent.utils.error_classifier import GraphExtractionErrorClassifier from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, MetadataAnalysisError, CodeGenerationError, GraphExtractionError, @@ -185,10 +183,6 @@ def _is_llm_fixable_error(err: GraphExtractionError) -> bool: All other categories (timeout, infrastructure, missing model, etc.) are not fixable by rewriting the script. """ - from graph_net.agent.utils.error_classifier import ( - GraphExtractionErrorClassifier, - ) - category = GraphExtractionErrorClassifier.classify_from_exception(err) return category == GraphExtractionErrorCategory.SCRIPT_EXECUTION_FAILED diff --git a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py index 2e6a57ff4..3e6213306 100644 --- a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py +++ b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py @@ -6,7 +6,10 @@ from graph_net.agent.metadata_analyzer.base import BaseMetadataAnalyzer from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata -from graph_net.agent.utils.exceptions import MetadataAnalysisError +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + MetadataAnalysisError, +) # Cap sequence length to avoid OOM: attention is O(n²), graph extraction @@ -53,7 +56,7 @@ def analyze(self, model_dir: Path) -> ModelMetadata: if not config_path.exists(): raise MetadataAnalysisError( f"config.json not found in {model_dir}", - error_category="config_not_found", + error_category=GraphExtractionErrorCategory.CONFIG_NOT_FOUND, ) try: @@ -106,14 +109,14 @@ def analyze(self, model_dir: Path) -> ModelMetadata: except json.JSONDecodeError as e: raise MetadataAnalysisError( f"Failed to parse config.json: {e}", - error_category="config_parse_error", + error_category=GraphExtractionErrorCategory.CONFIG_PARSE_ERROR, ) from e except MetadataAnalysisError: raise except Exception as e: raise MetadataAnalysisError( f"Failed to analyze model: {e}", - error_category="metadata_analysis_failed", + error_category=GraphExtractionErrorCategory.METADATA_ANALYSIS_FAILED, ) from e # ------------------------------------------------------------------ diff --git a/graph_net/agent/model_fetcher/huggingface_fetcher.py b/graph_net/agent/model_fetcher/huggingface_fetcher.py index f6e369843..f463858f1 100644 --- a/graph_net/agent/model_fetcher/huggingface_fetcher.py +++ b/graph_net/agent/model_fetcher/huggingface_fetcher.py @@ -11,7 +11,10 @@ snapshot_download = None from graph_net.agent.model_fetcher.base import BaseModelFetcher -from graph_net.agent.utils.exceptions import ModelFetchError +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + ModelFetchError, +) # Network-related exceptions that are worth retrying _RETRYABLE_ERRORS = ( @@ -146,12 +149,12 @@ def download(self, model_id: str) -> Path: if "404 Client Error" in err_text: raise ModelFetchError( f"Failed to download model {model_id}: {e}", - error_category="model_not_found", + error_category=GraphExtractionErrorCategory.MODEL_NOT_FOUND, ) from e if "403 Client Error" in err_text: raise ModelFetchError( f"Failed to download model {model_id}: {e}", - error_category="model_forbidden", + error_category=GraphExtractionErrorCategory.MODEL_FORBIDDEN, ) from e raise ModelFetchError( f"Failed to download model {model_id}: {e}" diff --git a/graph_net/agent/sample_verifier/basic_sample_verifier.py b/graph_net/agent/sample_verifier/basic_sample_verifier.py index 71fc53fa5..fa00800b5 100644 --- a/graph_net/agent/sample_verifier/basic_sample_verifier.py +++ b/graph_net/agent/sample_verifier/basic_sample_verifier.py @@ -4,7 +4,10 @@ from pathlib import Path from graph_net.agent.sample_verifier.base import BaseSampleVerifier -from graph_net.agent.utils.exceptions import SampleVerificationError +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + SampleVerificationError, +) class BasicSampleVerifier(BaseSampleVerifier): @@ -40,5 +43,5 @@ def verify(self, sample_dir: Path) -> bool: except Exception as e: raise SampleVerificationError( f"Verification failed: {e}", - error_category="sample_incomplete", + error_category=GraphExtractionErrorCategory.SAMPLE_INCOMPLETE, ) from e diff --git a/graph_net/agent/sample_verifier/forward_verifier.py b/graph_net/agent/sample_verifier/forward_verifier.py index eb7937940..b1b9ee3fc 100644 --- a/graph_net/agent/sample_verifier/forward_verifier.py +++ b/graph_net/agent/sample_verifier/forward_verifier.py @@ -7,7 +7,10 @@ from graph_net.agent.sample_verifier.base import BaseSampleVerifier from graph_net.agent.sample_verifier.basic_sample_verifier import BasicSampleVerifier -from graph_net.agent.utils.exceptions import SampleVerificationError +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + SampleVerificationError, +) # Inline eager runner — executed in a subprocess to isolate CUDA state. # Loads GraphModule from model.py, reconstructs tensors from weight_meta.py, @@ -85,7 +88,7 @@ def verify(self, sample_dir: Path) -> bool: except Exception as e: raise SampleVerificationError( f"Forward verification failed: {e}", - error_category="forward_verify_failed", + error_category=GraphExtractionErrorCategory.FORWARD_VERIFY_FAILED, ) from e def _run_forward(self, model_path: Path) -> tuple[bool, bool]: diff --git a/graph_net/agent/utils/error_classifier.py b/graph_net/agent/utils/error_classifier.py index 865433fae..400f7548e 100644 --- a/graph_net/agent/utils/error_classifier.py +++ b/graph_net/agent/utils/error_classifier.py @@ -7,38 +7,9 @@ from collections import defaultdict from dataclasses import dataclass -from enum import Enum from typing import Dict, List, Optional - -class GraphExtractionErrorCategory(str, Enum): - """Known categories of extraction failure.""" - - # Pre-extraction failures - MODEL_NOT_FOUND = "model_not_found" - MODEL_FORBIDDEN = "model_forbidden" - MODEL_DOWNLOAD_ERROR = "model_download_error" - - # Script generation / analysis failures - ANALYSIS_ERROR = "analysis_error" - CODE_GEN_ERROR = "code_gen_error" - - # Script execution failures - SCRIPT_EXECUTION_FAILED = "script_execution_failed" - SCRIPT_TIMEOUT = "script_timeout" - OUTPUT_DIR_NOT_FOUND = "output_dir_not_found" - - # LLM retry failures - LLM_TIMEOUT = "llm_timeout" - LLM_EXIT_ERROR = "llm_exit_error" - - # Post-extraction failures - SAMPLE_INCOMPLETE = "sample_incomplete" # missing files / bad JSON - FORWARD_VERIFY_FAILED = "forward_verify_failed" # eager forward pass failed - VERIFICATION_TIMEOUT = "verification_timeout" - - # Catch-all - UNKNOWN = "unknown" +from graph_net.agent.utils.exceptions import GraphExtractionErrorCategory @dataclass diff --git a/graph_net/agent/utils/exceptions.py b/graph_net/agent/utils/exceptions.py index 60376592b..d274338f7 100644 --- a/graph_net/agent/utils/exceptions.py +++ b/graph_net/agent/utils/exceptions.py @@ -1,12 +1,48 @@ """Custom exception classes for Agent. -Each exception may carry an `error_category` string so that +Each exception may carry an `error_category` so that error_classifier.py can route without string matching. """ +from enum import Enum from typing import Optional +class GraphExtractionErrorCategory(str, Enum): + """Known categories of extraction failure.""" + + # Pre-extraction failures + MODEL_NOT_FOUND = "model_not_found" + MODEL_FORBIDDEN = "model_forbidden" + MODEL_DOWNLOAD_ERROR = "model_download_error" + + # Config / metadata analysis failures + CONFIG_NOT_FOUND = "config_not_found" + CONFIG_PARSE_ERROR = "config_parse_error" + METADATA_ANALYSIS_FAILED = "metadata_analysis_failed" + + # Script generation failures + CODE_GEN_ERROR = "code_gen_error" + + # Script execution failures + SCRIPT_EXECUTION_FAILED = "script_execution_failed" + SCRIPT_TIMEOUT = "script_timeout" + OUTPUT_DIR_NOT_FOUND = "output_dir_not_found" + + # LLM retry failures + LLM_TIMEOUT = "llm_timeout" + LLM_EXIT_ERROR = "llm_exit_error" + + # Post-extraction failures + SAMPLE_INCOMPLETE = "sample_incomplete" + FORWARD_VERIFY_FAILED = "forward_verify_failed" + VERIFICATION_TIMEOUT = "verification_timeout" + VERIFICATION_FAILED = "verification_failed" + + # Catch-all + UNKNOWN = "unknown" + + class AgentError(Exception): """Base exception for Agent errors. @@ -14,12 +50,12 @@ class AgentError(Exception): need to repeat the category when the default is sufficient. """ - default_category: Optional[str] = None + default_category: Optional[GraphExtractionErrorCategory] = None def __init__( self, message: str, - error_category: Optional[str] = None, + error_category: Optional[GraphExtractionErrorCategory] = None, ): super().__init__(message) self.error_category = error_category or self.default_category @@ -28,12 +64,12 @@ def __init__( class ModelFetchError(AgentError): """Raised when model fetching fails. - Default: model_download_error. - Raise-sites should override for 404 (model_not_found) - or 403 (model_forbidden). + Default: MODEL_DOWNLOAD_ERROR. + Raise-sites should override for 404 (MODEL_NOT_FOUND) + or 403 (MODEL_FORBIDDEN). """ - default_category = "model_download_error" + default_category = GraphExtractionErrorCategory.MODEL_DOWNLOAD_ERROR class MetadataAnalysisError(AgentError): @@ -42,37 +78,37 @@ class MetadataAnalysisError(AgentError): Covers config missing, JSON parse errors, and unsupported architectures. """ - default_category = "metadata_analysis_error" + default_category = GraphExtractionErrorCategory.METADATA_ANALYSIS_FAILED class CodeGenerationError(AgentError): """Raised when code generation fails. - Default: code_gen_error. + Default: CODE_GEN_ERROR. Raise-sites should override for LLM-specific failures - (llm_timeout / llm_exit_error). + (LLM_TIMEOUT / LLM_EXIT_ERROR). """ - default_category = "code_gen_error" + default_category = GraphExtractionErrorCategory.CODE_GEN_ERROR class GraphExtractionError(AgentError): """Raised when graph extraction fails. - Default: unknown — raise-sites MUST override with one of: - - script_execution_failed - - script_timeout - - output_dir_not_found + Default: UNKNOWN — raise-sites MUST override with one of: + - SCRIPT_EXECUTION_FAILED + - SCRIPT_TIMEOUT + - OUTPUT_DIR_NOT_FOUND """ - default_category = "unknown" + default_category = GraphExtractionErrorCategory.UNKNOWN class SampleVerificationError(AgentError): """Raised when sample verification fails. - Default: verification_failed. - Raise-sites may override with verification_timeout. + Default: VERIFICATION_FAILED. + Raise-sites may override with VERIFICATION_TIMEOUT. """ - default_category = "verification_failed" + default_category = GraphExtractionErrorCategory.VERIFICATION_FAILED From 26bed31add7050f52f61296f5a8192f4c25b3bcf Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 18 May 2026 19:11:17 +0800 Subject: [PATCH 09/19] refactor(agent): replace markdown_report with plain list output - Rename markdown_report() to report_lines() and return List[str] instead of a markdown-formatted string. - Remove markdown syntax (#, |, ---, etc.) for simpler consumption. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/utils/error_classifier.py | 30 +++++++++++------------ 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/graph_net/agent/utils/error_classifier.py b/graph_net/agent/utils/error_classifier.py index 400f7548e..9b5b5b62d 100644 --- a/graph_net/agent/utils/error_classifier.py +++ b/graph_net/agent/utils/error_classifier.py @@ -93,9 +93,12 @@ def summary(self) -> Dict[str, object]: "models_per_category": dict(per_category), } - def markdown_report(self) -> str: - lines = ["# Extraction Error Report", ""] - lines.append(f"**Total errors**: {len(self.records)}") + def report_lines(self) -> List[str]: + """Plain-text report as a list of lines (no markdown).""" + lines: List[str] = [] + lines.append("Extraction Error Report") + lines.append("") + lines.append(f"Total errors: {len(self.records)}") lines.append("") counts: Dict[GraphExtractionErrorCategory, int] = defaultdict(int) @@ -106,26 +109,21 @@ def markdown_report(self) -> str: counts[rec.category] += 1 per_cat[rec.category].append(rec) - lines.append("## Summary by Category") - lines.append("") - lines.append("| Category | Count |") - lines.append("|----------|-------|") + lines.append("Summary by Category:") for cat, cnt in sorted(counts.items(), key=lambda x: -x[1]): - lines.append(f"| {cat.value} | {cnt} |") + lines.append(f" {cat.value}: {cnt}") lines.append("") - lines.append("## Details") - lines.append("") + lines.append("Details:") for cat, recs in sorted(per_cat.items(), key=lambda x: -len(x[1])): - lines.append(f"### {cat.value} ({len(recs)})") - lines.append("") + lines.append(f" {cat.value} ({len(recs)}):") for rec in recs[:10]: msg = ( rec.message[:120] + "..." if len(rec.message) > 120 else rec.message ) - lines.append(f"- `{rec.model_id}`: {msg}") + lines.append(f" - {rec.model_id}: {msg}") if len(recs) > 10: - lines.append(f"- ... and {len(recs) - 10} more") - lines.append("") + lines.append(f" - ... and {len(recs) - 10} more") + lines.append("") - return "\n".join(lines) + return lines From 9ba299fe89f117f461e7991e262b4727c1303958 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 18 May 2026 20:42:53 +0800 Subject: [PATCH 10/19] fix(agent): change llm_timeout default back to 600 Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/graph_net_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index da43ae6ab..6537d1f33 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -45,7 +45,7 @@ def __init__( llm_retry: bool = True, extract_timeout: Optional[int] = None, verify_timeout: Optional[int] = None, - llm_timeout: int = 900, + llm_timeout: int = 600, ): """ Initialize GraphNet Agent From b46b9b48bdcec51de587e9dfeddcca97d6ac1008 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 18 May 2026 20:51:29 +0800 Subject: [PATCH 11/19] feat(agent): expose error_category to parallel_extract results - Include ModelFetchError in explicit except clause so it returns EXTRACT_FAILED with proper classification instead of ERROR. - Worker reads agent.error_classifier.get_record(model_id) after extract_sample() and forwards error_category + error_message in result_dict so the main process can see which stage failed. - Fallback: if extract_sample itself raises unexpectedly, read error_category attribute from the raw exception. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/graph_net_agent.py | 8 +++++++- graph_net/agent/parallel_extract.py | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 6537d1f33..dfe8b3a10 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -19,6 +19,7 @@ MetadataAnalysisError, CodeGenerationError, GraphExtractionError, + ModelFetchError, SampleVerificationError, ) from graph_net.agent.utils.logger import setup_logger @@ -166,7 +167,12 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: self.logger.error(f"Extraction failed for {model_id}: {e}") self.error_classifier.classify_and_record(model_id, e) return ExtractionStatus.VERIFY_FAILED - except (MetadataAnalysisError, CodeGenerationError, GraphExtractionError) as e: + except ( + ModelFetchError, + MetadataAnalysisError, + CodeGenerationError, + GraphExtractionError, + ) as e: self.logger.error(f"Extraction failed for {model_id}: {e}") self.error_classifier.classify_and_record(model_id, e) return ExtractionStatus.EXTRACT_FAILED diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py index 3b29c4344..ebd5f547e 100644 --- a/graph_net/agent/parallel_extract.py +++ b/graph_net/agent/parallel_extract.py @@ -248,6 +248,11 @@ def _orphan_watcher(): result_dict["success"] = ok result_dict["status"] = status.value result_dict["timeout_success"] = timeout_success + # Expose error category so the main process can decide policy + rec = agent.error_classifier.get_record(model_id) + if rec is not None: + result_dict["error_category"] = rec.category.value + result_dict["error_message"] = rec.message except Exception as e: elapsed = time.time() - t0 print(f"{prefix} ERROR {model_id}: {e} ({elapsed:.1f}s)", flush=True) @@ -255,6 +260,9 @@ def _orphan_watcher(): result_dict["status"] = ExtractionStatus.ERROR.value result_dict["error"] = str(e) result_dict["timeout_success"] = False + raw_cat = getattr(e, "error_category", None) + if raw_cat is not None: + result_dict["error_category"] = str(raw_cat) result_dict["elapsed"] = round(elapsed, 2) result_dict["timestamp"] = datetime.now().isoformat() From ca59ba9ee1187485681e42ebff06f35f8bd913e2 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 19 May 2026 10:43:34 +0800 Subject: [PATCH 12/19] docs(agent): translate Chinese comments to English Keep source comments and docstrings in English while preserving existing LLM prompt text. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/code_generator/llm_code_fixer.py | 4 ++-- .../graph_extractor/subprocess_graph_extractor.py | 10 +++++----- graph_net/agent/graph_net_agent.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index 2cd124f9d..0fda02552 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -349,7 +349,7 @@ def _read_config(model_dir: Path) -> str: @staticmethod def _extract_key_fields(model_dir: Path) -> str: - """从 config.json 提取对输入构造最关键的字段,方便 LLM 直接读取。""" + """Extract the most important input-construction fields from config.json for the LLM.""" config_path = model_dir / "config.json" if not config_path.exists(): return "{}" @@ -395,7 +395,7 @@ def _extract_key_fields(model_dir: Path) -> str: "sample_rate", ] result = {k: cfg[k] for k in keys if k in cfg} - # 对嵌套 config 只取关键字段 + # Keep only key fields from nested configs. for nested in ("audio_config", "vision_config", "text_config"): if isinstance(result.get(nested), dict): result[nested] = { diff --git a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py index e5d4506ff..b579d7140 100644 --- a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py +++ b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py @@ -124,7 +124,7 @@ def extract(self, code_path: Path, model_id: str) -> Path: stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - # 用新进程组,方便整组 kill(避免遗留孙进程占显存) + # Start a new process group so the whole group can be killed, avoiding orphaned child processes holding GPU memory. start_new_session=True, ) pgid = os.getpgid(proc.pid) @@ -132,12 +132,12 @@ def extract(self, code_path: Path, model_id: str) -> Path: try: stdout, stderr = proc.communicate(timeout=self.timeout) except subprocess.TimeoutExpired: - # 先 kill 整个进程组,确保 GPU 显存释放 + # Kill the entire process group first to ensure GPU memory is released. try: os.killpg(pgid, signal.SIGKILL) except ProcessLookupError: proc.kill() - proc.communicate() # 回收僵尸进程 + proc.communicate() # Reap the zombie process raise GraphExtractionError( f"Script execution timed out after {self.timeout} seconds", error_category=GraphExtractionErrorCategory.SCRIPT_TIMEOUT, @@ -267,8 +267,8 @@ def _find_hash_named_dir(self, workspace_path: Path) -> Optional[Path]: def _is_valid_sample_dir(self, dir_path: Path) -> bool: """Check if a directory is a valid sample directory""" required_files = ["model.py", "graph_net.json"] - # 单图:根目录下有文件 + # Single graph: files exist in the root directory. if all((dir_path / f).exists() for f in required_files): return True - # 多子图:subgraph_* 子目录下有文件 + # Multiple subgraphs: files exist under subgraph_* directories. return any(dir_path.glob("subgraph_*/model.py")) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index dfe8b3a10..5395b2e27 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -237,7 +237,7 @@ def _llm_retry( return sample_dir except GraphExtractionError as retry_err: err = retry_err - current_script = fixed_path # 第二次把上一次修复的脚本+新报错再喂给 LLM + current_script = fixed_path # On the second attempt, feed the previous fixed script and new error back to the LLM raise err @@ -303,7 +303,7 @@ def _extract_graph(self, script_path: Path, model_id: str) -> Path: return sample_dir def _fix_model_name(self, sample_dir: Path, model_id: str) -> None: - """将 graph_net.json 中的 model_name 修正为原始 HuggingFace model_id(org/model)""" + """Update model_name in graph_net.json to the original HuggingFace model_id (org/model).""" for json_path in [ sample_dir / "graph_net.json", *sample_dir.glob("subgraph_*/graph_net.json"), From 99d10891717acceb723d9819425dcce073161aed Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 19 May 2026 11:05:00 +0800 Subject: [PATCH 13/19] fix(agent): skip non-fixable LLM retry errors Stop retrying LLM script fixes when a repaired script fails with a category that cannot be addressed by rewriting the script. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/graph_net_agent.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 5395b2e27..1b1812d8b 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -236,6 +236,12 @@ def _llm_retry( sample_dir = self._extract_graph(fixed_path, model_id) return sample_dir except GraphExtractionError as retry_err: + if not self._is_llm_fixable_error(retry_err): + self.logger.warning( + "LLM-fixed script failed with non-fixable error, " + f"skipping remaining retries: {retry_err}" + ) + raise err = retry_err current_script = fixed_path # On the second attempt, feed the previous fixed script and new error back to the LLM From 2722a9446af16df41bfcf7c6b9bf7c819fe5e3f9 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 19 May 2026 11:06:51 +0800 Subject: [PATCH 14/19] fix(agent): restore CPU verify timeout default Set CPU forward verification timeout back to 600 seconds and keep CLI help text in sync. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/parallel_extract.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py index ebd5f547e..0f9316238 100644 --- a/graph_net/agent/parallel_extract.py +++ b/graph_net/agent/parallel_extract.py @@ -407,7 +407,7 @@ def _parse_args() -> argparse.Namespace: "--verify-timeout", type=int, default=None, - help="Timeout in seconds for forward verification (default: 300 on GPU, 1200 on CPU)", + help="Timeout in seconds for forward verification (default: 300 on GPU, 600 on CPU)", ) parser.add_argument( "--use-llm", @@ -450,9 +450,7 @@ def _resolve_config(args: argparse.Namespace): extract_timeout = ( args.extract_timeout if args.extract_timeout is not None else 2000 ) - verify_timeout = ( - args.verify_timeout if args.verify_timeout is not None else 1200 - ) + verify_timeout = args.verify_timeout if args.verify_timeout is not None else 600 return workspace, gpus, num_workers, extract_timeout, verify_timeout From d93dfdbaf16ea24d2ce81ac518dd2446ffcfb375 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 19 May 2026 11:24:43 +0800 Subject: [PATCH 15/19] feat(agent): precheck HuggingFace model accessibility Use HuggingFace API metadata lookup before download to skip clearly missing or forbidden model repos while preserving download retries for transient failures. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/graph_net_agent.py | 13 ++++ .../model_fetcher/huggingface_fetcher.py | 78 +++++++++++++++---- 2 files changed, 77 insertions(+), 14 deletions(-) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index 1b1812d8b..f95cde938 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -249,6 +249,19 @@ def _llm_retry( def _fetch_model(self, model_id: str) -> Path: """Download model from HuggingFace Hub""" + self.logger.info(f"Checking model repo accessibility: {model_id}") + try: + self.model_fetcher.check_accessible(model_id) + except ModelFetchError as e: + if e.error_category in ( + GraphExtractionErrorCategory.MODEL_NOT_FOUND, + GraphExtractionErrorCategory.MODEL_FORBIDDEN, + ): + raise + self.logger.warning( + f"Model repo precheck failed for {model_id}, continuing to download: {e}" + ) + self.logger.info(f"Fetching model: {model_id}") model_dir = self.model_fetcher.download(model_id) self.logger.info(f"Model downloaded to: {model_dir}") diff --git a/graph_net/agent/model_fetcher/huggingface_fetcher.py b/graph_net/agent/model_fetcher/huggingface_fetcher.py index f463858f1..e524e26c1 100644 --- a/graph_net/agent/model_fetcher/huggingface_fetcher.py +++ b/graph_net/agent/model_fetcher/huggingface_fetcher.py @@ -6,8 +6,9 @@ from typing import Optional try: - from huggingface_hub import snapshot_download + from huggingface_hub import HfApi, snapshot_download except ImportError: + HfApi = None snapshot_download = None from graph_net.agent.model_fetcher.base import BaseModelFetcher @@ -36,7 +37,18 @@ _RETRYABLE_ERRORS = _RETRYABLE_ERRORS + (LocalEntryNotFoundError,) except ImportError: - pass + LocalEntryNotFoundError = None + +try: + from huggingface_hub.errors import ( + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, + ) +except ImportError: + GatedRepoError = None + HfHubHTTPError = None + RepositoryNotFoundError = None class HFFetcher(BaseModelFetcher): @@ -70,6 +82,54 @@ def __init__( # Resolve endpoint: explicit param > env var self.endpoint = endpoint or os.environ.get("HF_ENDPOINT") + def check_accessible(self, model_id: str) -> None: + """Check whether a HuggingFace model repo is reachable without downloading files.""" + if HfApi is None: + raise ModelFetchError( + "huggingface_hub is not installed. " + "Please install it with: pip install huggingface_hub" + ) + + try: + if self.endpoint: + os.environ["HF_ENDPOINT"] = self.endpoint + api = HfApi(endpoint=self.endpoint) + api.model_info( + repo_id=model_id, + repo_type="model", + token=self.token, + files_metadata=False, + ) + except Exception as e: + error_category = self._classify_hf_error(e) + raise ModelFetchError( + f"Model repo is not accessible for {model_id}: {e}", + error_category=error_category, + ) from e + + @staticmethod + def _classify_hf_error(error: Exception) -> GraphExtractionErrorCategory: + """Classify HuggingFace API/download errors into extraction categories.""" + if RepositoryNotFoundError is not None and isinstance( + error, RepositoryNotFoundError + ): + return GraphExtractionErrorCategory.MODEL_NOT_FOUND + if GatedRepoError is not None and isinstance(error, GatedRepoError): + return GraphExtractionErrorCategory.MODEL_FORBIDDEN + if HfHubHTTPError is not None and isinstance(error, HfHubHTTPError): + status_code = getattr(getattr(error, "response", None), "status_code", None) + if status_code == 404: + return GraphExtractionErrorCategory.MODEL_NOT_FOUND + if status_code in (401, 403): + return GraphExtractionErrorCategory.MODEL_FORBIDDEN + + err_text = str(error) + if "404 Client Error" in err_text: + return GraphExtractionErrorCategory.MODEL_NOT_FOUND + if "401 Client Error" in err_text or "403 Client Error" in err_text: + return GraphExtractionErrorCategory.MODEL_FORBIDDEN + return GraphExtractionErrorCategory.MODEL_DOWNLOAD_ERROR + def download(self, model_id: str) -> Path: """ Download model from HuggingFace Hub with retry on network errors. @@ -145,19 +205,9 @@ def download(self, model_id: str) -> Path: f"Failed to download model {model_id} after {self.max_retries} retries: {e}" ) from e except Exception as e: - err_text = str(e) - if "404 Client Error" in err_text: - raise ModelFetchError( - f"Failed to download model {model_id}: {e}", - error_category=GraphExtractionErrorCategory.MODEL_NOT_FOUND, - ) from e - if "403 Client Error" in err_text: - raise ModelFetchError( - f"Failed to download model {model_id}: {e}", - error_category=GraphExtractionErrorCategory.MODEL_FORBIDDEN, - ) from e raise ModelFetchError( - f"Failed to download model {model_id}: {e}" + f"Failed to download model {model_id}: {e}", + error_category=self._classify_hf_error(e), ) from e # Should not reach here, but just in case From 95e4163b0004028ec176270c3905c6eee94bdc31 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 19 May 2026 11:34:07 +0800 Subject: [PATCH 16/19] refactor(agent): constrain LLM fixer output Ask the LLM script fixer to emit a minimal complete run_model.py without comments, fallback logic, helpers, or unrelated validation. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/code_generator/llm_code_fixer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index 0fda02552..2e0fcef2c 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -24,7 +24,7 @@ _SYSTEM_PROMPT = """\ 你是 PyTorch / HuggingFace 模型计算图抽取专家。 -任务:修复一段失败的图抽取脚本,输出完整、可直接运行的 Python 脚本。 +任务:修复一段失败的图抽取脚本,输出完整、可直接运行但最小化的 Python 脚本。 ## 【硬性约束 - 违反即输出无效】 1. 抽取调用格式固定为: @@ -36,7 +36,8 @@ 3. 设备选择固定写法:device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4. 只允许使用 torch、transformers、graph_net 及 Python 标准库(os/pathlib/json 等) 5. 只输出代码块,格式:```python\\n...代码...\\n```,禁止输出任何说明文字 -6. 脚本必须简洁:禁止添加未要求的错误处理、fallback 逻辑、文件系统遍历或冗余注释。只修复导致报错的输入构造或调用方式,保持行数与原始脚本接近 +6. 必须输出完整但最小化的脚本,只保留:必要 import、模型/config 加载、输入 tensor 构造、graph_net.torch.extract(...)、一次 forward 调用 +7. 禁止添加注释、helper 函数、错误处理、try/except、fallback 逻辑、重试逻辑、文件系统遍历、额外校验或无关打印。只修复导致报错的输入构造或调用方式,保持行数尽可能少 ## 【输入构造规范 - 按 model_type 选择对应方案】 @@ -284,8 +285,9 @@ def _build_prompt( f"### 失败脚本\n```python\n{compact_script}\n```\n\n" f"### 错误信息\n```\n{truncated_error}\n```\n\n" f"### 输出要求\n" - f"直接输出修复后的完整脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明。" - f"脚本必须简洁,禁止添加未要求的 fallback 或文件遍历代码:" + f"直接输出修复后的完整最小脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明。" + f"只保留必要 import、模型/config 加载、输入 tensor 构造、extract 调用和一次 forward。" + f"禁止注释、helper、try/except、fallback、重试、文件遍历、额外校验或无关打印。" ) def _call_ducc(self, prompt: str) -> str: From 6744a79110e08a74cd40e329fccc825d65b321ca Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 19 May 2026 11:42:55 +0800 Subject: [PATCH 17/19] fix(agent): restore LLM timeout default Set the default LLM script-fix timeout back to 360 seconds and update the parameter documentation. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/graph_net_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index f95cde938..0cf971807 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -46,7 +46,7 @@ def __init__( llm_retry: bool = True, extract_timeout: Optional[int] = None, verify_timeout: Optional[int] = None, - llm_timeout: int = 600, + llm_timeout: int = 360, ): """ Initialize GraphNet Agent @@ -61,7 +61,7 @@ def __init__( (default None -> 1000s). verify_timeout: Timeout in seconds for forward verification subprocess (default None -> 300s). - llm_timeout: Timeout in seconds for LLM script fix (default: 600). + llm_timeout: Timeout in seconds for LLM script fix (default: 360). """ if workspace is None: workspace = os.environ.get( From 307ca47e342878f9c637e2383edc17b677e7bf46 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 19 May 2026 13:45:06 +0800 Subject: [PATCH 18/19] fix(agent): support current HuggingFace model_info API Remove the unsupported repo_type argument from the model accessibility precheck so it works with the installed huggingface_hub version. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/model_fetcher/huggingface_fetcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graph_net/agent/model_fetcher/huggingface_fetcher.py b/graph_net/agent/model_fetcher/huggingface_fetcher.py index e524e26c1..b53b48452 100644 --- a/graph_net/agent/model_fetcher/huggingface_fetcher.py +++ b/graph_net/agent/model_fetcher/huggingface_fetcher.py @@ -96,7 +96,6 @@ def check_accessible(self, model_id: str) -> None: api = HfApi(endpoint=self.endpoint) api.model_info( repo_id=model_id, - repo_type="model", token=self.token, files_metadata=False, ) From 25a0746792e2c7f2dea83975ba189504e464339f Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 19 May 2026 14:22:16 +0800 Subject: [PATCH 19/19] refactor(agent): generate minimal extraction scripts Simplify generated run_model.py files and structure inputs as a dictionary literal to reduce retry prompt size. Co-Authored-By: Claude Opus 4.6 --- .../code_generator/template_generator.py | 86 +++++-------------- 1 file changed, 22 insertions(+), 64 deletions(-) diff --git a/graph_net/agent/code_generator/template_generator.py b/graph_net/agent/code_generator/template_generator.py index c3832c274..e2e051b31 100644 --- a/graph_net/agent/code_generator/template_generator.py +++ b/graph_net/agent/code_generator/template_generator.py @@ -79,34 +79,15 @@ def _generate_standard_code( short_name = self._model_short_name(model_metadata.model_id) code = f"""import torch -try: - from transformers import AutoModel -except ImportError: - raise ImportError("transformers is required. Install with: pip install transformers") - import graph_net -def main(): - # Load model -{self._indent(load_code, 4)} - - # Prepare inputs -{self._indent(input_code, 4)} - - # Extract graph - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device).eval() - - # Move inputs to same device as model - inputs = {{k: v.to(device) for k, v in inputs.items()}} +{load_code} - wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = model.to(device).eval() +{input_code} - with torch.no_grad(): - wrapped(**inputs) - -if __name__ == "__main__": - main() +graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()(**inputs) """ return code @@ -118,37 +99,19 @@ def _generate_diffusion_code( input_code = self._generate_input_code(model_metadata) short_name = self._model_short_name(model_metadata.model_id) - # Diffusion model forward takes positional args, not **inputs dict code = f"""import torch -try: - from diffusers import UNet2DConditionModel -except ImportError: - raise ImportError("diffusers is required. Install with: pip install diffusers") - import graph_net -def main(): - # Load model -{self._indent(load_code, 4)} - - # Prepare inputs -{self._indent(input_code, 4)} +{load_code} - # Extract graph - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device).eval() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = model.to(device).eval() +{input_code} - sample = inputs["sample"].to(device) - timestep = inputs["timestep"].to(device) - encoder_hidden_states = inputs["encoder_hidden_states"].to(device) - - wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval() - - with torch.no_grad(): - wrapped(sample, timestep, encoder_hidden_states) - -if __name__ == "__main__": - main() +sample = inputs["sample"] +timestep = inputs["timestep"] +encoder_hidden_states = inputs["encoder_hidden_states"] +graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()(sample, timestep, encoder_hidden_states) """ return code @@ -199,7 +162,7 @@ def _generate_model_loader( def _generate_input_code(self, model_metadata: ModelMetadata) -> str: """Generate input tensor construction code based on model metadata""" - lines = ["inputs = {}"] + lines = ["inputs = {"] for name, shape in model_metadata.input_shapes.items(): dtype = model_metadata.input_dtypes.get(name, "int64") @@ -209,18 +172,18 @@ def _generate_input_code(self, model_metadata: ModelMetadata) -> str: if dtype == "int64": if "input_ids" in name.lower() or "decoder_input_ids" in name.lower(): safe_vocab_size = self._calculate_safe_vocab_size(model_metadata) - lines.append( - f'inputs["{name}"] = torch.randint(0, {safe_vocab_size}, {shape_tuple}, dtype={torch_dtype})' + value = ( + f"torch.randint(0, {safe_vocab_size}, {shape_tuple}, " + f"dtype={torch_dtype}).to(device)" ) else: - lines.append( - f'inputs["{name}"] = torch.ones({shape_tuple}, dtype={torch_dtype})' - ) + value = f"torch.ones({shape_tuple}, dtype={torch_dtype}).to(device)" else: - lines.append( - f'inputs["{name}"] = torch.randn({shape_tuple}, dtype={torch_dtype})' - ) + value = f"torch.randn({shape_tuple}, dtype={torch_dtype}).to(device)" + + lines.append(f' "{name}": {value},') + lines.append("}") return "\n".join(lines) def _get_torch_dtype(self, dtype: str) -> str: @@ -261,8 +224,3 @@ def _is_large_vocab_model_type(self, model_type: str) -> bool: or "xlm_roberta" in model_type or "roberta" in model_type ) - - def _indent(self, text: str, spaces: int) -> str: - """Indent text by specified spaces""" - indent = " " * spaces - return "\n".join(indent + line for line in text.split("\n"))