From 116fff767d23df678db54e8f020995e7abe330e9 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 00:25:43 +0200 Subject: [PATCH 01/21] refactor(install): split install.py into submodules (#1078 Commit 1a) Split commands/install.py (2085 -> 765 lines) into six install/ submodules: pkg_resolution, apm_packages, install_cmd_phases, mcp_handler, manifest_rollback, cli_context. Reduces statements/branches/complexity and file-length below the issue #1078 Stage 2 targets while preserving the public command surface and @patch test seams via module-level re-exports. - RULE A: moved helpers re-exported from commands.install so @patch('apm_cli.commands.install.*') targets keep resolving. - RULE B: moved call sites route patched symbols through the commands.install module object (_m.) so monkeypatches intercept across module boundaries. - Removed the duplicated inline local-bundle block; wired the extracted _try_local_bundle_install helper (R0801 duplication guardrail green). Behaviour-preserving: 16645 unit + acceptance and 927 install integration tests pass; ruff (std + final thresholds), ruff format, and R0801 all green. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/commands/install.py | 1570 ++------------------- src/apm_cli/install/apm_packages.py | 458 ++++++ src/apm_cli/install/cli_context.py | 56 + src/apm_cli/install/install_cmd_phases.py | 208 +++ src/apm_cli/install/manifest_rollback.py | 59 + src/apm_cli/install/mcp_handler.py | 134 ++ src/apm_cli/install/pkg_resolution.py | 486 +++++++ 7 files changed, 1526 insertions(+), 1445 deletions(-) create mode 100644 src/apm_cli/install/apm_packages.py create mode 100644 src/apm_cli/install/cli_context.py create mode 100644 src/apm_cli/install/install_cmd_phases.py create mode 100644 src/apm_cli/install/manifest_rollback.py create mode 100644 src/apm_cli/install/mcp_handler.py create mode 100644 src/apm_cli/install/pkg_resolution.py diff --git a/src/apm_cli/commands/install.py b/src/apm_cli/commands/install.py index fbeca6a6a..54738fafe 100644 --- a/src/apm_cli/commands/install.py +++ b/src/apm_cli/commands/install.py @@ -1,184 +1,113 @@ """APM install command and dependency installation engine.""" -import builtins +import builtins # noqa: I001 import contextlib -import dataclasses import os import sys import time -from collections.abc import Callable from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional import click -from apm_cli.install.artifactory_resolver import _resolve_artifactory_boundary from apm_cli.install.errors import ( AuthenticationError, DirectDependencyError, - FrozenInstallError, - PolicyViolationError, ) -from apm_cli.install.gitlab_resolver import _try_resolve_gitlab_direct_shorthand +from apm_cli.install.gitlab_resolver import _try_resolve_gitlab_direct_shorthand # noqa: F401 -if TYPE_CHECKING: - from apm_cli.install.plan import UpdatePlan - -# Re-export the pre-deploy security scan so that bare-name call sites inside -# this module and ``tests/unit/test_install_scanning.py``'s direct import -# (``from apm_cli.commands.install import _pre_deploy_security_scan``) keep -# working without modification. +# Re-export _pre_deploy_security_scan for bare-name call sites + test imports. from apm_cli.install.helpers.security_scan import _pre_deploy_security_scan # noqa: F401 from apm_cli.install.insecure_policy import ( InsecureDependencyPolicyError, _allow_insecure_host_callback, - _check_insecure_dependencies, - _collect_insecure_dependency_infos, # noqa: F401 -- re-exported; test_architecture_invariants checks importability - _format_insecure_dependency_requirements, - _format_insecure_dependency_warning, # noqa: F401 -- re-exported; test_architecture_invariants checks importability - _get_insecure_dependency_url, - _guard_transitive_insecure_dependencies, # noqa: F401 -- re-exported; test_architecture_invariants checks importability - _InsecureDependencyInfo, # noqa: F401 -- re-exported; test_architecture_invariants checks importability + _check_insecure_dependencies, # noqa: F401 -- RULE B: apm_packages.py uses _m._check_insecure_dependencies + _collect_insecure_dependency_infos, # noqa: F401 -- test_architecture_invariants checks importability + _format_insecure_dependency_warning, # noqa: F401 -- test_architecture_invariants checks importability + _guard_transitive_insecure_dependencies, # noqa: F401 -- test_architecture_invariants checks importability + _InsecureDependencyInfo, # noqa: F401 -- test_architecture_invariants checks importability ) -# Re-export MCP add/build helpers under their underscore-prefixed legacy -# names. Aliases live in mcp/writer.py and mcp/entry.py respectively. +# Re-export MCP add/build helpers under their underscore-prefixed legacy names. from apm_cli.install.mcp.entry import _build_mcp_entry # noqa: F401 from apm_cli.install.mcp.writer import _add_mcp_to_apm_yml # noqa: F401 from apm_cli.install.package_resolution import ( GIT_PARENT_USER_SCOPE_ERROR, - dependency_reference_to_yaml_entry, - persist_dependency_list_if_changed, - resolve_parsed_dependency_reference, - update_existing_dependency_entry_if_needed, - user_scope_rejection_reason, + persist_dependency_list_if_changed, # noqa: F401 -- RULE A: re-exported for @patch('apm_cli.commands.install.*') + resolve_parsed_dependency_reference, # noqa: F401 -- RULE A: re-exported for @patch('apm_cli.commands.install.*') + user_scope_rejection_reason, # noqa: F401 -- RULE A: re-exported for @patch('apm_cli.commands.install.*') ) from apm_cli.install.package_selection import only_packages_from_validation -# Re-export local-content leaf helpers so that callers inside this module -# (e.g. _install_apm_dependencies) and any future test patches against -# "apm_cli.commands.install._copy_local_package" keep working. +# Re-export helpers for @patch compatibility and test importability checks. from apm_cli.install.phases.local_content import ( - _copy_local_package, # noqa: F401 -- re-exported; test_architecture_invariants checks importability - _has_local_apm_content, # noqa: F401 -- re-exported; test_architecture_invariants checks importability - _project_has_root_primitives, + _copy_local_package, # noqa: F401 + _has_local_apm_content, # noqa: F401 + _project_has_root_primitives, # noqa: F401 -- RULE B: apm_packages.py uses _m._project_has_root_primitives ) - -# Re-export lockfile hash helper so existing call sites and the regression -# test pinned in #762 (test_hash_deployed_is_module_level_and_works) keep -# working via "apm_cli.commands.install._hash_deployed". from apm_cli.install.phases.lockfile import compute_deployed_hashes as _hash_deployed # noqa: F401 - -# Re-export DI-seam helpers from the install services module so that test -# patches against ``apm_cli.commands.install._integrate_*`` keep working. from apm_cli.install.services import ( - _integrate_local_content, # noqa: F401 -- re-exported; test_architecture_invariants checks importability - _integrate_package_primitives, # noqa: F401 -- re-exported; tests import/patch from apm_cli.commands.install + _integrate_local_content, # noqa: F401 + _integrate_package_primitives, # noqa: F401 ) - -# Re-export validation leaf helpers so that existing test patches like -# @patch("apm_cli.commands.install._validate_package_exists") keep working. -# _validate_and_add_packages_to_apm_yml stays here (not moved) because it -# calls _validate_package_exists and _local_path_failure_reason via module- -# level name lookup -- keeping it co-located means @patch on this module -# intercepts those calls without test changes. from apm_cli.install.validation import ( - _local_path_failure_reason, - _local_path_no_markers_hint, # noqa: F401 -- re-exported; test_architecture_invariants checks importability - _validate_package_exists, + _local_path_failure_reason, # noqa: F401 -- RULE B: pkg_resolution.py uses _m._local_path_failure_reason... wait, no it doesn't. Keep for @patch compat. + _local_path_no_markers_hint, # noqa: F401 + _validate_package_exists, # noqa: F401 -- RULE B: pkg_resolution.py uses _m._validate_package_exists ) from apm_cli.utils.diagnostics import DiagnosticCollector # noqa: F401 -from ..constants import ( - APM_YML_FILENAME, - InstallMode, +# Re-export manifest rollback helpers so @patch targets keep working. +from apm_cli.install.manifest_rollback import ( # noqa: F401 + _maybe_rollback_manifest, + _restore_manifest_from_snapshot, ) -from ..core.auth import AuthResolver -from ..core.command_logger import InstallLogger, _ValidationOutcome -from ..core.target_detection import TargetParamType -# MCP --mcp helpers (module-level re-exports for test patches); must stay at -# import time per comments in the original mid-file block. -from ..install.mcp.command import run_mcp_install as _run_mcp_install -from ..install.mcp.conflicts import ( - validate_mcp_conflicts as _validate_mcp_conflicts, +# Re-export InstallContext so @patch and imports in tests keep working. +from apm_cli.install.cli_context import InstallContext + +# Re-export pkg resolution helpers so @patch('apm_cli.commands.install.*') works. +from apm_cli.install.pkg_resolution import ( # noqa: F401 + _check_package_conflicts, + _merge_packages_into_yml, + _resolve_package_references, + _validate_and_add_packages_to_apm_yml, ) -from ..install.mcp.registry import ( - resolve_registry_url as _resolve_registry_url, + +# install() sub-module re-exports kept here for @patch compatibility. +from apm_cli.install.apm_packages import ( # noqa: F401 + _install_apm_dependencies, + _install_apm_packages, + _post_install_summary, ) -from ..install.mcp.registry import ( - validate_mcp_dry_run_entry as _validate_mcp_dry_run_entry, +from apm_cli.install.install_cmd_phases import ( + _compute_argv_pre_dash, + _execute_install_and_summary, + _resolve_protocol_and_fallback, + _resolve_scope_and_paths, + _setup_auth_and_check_manifest, + _try_local_bundle_install, ) +from apm_cli.install.mcp_handler import _McpConnectionParams, _handle_mcp_install + +from ..constants import InstallMode +from ..core.auth import AuthResolver # noqa: F401 -- RULE B: install_cmd_phases.py uses _m.AuthResolver +from ..core.command_logger import InstallLogger, _ValidationOutcome # noqa: F401 +from ..core.target_detection import TargetParamType + +# MCP helpers and console utilities. +from ..install.mcp.command import run_mcp_install as _run_mcp_install # noqa: F401 +from ..install.mcp.conflicts import validate_mcp_conflicts as _validate_mcp_conflicts from ..install.mcp.registry import ( + resolve_registry_url as _resolve_registry_url, # noqa: F401 + validate_mcp_dry_run_entry as _validate_mcp_dry_run_entry, # noqa: F401 validate_registry_url as _validate_registry_url, ) -from ..utils.console import ( # noqa: F401 -- _rich_success re-exported; tests patch commands.install._rich_success +from ..utils.console import ( # noqa: F401 _rich_echo, _rich_error, _rich_info, _rich_success, ) -from ._helpers import ( - _create_minimal_apm_yml, - _get_default_config, -) - -# --------------------------------------------------------------------------- -# Manifest snapshot + rollback (W2-pkg-rollback, #827) -# --------------------------------------------------------------------------- -# When the user runs ``apm install ``, ``_validate_and_add_packages_to_apm_yml`` -# mutates ``apm.yml`` BEFORE the install pipeline runs. If the pipeline fails -# (policy block, download error, etc.) the failed package would stay in -# ``apm.yml`` forever. These helpers snapshot the raw bytes before mutation -# and atomically restore on failure. -# --------------------------------------------------------------------------- - - -def _restore_manifest_from_snapshot( - manifest_path: "Path", - snapshot: bytes, -) -> None: - """Atomically restore ``apm.yml`` from a raw-bytes snapshot. - - Uses temp-file + ``os.replace`` to avoid torn writes, mirroring the - W1 cache atomic-write pattern (``discovery.py``). - """ - import os - import tempfile - - fd, tmp_name = tempfile.mkstemp( - prefix="apm-restore-", - dir=str(manifest_path.parent), - ) - try: - with os.fdopen(fd, "wb") as fh: - fh.write(snapshot) - os.replace(tmp_name, str(manifest_path)) - except Exception: - with contextlib.suppress(OSError): - os.unlink(tmp_name) - raise - - -def _maybe_rollback_manifest( - manifest_path: "Path", - snapshot: "bytes | None", - logger: "InstallLogger", -) -> None: - """Restore ``apm.yml`` from *snapshot* if one was captured, then log. - - No-op when *snapshot* is ``None`` (i.e. the command was not - ``apm install `` or the manifest did not exist before mutation). - """ - if snapshot is None: - return - try: - _restore_manifest_from_snapshot(manifest_path, snapshot) - logger.progress("apm.yml restored to its previous state.") - except Exception: - # Best-effort: if the restore itself fails, warn but don't mask - # the original exception that triggered the rollback. - logger.warning("Failed to restore apm.yml to its previous state.") # CRITICAL: Shadow Python builtins that share names with Click commands @@ -188,74 +117,12 @@ def _maybe_rollback_manifest( # --------------------------------------------------------------------------- -# InstallContext -- parameter bundle for the APM install pipeline +# Argv ``--`` boundary helpers (W3 --mcp flag) +# Test seams; _split_argv_at_double_dash separates pre-``--`` packages from +# the post-``--`` stdio-command argv. # --------------------------------------------------------------------------- -@dataclasses.dataclass -class InstallContext: - """Bundles install command state to reduce function signatures. - - Created by :func:`install` after argument parsing and scope resolution, - then threaded through :func:`_install_apm_packages` and - :func:`_post_install_summary` to avoid long parameter lists. - """ - - scope: Any # InstallScope - manifest_path: "Path" - manifest_display: str - apm_dir: "Path" - project_root: "Path" - logger: Any # InstallLogger - auth_resolver: Any # AuthResolver - verbose: bool - force: bool - dry_run: bool - update: bool - dev: bool - runtime: str | None - exclude: str | None - target: str | None - parallel_downloads: int - allow_insecure: bool - allow_insecure_hosts: tuple - protocol_pref: Any # ProtocolPreference - allow_protocol_fallback: bool - trust_transitive_mcp: bool - no_policy: bool - install_mode: Any # InstallMode - packages: tuple # Original Click packages - refresh: bool = False - only_packages: builtins.list | None = None - manifest_snapshot: bytes | None = None - snapshot_manifest_path: Optional["Path"] = None - legacy_skill_paths: bool = False - frozen: bool = False - plan_callback: "Callable[[UpdatePlan], bool] | None" = None - skill_subset: "builtins.tuple[str, ...] | None" = None - skill_subset_from_cli: bool = False - audit_override: str | None = None - - -# --------------------------------------------------------------------------- -# Argv `--` boundary helpers (W3 --mcp flag) -# --------------------------------------------------------------------------- -# -# Click's ``nargs=-1`` silently swallows the ``--`` separator and merges -# everything after it into the positional argument tuple. For -# ``apm install --mcp foo -- npx -y srv`` we cannot distinguish that from -# ``apm install --mcp foo npx -y srv`` once Click is done parsing. -# -# We therefore inspect ``sys.argv`` ourselves to detect the boundary and -# extract the post-``--`` portion as the stdio command argv. ``--`` IS -# present in ``sys.argv`` even though Click strips it from the parsed -# arguments. The pre-``--`` portion is used to flag conflicts (E1). -# -# ``_get_invocation_argv`` exists as a tiny seam so tests using -# ``CliRunner`` (which does not modify ``sys.argv``) can patch it without -# resorting to ``monkeypatch.setattr('sys.argv', ...)``. - - def _get_invocation_argv(): """Return the process invocation argv. Wrapped for test injection.""" return sys.argv @@ -276,10 +143,10 @@ def _split_argv_at_double_dash(argv): APM_DEPS_AVAILABLE = False _APM_IMPORT_ERROR = None try: - from ..deps.apm_resolver import APMDependencyResolver - from ..deps.lockfile import LockFile, get_lockfile_path, migrate_lockfile_if_needed - from ..integration.mcp_integrator import MCPIntegrator - from ..models.apm_package import APMPackage, DependencyReference + from ..deps.apm_resolver import APMDependencyResolver # noqa: I001 + from ..deps.lockfile import LockFile, get_lockfile_path, migrate_lockfile_if_needed # noqa: F401 -- RULE B via apm_packages.py + from ..integration.mcp_integrator import MCPIntegrator # noqa: F401 + from ..models.apm_package import APMPackage, DependencyReference # noqa: F401 -- RULE B class _ScopedInstallDependencyResolver(APMDependencyResolver): """Install-time resolver; blocks ``git: parent`` expansion at user scope.""" @@ -301,584 +168,7 @@ def expand_parent_repo_decl(self, parent_dep, child_dep): _ScopedInstallDependencyResolver = None # type: ignore[misc,assignment] -# --------------------------------------------------------------------------- -# Package validation helpers (extracted from _validate_and_add_packages_to_apm_yml) -# --------------------------------------------------------------------------- - - -def _check_package_conflicts(current_deps): - """Build identity set from existing deps for duplicate detection. - - Parses each entry in *current_deps* (string or dict form) through - :class:`DependencyReference` and collects identity strings. - - Returns: - ``set`` of identity strings for existing dependencies. - """ - existing_identities = builtins.set() - for dep_entry in current_deps: - try: - if isinstance(dep_entry, str): - ref = DependencyReference.parse(dep_entry) - elif isinstance(dep_entry, builtins.dict): - ref = DependencyReference.parse_from_dict(dep_entry) - else: - continue - existing_identities.add(ref.get_identity()) - except (ValueError, TypeError, AttributeError, KeyError): - continue - return existing_identities - - -def _resolve_package_references( - packages, - current_deps, - existing_identities, - *, - auth_resolver=None, - logger=None, - scope=None, - allow_insecure=False, - skill_subset=None, -): - """Validate, canonicalize, and resolve package references. - - Handles marketplace refs, canonical parsing, insecure-URL guards, - local-at-user-scope rejection, and accessibility checks. - - *existing_identities* is mutated (new identities are added to prevent - duplicates within the same batch). - - Returns: - Tuple of ``(valid_outcomes, invalid_outcomes, validated_packages, - marketplace_provenance, apm_yml_entries, dependencies_changed)``. - """ - valid_outcomes = [] # (canonical, already_present) tuples - invalid_outcomes = [] # (package, reason) tuples - _marketplace_provenance = {} # canonical -> {discovered_via, marketplace_plugin_name} - _apm_yml_entries = {} # canonical -> apm.yml entry (str or dict for HTTP deps) - validated_packages = [] - dependencies_changed = False - - if logger: - logger.validation_start(len(packages)) - - for package in packages: - # --- Marketplace pre-parse intercept --- - # If input has no slash and is not a local path, check if it is a - # marketplace ref (NAME@MARKETPLACE). If so, resolve it to a - # canonical owner/repo[#ref] string before entering the standard - # parse path. Anything that doesn't match is rejected as an - # invalid format. - marketplace_provenance = None - marketplace_dep_ref = None - if "/" not in package and not DependencyReference.is_local_path(package): - try: - from ..marketplace.resolver import ( - parse_marketplace_ref, - resolve_marketplace_plugin, - ) - - mkt_ref = parse_marketplace_ref(package) - except ImportError: - mkt_ref = None - - if mkt_ref is not None: - plugin_name, marketplace_name, version_spec = mkt_ref - try: - warning_handler = None - if logger: - - def warning_handler(msg): - return logger.warning(msg) - - logger.verbose_detail( - f" Resolving {plugin_name}@{marketplace_name} via marketplace..." - ) - resolution = resolve_marketplace_plugin( - plugin_name, - marketplace_name, - version_spec=version_spec, - auth_resolver=auth_resolver, - warning_handler=warning_handler, - ) - canonical_str, _resolved_plugin = resolution - if logger: - logger.verbose_detail(f" Resolved to: {canonical_str}") - # #1326: dependency-confusion fail-closed gate. - # Bare ``owner/repo`` on *.ghe.com falls back to - # github.com -- refuse before outbound validation so - # no probe reaches a potentially attacker-controlled URL. - # Escape hatch: host-qualify ``repo:`` in marketplace.json. - _risk = resolution.cross_repo_misconfig_risk - if _risk is not None: - _lead = ( - f"refused (dependency-confusion risk #1326): bare" - f" `repo: {_risk.bare_repo_field}` on enterprise" - f" marketplace '{_risk.marketplace_host}' is ambiguous." - f" Host-qualify the plugin `repo` field in" - f" marketplace.json to one of:" - ) - reason = "\n".join( - [ - _lead, - f" - '{_risk.suggested_qualified_repo}' (enterprise dep on this marketplace)", - f" - 'github.com/{_risk.bare_repo_field}' (declared cross-host dep on public github.com)", - ] - ) - invalid_outcomes.append((package, reason)) - if logger: - logger.validation_fail(package, reason) - continue - marketplace_provenance = { - "discovered_via": marketplace_name, - "marketplace_plugin_name": plugin_name, - } - package = canonical_str - marketplace_dep_ref = getattr(resolution, "dependency_reference", None) - except Exception as mkt_err: - reason = str(mkt_err) - invalid_outcomes.append((package, reason)) - if logger: - logger.validation_fail(package, reason) - continue - else: - # No slash, not a local path, and not a marketplace ref - reason = "invalid format -- use 'owner/repo' or 'plugin-name@marketplace'" - invalid_outcomes.append((package, reason)) - if logger: - logger.validation_fail(package, reason) - continue - - # Canonicalize input - try: - dep_ref, direct_virtual_resolved = resolve_parsed_dependency_reference( - package, - marketplace_dep_ref, - dependency_reference_cls=DependencyReference, - try_resolve_gitlab_direct_shorthand=_try_resolve_gitlab_direct_shorthand, - resolve_artifactory_boundary=_resolve_artifactory_boundary, - auth_resolver=auth_resolver, - verbose=bool(logger and logger.verbose), - logger=logger, - ) - canonical = dep_ref.to_canonical() - identity = dep_ref.get_identity() - # Attach --skill filter so to_apm_yml_entry() emits the dict form - if skill_subset: - # Normalize: strip whitespace, drop empty strings, deduplicate - # (preserve order) so invalid or redundant names can't persist. - _seen: builtins.set[str] = builtins.set() - _normalized: builtins.list[str] = [] - for _s in skill_subset: - _s = _s.strip() - if _s and _s not in _seen: - _seen.add(_s) - _normalized.append(_s) - dep_ref.skill_subset = _normalized - if marketplace_dep_ref is not None or direct_virtual_resolved: - _apm_yml_entries[canonical] = dependency_reference_to_yaml_entry(dep_ref) - except ValueError as e: - reason = str(e) - invalid_outcomes.append((package, reason)) - if logger: - logger.validation_fail(package, reason) - continue - - if dep_ref.is_insecure: - if not allow_insecure: - # The reason string embeds the full URL already, so skip - # logger.validation_fail (which prepends "{package} -- ") to - # avoid rendering the URL twice. Use logger.error directly. - reason = _format_insecure_dependency_requirements( - _get_insecure_dependency_url(dep_ref) - ) - invalid_outcomes.append((package, reason)) - if logger: - logger.error(reason) - continue - dep_ref.allow_insecure = True - _apm_yml_entries[canonical] = dep_ref.to_apm_yml_entry() - - scope_reject = user_scope_rejection_reason(dep_ref, scope) - if scope_reject: - invalid_outcomes.append((package, scope_reject)) - if logger: - logger.validation_fail(package, scope_reject) - continue - - # Ensure structured entry is used for apm.yml persistence when skill - # filter is active (normal non-marketplace/non-insecure path doesn't - # set _apm_yml_entries; _merge_packages_into_yml falls back to the - # plain canonical string without this). - if skill_subset and canonical not in _apm_yml_entries: - _apm_yml_entries[canonical] = dep_ref.to_apm_yml_entry() - - # Check if package is already in dependencies (by identity) - already_in_deps = identity in existing_identities - - # Validate package exists and is accessible - verbose = bool(logger and logger.verbose) - if _validate_package_exists( - package, - verbose=verbose, - auth_resolver=auth_resolver, - logger=logger, - dep_ref=dep_ref, - ): - updates_existing_entry = update_existing_dependency_entry_if_needed( - current_deps, - already_in_deps=already_in_deps, - apm_yml_entries=_apm_yml_entries, - canonical=canonical, - dep_ref=dep_ref, - identity=identity, - dependency_reference_cls=DependencyReference, - logger=logger, - ) - valid_outcomes.append((canonical, already_in_deps)) - if logger: - logger.validation_pass(canonical, already_in_deps, updates_existing_entry) - - if not already_in_deps: - validated_packages.append(canonical) - existing_identities.add(identity) # prevent duplicates within batch - dependencies_changed = dependencies_changed or updates_existing_entry - if marketplace_provenance: - _marketplace_provenance[identity] = marketplace_provenance - else: - reason = _local_path_failure_reason(dep_ref) - if not reason: - # Round-4 panel fix (devx-ux): name the four-step probe - # chain explicitly when the validator exhausted it - # (virtual subdirectory + explicit ref). Generic "not - # accessible" hides the failure mode for the precise - # case where the most diagnostics are available. - is_subdir_ref_chain = ( - dep_ref.is_virtual - and dep_ref.is_virtual_subdirectory() - and bool(dep_ref.reference) - ) - if is_subdir_ref_chain: - reason = ( - "all probes failed (marker-file, Contents API, " - "git ls-remote, shallow-fetch) -- verify the path " - "and ref exist and that your credentials have " - "read access" - ) - if not verbose: - reason += " (run with --verbose for the full probe log)" - else: - reason = "not accessible or doesn't exist" - if not verbose: - reason += " -- run with --verbose for auth details" - invalid_outcomes.append((package, reason)) - if logger: - logger.validation_fail(package, reason) - - return ( - valid_outcomes, - invalid_outcomes, - validated_packages, - _marketplace_provenance, - _apm_yml_entries, - dependencies_changed, - ) - - -def _merge_packages_into_yml( - validated_packages, - apm_yml_entries, - current_deps, - data, - dep_section, - apm_yml_path, - *, - dev=False, - logger=None, -): - """Append *validated_packages* to the dependency list and write apm.yml. - - Mutates *current_deps* in place and persists the updated manifest to - *apm_yml_path*. - """ - dep_label = "devDependencies" if dev else "apm.yml" - for package in validated_packages: - current_deps.append(apm_yml_entries.get(package, package)) - if logger: - logger.verbose_detail(f"Added {package} to {dep_label}") - - # Update dependencies - data[dep_section]["apm"] = current_deps - - # Write back to apm.yml - try: - from ..utils.yaml_io import dump_yaml - - dump_yaml(data, apm_yml_path) - if logger: - logger.success( - f"Updated {APM_YML_FILENAME} with {len(validated_packages)} new package(s)" - ) - except Exception as e: - if logger: - logger.error(f"Failed to write {APM_YML_FILENAME}: {e}") - else: - _rich_error(f"Failed to write {APM_YML_FILENAME}: {e}") - sys.exit(1) - - -def _validate_and_add_packages_to_apm_yml( - packages, - dry_run=False, - dev=False, - logger=None, - manifest_path=None, - auth_resolver=None, - scope=None, - allow_insecure=False, - skill_subset=None, -): - """Validate packages exist and can be accessed, then add to apm.yml dependencies section. - - Implements normalize-on-write: any input form (HTTPS URL, SSH URL, FQDN, shorthand) - is canonicalized before storage. Default host (github.com) is stripped; - non-default hosts are preserved. Duplicates are detected by identity. - - Args: - packages: Package specifiers to validate and add. - dry_run: If True, only show what would be added. - dev: If True, write to devDependencies instead of dependencies. - logger: InstallLogger for structured output. - manifest_path: Explicit path to apm.yml (defaults to cwd/apm.yml). - auth_resolver: Shared auth resolver for caching credentials. - scope: InstallScope controlling project vs user deployment. - - Returns: - Tuple of (validated_packages list, _ValidationOutcome). - """ - from pathlib import Path - - apm_yml_path = manifest_path or Path(APM_YML_FILENAME) - - # Read current apm.yml - try: - from ..utils.yaml_io import load_yaml - - data = load_yaml(apm_yml_path) or {} - except Exception as e: - if logger: - logger.error(f"Failed to read {APM_YML_FILENAME}: {e}") - else: - _rich_error(f"Failed to read {APM_YML_FILENAME}: {e}") - sys.exit(1) - - # Ensure dependencies structure exists - dep_section = "devDependencies" if dev else "dependencies" - if dep_section not in data: - data[dep_section] = {} - if "apm" not in data[dep_section]: - data[dep_section]["apm"] = [] - - current_deps = data[dep_section]["apm"] or [] - - # Detect duplicates against existing deps - existing_identities = _check_package_conflicts(current_deps) - - # Validate and canonicalize all package references - ( - valid_outcomes, - invalid_outcomes, - validated_packages, - _marketplace_provenance, - _apm_yml_entries, - dependencies_changed, - ) = _resolve_package_references( - packages, - current_deps, - existing_identities, - auth_resolver=auth_resolver, - logger=logger, - scope=scope, - allow_insecure=allow_insecure, - skill_subset=skill_subset, - ) - - outcome = _ValidationOutcome( - valid=valid_outcomes, - invalid=invalid_outcomes, - marketplace_provenance=_marketplace_provenance or None, - ) - - # Let the logger emit a summary and decide whether to continue - if logger: - should_continue = logger.validation_summary(outcome) - if not should_continue: - return [], outcome - - if not validated_packages: - if dry_run: - if logger: - logger.progress("No new packages to add") - # If all packages already exist in apm.yml, that's OK - we'll reinstall them - persist_dependency_list_if_changed( - dependencies_changed=dependencies_changed, - data=data, - dep_section=dep_section, - current_deps=current_deps, - apm_yml_path=apm_yml_path, - apm_yml_filename=APM_YML_FILENAME, - logger=logger, - rich_error=_rich_error, - sys_exit=sys.exit, - ) - return [], outcome - - if dry_run: - if logger: - logger.progress(f"Dry run: Would add {len(validated_packages)} package(s) to apm.yml") - for pkg in validated_packages: - logger.verbose_detail(f" + {pkg}") - return validated_packages, outcome - - # Persist validated packages to apm.yml - _merge_packages_into_yml( - validated_packages, - _apm_yml_entries, - current_deps, - data, - dep_section, - apm_yml_path, - dev=dev, - logger=logger, - ) - - return validated_packages, outcome - - -# --------------------------------------------------------------------------- -# MCP CLI helpers (W3 --mcp flag) -# --------------------------------------------------------------------------- - -# F7 / F5 install-time MCP warnings live in apm_cli/install/mcp/warnings.py -# per LOC budget. Re-bind module-level names for back-compat with tests -# that still patch ``apm_cli.commands.install._warn_*``. - -# MCP registry / dry-run helpers are imported at module top (see -# ``..install.mcp.*`` imports above) so test patches keep working. - -# --------------------------------------------------------------------------- -# install() decomposition: extracted flow helpers -# --------------------------------------------------------------------------- - - -def _handle_mcp_install( - *, - mcp_name, - transport, - url, - env_pairs, - header_pairs, - mcp_version, - command_argv, - dev, - force, - runtime, - exclude, - verbose, - logger, - no_policy, - validated_registry_url, -): - """Execute the ``--mcp`` install path (MCP server add). - - Resolves registry URL, runs policy preflight, handles dry-run, - and delegates to :func:`_run_mcp_install` for the actual installation. - Called from :func:`install` when ``--mcp`` is specified; the caller - returns immediately after this function completes. - """ - from ..core.scope import ( - InstallScope, - get_apm_dir, - get_manifest_path, - ) - - # Apply CLI > env > default precedence; emit override diagnostic. - resolved_registry_url, _registry_source = _resolve_registry_url( - validated_registry_url, - logger=logger, - ) - mcp_scope = InstallScope.PROJECT - mcp_manifest_path = get_manifest_path(mcp_scope) - mcp_apm_dir = get_apm_dir(mcp_scope) - # -- W2-mcp-preflight: policy enforcement before MCP install -- - # Build a lightweight MCPDependency for policy evaluation. - # This mirrors _build_mcp_entry routing but we only need the - # fields that policy checks inspect (name, transport, registry). - from ..models.dependency.mcp import MCPDependency as _MCPDep - from ..policy.install_preflight import ( - PolicyBlockError, - run_policy_preflight, - ) - - _is_self_defined = bool(url or command_argv) - _preflight_transport = transport - if _preflight_transport is None: - if command_argv: - _preflight_transport = "stdio" - elif url: - _preflight_transport = "http" - _preflight_dep = _MCPDep( - name=mcp_name, - transport=_preflight_transport, - registry=False if _is_self_defined else None, - url=url, - ) - - try: - _pf_result, _pf_active = run_policy_preflight( - project_root=Path.cwd(), - mcp_deps=[_preflight_dep], - no_policy=no_policy, - logger=logger, - dry_run=logger.dry_run, - ) - except PolicyBlockError: - # Diagnostics already emitted by the helper + logger. - logger.render_summary() - sys.exit(1) - - if logger.dry_run: - # C1: validate eagerly so dry-run rejects what real install would. - _validate_mcp_dry_run_entry( - mcp_name, - transport=transport, - url=url, - env=env_pairs, - headers=header_pairs, - version=mcp_version, - command_argv=command_argv, - registry_url=resolved_registry_url, - ) - logger.dry_run_notice(f"would add MCP server '{mcp_name}' to {mcp_manifest_path}") - return - _run_mcp_install( - mcp_name=mcp_name, - transport=transport, - url=url, - env_pairs=env_pairs, - header_pairs=header_pairs, - mcp_version=mcp_version, - command_argv=command_argv, - dev=dev, - force=force, - runtime=runtime, - exclude=exclude, - logger=logger, - apm_dir=mcp_apm_dir, - scope=mcp_scope, - registry_url=validated_registry_url, - ) +# MCP helpers and re-exports are in the imports section above. @click.command( @@ -1195,15 +485,7 @@ def install( # noqa: PLR0913 "--frozen and --update are mutually exclusive. " "Use 'apm update' to refresh refs, then 'apm install --frozen' in CI." ) - # --root: see apm_cli.install.root_redirect.install_root_redirect. - # Conflicts with --global (user scope writes are anchored at $HOME - # and have no concept of an arbitrary deploy root). ``--dry-run`` is - # threaded through so the context manager skips the ``mkdir`` - # side-effect on previews. Entered manually (rather than via - # ``with``) so the existing top-level try/except/finally below does - # not need a full-body re-indent; the matching ``__exit__`` in that - # ``finally`` restores cwd + clears the source-root override on every - # exit path (return, sys.exit -> SystemExit, exception). + # --root: entered manually so the existing try/finally handles __exit__. if root and global_: raise click.UsageError("--root is not valid with --global (user scope)") from ..core.install_audit import resolve_audit_override_from_cli @@ -1223,16 +505,8 @@ def install( # noqa: PLR0913 is_partial = bool(packages) logger = InstallLogger(verbose=verbose, dry_run=dry_run, partial=is_partial) - # W2-pkg-rollback (#827): snapshot bytes captured BEFORE - # _validate_and_add_packages_to_apm_yml mutates apm.yml. Initialised - # to None here -- BEFORE any branch that might raise (e.g. the local - # bundle early-exit path below) -- so the `except` handlers at the - # bottom of this function can always reference both names without - # UnboundLocalError. The bug this prevents: an exception raised in - # the local-bundle branch (e.g. a click.Abort from integrity-verify - # failure on Windows) would otherwise be masked by an - # UnboundLocalError inside the handler that calls - # _maybe_rollback_manifest(_snapshot_manifest_path, ...). + # W2-pkg-rollback (#827): init snapshot vars before any branch that might + # raise, so except handlers can always reference them (no UnboundLocalError). _manifest_snapshot: bytes | None = None _snapshot_manifest_path: Path | None = None @@ -1249,75 +523,41 @@ def install( # noqa: PLR0913 # entirely and deploy the bundle's files directly. Local bundles # are imperative deploys -- they do NOT mutate apm.yml. # ---------------------------------------------------------------- - if len(packages) == 1 and not mcp_name and (_probe := Path(packages[0])).exists(): - from ..bundle.local_bundle import detect_local_bundle as _detect_lb - from ..install.local_bundle_handler import install_local_bundle as _install_lb - - _bundle_info = _detect_lb(_probe) - if _bundle_info is not None: - _install_lb( - bundle_info=_bundle_info, - bundle_arg=packages[0], - target=target, - global_=global_, - force=force, - dry_run=dry_run, - verbose=verbose, - alias=alias, - logger=logger, - legacy_skill_paths=legacy_skill_paths, - # Rejected-flag context for consolidated UsageError: - rejected_flags={ - "--update": update, - "--only": only, - "--runtime": runtime, - "--exclude": exclude, - "--dev": dev, - "--ssh": use_ssh, - "--https": use_https, - "--allow-protocol-fallback": allow_protocol_fallback, - "--mcp": mcp_name, - "--registry": registry_url, - "--skill": bool(skill_names), - "--parallel-downloads": parallel_downloads != 4, - "--allow-insecure": allow_insecure, - "--allow-insecure-host": bool(allow_insecure_hosts), - "--no-policy": no_policy, - }, - ) - # Local bundle install renders its own summary; mark - # ``summary_rendered = True`` so the finally-block (line ~1423) - # does not emit a misleading "install interrupted" line on the - # success path. See issue #1207 D3. - summary_rendered = True - return - # IM7: path exists but isn't a recognised bundle. For tarball - # extensions (.tar.gz / .tgz) the user clearly meant a bundle - # artifact, so raise a targeted UsageError instead of falling - # through to the registry path (which would try to clone). - # For bare directories we still fall through, because - # ``apm install ./packages/source-pkg`` is a supported local-path - # install that goes through the dependency-resolver pipeline. - _suffix = _probe.name.lower() - if _probe.is_file() and (_suffix.endswith(".tar.gz") or _suffix.endswith(".tgz")): - # Distinguish legacy --format apm bundles (apm.lock.yaml - # present, plugin.json absent) from arbitrary tarballs so - # the error message guides the user to the right next step. - from ..bundle.local_bundle import _looks_like_legacy_apm_bundle - - if _looks_like_legacy_apm_bundle(_probe): - raise click.UsageError( - f"'{packages[0]}' was packed with '--format apm' (legacy format). " - "'apm install ' requires the plugin format. " - "Repack with 'apm pack --format plugin --archive', " - "or use 'apm unpack' to deploy the legacy bundle." - ) - raise click.UsageError( - f"'{packages[0]}' is not a valid APM bundle archive " - "(no plugin.json found at the bundle root). " - "Use 'apm install org/package' for registry installs, " - "or repack the source with 'apm pack'." - ) + if _try_local_bundle_install( + packages, + mcp_name, + target, + global_, + force, + dry_run, + verbose, + alias, + logger, + legacy_skill_paths, + rejected_flags={ + "--update": update, + "--only": only, + "--runtime": runtime, + "--exclude": exclude, + "--dev": dev, + "--ssh": use_ssh, + "--https": use_https, + "--allow-protocol-fallback": allow_protocol_fallback, + "--mcp": mcp_name, + "--registry": registry_url, + "--skill": bool(skill_names), + "--parallel-downloads": parallel_downloads != 4, + "--allow-insecure": allow_insecure, + "--allow-insecure-host": bool(allow_insecure_hosts), + "--no-policy": no_policy, + }, + ): + # Local bundle install renders its own summary; mark + # ``summary_rendered = True`` so the finally-block does not emit a + # misleading "install interrupted" line on the success path + # (issue #1207 D3). + summary_rendered = True + return # IM8: --as is only meaningful for local-bundle installs. If we get # here, no local bundle was detected, so reject --as instead of # silently ignoring it. @@ -1339,19 +579,9 @@ def install( # noqa: PLR0913 # always reference them without UnboundLocalError. # ---------------------------------------------------------------- - # --mcp branch (W3): when --mcp is set, route to the dedicated - # MCP-add path. We compute the post-`--` argv here BEFORE Click's - # silent handling: see _split_argv_at_double_dash(). + # --mcp branch (W3): compute argv split; route to MCP path if set. # ---------------------------------------------------------------- - _, command_argv = _split_argv_at_double_dash(_get_invocation_argv()) - # `packages` from Click already includes the post-`--` items; the - # pre-`--` portion is what the user typed as positional packages. - if command_argv: - split_idx = len(packages) - len(command_argv) - split_idx = max(split_idx, 0) - pre_dash_packages = builtins.tuple(packages[:split_idx]) - else: - pre_dash_packages = builtins.tuple(packages) + command_argv, pre_dash_packages = _compute_argv_pre_dash(packages) # Validate --registry (raises UsageError on a bad URL). validated_registry_url = _validate_registry_url(registry_url) @@ -1384,11 +614,13 @@ def install( # noqa: PLR0913 if mcp_name is not None: _handle_mcp_install( mcp_name=mcp_name, - transport=transport, - url=url, - env_pairs=env_pairs, - header_pairs=header_pairs, - mcp_version=mcp_version, + mcp_conn=_McpConnectionParams( + transport=transport, + url=url, + env_pairs=env_pairs, + header_pairs=header_pairs, + mcp_version=mcp_version, + ), command_argv=command_argv, dev=dev, force=force, @@ -1401,87 +633,21 @@ def install( # noqa: PLR0913 ) return - # Resolve transport selection inputs. - from ..deps.transport_selection import ( - ProtocolPreference, + # Resolve transport preference, scope, paths, and auth resolver. + protocol_pref, allow_protocol_fallback = _resolve_protocol_and_fallback( + use_ssh, use_https, allow_protocol_fallback ) - - if use_ssh and use_https: + if protocol_pref is None: _rich_error("Options --ssh and --https are mutually exclusive.", symbol="error") sys.exit(2) - if use_ssh: - protocol_pref = ProtocolPreference.SSH - elif use_https: - protocol_pref = ProtocolPreference.HTTPS - else: - # Precedence: APM_GIT_PROTOCOL env var > apm config ssh > git insteadOf - from ..config import get_apm_protocol_pref as _get_apm_protocol_pref - - _pref_str = _get_apm_protocol_pref() - protocol_pref = ProtocolPreference.from_str(_pref_str) - # CLI flag > env var (APM_ALLOW_PROTOCOL_FALLBACK) > apm config > default. - # get_apm_allow_protocol_fallback() already encodes env > config > False. - from ..config import get_apm_allow_protocol_fallback as _get_apm_apf - allow_protocol_fallback = allow_protocol_fallback or _get_apm_apf() - - # Resolve scope - from ..core.scope import ( - InstallScope, - ensure_user_dirs, - get_apm_dir, - get_manifest_path, - warn_unsupported_user_scope, + scope, manifest_path, apm_dir, manifest_display, project_root = _resolve_scope_and_paths( + global_, logger ) - scope = InstallScope.USER if global_ else InstallScope.PROJECT - - if scope is InstallScope.USER: - ensure_user_dirs() - logger.progress("Installing to user scope (~/.apm/)") - _scope_warn = warn_unsupported_user_scope() - if _scope_warn: - logger.warning(_scope_warn) - - # Scope-aware paths - manifest_path = get_manifest_path(scope) - apm_dir = get_apm_dir(scope) - # Display name for messages (short for project scope, full for user scope) - manifest_display = str(manifest_path) if scope is InstallScope.USER else APM_YML_FILENAME - - # Project root for integration (used by both dep and local integration) - from ..core.scope import get_deploy_root - - project_root = get_deploy_root(scope) - - # Create shared auth resolver for all downloads in this CLI invocation - # to ensure credentials are cached and reused (prevents duplicate auth popups) - auth_resolver = AuthResolver() - # F2/F3 #856: thread the InstallLogger into AuthResolver so the verbose - # auth-source line and the deferred stale-PAT [!] warning route through - # CommandLogger / DiagnosticCollector instead of stderr/inline writes. - auth_resolver.set_logger(logger) - - # Check if apm.yml exists - apm_yml_exists = manifest_path.exists() - - # Auto-bootstrap: create minimal apm.yml when packages specified but no apm.yml - if not apm_yml_exists and packages: - # Get current directory name as project name - project_name = Path.cwd().name if scope is InstallScope.PROJECT else Path.home().name - config = _get_default_config(project_name) - _create_minimal_apm_yml(config, target_path=manifest_path) - logger.success(f"Created {manifest_display}") - - # Error when NO apm.yml AND NO packages - if not apm_yml_exists and not packages: - logger.error(f"No {manifest_display} found") - if scope is InstallScope.USER: - logger.progress("Run 'apm install -g ' to auto-create + install") - else: - logger.progress("Run 'apm init' to create one, or:") - logger.progress(" apm install to auto-create + install") - sys.exit(1) + auth_resolver = _setup_auth_and_check_manifest( + scope, packages, manifest_path, manifest_display, logger + ) # If packages are specified, validate and add them to apm.yml first outcome = None @@ -1550,33 +716,9 @@ def install( # noqa: PLR0913 skill_subset_from_cli=bool(skill_names), ) - apm_count, mcp_count, lsp_count, apm_diagnostics = _install_apm_packages( - install_ctx, - outcome, - ) - - _post_install_summary( - logger=logger, - apm_count=apm_count, - mcp_count=mcp_count, - lsp_count=lsp_count, - apm_diagnostics=apm_diagnostics, - force=force, - elapsed_seconds=time.perf_counter() - install_started_at, - ) + _execute_install_and_summary(install_ctx, outcome, frozen, install_started_at) summary_rendered = True - if frozen and apm_count > 0: - # --frozen verifies LOCKFILE STRUCTURE (every apm.yml dep - # has a lock entry), not on-disk content integrity. Make - # the scope explicit so a CI pipeline that skips - # 'apm audit' on the assumption that --frozen covers SHA - # verification is corrected at the moment of use. - _rich_info( - "Lockfile presence verified. Run 'apm audit' for on-disk content integrity.", - symbol="info", - ) - except InsecureDependencyPolicyError: _maybe_rollback_manifest(_snapshot_manifest_path, _manifest_snapshot, logger) sys.exit(1) @@ -1621,465 +763,3 @@ def install( # noqa: PLR0913 os.environ.pop("APM_VERBOSE", None) else: os.environ["APM_VERBOSE"] = _apm_verbose_prev - - -# --------------------------------------------------------------------------- -# install() decomposition: APM pipeline + post-install summary -# --------------------------------------------------------------------------- - - -def _install_apm_packages(ctx, outcome): - """Execute the APM + transitive MCP installation pipeline. - - Parses ``apm.yml``, installs APM dependencies, collects and installs - transitive MCP servers, and handles lockfile updates. - - Args: - ctx: :class:`InstallContext` with configuration and environment. - outcome: ``_ValidationOutcome`` from package validation (may be - ``None`` when no explicit packages were passed). - - Returns: - Tuple of ``(apm_count, mcp_count, lsp_count, apm_diagnostics)``. - """ - logger = ctx.logger - - logger.resolution_start( - to_install_count=len(ctx.only_packages or []) if ctx.packages else 0, - lockfile_count=0, # Refined later inside _install_apm_dependencies - ) - - # Parse apm.yml to get both APM and MCP dependencies - try: - apm_package = APMPackage.from_apm_yml(ctx.manifest_path) - except Exception as e: - logger.error(f"Failed to parse {ctx.manifest_display}: {e}") - sys.exit(1) - - logger.verbose_detail( - f"Parsed {APM_YML_FILENAME}: {len(apm_package.get_apm_dependencies())} APM deps, " - f"{len(apm_package.get_mcp_dependencies())} MCP deps" - + ( - f", {len(apm_package.get_dev_apm_dependencies())} dev deps" - if apm_package.get_dev_apm_dependencies() - else "" - ) - ) - - # Get APM and MCP dependencies - apm_deps = apm_package.get_apm_dependencies() - dev_apm_deps = apm_package.get_dev_apm_dependencies() - has_any_apm_deps = bool(apm_deps) or bool(dev_apm_deps) - mcp_deps = apm_package.get_mcp_dependencies() - - all_apm_deps = list(apm_deps) + list(dev_apm_deps) - _check_insecure_dependencies(all_apm_deps, ctx.allow_insecure, logger) - - # Determine what to install based on install mode - should_install_apm = ctx.install_mode != InstallMode.MCP - should_install_mcp = ctx.install_mode != InstallMode.APM - should_install_lsp = should_install_mcp - - # Show what will be installed if dry run - if ctx.dry_run: - # -- W2-dry-run (#827): policy preflight in preview mode -- - # Runs discovery + checks against direct manifest deps (not - # resolved/transitive -- dry-run does not run the resolver). - # Block-severity violations render as "Would be blocked by - # policy" without raising. Documented limitation: transitive - # deps are NOT evaluated since the resolver does not run. - from apm_cli.policy.install_preflight import run_policy_preflight as _dr_preflight - - _dr_apm_deps = builtins.list(apm_deps) + builtins.list(dev_apm_deps) - _dr_preflight( - project_root=ctx.project_root, - apm_deps=_dr_apm_deps, - mcp_deps=mcp_deps if should_install_mcp else None, - no_policy=ctx.no_policy, - logger=logger, - dry_run=True, - ) - - from apm_cli.install.presentation.dry_run import render_and_exit - - render_and_exit( - logger=logger, - should_install_apm=should_install_apm, - apm_deps=apm_deps, - mcp_deps=mcp_deps, - dev_apm_deps=dev_apm_deps, - should_install_mcp=should_install_mcp, - update=ctx.update, - only_packages=ctx.only_packages, - apm_dir=ctx.apm_dir, - ) - return 0, 0, 0, None # render_and_exit exits; this line is defensive - - # Install APM dependencies first (if requested) - apm_count = 0 - - # Migrate legacy apm.lock -> apm.lock.yaml if needed (one-time, transparent) - migrate_lockfile_if_needed(ctx.apm_dir) - - # Capture old MCP servers and configs from lockfile BEFORE - # _install_apm_dependencies regenerates it (which drops the fields). - # We always read this -- even when --only=apm -- so we can restore the - # field after the lockfile is regenerated by the APM install step. - old_mcp_servers: builtins.set = builtins.set() - old_mcp_configs: builtins.dict = {} - _lock_path = get_lockfile_path(ctx.apm_dir) - _existing_lock = LockFile.read(_lock_path) - if _existing_lock: - old_mcp_servers = builtins.set(_existing_lock.mcp_servers) - old_mcp_configs = builtins.dict(_existing_lock.mcp_configs) - - # Enter the APM install path when there are deps, local .apm/ primitives - # (#714), OR orphan deps in the lockfile to clean up (manifest emptied). - from apm_cli.core.scope import InstallScope - from apm_cli.core.scope import get_deploy_root as _get_deploy_root - from apm_cli.deps.lockfile import _SELF_KEY as _LOCK_SELF_KEY - - _cli_project_root = _get_deploy_root(ctx.scope) - _has_orphan_deps_in_lock = bool( - _existing_lock - and not has_any_apm_deps - and any(k != _LOCK_SELF_KEY for k in _existing_lock.dependencies) - ) - apm_diagnostics = None - if should_install_apm and ( - has_any_apm_deps - or _project_has_root_primitives(_cli_project_root) - or _has_orphan_deps_in_lock - ): - if not APM_DEPS_AVAILABLE: - logger.error("APM dependency system not available") - logger.progress(f"Import error: {_APM_IMPORT_ERROR}") - sys.exit(1) - - try: - # If specific packages were requested, only install those - # Otherwise install all from apm.yml. - # `only_packages` was computed above so the dry-run preview - # and the actual install share one canonical list. - install_result = _install_apm_dependencies( - apm_package, - ctx.update, - ctx.verbose, - ctx.only_packages, - force=ctx.force, - parallel_downloads=ctx.parallel_downloads, - logger=logger, - scope=ctx.scope, - auth_resolver=ctx.auth_resolver, - target=ctx.target, - allow_insecure=ctx.allow_insecure, - allow_insecure_hosts=ctx.allow_insecure_hosts, - marketplace_provenance=( - outcome.marketplace_provenance if ctx.packages and outcome else None - ), - protocol_pref=ctx.protocol_pref, - allow_protocol_fallback=ctx.allow_protocol_fallback, - no_policy=ctx.no_policy, - audit_override=ctx.audit_override, - legacy_skill_paths=ctx.legacy_skill_paths, - frozen=ctx.frozen, - plan_callback=ctx.plan_callback, - skill_subset=ctx.skill_subset, - skill_subset_from_cli=ctx.skill_subset_from_cli, - refresh=ctx.refresh, - ) - apm_count = install_result.installed_count - apm_diagnostics = install_result.diagnostics - except InsecureDependencyPolicyError: - _maybe_rollback_manifest(ctx.snapshot_manifest_path, ctx.manifest_snapshot, logger) - sys.exit(1) - except AuthenticationError as e: - # #1015: render auth diagnostics on the DEFAULT path (not --verbose). - _maybe_rollback_manifest(ctx.snapshot_manifest_path, ctx.manifest_snapshot, logger) - _rich_error(str(e)) - if e.diagnostic_context: - _rich_echo(e.diagnostic_context) - _rich_info("Tip: run 'apm doctor' to diagnose auth and connectivity.", symbol="info") - sys.exit(1) - except FrozenInstallError as e: - _maybe_rollback_manifest(ctx.snapshot_manifest_path, ctx.manifest_snapshot, logger) - _rich_error(str(e)) - for reason in e.reasons: - _rich_echo(reason) - _rich_info( - "Tip: run 'apm outdated' to see what changed, then 'apm update'.", - symbol="info", - ) - sys.exit(1) - except Exception as e: - _maybe_rollback_manifest(ctx.snapshot_manifest_path, ctx.manifest_snapshot, logger) - # #832: surface PolicyViolationError verbatim (no double-nesting). - msg = ( - str(e) - if isinstance(e, PolicyViolationError) - else f"Failed to install APM dependencies: {e}" - ) - logger.error(msg) - if not ctx.verbose: - logger.progress("Run with --verbose for detailed diagnostics") - sys.exit(1) - elif should_install_apm and not has_any_apm_deps: - logger.verbose_detail("No APM dependencies found in apm.yml") - - # When --update is used, package files on disk may have changed. - # Clear the parse cache so transitive MCP collection reads fresh data. - if ctx.update: - from apm_cli.models.apm_package import clear_apm_yml_cache - - clear_apm_yml_cache() - - # Collect transitive MCP dependencies from resolved APM packages - transitive_mcp = [] - from ..core.scope import get_modules_dir - - apm_modules_path = get_modules_dir(ctx.scope) - if should_install_mcp and apm_modules_path.exists(): - lock_path = get_lockfile_path(ctx.apm_dir) - transitive_mcp = MCPIntegrator.collect_transitive( - apm_modules_path, - lock_path, - ctx.trust_transitive_mcp, - diagnostics=apm_diagnostics, - ) - if transitive_mcp: - logger.verbose_detail(f"Collected {len(transitive_mcp)} transitive MCP dependency(ies)") - mcp_deps = MCPIntegrator.deduplicate(mcp_deps + transitive_mcp) - - # -- S1/S2 fix (#827-C2/C3): enforce policy on ALL MCP deps ---- - # The pipeline gate phase (policy_gate.py) checks direct APM deps - # and direct MCP deps from apm.yml. However, transitive MCP - # servers (discovered via collect_transitive above) are only known - # after APM packages are installed. Run a second preflight - # against the *merged* MCP set (direct + transitive) BEFORE - # MCPIntegrator writes runtime configs. On PolicyBlockError we - # abort the MCP write but leave already-installed APM packages - # in place (they were approved by the gate phase). - if should_install_mcp and mcp_deps: - from apm_cli.policy.install_preflight import ( - PolicyBlockError as _TransitivePBE, - ) - from apm_cli.policy.install_preflight import ( - run_policy_preflight as _transitive_preflight, - ) - - try: - _transitive_preflight( - project_root=ctx.project_root, - mcp_deps=mcp_deps, - no_policy=ctx.no_policy, - logger=logger, - dry_run=False, - ) - except _TransitivePBE: - logger.error( - "MCP server(s) blocked by org policy. " - "APM packages remain installed; MCP configs were NOT written." - ) - logger.render_summary() - sys.exit(1) - - # Continue with MCP installation (existing logic) - mcp_count = 0 - new_mcp_servers: builtins.set = builtins.set() - # Forward only the targets-key the user actually declared so parse_targets_field - # in the gate sees the same dict shape it sees from raw apm.yml. Including a - # `targets: None` placeholder when the user wrote `target:` (singular) would - # falsely trip the conflict-mutex check (see core.apm_yml.parse_targets_field). - # This restores parity with `apm install` for users on the modern `targets:` - # plural form -- without this, `targets:` was silently dropped at the call - # site and the gate fell back to permissive directory detection (#1335). - mcp_apm_config: dict = {"scripts": apm_package.scripts or {}} - if apm_package.targets is not None: - mcp_apm_config["targets"] = apm_package.targets - elif apm_package.target is not None: - mcp_apm_config["target"] = apm_package.target - if should_install_mcp and mcp_deps: - mcp_count = MCPIntegrator.install( - mcp_deps, - ctx.runtime, - ctx.exclude, - ctx.verbose, - stored_mcp_configs=old_mcp_configs, - apm_config=mcp_apm_config, - project_root=ctx.project_root, - user_scope=(ctx.scope is InstallScope.USER), - explicit_target=ctx.target, - diagnostics=apm_diagnostics, - scope=ctx.scope, - ) - new_mcp_servers = MCPIntegrator.get_server_names(mcp_deps) - new_mcp_configs = MCPIntegrator.get_server_configs(mcp_deps) - - # Remove stale MCP servers that are no longer needed - stale_servers = old_mcp_servers - new_mcp_servers - if stale_servers: - MCPIntegrator.remove_stale( - stale_servers, - ctx.runtime, - ctx.exclude, - project_root=ctx.project_root, - user_scope=(ctx.scope is InstallScope.USER), - scope=ctx.scope, - ) - - # Persist the new MCP server set and configs in the lockfile - MCPIntegrator.update_lockfile(new_mcp_servers, _lock_path, mcp_configs=new_mcp_configs) - elif should_install_mcp and not mcp_deps: - # No MCP deps at all -- remove any old APM-managed servers - if old_mcp_servers: - MCPIntegrator.remove_stale( - old_mcp_servers, - ctx.runtime, - ctx.exclude, - project_root=ctx.project_root, - user_scope=(ctx.scope is InstallScope.USER), - scope=ctx.scope, - ) - MCPIntegrator.update_lockfile(builtins.set(), _lock_path, mcp_configs={}) - logger.verbose_detail("No MCP dependencies found in apm.yml") - elif not should_install_mcp and old_mcp_servers: - # --only=apm: APM install regenerated the lockfile and dropped - # mcp_servers. Restore the previous set so it is not lost. - MCPIntegrator.update_lockfile(old_mcp_servers, _lock_path, mcp_configs=old_mcp_configs) - - # ------------------------------------------------------------------------- - # LSP integration (extracted to install/lsp/integration.py) - # ------------------------------------------------------------------------- - from apm_cli.install.lsp import run_lsp_integration - - lsp_count = run_lsp_integration( - apm_package=apm_package, - apm_modules_path=apm_modules_path, - lock_path=_lock_path, - existing_lock=_existing_lock, - project_root=ctx.project_root, - user_scope=(ctx.scope is InstallScope.USER), - should_install=should_install_lsp, - logger=logger, - diagnostics=apm_diagnostics, - target_context=(mcp_apm_config, ctx.target, ctx.scope), - ) - - # Local .apm/ content integration is now handled inside the - # install pipeline (phases/integrate.py + phases/post_deps_local.py, - # refactor F3). The duplicate target resolution, integrator - # initialization, and inline stale-cleanup block that lived here - # have been removed. - - return apm_count, mcp_count, lsp_count, apm_diagnostics - - -def _post_install_summary( - *, logger, apm_count, mcp_count, lsp_count=0, apm_diagnostics, force, elapsed_seconds=None -): - """Thin shim forwarding to :func:`apm_cli.install.summary.render_post_install_summary`. - - Kept as a module-level alias so existing tests that - ``@patch("apm_cli.commands.install._post_install_summary")`` continue - to work after the extraction (microsoft/apm#1116, F5). - """ - from apm_cli.install.summary import render_post_install_summary - - render_post_install_summary( - logger=logger, - apm_count=apm_count, - mcp_count=mcp_count, - lsp_count=lsp_count, - apm_diagnostics=apm_diagnostics, - force=force, - elapsed_seconds=elapsed_seconds, - ) - - -# --------------------------------------------------------------------------- -# Install engine -# --------------------------------------------------------------------------- - - -# Re-exports for backward compatibility -- the real implementations live -# in apm_cli.install.services (P1 -- DI seam). Tests that -# @patch("apm_cli.commands.install._integrate_package_primitives") still -# work because patching this module-level alias rebinds the name where -# call-sites in this module would look it up. Tests inside this codebase -# now patch the canonical apm_cli.install.services._integrate_package_primitives -# directly to avoid relying on transitive aliasing. - - -# --------------------------------------------------------------------------- -# Pipeline entry point -- thin re-export preserving the patch path -# ``apm_cli.commands.install._install_apm_dependencies`` used by tests. -# -# The real implementation lives in ``apm_cli.install.pipeline`` (F2). -# --------------------------------------------------------------------------- -def _install_apm_dependencies( # noqa: PLR0913 - apm_package: "APMPackage", - update_refs: bool = False, - verbose: bool = False, - only_packages: "builtins.list | None" = None, - force: bool = False, - parallel_downloads: int = 4, - logger: "InstallLogger" = None, - scope=None, - auth_resolver: "AuthResolver" = None, - target: str | None = None, - allow_insecure: bool = False, - allow_insecure_hosts=(), - marketplace_provenance: dict = None, - protocol_pref=None, - allow_protocol_fallback: "bool | None" = None, - no_policy: bool = False, - audit_override: "str | None" = None, - skill_subset: "builtins.tuple | None" = None, - skill_subset_from_cli: bool = False, - legacy_skill_paths: bool = False, - frozen: bool = False, - plan_callback=None, - refresh: bool = False, - lockfile_only: bool = False, -): - """Thin wrapper -- builds an :class:`InstallRequest` and delegates to - :class:`apm_cli.install.service.InstallService`. - - Kept here so that ``@patch("apm_cli.commands.install._install_apm_dependencies")`` - continues to intercept calls from the Click handler. The service - itself is the typed Application Service entry point for any future - programmatic callers. - """ - if not APM_DEPS_AVAILABLE: - raise RuntimeError("APM dependency system not available") - - from apm_cli.install.request import InstallRequest - from apm_cli.install.service import InstallService - - request = InstallRequest( - apm_package=apm_package, - update_refs=update_refs, - verbose=verbose, - only_packages=only_packages, - force=force, - parallel_downloads=parallel_downloads, - logger=logger, - scope=scope, - auth_resolver=auth_resolver, - target=target, - allow_insecure=allow_insecure, - allow_insecure_hosts=allow_insecure_hosts, - marketplace_provenance=marketplace_provenance, - protocol_pref=protocol_pref, - allow_protocol_fallback=allow_protocol_fallback, - no_policy=no_policy, - audit_override=audit_override, - skill_subset=skill_subset, - skill_subset_from_cli=skill_subset_from_cli, - legacy_skill_paths=legacy_skill_paths, - frozen=frozen, - plan_callback=plan_callback, - refresh=refresh, - lockfile_only=lockfile_only, - ) - return InstallService().run(request) diff --git a/src/apm_cli/install/apm_packages.py b/src/apm_cli/install/apm_packages.py new file mode 100644 index 000000000..eb3683d29 --- /dev/null +++ b/src/apm_cli/install/apm_packages.py @@ -0,0 +1,458 @@ +"""APM package install pipeline and post-install summary helpers.""" + +import builtins +import sys + +from apm_cli.constants import APM_YML_FILENAME, InstallMode +from apm_cli.install.errors import ( + AuthenticationError, + FrozenInstallError, + PolicyViolationError, +) +from apm_cli.install.insecure_policy import InsecureDependencyPolicyError + + +def _install_apm_packages(ctx, outcome): + """Execute the APM + transitive MCP installation pipeline. + + Parses ``apm.yml``, installs APM dependencies, collects and installs + transitive MCP servers, and handles lockfile updates. + + Args: + ctx: :class:`InstallContext` with configuration and environment. + outcome: ``_ValidationOutcome`` from package validation (may be + ``None`` when no explicit packages were passed). + + Returns: + Tuple of ``(apm_count, mcp_count, lsp_count, apm_diagnostics)``. + """ + # RULE B: late import so @patch("apm_cli.commands.install.") intercepts + # calls to APMPackage, LockFile, _install_apm_dependencies, _maybe_rollback_manifest, + # _rich_*, _check_insecure_dependencies, migrate_lockfile_if_needed, + # get_lockfile_path, _project_has_root_primitives, APM_DEPS_AVAILABLE. + import apm_cli.commands.install as _m + + logger = ctx.logger + + logger.resolution_start( + to_install_count=len(ctx.only_packages or []) if ctx.packages else 0, + lockfile_count=0, # Refined later inside _install_apm_dependencies + ) + + # Parse apm.yml to get both APM and MCP dependencies + try: + apm_package = _m.APMPackage.from_apm_yml(ctx.manifest_path) + except Exception as e: + logger.error(f"Failed to parse {ctx.manifest_display}: {e}") + sys.exit(1) + + logger.verbose_detail( + f"Parsed {APM_YML_FILENAME}: {len(apm_package.get_apm_dependencies())} APM deps, " + f"{len(apm_package.get_mcp_dependencies())} MCP deps" + + ( + f", {len(apm_package.get_dev_apm_dependencies())} dev deps" + if apm_package.get_dev_apm_dependencies() + else "" + ) + ) + + # Get APM and MCP dependencies + apm_deps = apm_package.get_apm_dependencies() + dev_apm_deps = apm_package.get_dev_apm_dependencies() + has_any_apm_deps = bool(apm_deps) or bool(dev_apm_deps) + mcp_deps = apm_package.get_mcp_dependencies() + + all_apm_deps = builtins.list(apm_deps) + builtins.list(dev_apm_deps) + _m._check_insecure_dependencies(all_apm_deps, ctx.allow_insecure, logger) + + # Determine what to install based on install mode + should_install_apm = ctx.install_mode != InstallMode.MCP + should_install_mcp = ctx.install_mode != InstallMode.APM + should_install_lsp = should_install_mcp + + # Show what will be installed if dry run + if ctx.dry_run: + # -- W2-dry-run (#827): policy preflight in preview mode -- + # Runs discovery + checks against direct manifest deps (not + # resolved/transitive -- dry-run does not run the resolver). + # Block-severity violations render as "Would be blocked by + # policy" without raising. Documented limitation: transitive + # deps are NOT evaluated since the resolver does not run. + from apm_cli.policy.install_preflight import run_policy_preflight as _dr_preflight + + _dr_preflight( + project_root=ctx.project_root, + apm_deps=builtins.list(apm_deps) + builtins.list(dev_apm_deps), + mcp_deps=mcp_deps if should_install_mcp else None, + no_policy=ctx.no_policy, + logger=logger, + dry_run=True, + ) + + from apm_cli.install.presentation.dry_run import render_and_exit + + render_and_exit( + logger=logger, + should_install_apm=should_install_apm, + apm_deps=apm_deps, + mcp_deps=mcp_deps, + dev_apm_deps=dev_apm_deps, + should_install_mcp=should_install_mcp, + update=ctx.update, + only_packages=ctx.only_packages, + apm_dir=ctx.apm_dir, + ) + return 0, 0, 0, None # render_and_exit exits; this line is defensive + + # Install APM dependencies first (if requested) + apm_count = 0 + + # Migrate legacy apm.lock -> apm.lock.yaml if needed (one-time, transparent) + _m.migrate_lockfile_if_needed(ctx.apm_dir) + + # Capture old MCP servers and configs from lockfile BEFORE + # _install_apm_dependencies regenerates it (which drops the fields). + # We always read this -- even when --only=apm -- so we can restore the + # field after the lockfile is regenerated by the APM install step. + old_mcp_servers: builtins.set = builtins.set() + old_mcp_configs: builtins.dict = {} + _lock_path = _m.get_lockfile_path(ctx.apm_dir) + _existing_lock = _m.LockFile.read(_lock_path) + if _existing_lock: + old_mcp_servers = builtins.set(_existing_lock.mcp_servers) + old_mcp_configs = builtins.dict(_existing_lock.mcp_configs) + + # Enter the APM install path when there are deps, local .apm/ primitives + # (#714), OR orphan deps in the lockfile to clean up (manifest emptied). + from apm_cli.core.scope import InstallScope + from apm_cli.core.scope import get_deploy_root as _get_deploy_root + from apm_cli.deps.lockfile import _SELF_KEY as _LOCK_SELF_KEY + + _cli_project_root = _get_deploy_root(ctx.scope) + _has_orphan_deps_in_lock = bool( + _existing_lock + and not has_any_apm_deps + and any(k != _LOCK_SELF_KEY for k in _existing_lock.dependencies) + ) + apm_diagnostics = None + if should_install_apm and ( + has_any_apm_deps + or _m._project_has_root_primitives(_cli_project_root) + or _has_orphan_deps_in_lock + ): + if not _m.APM_DEPS_AVAILABLE: + logger.error("APM dependency system not available") + logger.progress(f"Import error: {_m._APM_IMPORT_ERROR}") + sys.exit(1) + + try: + # If specific packages were requested, only install those + # Otherwise install all from apm.yml. + # `only_packages` was computed above so the dry-run preview + # and the actual install share one canonical list. + install_result = _m._install_apm_dependencies( + apm_package, + ctx.update, + ctx.verbose, + ctx.only_packages, + force=ctx.force, + parallel_downloads=ctx.parallel_downloads, + logger=logger, + scope=ctx.scope, + auth_resolver=ctx.auth_resolver, + target=ctx.target, + allow_insecure=ctx.allow_insecure, + allow_insecure_hosts=ctx.allow_insecure_hosts, + marketplace_provenance=( + outcome.marketplace_provenance if ctx.packages and outcome else None + ), + protocol_pref=ctx.protocol_pref, + allow_protocol_fallback=ctx.allow_protocol_fallback, + no_policy=ctx.no_policy, + audit_override=ctx.audit_override, + legacy_skill_paths=ctx.legacy_skill_paths, + frozen=ctx.frozen, + plan_callback=ctx.plan_callback, + skill_subset=ctx.skill_subset, + skill_subset_from_cli=ctx.skill_subset_from_cli, + refresh=ctx.refresh, + ) + apm_count = install_result.installed_count + apm_diagnostics = install_result.diagnostics + except InsecureDependencyPolicyError: + _m._maybe_rollback_manifest(ctx.snapshot_manifest_path, ctx.manifest_snapshot, logger) + sys.exit(1) + except AuthenticationError as e: + # #1015: render auth diagnostics on the DEFAULT path (not --verbose). + _m._maybe_rollback_manifest(ctx.snapshot_manifest_path, ctx.manifest_snapshot, logger) + _m._rich_error(str(e)) + if e.diagnostic_context: + _m._rich_echo(e.diagnostic_context) + _m._rich_info("Tip: run 'apm doctor' to diagnose auth and connectivity.", symbol="info") + sys.exit(1) + except FrozenInstallError as e: + _m._maybe_rollback_manifest(ctx.snapshot_manifest_path, ctx.manifest_snapshot, logger) + _m._rich_error(str(e)) + for reason in e.reasons: + _m._rich_echo(reason) + _m._rich_info( + "Tip: run 'apm outdated' to see what changed, then 'apm update'.", + symbol="info", + ) + sys.exit(1) + except Exception as e: + _m._maybe_rollback_manifest(ctx.snapshot_manifest_path, ctx.manifest_snapshot, logger) + # #832: surface PolicyViolationError verbatim (no double-nesting). + msg = ( + str(e) + if isinstance(e, PolicyViolationError) + else f"Failed to install APM dependencies: {e}" + ) + logger.error(msg) + if not ctx.verbose: + logger.progress("Run with --verbose for detailed diagnostics") + sys.exit(1) + elif should_install_apm and not has_any_apm_deps: + logger.verbose_detail("No APM dependencies found in apm.yml") + + # When --update is used, package files on disk may have changed. + # Clear the parse cache so transitive MCP collection reads fresh data. + if ctx.update: + from apm_cli.models.apm_package import clear_apm_yml_cache + + clear_apm_yml_cache() + + # Collect transitive MCP dependencies from resolved APM packages + transitive_mcp = [] + from apm_cli.core.scope import get_modules_dir + from apm_cli.integration.mcp_integrator import MCPIntegrator + + apm_modules_path = get_modules_dir(ctx.scope) + if should_install_mcp and apm_modules_path.exists(): + lock_path = _m.get_lockfile_path(ctx.apm_dir) + transitive_mcp = MCPIntegrator.collect_transitive( + apm_modules_path, + lock_path, + ctx.trust_transitive_mcp, + diagnostics=apm_diagnostics, + ) + if transitive_mcp: + logger.verbose_detail(f"Collected {len(transitive_mcp)} transitive MCP dependency(ies)") + mcp_deps = MCPIntegrator.deduplicate(mcp_deps + transitive_mcp) + + # -- S1/S2 fix (#827-C2/C3): enforce policy on ALL MCP deps ---- + # The pipeline gate phase (policy_gate.py) checks direct APM deps + # and direct MCP deps from apm.yml. However, transitive MCP + # servers (discovered via collect_transitive above) are only known + # after APM packages are installed. Run a second preflight + # against the *merged* MCP set (direct + transitive) BEFORE + # MCPIntegrator writes runtime configs. On PolicyBlockError we + # abort the MCP write but leave already-installed APM packages + # in place (they were approved by the gate phase). + if should_install_mcp and mcp_deps: + from apm_cli.policy.install_preflight import ( + PolicyBlockError as _TransitivePBE, + ) + from apm_cli.policy.install_preflight import ( + run_policy_preflight as _transitive_preflight, + ) + + try: + _transitive_preflight( + project_root=ctx.project_root, + mcp_deps=mcp_deps, + no_policy=ctx.no_policy, + logger=logger, + dry_run=False, + ) + except _TransitivePBE: + logger.error( + "MCP server(s) blocked by org policy. " + "APM packages remain installed; MCP configs were NOT written." + ) + logger.render_summary() + sys.exit(1) + + # Continue with MCP installation (existing logic) + mcp_count = 0 + new_mcp_servers: builtins.set = builtins.set() + # Forward only the targets-key the user actually declared so parse_targets_field + # in the gate sees the same dict shape it sees from raw apm.yml. Including a + # `targets: None` placeholder when the user wrote `target:` (singular) would + # falsely trip the conflict-mutex check (see core.apm_yml.parse_targets_field). + # This restores parity with `apm install` for users on the modern `targets:` + # plural form -- without this, `targets:` was silently dropped at the call + # site and the gate fell back to permissive directory detection (#1335). + mcp_apm_config: dict = {"scripts": apm_package.scripts or {}} + if apm_package.targets is not None: + mcp_apm_config["targets"] = apm_package.targets + elif apm_package.target is not None: + mcp_apm_config["target"] = apm_package.target + if should_install_mcp and mcp_deps: + mcp_count = MCPIntegrator.install( + mcp_deps, + ctx.runtime, + ctx.exclude, + ctx.verbose, + stored_mcp_configs=old_mcp_configs, + apm_config=mcp_apm_config, + project_root=ctx.project_root, + user_scope=(ctx.scope is InstallScope.USER), + explicit_target=ctx.target, + diagnostics=apm_diagnostics, + scope=ctx.scope, + ) + new_mcp_servers = MCPIntegrator.get_server_names(mcp_deps) + new_mcp_configs = MCPIntegrator.get_server_configs(mcp_deps) + + # Remove stale MCP servers that are no longer needed + stale_servers = old_mcp_servers - new_mcp_servers + if stale_servers: + MCPIntegrator.remove_stale( + stale_servers, + ctx.runtime, + ctx.exclude, + project_root=ctx.project_root, + user_scope=(ctx.scope is InstallScope.USER), + scope=ctx.scope, + ) + + # Persist the new MCP server set and configs in the lockfile + MCPIntegrator.update_lockfile(new_mcp_servers, _lock_path, mcp_configs=new_mcp_configs) + elif should_install_mcp and not mcp_deps: + # No MCP deps at all -- remove any old APM-managed servers + if old_mcp_servers: + MCPIntegrator.remove_stale( + old_mcp_servers, + ctx.runtime, + ctx.exclude, + project_root=ctx.project_root, + user_scope=(ctx.scope is InstallScope.USER), + scope=ctx.scope, + ) + MCPIntegrator.update_lockfile(builtins.set(), _lock_path, mcp_configs={}) + logger.verbose_detail("No MCP dependencies found in apm.yml") + elif not should_install_mcp and old_mcp_servers: + # --only=apm: APM install regenerated the lockfile and dropped + # mcp_servers. Restore the previous set so it is not lost. + MCPIntegrator.update_lockfile(old_mcp_servers, _lock_path, mcp_configs=old_mcp_configs) + + # ------------------------------------------------------------------------- + # LSP integration (extracted to install/lsp/integration.py) + # ------------------------------------------------------------------------- + from apm_cli.install.lsp import run_lsp_integration + + lsp_count = run_lsp_integration( + apm_package=apm_package, + apm_modules_path=apm_modules_path, + lock_path=_lock_path, + existing_lock=_existing_lock, + project_root=ctx.project_root, + user_scope=(ctx.scope is InstallScope.USER), + should_install=should_install_lsp, + logger=logger, + diagnostics=apm_diagnostics, + target_context=(mcp_apm_config, ctx.target, ctx.scope), + ) + + # Local .apm/ content integration is now handled inside the + # install pipeline (phases/integrate.py + phases/post_deps_local.py, + # refactor F3). The duplicate target resolution, integrator + # initialization, and inline stale-cleanup block that lived here + # have been removed. + + return apm_count, mcp_count, lsp_count, apm_diagnostics + + +def _post_install_summary( + *, logger, apm_count, mcp_count, lsp_count=0, apm_diagnostics, force, elapsed_seconds=None +): + """Thin shim forwarding to :func:`apm_cli.install.summary.render_post_install_summary`. + + Kept as a module-level alias so existing tests that + ``@patch("apm_cli.commands.install._post_install_summary")`` continue + to work after the extraction (microsoft/apm#1116, F5). + """ + from apm_cli.install.summary import render_post_install_summary + + render_post_install_summary( + logger=logger, + apm_count=apm_count, + mcp_count=mcp_count, + lsp_count=lsp_count, + apm_diagnostics=apm_diagnostics, + force=force, + elapsed_seconds=elapsed_seconds, + ) + + +def _install_apm_dependencies( # noqa: PLR0913 + apm_package, + update_refs: bool = False, + verbose: bool = False, + only_packages=None, + force: bool = False, + parallel_downloads: int = 4, + logger=None, + scope=None, + auth_resolver=None, + target=None, + allow_insecure: bool = False, + allow_insecure_hosts=(), + marketplace_provenance=None, + protocol_pref=None, + allow_protocol_fallback=None, + no_policy: bool = False, + audit_override=None, + skill_subset=None, + skill_subset_from_cli: bool = False, + legacy_skill_paths: bool = False, + frozen: bool = False, + plan_callback=None, + refresh: bool = False, + lockfile_only: bool = False, +): + """Thin wrapper -- builds an :class:`InstallRequest` and delegates to + :class:`apm_cli.install.service.InstallService`. + + Kept here so that ``@patch("apm_cli.commands.install._install_apm_dependencies")`` + continues to intercept calls from the Click handler. The service + itself is the typed Application Service entry point for any future + programmatic callers. + """ + # RULE B: access APM_DEPS_AVAILABLE via install module so patches are honoured. + import apm_cli.commands.install as _m + + if not _m.APM_DEPS_AVAILABLE: + raise RuntimeError("APM dependency system not available") + + from apm_cli.install.request import InstallRequest + from apm_cli.install.service import InstallService + + request = InstallRequest( + apm_package=apm_package, + update_refs=update_refs, + verbose=verbose, + only_packages=only_packages, + force=force, + parallel_downloads=parallel_downloads, + logger=logger, + scope=scope, + auth_resolver=auth_resolver, + target=target, + allow_insecure=allow_insecure, + allow_insecure_hosts=allow_insecure_hosts, + marketplace_provenance=marketplace_provenance, + protocol_pref=protocol_pref, + allow_protocol_fallback=allow_protocol_fallback, + no_policy=no_policy, + audit_override=audit_override, + skill_subset=skill_subset, + skill_subset_from_cli=skill_subset_from_cli, + legacy_skill_paths=legacy_skill_paths, + frozen=frozen, + plan_callback=plan_callback, + refresh=refresh, + lockfile_only=lockfile_only, + ) + return InstallService().run(request) diff --git a/src/apm_cli/install/cli_context.py b/src/apm_cli/install/cli_context.py new file mode 100644 index 000000000..70d152452 --- /dev/null +++ b/src/apm_cli/install/cli_context.py @@ -0,0 +1,56 @@ +"""InstallContext dataclass: parameter bundle for the APM install CLI command.""" + +import builtins +import dataclasses +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from collections.abc import Callable + + from apm_cli.install.plan import UpdatePlan + + +@dataclasses.dataclass +class InstallContext: + """Bundles install command state to reduce function signatures. + + Created by :func:`install` after argument parsing and scope resolution, + then threaded through :func:`_install_apm_packages` and + :func:`_post_install_summary` to avoid long parameter lists. + """ + + scope: Any # InstallScope + manifest_path: "Path" + manifest_display: str + apm_dir: "Path" + project_root: "Path" + logger: Any # InstallLogger + auth_resolver: Any # AuthResolver + verbose: bool + force: bool + dry_run: bool + update: bool + dev: bool + runtime: str | None + exclude: str | None + target: str | None + parallel_downloads: int + allow_insecure: bool + allow_insecure_hosts: tuple + protocol_pref: Any # ProtocolPreference + allow_protocol_fallback: bool + trust_transitive_mcp: bool + no_policy: bool + install_mode: Any # InstallMode + packages: tuple # Original Click packages + refresh: bool = False + only_packages: builtins.list | None = None + manifest_snapshot: bytes | None = None + snapshot_manifest_path: Optional["Path"] = None + legacy_skill_paths: bool = False + frozen: bool = False + plan_callback: "Callable[[UpdatePlan], bool] | None" = None + skill_subset: "builtins.tuple[str, ...] | None" = None + skill_subset_from_cli: bool = False + audit_override: str | None = None diff --git a/src/apm_cli/install/install_cmd_phases.py b/src/apm_cli/install/install_cmd_phases.py new file mode 100644 index 000000000..51fb88ec4 --- /dev/null +++ b/src/apm_cli/install/install_cmd_phases.py @@ -0,0 +1,208 @@ +"""Extracted install command phase helpers to reduce install() complexity.""" + +import builtins +import sys +import time +from pathlib import Path + +import click + + +def _try_local_bundle_install( + packages, + mcp_name, + target, + global_, + force, + dry_run, + verbose, + alias, + logger, + legacy_skill_paths, + rejected_flags, +): + """Detect and install a local bundle; return True if handled (caller should return). + + Raises :class:`click.UsageError` for unrecognised tarballs. Falls through + (returns False) when the sole positional arg is not a recognised bundle. + """ + if not (len(packages) == 1 and not mcp_name and (_probe := Path(packages[0])).exists()): + return False + + from apm_cli.bundle.local_bundle import detect_local_bundle as _detect_lb + from apm_cli.install.local_bundle_handler import install_local_bundle as _install_lb + + _bundle_info = _detect_lb(_probe) + if _bundle_info is not None: + _install_lb( + bundle_info=_bundle_info, + bundle_arg=packages[0], + target=target, + global_=global_, + force=force, + dry_run=dry_run, + verbose=verbose, + alias=alias, + logger=logger, + legacy_skill_paths=legacy_skill_paths, + rejected_flags=rejected_flags, + ) + return True + # IM7: path exists but isn't a recognised bundle. For tarball extensions + # (.tar.gz / .tgz) the user clearly meant a bundle artifact. + _suffix = _probe.name.lower() + if _probe.is_file() and (_suffix.endswith(".tar.gz") or _suffix.endswith(".tgz")): + from apm_cli.bundle.local_bundle import _looks_like_legacy_apm_bundle + + if _looks_like_legacy_apm_bundle(_probe): + raise click.UsageError( + f"'{packages[0]}' was packed with '--format apm' (legacy format). " + "'apm install ' requires the plugin format. " + "Repack with 'apm pack --format plugin --archive', " + "or use 'apm unpack' to deploy the legacy bundle." + ) + raise click.UsageError( + f"'{packages[0]}' is not a valid APM bundle archive " + "(no plugin.json found at the bundle root). " + "Use 'apm install org/package' for registry installs, " + "or repack the source with 'apm pack'." + ) + return False + + +def _resolve_protocol_and_fallback(use_ssh, use_https, allow_protocol_fallback): + """Return ``(protocol_pref, allow_protocol_fallback)`` or ``(None, None)`` on conflict. + + Caller must check for ``(None, None)`` and emit the mutual-exclusion error. + """ + from apm_cli.config import get_apm_allow_protocol_fallback, get_apm_protocol_pref + from apm_cli.deps.transport_selection import ProtocolPreference + + if use_ssh and use_https: + return None, None + if use_ssh: + pref = ProtocolPreference.SSH + elif use_https: + pref = ProtocolPreference.HTTPS + else: + pref = ProtocolPreference.from_str(get_apm_protocol_pref()) + fallback = allow_protocol_fallback or get_apm_allow_protocol_fallback() + return pref, fallback + + +def _resolve_scope_and_paths(global_, logger): + """Resolve install scope, paths, and scope-specific warnings. + + Returns ``(scope, manifest_path, apm_dir, manifest_display, project_root)``. + """ + from apm_cli.constants import APM_YML_FILENAME + from apm_cli.core.scope import ( + InstallScope, + ensure_user_dirs, + get_apm_dir, + get_deploy_root, + get_manifest_path, + warn_unsupported_user_scope, + ) + + scope = InstallScope.USER if global_ else InstallScope.PROJECT + if scope is InstallScope.USER: + ensure_user_dirs() + logger.progress("Installing to user scope (~/.apm/)") + _scope_warn = warn_unsupported_user_scope() + if _scope_warn: + logger.warning(_scope_warn) + + manifest_path = get_manifest_path(scope) + apm_dir = get_apm_dir(scope) + manifest_display = str(manifest_path) if scope is InstallScope.USER else APM_YML_FILENAME + project_root = get_deploy_root(scope) + return scope, manifest_path, apm_dir, manifest_display, project_root + + +def _setup_auth_and_check_manifest(scope, packages, manifest_path, manifest_display, logger): + """Create shared :class:`AuthResolver`; bootstrap or error-check ``apm.yml``. + + May call ``sys.exit(1)`` when no manifest exists and no packages are given. + Returns the newly constructed :class:`AuthResolver` instance. + """ + # RULE B: AuthResolver is patched at apm_cli.commands.install.AuthResolver in tests. + import apm_cli.commands.install as _m + from apm_cli.commands._helpers import _create_minimal_apm_yml, _get_default_config + from apm_cli.core.scope import InstallScope + + auth_resolver = _m.AuthResolver() + # F2/F3 #856: thread the InstallLogger into AuthResolver so the verbose + # auth-source line and the deferred stale-PAT [!] warning route through + # CommandLogger / DiagnosticCollector instead of stderr/inline writes. + auth_resolver.set_logger(logger) + + apm_yml_exists = manifest_path.exists() + + # Auto-bootstrap: create minimal apm.yml when packages specified but no apm.yml + if not apm_yml_exists and packages: + project_name = Path.cwd().name if scope is InstallScope.PROJECT else Path.home().name + config = _get_default_config(project_name) + _create_minimal_apm_yml(config, target_path=manifest_path) + logger.success(f"Created {manifest_display}") + + # Error when NO apm.yml AND NO packages + if not apm_yml_exists and not packages: + logger.error(f"No {manifest_display} found") + if scope is InstallScope.USER: + logger.progress("Run 'apm install -g ' to auto-create + install") + else: + logger.progress("Run 'apm init' to create one, or:") + logger.progress(" apm install to auto-create + install") + sys.exit(1) + + return auth_resolver + + +def _execute_install_and_summary(install_ctx, outcome, frozen, install_started_at): + """Run the install pipeline and emit the post-install summary. + + Returns ``apm_count`` (number of installed APM packages). + """ + # RULE B: _install_apm_packages, _post_install_summary, _rich_info are patched + # at apm_cli.commands.install.* in tests. + import apm_cli.commands.install as _m + + apm_count, mcp_count, lsp_count, apm_diagnostics = _m._install_apm_packages( + install_ctx, outcome + ) + _m._post_install_summary( + logger=install_ctx.logger, + apm_count=apm_count, + mcp_count=mcp_count, + lsp_count=lsp_count, + apm_diagnostics=apm_diagnostics, + force=install_ctx.force, + elapsed_seconds=time.perf_counter() - install_started_at, + ) + if frozen and apm_count > 0: + # --frozen verifies LOCKFILE STRUCTURE (every apm.yml dep has a lock entry), + # not on-disk content integrity. + _m._rich_info( + "Lockfile presence verified. Run 'apm audit' for on-disk content integrity.", + symbol="info", + ) + return apm_count + + +def _compute_argv_pre_dash(packages): + """Return ``(command_argv, pre_dash_packages)`` by splitting sys.argv at ``--``. + + Uses RULE B so ``@patch("apm_cli.commands.install._split_argv_at_double_dash")`` + and ``@patch("apm_cli.commands.install._get_invocation_argv")`` still work. + """ + # RULE B: argv helpers are @patch targets in tests. + import apm_cli.commands.install as _m + + _, command_argv = _m._split_argv_at_double_dash(_m._get_invocation_argv()) + if command_argv: + split_idx = max(len(packages) - len(command_argv), 0) + pre_dash_packages = builtins.tuple(packages[:split_idx]) + else: + pre_dash_packages = builtins.tuple(packages) + return command_argv, pre_dash_packages diff --git a/src/apm_cli/install/manifest_rollback.py b/src/apm_cli/install/manifest_rollback.py new file mode 100644 index 000000000..93f52388b --- /dev/null +++ b/src/apm_cli/install/manifest_rollback.py @@ -0,0 +1,59 @@ +"""Manifest snapshot and rollback helpers for APM install.""" + +import contextlib +import os +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from apm_cli.core.command_logger import InstallLogger + + +def _restore_manifest_from_snapshot( + manifest_path: "Path", + snapshot: bytes, +) -> None: + """Atomically restore ``apm.yml`` from a raw-bytes snapshot. + + Uses temp-file + ``os.replace`` to avoid torn writes, mirroring the + W1 cache atomic-write pattern (``discovery.py``). + """ + import tempfile + + fd, tmp_name = tempfile.mkstemp( + prefix="apm-restore-", + dir=str(manifest_path.parent), + ) + try: + with os.fdopen(fd, "wb") as fh: + fh.write(snapshot) + os.replace(tmp_name, str(manifest_path)) + except Exception: + with contextlib.suppress(OSError): + os.unlink(tmp_name) + raise + + +def _maybe_rollback_manifest( + manifest_path: "Path", + snapshot: "bytes | None", + logger: "InstallLogger", +) -> None: + """Restore ``apm.yml`` from *snapshot* if one was captured, then log. + + No-op when *snapshot* is ``None`` (i.e. the command was not + ``apm install `` or the manifest did not exist before mutation). + """ + if snapshot is None: + return + # RULE B: _restore_manifest_from_snapshot is patched at + # apm_cli.commands.install.* in tests that exercise this function. + import apm_cli.commands.install as _m + + try: + _m._restore_manifest_from_snapshot(manifest_path, snapshot) + logger.progress("apm.yml restored to its previous state.") + except Exception: + # Best-effort: if the restore itself fails, warn but don't mask + # the original exception that triggered the rollback. + logger.warning("Failed to restore apm.yml to its previous state.") diff --git a/src/apm_cli/install/mcp_handler.py b/src/apm_cli/install/mcp_handler.py new file mode 100644 index 000000000..0dc3d88a5 --- /dev/null +++ b/src/apm_cli/install/mcp_handler.py @@ -0,0 +1,134 @@ +"""MCP install handler for the APM CLI install command.""" + +import dataclasses +import sys +from pathlib import Path + + +@dataclasses.dataclass +class _McpConnectionParams: + """Groups MCP server connection/definition params to reduce install arg count.""" + + transport: str | None + url: str | None + env_pairs: tuple + header_pairs: tuple + mcp_version: str | None + + +def _handle_mcp_install( + *, + mcp_name, + mcp_conn: _McpConnectionParams, + command_argv, + dev, + force, + runtime, + exclude, + verbose, + logger, + no_policy, + validated_registry_url, +): + """Execute the ``--mcp`` install path (MCP server add). + + Resolves registry URL, runs policy preflight, handles dry-run, + and delegates to :func:`_run_mcp_install` for the actual installation. + Called from :func:`install` when ``--mcp`` is specified; the caller + returns immediately after this function completes. + """ + # RULE B: _run_mcp_install is patched at apm_cli.commands.install.* in tests. + import apm_cli.commands.install as _m + from apm_cli.install.mcp.registry import resolve_registry_url as _resolve_registry_url + from apm_cli.install.mcp.registry import ( + validate_mcp_dry_run_entry as _validate_mcp_dry_run_entry, + ) + + transport = mcp_conn.transport + url = mcp_conn.url + env_pairs = mcp_conn.env_pairs + header_pairs = mcp_conn.header_pairs + mcp_version = mcp_conn.mcp_version + + from ..core.scope import ( + InstallScope, + get_apm_dir, + get_manifest_path, + ) + + # Apply CLI > env > default precedence; emit override diagnostic. + resolved_registry_url, _registry_source = _resolve_registry_url( + validated_registry_url, + logger=logger, + ) + mcp_scope = InstallScope.PROJECT + mcp_manifest_path = get_manifest_path(mcp_scope) + mcp_apm_dir = get_apm_dir(mcp_scope) + # -- W2-mcp-preflight: policy enforcement before MCP install -- + # Build a lightweight MCPDependency for policy evaluation. + # This mirrors _build_mcp_entry routing but we only need the + # fields that policy checks inspect (name, transport, registry). + from ..models.dependency.mcp import MCPDependency as _MCPDep + from ..policy.install_preflight import ( + PolicyBlockError, + run_policy_preflight, + ) + + _is_self_defined = bool(url or command_argv) + _preflight_transport = transport + if _preflight_transport is None: + if command_argv: + _preflight_transport = "stdio" + elif url: + _preflight_transport = "http" + _preflight_dep = _MCPDep( + name=mcp_name, + transport=_preflight_transport, + registry=False if _is_self_defined else None, + url=url, + ) + + try: + _pf_result, _pf_active = run_policy_preflight( + project_root=Path.cwd(), + mcp_deps=[_preflight_dep], + no_policy=no_policy, + logger=logger, + dry_run=logger.dry_run, + ) + except PolicyBlockError: + # Diagnostics already emitted by the helper + logger. + logger.render_summary() + sys.exit(1) + + if logger.dry_run: + # C1: validate eagerly so dry-run rejects what real install would. + _validate_mcp_dry_run_entry( + mcp_name, + transport=transport, + url=url, + env=env_pairs, + headers=header_pairs, + version=mcp_version, + command_argv=command_argv, + registry_url=resolved_registry_url, + ) + logger.dry_run_notice(f"would add MCP server '{mcp_name}' to {mcp_manifest_path}") + return + _m._run_mcp_install( + mcp_name=mcp_name, + transport=transport, + url=url, + env_pairs=env_pairs, + header_pairs=header_pairs, + mcp_version=mcp_version, + command_argv=command_argv, + dev=dev, + force=force, + runtime=runtime, + exclude=exclude, + logger=logger, + apm_dir=mcp_apm_dir, + scope=mcp_scope, + registry_url=validated_registry_url, + ) diff --git a/src/apm_cli/install/pkg_resolution.py b/src/apm_cli/install/pkg_resolution.py new file mode 100644 index 000000000..cbb4a0259 --- /dev/null +++ b/src/apm_cli/install/pkg_resolution.py @@ -0,0 +1,486 @@ +"""Package validation and manifest update helpers for the APM install command.""" + +import builtins +import sys + +from apm_cli.constants import APM_YML_FILENAME +from apm_cli.install.artifactory_resolver import _resolve_artifactory_boundary +from apm_cli.install.insecure_policy import ( + _format_insecure_dependency_requirements, + _get_insecure_dependency_url, +) +from apm_cli.install.package_resolution import ( + dependency_reference_to_yaml_entry, + update_existing_dependency_entry_if_needed, +) +from apm_cli.install.validation import _local_path_failure_reason + + +def _check_package_conflicts(current_deps): + """Build identity set from existing deps for duplicate detection. + + Parses each entry in *current_deps* (string or dict form) through + :class:`DependencyReference` and collects identity strings. + + Returns: + ``set`` of identity strings for existing dependencies. + """ + # RULE B: DependencyReference is patched at apm_cli.commands.install.* in tests. + import apm_cli.commands.install as _m + + DependencyReference = _m.DependencyReference + + existing_identities = builtins.set() + for dep_entry in current_deps: + try: + if isinstance(dep_entry, str): + ref = DependencyReference.parse(dep_entry) + elif isinstance(dep_entry, builtins.dict): + ref = DependencyReference.parse_from_dict(dep_entry) + else: + continue + existing_identities.add(ref.get_identity()) + except (ValueError, TypeError, AttributeError, KeyError): + continue + return existing_identities + + +def _resolve_package_references( + packages, + current_deps, + existing_identities, + *, + auth_resolver=None, + logger=None, + scope=None, + allow_insecure=False, + skill_subset=None, +): + """Validate, canonicalize, and resolve package references. + + Handles marketplace refs, canonical parsing, insecure-URL guards, + local-at-user-scope rejection, and accessibility checks. + + *existing_identities* is mutated (new identities are added to prevent + duplicates within the same batch). + + Returns: + Tuple of ``(valid_outcomes, invalid_outcomes, validated_packages, + marketplace_provenance, apm_yml_entries, dependencies_changed)``. + """ + # RULE B: DependencyReference and _validate_package_exists are patched at + # apm_cli.commands.install.* in tests that call this function directly. + import apm_cli.commands.install as _m + + DependencyReference = _m.DependencyReference + _vpe = _m._validate_package_exists + + valid_outcomes = [] # (canonical, already_present) tuples + invalid_outcomes = [] # (package, reason) tuples + _marketplace_provenance = {} # canonical -> {discovered_via, marketplace_plugin_name} + _apm_yml_entries = {} # canonical -> apm.yml entry (str or dict for HTTP deps) + validated_packages = [] + dependencies_changed = False + + if logger: + logger.validation_start(len(packages)) + + for package in packages: + # --- Marketplace pre-parse intercept --- + # If input has no slash and is not a local path, check if it is a + # marketplace ref (NAME@MARKETPLACE). If so, resolve it to a + # canonical owner/repo[#ref] string before entering the standard + # parse path. Anything that doesn't match is rejected as an + # invalid format. + marketplace_provenance = None + marketplace_dep_ref = None + if "/" not in package and not DependencyReference.is_local_path(package): + try: + from apm_cli.marketplace.resolver import ( + parse_marketplace_ref, + resolve_marketplace_plugin, + ) + + mkt_ref = parse_marketplace_ref(package) + except ImportError: + mkt_ref = None + + if mkt_ref is not None: + plugin_name, marketplace_name, version_spec = mkt_ref + try: + warning_handler = None + if logger: + + def warning_handler(msg): + return logger.warning(msg) + + logger.verbose_detail( + f" Resolving {plugin_name}@{marketplace_name} via marketplace..." + ) + resolution = resolve_marketplace_plugin( + plugin_name, + marketplace_name, + version_spec=version_spec, + auth_resolver=auth_resolver, + warning_handler=warning_handler, + ) + canonical_str, _resolved_plugin = resolution + if logger: + logger.verbose_detail(f" Resolved to: {canonical_str}") + # #1326: dependency-confusion fail-closed gate. + # Bare ``owner/repo`` on *.ghe.com falls back to + # github.com -- refuse before outbound validation so + # no probe reaches a potentially attacker-controlled URL. + # Escape hatch: host-qualify ``repo:`` in marketplace.json. + _risk = resolution.cross_repo_misconfig_risk + if _risk is not None: + _lead = ( + f"refused (dependency-confusion risk #1326): bare" + f" `repo: {_risk.bare_repo_field}` on enterprise" + f" marketplace '{_risk.marketplace_host}' is ambiguous." + f" Host-qualify the plugin `repo` field in" + f" marketplace.json to one of:" + ) + reason = "\n".join( + [ + _lead, + f" - '{_risk.suggested_qualified_repo}' (enterprise dep on this marketplace)", + f" - 'github.com/{_risk.bare_repo_field}' (declared cross-host dep on public github.com)", + ] + ) + invalid_outcomes.append((package, reason)) + if logger: + logger.validation_fail(package, reason) + continue + marketplace_provenance = { + "discovered_via": marketplace_name, + "marketplace_plugin_name": plugin_name, + } + package = canonical_str + marketplace_dep_ref = getattr(resolution, "dependency_reference", None) + except Exception as mkt_err: + reason = str(mkt_err) + invalid_outcomes.append((package, reason)) + if logger: + logger.validation_fail(package, reason) + continue + else: + # No slash, not a local path, and not a marketplace ref + reason = "invalid format -- use 'owner/repo' or 'plugin-name@marketplace'" + invalid_outcomes.append((package, reason)) + if logger: + logger.validation_fail(package, reason) + continue + + # Canonicalize input + try: + dep_ref, direct_virtual_resolved = _m.resolve_parsed_dependency_reference( + package, + marketplace_dep_ref, + dependency_reference_cls=DependencyReference, + try_resolve_gitlab_direct_shorthand=_m._try_resolve_gitlab_direct_shorthand, + resolve_artifactory_boundary=_resolve_artifactory_boundary, + auth_resolver=auth_resolver, + verbose=bool(logger and logger.verbose), + logger=logger, + ) + canonical = dep_ref.to_canonical() + identity = dep_ref.get_identity() + # Attach --skill filter so to_apm_yml_entry() emits the dict form + if skill_subset: + # Normalize: strip whitespace, drop empty strings, deduplicate + # (preserve order) so invalid or redundant names can't persist. + _seen: builtins.set[str] = builtins.set() + _normalized: builtins.list[str] = [] + for _s in skill_subset: + _s = _s.strip() + if _s and _s not in _seen: + _seen.add(_s) + _normalized.append(_s) + dep_ref.skill_subset = _normalized + if marketplace_dep_ref is not None or direct_virtual_resolved: + _apm_yml_entries[canonical] = dependency_reference_to_yaml_entry(dep_ref) + except ValueError as e: + reason = str(e) + invalid_outcomes.append((package, reason)) + if logger: + logger.validation_fail(package, reason) + continue + + if dep_ref.is_insecure: + if not allow_insecure: + # The reason string embeds the full URL already, so skip + # logger.validation_fail (which prepends "{package} -- ") to + # avoid rendering the URL twice. Use logger.error directly. + reason = _format_insecure_dependency_requirements( + _get_insecure_dependency_url(dep_ref) + ) + invalid_outcomes.append((package, reason)) + if logger: + logger.error(reason) + continue + dep_ref.allow_insecure = True + _apm_yml_entries[canonical] = dep_ref.to_apm_yml_entry() + + scope_reject = _m.user_scope_rejection_reason(dep_ref, scope) + if scope_reject: + invalid_outcomes.append((package, scope_reject)) + if logger: + logger.validation_fail(package, scope_reject) + continue + + # Ensure structured entry is used for apm.yml persistence when skill + # filter is active (normal non-marketplace/non-insecure path doesn't + # set _apm_yml_entries; _merge_packages_into_yml falls back to the + # plain canonical string without this). + if skill_subset and canonical not in _apm_yml_entries: + _apm_yml_entries[canonical] = dep_ref.to_apm_yml_entry() + + # Check if package is already in dependencies (by identity) + already_in_deps = identity in existing_identities + + # Validate package exists and is accessible + verbose = bool(logger and logger.verbose) + if _vpe( + package, + verbose=verbose, + auth_resolver=auth_resolver, + logger=logger, + dep_ref=dep_ref, + ): + updates_existing_entry = update_existing_dependency_entry_if_needed( + current_deps, + already_in_deps=already_in_deps, + apm_yml_entries=_apm_yml_entries, + canonical=canonical, + dep_ref=dep_ref, + identity=identity, + dependency_reference_cls=DependencyReference, + logger=logger, + ) + valid_outcomes.append((canonical, already_in_deps)) + if logger: + logger.validation_pass(canonical, already_in_deps, updates_existing_entry) + + if not already_in_deps: + validated_packages.append(canonical) + existing_identities.add(identity) # prevent duplicates within batch + dependencies_changed = dependencies_changed or updates_existing_entry + if marketplace_provenance: + _marketplace_provenance[identity] = marketplace_provenance + else: + reason = _local_path_failure_reason(dep_ref) + if not reason: + # Round-4 panel fix (devx-ux): name the four-step probe + # chain explicitly when the validator exhausted it + # (virtual subdirectory + explicit ref). Generic "not + # accessible" hides the failure mode for the precise + # case where the most diagnostics are available. + is_subdir_ref_chain = ( + dep_ref.is_virtual + and dep_ref.is_virtual_subdirectory() + and bool(dep_ref.reference) + ) + if is_subdir_ref_chain: + reason = ( + "all probes failed (marker-file, Contents API, " + "git ls-remote, shallow-fetch) -- verify the path " + "and ref exist and that your credentials have " + "read access" + ) + if not verbose: + reason += " (run with --verbose for the full probe log)" + else: + reason = "not accessible or doesn't exist" + if not verbose: + reason += " -- run with --verbose for auth details" + invalid_outcomes.append((package, reason)) + if logger: + logger.validation_fail(package, reason) + + return ( + valid_outcomes, + invalid_outcomes, + validated_packages, + _marketplace_provenance, + _apm_yml_entries, + dependencies_changed, + ) + + +def _merge_packages_into_yml( + validated_packages, + apm_yml_entries, + current_deps, + data, + dep_section, + apm_yml_path, + *, + dev=False, + logger=None, +): + """Append *validated_packages* to the dependency list and write apm.yml. + + Mutates *current_deps* in place and persists the updated manifest to + *apm_yml_path*. + """ + # RULE B: _rich_error is patched at apm_cli.commands.install.* in tests. + import apm_cli.commands.install as _m + + dep_label = "devDependencies" if dev else "apm.yml" + for package in validated_packages: + current_deps.append(apm_yml_entries.get(package, package)) + if logger: + logger.verbose_detail(f"Added {package} to {dep_label}") + + # Update dependencies + data[dep_section]["apm"] = current_deps + + # Write back to apm.yml + try: + from apm_cli.utils.yaml_io import dump_yaml + + dump_yaml(data, apm_yml_path) + if logger: + logger.success( + f"Updated {APM_YML_FILENAME} with {len(validated_packages)} new package(s)" + ) + except Exception as e: + if logger: + logger.error(f"Failed to write {APM_YML_FILENAME}: {e}") + else: + _m._rich_error(f"Failed to write {APM_YML_FILENAME}: {e}") + sys.exit(1) + + +def _validate_and_add_packages_to_apm_yml( + packages, + dry_run=False, + dev=False, + logger=None, + manifest_path=None, + auth_resolver=None, + scope=None, + allow_insecure=False, + skill_subset=None, +): + """Validate packages exist and can be accessed, then add to apm.yml dependencies section. + + Implements normalize-on-write: any input form (HTTPS URL, SSH URL, FQDN, shorthand) + is canonicalized before storage. Default host (github.com) is stripped; + non-default hosts are preserved. Duplicates are detected by identity. + + Args: + packages: Package specifiers to validate and add. + dry_run: If True, only show what would be added. + dev: If True, write to devDependencies instead of dependencies. + logger: InstallLogger for structured output. + manifest_path: Explicit path to apm.yml (defaults to cwd/apm.yml). + auth_resolver: Shared auth resolver for caching credentials. + scope: InstallScope controlling project vs user deployment. + + Returns: + Tuple of (validated_packages list, _ValidationOutcome). + """ + from pathlib import Path + + # RULE B: _check_package_conflicts, _resolve_package_references, _merge_packages_into_yml, + # and _rich_error are all patched at apm_cli.commands.install.* in tests. + import apm_cli.commands.install as _m + + apm_yml_path = manifest_path or Path(APM_YML_FILENAME) + + # Read current apm.yml + try: + from apm_cli.utils.yaml_io import load_yaml + + data = load_yaml(apm_yml_path) or {} + except Exception as e: + if logger: + logger.error(f"Failed to read {APM_YML_FILENAME}: {e}") + else: + _m._rich_error(f"Failed to read {APM_YML_FILENAME}: {e}") + sys.exit(1) + + # Ensure dependencies structure exists + dep_section = "devDependencies" if dev else "dependencies" + if dep_section not in data: + data[dep_section] = {} + if "apm" not in data[dep_section]: + data[dep_section]["apm"] = [] + + current_deps = data[dep_section]["apm"] or [] + + # Detect duplicates against existing deps + existing_identities = _m._check_package_conflicts(current_deps) + + # Validate and canonicalize all package references + ( + valid_outcomes, + invalid_outcomes, + validated_packages, + _marketplace_provenance, + _apm_yml_entries, + dependencies_changed, + ) = _m._resolve_package_references( + packages, + current_deps, + existing_identities, + auth_resolver=auth_resolver, + logger=logger, + scope=scope, + allow_insecure=allow_insecure, + skill_subset=skill_subset, + ) + + outcome = _m._ValidationOutcome( + valid=valid_outcomes, + invalid=invalid_outcomes, + marketplace_provenance=_marketplace_provenance or None, + ) + + # Let the logger emit a summary and decide whether to continue + if logger: + should_continue = logger.validation_summary(outcome) + if not should_continue: + return [], outcome + + if not validated_packages: + if dry_run: + if logger: + logger.progress("No new packages to add") + # If all packages already exist in apm.yml, that's OK - we'll reinstall them + persist_dependency_list_if_changed = _m.persist_dependency_list_if_changed + persist_dependency_list_if_changed( + dependencies_changed=dependencies_changed, + data=data, + dep_section=dep_section, + current_deps=current_deps, + apm_yml_path=apm_yml_path, + apm_yml_filename=APM_YML_FILENAME, + logger=logger, + rich_error=_m._rich_error, + sys_exit=sys.exit, + ) + return [], outcome + + if dry_run: + if logger: + logger.progress(f"Dry run: Would add {len(validated_packages)} package(s) to apm.yml") + for pkg in validated_packages: + logger.verbose_detail(f" + {pkg}") + return validated_packages, outcome + + # Persist validated packages to apm.yml + _m._merge_packages_into_yml( + validated_packages, + _apm_yml_entries, + current_deps, + data, + dep_section, + apm_yml_path, + dev=dev, + logger=logger, + ) + + return validated_packages, outcome From 59a96012a229aace189657d129cdb98ad82cacf5 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 00:57:51 +0200 Subject: [PATCH 02/21] refactor(install): cut services.py complexity via bundle/options reuse (#1078) Bring install/services.py under Stage 2 thresholds and <800 lines (859->684) by reusing existing structure instead of moving code: - integrate_local_content: collapse six separate integrator params into the existing IntegratorBundle (15->10 args). - integrate_package_primitives: group optional knobs (skill_subset, scratch_root, policy) into a new frozen IntegrationOptions dataclass (14->12 args); extract per-kind/skill logging + cowork-warn helpers into new sibling install/services_integrate.py to clear C901/branch/stmt. - integrate_local_bundle: extract slug validation helper to clear C901. Monkeypatch seams preserved (both functions stay defined in services.py); public names and backward-compat aliases re-exported. Behaviour unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/install/drift.py | 9 +- src/apm_cli/install/phases/integrate.py | 16 +- src/apm_cli/install/services.py | 289 ++++-------------- src/apm_cli/install/services_integrate.py | 268 ++++++++++++++++ src/apm_cli/install/template.py | 23 +- .../test_install_services_orchestration.py | 35 ++- .../test_install_services_phase3w5.py | 35 ++- tests/unit/install/test_services_branches.py | 19 +- tests/unit/install/test_services_phase3.py | 19 +- tests/unit/test_local_content_install.py | 17 +- 10 files changed, 410 insertions(+), 320 deletions(-) create mode 100644 src/apm_cli/install/services_integrate.py diff --git a/src/apm_cli/install/drift.py b/src/apm_cli/install/drift.py index 5bd7ca309..2f375cf1b 100644 --- a/src/apm_cli/install/drift.py +++ b/src/apm_cli/install/drift.py @@ -407,7 +407,11 @@ def run_replay(config: ReplayConfig, logger: CheckLogger) -> Path: Surfaced verbatim when a locked dep is not in the cache. """ from apm_cli.deps.lockfile import _SELF_KEY, LockFile - from apm_cli.install.services import IntegratorBundle, integrate_package_primitives + from apm_cli.install.services import ( + IntegrationOptions, + IntegratorBundle, + integrate_package_primitives, + ) from apm_cli.integration.targets import resolve_targets from apm_cli.utils.diagnostics import DiagnosticCollector @@ -501,9 +505,8 @@ def run_replay(config: ReplayConfig, logger: CheckLogger) -> Path: package_name=dep_key, logger=None, scope=None, - skill_subset=None, ctx=None, - scratch_root=scratch_root, + options=IntegrationOptions(scratch_root=scratch_root), ) replayed_count += 1 finally: diff --git a/src/apm_cli/install/phases/integrate.py b/src/apm_cli/install/phases/integrate.py index f29ff9ef1..e41507378 100644 --- a/src/apm_cli/install/phases/integrate.py +++ b/src/apm_cli/install/phases/integrate.py @@ -22,7 +22,7 @@ from apm_cli.install.phases._redownload import _should_skip_redownload from apm_cli.install.phases._skip_logic import _compute_skip_download from apm_cli.install.phases.heal import run_heal_chain -from apm_cli.install.services import integrate_local_content +from apm_cli.install.services import IntegratorBundle, integrate_local_content from apm_cli.install.sources import make_dependency_source from apm_cli.install.template import run_integration_template @@ -335,12 +335,14 @@ def _integrate_root_project( _root_result = integrate_local_content( ctx.project_root, targets=ctx.targets, - prompt_integrator=ctx.integrators["prompt"], - agent_integrator=ctx.integrators["agent"], - skill_integrator=ctx.integrators["skill"], - instruction_integrator=ctx.integrators["instruction"], - command_integrator=ctx.integrators["command"], - hook_integrator=ctx.integrators["hook"], + integrators=IntegratorBundle( + prompt=ctx.integrators["prompt"], + agent=ctx.integrators["agent"], + skill=ctx.integrators["skill"], + instruction=ctx.integrators["instruction"], + command=ctx.integrators["command"], + hook=ctx.integrators["hook"], + ), force=ctx.force, managed_files=_local_managed, diagnostics=diagnostics, diff --git a/src/apm_cli/install/services.py b/src/apm_cli/install/services.py index 67c25e326..a02626863 100644 --- a/src/apm_cli/install/services.py +++ b/src/apm_cli/install/services.py @@ -19,7 +19,7 @@ from __future__ import annotations import builtins -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any @@ -30,6 +30,12 @@ from ..integration.base_integrator import BaseIntegrator from ..utils.diagnostics import DiagnosticCollector +from .services_integrate import ( + _log_per_kind_results, + _log_skill_result, + _validate_bundle_slug, + _warn_cowork_nonsupported, +) # CRITICAL: Shadow Python builtins that share names with Click commands so # ``set()`` / ``list()`` / ``dict()`` resolve to the builtins, not Click @@ -45,7 +51,7 @@ class IntegratorBundle: """Groups the six primitive integrators passed to ``integrate_package_primitives``. Using a bundle reduces the public argument count of - ``integrate_package_primitives`` below the PLR0913 threshold (≤15) while + ``integrate_package_primitives`` below the PLR0913 threshold while keeping the integrator objects strongly typed and discoverable. """ @@ -57,6 +63,30 @@ class IntegratorBundle: hook: BaseIntegrator +@dataclass(frozen=True) +class IntegrationOptions: + """Optional configuration knobs for ``integrate_package_primitives``. + + Grouping these advanced/optional parameters into a single object reduces + the public argument count of ``integrate_package_primitives`` without + losing expressiveness at call sites. + + Attributes + ---------- + skill_subset: + When set, only skills whose names appear in this tuple are deployed. + scratch_root: + When set, the caller is replaying integration into an isolated + directory (drift-replay). Must be a parent of *project_root*. + policy: + Enterprise security policy object forwarded to the skill integrator. + """ + + skill_subset: tuple | None = None + scratch_root: Path | None = None + policy: Any = field(default=None, compare=False) + + def _deployed_path_entry( target_path: Path, project_root: Path, @@ -127,10 +157,8 @@ def integrate_package_primitives( package_name: str = "", logger: InstallLogger | None = None, scope: InstallScope | None = None, - skill_subset: tuple | None = None, ctx: InstallContext | None = None, - scratch_root: Path | None = None, - policy: Any = None, + options: IntegrationOptions | None = None, ) -> dict: """Run the full integration pipeline for a single package. @@ -146,12 +174,20 @@ def integrate_package_primitives( (Amendment 6) is emitted once per install run for packages that contain non-skill primitives when the cowork target is active. + Advanced options (skill filtering, drift-replay scratch root, policy) + are grouped in the optional *options* parameter. + Returns a dict with integration counters and the list of deployed file paths. """ from apm_cli.integration.dispatch import get_dispatch_table from ..core.scope import InstallScope + _opts = options or IntegrationOptions() + skill_subset = _opts.skill_subset + scratch_root = _opts.scratch_root + policy = _opts.policy + _dispatch = get_dispatch_table() result = { "prompts": 0, @@ -175,84 +211,15 @@ def integrate_package_primitives( # the caller is replaying integration into an isolated directory. # We assert it exists and is NOT inside ``project_root`` to keep the # read-only contract of ``apm audit --check drift`` enforceable. - # The ``project_root`` passed in will already point at ``scratch_root`` - # (so all writes redirect via target.deploy_path), so this check is - # purely defense-in-depth against accidental misuse. # ------------------------------------------------------------------ if scratch_root is not None: from apm_cli.utils.path_security import ensure_path_within scratch_root = Path(scratch_root).resolve() - # ``project_root`` is the redirect target; it must equal scratch_root - # OR sit inside it. ensure_path_within(child, parent) raises if not. ensure_path_within(Path(project_root).resolve(), scratch_root) - # --- Amendment 6: cowork non-skill primitive warning (once per run) --- - _cowork_active = any(t.name == "copilot-cowork" for t in targets) - if _cowork_active and ctx is not None and not ctx.cowork_nonsupported_warned: - _apm_dir = Path(package_info.install_path) / ".apm" - _NON_SKILL_DIRS = { - "agents": "agents", - "prompts": "prompts", - "instructions": "instructions", - "hooks": "hooks", - # Commands live under ``.apm/prompts/`` and cannot be - # distinguished from general prompts at directory level - # without inspecting frontmatter. Omitted to avoid - # misleading duplicate warnings. - } - _found_types = [ - ptype - for ptype, subdir in _NON_SKILL_DIRS.items() - if (_apm_dir / subdir).is_dir() and any((_apm_dir / subdir).iterdir()) - ] - if _found_types: - _pkg_label = package_name or getattr(package_info, "name", "unknown") - _types_str = ", ".join(sorted(builtins.set(_found_types))) - _warn_msg = ( - f"copilot-cowork target only supports skills; " - f"non-skill primitives in {_pkg_label} " - f"({_types_str}) will not deploy to cowork" - ) - if logger: - logger.warning(_warn_msg, symbol="warning") - diagnostics.warn(_warn_msg) - ctx.cowork_nonsupported_warned = True - - def _log_integration(msg): - if logger: - logger.tree_item(msg) - - def _format_target_collapse(paths: list[str], verbose: bool) -> tuple[str, list[str]]: - """Apply the 1/2/3+ multi-target collapse rule. - - Returns a tuple ``(suffix, expansion_lines)``: - - * ``suffix`` -- the text appended after ``-> `` on the aggregate line. - * ``expansion_lines`` -- extra `` | -> `` lines emitted - AFTER the aggregate line when ``verbose`` is True. Empty list when - collapsed. - - The rule: - 1 target -> ```` - 2 targets -> ``, `` - 3+ -> ``N targets`` (verbose forces full enumeration) - """ - deduped: list[str] = [] - seen: set = builtins.set() - for p in paths: - if p not in seen: - seen.add(p) - deduped.append(p) - if verbose and len(deduped) >= 2: - return "", [f" | -> {p}" for p in deduped] - if len(deduped) == 0: - return "", [] - if len(deduped) == 1: - return deduped[0], [] - if len(deduped) == 2: - return f"{deduped[0]}, {deduped[1]}", [] - return f"{len(deduped)} targets", [] + # Amendment 6: cowork non-skill primitive warning (once per run). + _warn_cowork_nonsupported(targets, ctx, package_info, package_name, logger, diagnostics) _verbose = bool(getattr(ctx, "verbose", False)) if ctx is not None else False @@ -310,21 +277,11 @@ def _format_target_collapse(paths: list[str], verbose: bool) -> tuple[str, list[ # Treat anything that is not a real int as 0 so we never # invent fake adopt counts. _adopted = _adopted_attr if isinstance(_adopted_attr, int) else 0 - # Show the per-kind line whenever ANY work happened -- either - # a fresh integrate or a silent adopt of pre-existing - # byte-identical files. Adopt-only runs (e.g. re-install - # after lockfile wipe) used to print nothing here, which made - # the install summary look like a no-op even though the - # lockfile WAS being repopulated. Surfacing adopt counts - # restores operator trust in CI. + # Show the per-kind line whenever ANY work happened. if _int_result.files_integrated <= 0 and _adopted <= 0: continue _agg_files += _int_result.files_integrated _agg_adopted += _adopted - # Only count fresh integrations against the package counter - # so totals like "3 prompts integrated" stay truthful; - # adopted files are surfaced separately in the per-kind - # line. result[_entry.counter_key] += _int_result.files_integrated _effective_root = _mapping.deploy_root or _target.root_dir _deploy_dir = ( @@ -333,9 +290,6 @@ def _format_target_collapse(paths: list[str], verbose: bool) -> tuple[str, list[ else f"{_effective_root}/" ) if _prim_name == "instructions" and _mapping.output_compare: - # Rule-dir formats (cursor/claude/windsurf) are the - # output_compare set; derive the label from the same flag so a - # new rule format needs no edit here. _label = "rule(s)" elif _prim_name == "instructions": _label = "instruction(s)" @@ -355,40 +309,7 @@ def _format_target_collapse(paths: list[str], verbose: bool) -> tuple[str, list[ "paths": _agg_paths, } - # Emit aggregated per-kind lines in dispatch order so output is stable. - for _prim_name in _dispatch: - if _prim_name not in _per_kind: - continue - _info = _per_kind[_prim_name] - _suffix, _expansion = _format_target_collapse(_info["paths"], _verbose) - # Build the verb + count phrase. When at least one file was - # freshly integrated we lead with "N X integrated"; pure-adopt - # runs (no fresh writes) lead with "N X adopted" so the line - # still appears and the count is truthful. - _files = _info["files"] - _adopted = _info["adopted"] - if _files > 0: - _verb_phrase = f"{_files} {_info['label']} integrated" - if _adopted > 0: - _verb_phrase = f"{_verb_phrase} ({_adopted} adopted)" - else: - _verb_phrase = f"{_adopted} {_info['label']} adopted" - if _expansion: - _log_integration(f" |-- {_verb_phrase}:") - for line in _expansion: - _log_integration(line) - else: - _log_integration(f" |-- {_verb_phrase} -> {_suffix}") - # Emit a one-line "next step" hint when copilot-app workflows - # were integrated: the row lands enabled=0 and the user has to - # flip the toggle in the Copilot App's Workflows tab before the - # schedule fires. This is the "failure mode is the product" - # surface for project-scope ride-along installs where a - # contributor may not have read the integration doc. - if any(p.startswith("copilot-app/") for p in _info["paths"]) and _info["files"] > 0: - _log_integration( - " |-- workflows arrive disabled; enable from the Copilot App's Workflows tab" - ) + _log_per_kind_results(_per_kind, _dispatch, _verbose, logger) skill_result = integrators.skill.integrate_package_skill( package_info, @@ -401,63 +322,15 @@ def _format_target_collapse(paths: list[str], verbose: bool) -> tuple[str, list[ scope=scope, policy=policy, ) - _skill_target_dirs: set = builtins.set() - for tp in skill_result.target_paths: - try: - rel = tp.relative_to(project_root) - if rel.parts: - _skill_target_dirs.add(rel.parts[0]) - except ValueError: - # Dynamic-root target (copilot-cowork) -- path is outside project tree. - _skill_target_dirs.add("copilot-cowork") - _skill_target_paths = [f"{d}/skills/" for d in sorted(_skill_target_dirs)] - if not _skill_target_paths: - _skill_target_paths = ["skills/"] - _skill_suffix, _skill_expansion = _format_target_collapse(_skill_target_paths, _verbose) - if skill_result.skill_created: - result["skills"] += 1 - if _skill_expansion: - _log_integration(" |-- Skill integrated:") - for line in _skill_expansion: - _log_integration(line) - else: - _log_integration(f" |-- Skill integrated -> {_skill_suffix}") - if skill_result.sub_skills_promoted > 0: - result["sub_skills"] += skill_result.sub_skills_promoted - if _skill_expansion: - _log_integration(f" |-- {skill_result.sub_skills_promoted} skill(s) integrated:") - for line in _skill_expansion: - _log_integration(line) - else: - _log_integration( - f" |-- {skill_result.sub_skills_promoted} skill(s) integrated -> {_skill_suffix}" - ) - if skill_result.bin_deployed > 0: - _log_integration( - f" |-- {skill_result.bin_deployed} executable(s) deployed to " - f"Claude Code's PATH -> {_skill_suffix} (invoked without confirmation)" - ) - _log_integration(" |-- run /reload-plugins or restart Claude Code to activate") - elif skill_result.bin_skipped_reason == "project_scope": - _log_integration( - " |-- plugin ships executables; re-run with -g (global) to deploy them to Claude Code" - ) - elif skill_result.bin_skipped_reason == "no_claude_target": - _log_integration( - " |-- plugin ships executables; no active Claude Code skills target to receive them" - ) - for tp in skill_result.target_paths: - deployed.append(_deployed_path_entry(tp, project_root, targets)) + _log_skill_result(skill_result, result, project_root, targets, _verbose, logger) - # A3: warm-cache visibility. If nothing was integrated for any kind AND - # no skill was created, emit one annotation so the user knows the dep - # was evaluated (the [+] header above already carries the SHA). + # A3: warm-cache visibility. _total_integrated = sum(_info["files"] for _info in _per_kind.values()) _total_integrated += int(skill_result.skill_created) _total_integrated += int(skill_result.sub_skills_promoted) _total_integrated += int(skill_result.bin_deployed) - if _total_integrated == 0: - _log_integration(" |-- (files unchanged)") + if _total_integrated == 0 and logger: + logger.tree_item(" |-- (files unchanged)") return result @@ -466,12 +339,7 @@ def integrate_local_content( project_root: Path, *, targets: Any, - prompt_integrator: Any, - agent_integrator: Any, - skill_integrator: Any, - instruction_integrator: Any, - command_integrator: Any, - hook_integrator: Any, + integrators: IntegratorBundle, force: bool, managed_files: Any, diagnostics: DiagnosticCollector, @@ -494,6 +362,7 @@ def integrate_local_content( project_root: Deploy root -- where ``.claude/``, ``.codex/``, etc. are written. Also used to compute relative paths for tracking deployed files. + integrators: Bundle of the six primitive integrators. source_root: Where to discover the synthetic local package's ``.apm/`` content. Defaults to ``project_root`` when not provided. When ``apm install --root`` is in play, @@ -524,14 +393,7 @@ def integrate_local_content( local_info, project_root, targets=targets, - integrators=IntegratorBundle( - prompt=prompt_integrator, - agent=agent_integrator, - skill=skill_integrator, - instruction=instruction_integrator, - command=command_integrator, - hook=hook_integrator, - ), + integrators=integrators, force=force, managed_files=managed_files, diagnostics=diagnostics, @@ -712,41 +574,11 @@ def integrate_local_bundle( # is a no-op for these clients. _first_seg = rel.split("/", 1)[0] if "/" in rel else "" if _first_seg == "instructions" and "instructions" not in (target.primitives or {}): - # Slug must be safe for filesystem path construction -- - # ``package_id`` originates from untrusted ``plugin.json``. - # Enforce a strict character whitelist documented in - # docs/src/content/docs/enterprise/security.md so - # forward slashes, null bytes, spaces, and other - # filesystem-significant characters are rejected before - # any path construction or resolution. + # Slug must be safe for filesystem path construction. + # CR1.5 (#1217 review): _validate_bundle_slug enforces the + # ASCII-only [A-Za-z0-9._-] whitelist and rejects traversal. _slug_str = str(slug) - # CR1.5 (#1217 review): use ASCII-only validation, not - # ``str.isalnum`` (which accepts Unicode letters/digits - # like accented or non-Latin chars and would slip past - # the documented [A-Za-z0-9._-] whitelist). - _ALLOWED = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-") - _slug_ok = ( - bool(_slug_str) - and all(c in _ALLOWED for c in _slug_str) - and not _slug_str.startswith(".") - and not _slug_str.endswith(".") - and ".." not in _slug_str - ) - if not _slug_ok: - if logger is not None: - logger.warning( - f"Skipped instruction staging for unsafe slug {_slug_str!r}: " - "slug must match [A-Za-z0-9._-]+ with no leading/trailing dot, no '..'" - ) - skipped += 1 - continue - try: - validate_path_segments(_slug_str, context="bundle slug") - except PathTraversalError as exc: - if logger is not None: - logger.warning( - f"Skipped instruction staging for unsafe slug {_slug_str!r}: {exc}" - ) + if not _validate_bundle_slug(_slug_str, logger): skipped += 1 continue stage_root = project_root / "apm_modules" / slug / ".apm" / "instructions" @@ -757,14 +589,7 @@ def integrate_local_bundle( logger.warning(f"Skipped unsafe stage root for {slug!r}: {exc}") skipped += 1 continue - # PR #1217 review: preserve nested subdirs under - # ``instructions/`` so two files with the same basename - # (e.g. ``instructions/a/x.md`` and - # ``instructions/b/x.md``) do not collide at the staged - # location. ``rel`` already starts with - # ``instructions/`` so we strip that prefix before - # joining under the stage root (which itself ends in - # ``.apm/instructions``). + # Preserve nested subdirs under ``instructions/`` (PR #1217). _rel_under_instructions = rel.split("/", 1)[1] if "/" in rel else Path(rel).name dest = stage_root / _rel_under_instructions deploy_root = stage_root diff --git a/src/apm_cli/install/services_integrate.py b/src/apm_cli/install/services_integrate.py new file mode 100644 index 000000000..9bc6fe1ae --- /dev/null +++ b/src/apm_cli/install/services_integrate.py @@ -0,0 +1,268 @@ +"""Private helpers extracted from services.py to keep each function under threshold. + +All symbols here are module-private (single underscore prefix) and are only +called from ``apm_cli.install.services``. They are NOT part of the public +API and MUST NOT be imported from outside this package. +""" + +from __future__ import annotations + +import builtins +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from ..core.command_logger import InstallLogger + from ..utils.diagnostics import DiagnosticCollector + + +# Shadow builtins shadowed at the top of services.py for the same reason. +set = builtins.set +list = builtins.list +dict = builtins.dict + + +# --------------------------------------------------------------------------- +# _format_target_collapse +# --------------------------------------------------------------------------- + + +def _format_target_collapse(paths: list[str], verbose: bool) -> tuple[str, list[str]]: + """Apply the 1/2/3+ multi-target collapse rule. + + Returns a tuple ``(suffix, expansion_lines)``: + + * ``suffix`` -- the text appended after ``-> `` on the aggregate line. + * ``expansion_lines`` -- extra `` | -> `` lines emitted + AFTER the aggregate line when ``verbose`` is True. Empty list when + collapsed. + + The rule: + 1 target -> ```` + 2 targets -> ``, `` + 3+ -> ``N targets`` (verbose forces full enumeration) + """ + deduped: list[str] = [] + seen: set = builtins.set() + for p in paths: + if p not in seen: + seen.add(p) + deduped.append(p) + if verbose and len(deduped) >= 2: + return "", [f" | -> {p}" for p in deduped] + if len(deduped) == 0: + return "", [] + if len(deduped) == 1: + return deduped[0], [] + if len(deduped) == 2: + return f"{deduped[0]}, {deduped[1]}", [] + return f"{len(deduped)} targets", [] + + +# --------------------------------------------------------------------------- +# _warn_cowork_nonsupported +# --------------------------------------------------------------------------- + + +def _warn_cowork_nonsupported( + targets: Any, + ctx: Any, + package_info: Any, + package_name: str, + logger: InstallLogger | None, + diagnostics: DiagnosticCollector, +) -> None: + """Emit the Amendment-6 cowork non-skill primitive warning (once per run). + + Checks whether the copilot-cowork target is active and whether the package + contains any non-skill primitives. When both conditions hold the warning + is logged via *logger* and recorded in *diagnostics*, then the + ``ctx.cowork_nonsupported_warned`` flag is set to prevent duplicate lines. + """ + import builtins as _builtins + + _cowork_active = any(t.name == "copilot-cowork" for t in targets) + if not (_cowork_active and ctx is not None and not ctx.cowork_nonsupported_warned): + return + _apm_dir = Path(package_info.install_path) / ".apm" + _NON_SKILL_DIRS = { + "agents": "agents", + "prompts": "prompts", + "instructions": "instructions", + "hooks": "hooks", + } + _found_types = [ + ptype + for ptype, subdir in _NON_SKILL_DIRS.items() + if (_apm_dir / subdir).is_dir() and any((_apm_dir / subdir).iterdir()) + ] + if not _found_types: + return + _pkg_label = package_name or getattr(package_info, "name", "unknown") + _types_str = ", ".join(sorted(_builtins.set(_found_types))) + _warn_msg = ( + f"copilot-cowork target only supports skills; " + f"non-skill primitives in {_pkg_label} " + f"({_types_str}) will not deploy to cowork" + ) + if logger: + logger.warning(_warn_msg, symbol="warning") + diagnostics.warn(_warn_msg) + ctx.cowork_nonsupported_warned = True + + +# --------------------------------------------------------------------------- +# _log_per_kind_results +# --------------------------------------------------------------------------- + + +def _log_per_kind_results( + per_kind: dict[str, dict[str, Any]], + dispatch: dict, + verbose: bool, + logger: InstallLogger | None, +) -> None: + """Emit one aggregated log line per primitive kind in dispatch order. + + ``per_kind`` maps primitive name to a sub-dict with keys + ``files``, ``adopted``, ``label``, and ``paths``. Kinds absent from + ``per_kind`` are silently skipped. + """ + for _prim_name in dispatch: + if _prim_name not in per_kind: + continue + _info = per_kind[_prim_name] + _suffix, _expansion = _format_target_collapse(_info["paths"], verbose) + _files = _info["files"] + _adopted = _info["adopted"] + if _files > 0: + _verb_phrase = f"{_files} {_info['label']} integrated" + if _adopted > 0: + _verb_phrase = f"{_verb_phrase} ({_adopted} adopted)" + else: + _verb_phrase = f"{_adopted} {_info['label']} adopted" + if logger is None: + continue + if _expansion: + logger.tree_item(f" |-- {_verb_phrase}:") + for line in _expansion: + logger.tree_item(line) + else: + logger.tree_item(f" |-- {_verb_phrase} -> {_suffix}") + if any(p.startswith("copilot-app/") for p in _info["paths"]) and _files > 0: + logger.tree_item( + " |-- workflows arrive disabled; enable from the Copilot App's Workflows tab" + ) + + +# --------------------------------------------------------------------------- +# _log_skill_result +# --------------------------------------------------------------------------- + + +def _log_skill_result( + skill_result: Any, + result: dict, + project_root: Path, + targets: Any, + verbose: bool, + logger: InstallLogger | None, +) -> None: + """Process skill integration result: update counters and emit log lines. + + Mutates *result* in-place (``skills``, ``sub_skills``, ``deployed_files`` + keys) and emits tree-item log lines via *logger*. + """ + from apm_cli.install.services import _deployed_path_entry + + _skill_target_dirs: set = builtins.set() + for tp in skill_result.target_paths: + try: + rel = tp.relative_to(project_root) + if rel.parts: + _skill_target_dirs.add(rel.parts[0]) + except ValueError: + _skill_target_dirs.add("copilot-cowork") + _skill_target_paths = [f"{d}/skills/" for d in sorted(_skill_target_dirs)] + if not _skill_target_paths: + _skill_target_paths = ["skills/"] + _skill_suffix, _skill_expansion = _format_target_collapse(_skill_target_paths, verbose) + + if skill_result.skill_created: + result["skills"] += 1 + if logger: + if _skill_expansion: + logger.tree_item(" |-- Skill integrated:") + for line in _skill_expansion: + logger.tree_item(line) + else: + logger.tree_item(f" |-- Skill integrated -> {_skill_suffix}") + + if skill_result.sub_skills_promoted > 0: + result["sub_skills"] += skill_result.sub_skills_promoted + if logger: + if _skill_expansion: + logger.tree_item(f" |-- {skill_result.sub_skills_promoted} skill(s) integrated:") + for line in _skill_expansion: + logger.tree_item(line) + else: + logger.tree_item( + f" |-- {skill_result.sub_skills_promoted} skill(s) integrated" + f" -> {_skill_suffix}" + ) + + if skill_result.bin_deployed > 0 and logger: + logger.tree_item( + f" |-- {skill_result.bin_deployed} executable(s) deployed to " + f"Claude Code's PATH -> {_skill_suffix} (invoked without confirmation)" + ) + logger.tree_item(" |-- run /reload-plugins or restart Claude Code to activate") + elif skill_result.bin_skipped_reason == "project_scope" and logger: + logger.tree_item( + " |-- plugin ships executables; re-run with -g (global) to deploy them to Claude Code" + ) + elif skill_result.bin_skipped_reason == "no_claude_target" and logger: + logger.tree_item( + " |-- plugin ships executables; no active Claude Code skills target to receive them" + ) + + for tp in skill_result.target_paths: + result["deployed_files"].append(_deployed_path_entry(tp, project_root, targets)) + + +# --------------------------------------------------------------------------- +# _validate_bundle_slug +# --------------------------------------------------------------------------- + + +def _validate_bundle_slug(slug_str: str, logger: InstallLogger | None) -> bool: + """Return True if *slug_str* passes the bundle-slug whitelist check. + + The allowed character set is ``[A-Za-z0-9._-]+`` with no leading or + trailing dot and no ``..`` sequence. Invalid slugs are logged as a + warning and cause the caller to skip the instruction-staging step. + """ + from apm_cli.utils.path_security import PathTraversalError, validate_path_segments + + _ALLOWED = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-") + _slug_ok = ( + bool(slug_str) + and all(c in _ALLOWED for c in slug_str) + and not slug_str.startswith(".") + and not slug_str.endswith(".") + and ".." not in slug_str + ) + if not _slug_ok: + if logger is not None: + logger.warning( + f"Skipped instruction staging for unsafe slug {slug_str!r}: " + "slug must match [A-Za-z0-9._-]+ with no leading/trailing dot, no '..'" + ) + return False + try: + validate_path_segments(slug_str, context="bundle slug") + except PathTraversalError as exc: + if logger is not None: + logger.warning(f"Skipped instruction staging for unsafe slug {slug_str!r}: {exc}") + return False + return True diff --git a/src/apm_cli/install/template.py b/src/apm_cli/install/template.py index 82d189399..d685b126f 100644 --- a/src/apm_cli/install/template.py +++ b/src/apm_cli/install/template.py @@ -14,7 +14,11 @@ from __future__ import annotations from apm_cli.install.helpers.security_scan import _pre_deploy_security_scan -from apm_cli.install.services import IntegratorBundle, integrate_package_primitives +from apm_cli.install.services import ( + IntegrationOptions, + IntegratorBundle, + integrate_package_primitives, +) from apm_cli.install.sources import DependencySource, Materialization @@ -89,17 +93,14 @@ def _integrate_materialization( package_name=dep_key, logger=logger, scope=ctx.scope, - # Per-package effective subset: CLI --skill overrides per-entry - # apm.yml skills:. When CLI is absent (bare reinstall), fall back - # to the dep_ref's persisted skill_subset. - # When CLI explicitly provided (even --skill '*'), use ctx value - # (which is None for '*' = install all). - skill_subset=( - ctx.skill_subset - if ctx.skill_subset_from_cli - else (tuple(dep_ref.skill_subset) if dep_ref.skill_subset else None) - ), ctx=ctx, + options=IntegrationOptions( + skill_subset=( + ctx.skill_subset + if ctx.skill_subset_from_cli + else (tuple(dep_ref.skill_subset) if dep_ref.skill_subset else None) + ), + ), ) mutation_keys = ( "prompts", diff --git a/tests/integration/test_install_services_orchestration.py b/tests/integration/test_install_services_orchestration.py index 80fa9a0aa..0ecc234bb 100644 --- a/tests/integration/test_install_services_orchestration.py +++ b/tests/integration/test_install_services_orchestration.py @@ -7,7 +7,7 @@ import pytest from apm_cli.install import services -from apm_cli.install.services import IntegratorBundle +from apm_cli.install.services import IntegrationOptions, IntegratorBundle from apm_cli.integration.base_integrator import IntegrationResult @@ -171,9 +171,8 @@ def invoke_integrate( package_name=package_name, logger=logger, scope=None, - skill_subset=skill_subset, ctx=ctx, - scratch_root=scratch_root, + options=IntegrationOptions(skill_subset=skill_subset, scratch_root=scratch_root), integrators=IntegratorBundle( prompt=integrators["prompt_integrator"], agent=integrators["agent_integrator"], @@ -937,12 +936,14 @@ def test_calls_integrate_package_primitives_via_module_lookup(self, tmp_path: Pa result = services.integrate_local_content( tmp_path, targets=[], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock(), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock(), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), + ), force=False, managed_files=set(), diagnostics=MagicMock(), @@ -959,14 +960,16 @@ def test_missing_apm_dir_means_no_local_integration(self, tmp_path: Path) -> Non result = services.integrate_local_content( tmp_path, targets=[make_target(primitives={})], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock( - integrate_package_skill=MagicMock(return_value=make_skill_result()) + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock( + integrate_package_skill=MagicMock(return_value=make_skill_result()) + ), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), ), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), force=False, managed_files=set(), diagnostics=MagicMock(), diff --git a/tests/integration/test_install_services_phase3w5.py b/tests/integration/test_install_services_phase3w5.py index 80fa9a0aa..0ecc234bb 100644 --- a/tests/integration/test_install_services_phase3w5.py +++ b/tests/integration/test_install_services_phase3w5.py @@ -7,7 +7,7 @@ import pytest from apm_cli.install import services -from apm_cli.install.services import IntegratorBundle +from apm_cli.install.services import IntegrationOptions, IntegratorBundle from apm_cli.integration.base_integrator import IntegrationResult @@ -171,9 +171,8 @@ def invoke_integrate( package_name=package_name, logger=logger, scope=None, - skill_subset=skill_subset, ctx=ctx, - scratch_root=scratch_root, + options=IntegrationOptions(skill_subset=skill_subset, scratch_root=scratch_root), integrators=IntegratorBundle( prompt=integrators["prompt_integrator"], agent=integrators["agent_integrator"], @@ -937,12 +936,14 @@ def test_calls_integrate_package_primitives_via_module_lookup(self, tmp_path: Pa result = services.integrate_local_content( tmp_path, targets=[], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock(), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock(), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), + ), force=False, managed_files=set(), diagnostics=MagicMock(), @@ -959,14 +960,16 @@ def test_missing_apm_dir_means_no_local_integration(self, tmp_path: Path) -> Non result = services.integrate_local_content( tmp_path, targets=[make_target(primitives={})], - prompt_integrator=MagicMock(), - agent_integrator=MagicMock(), - skill_integrator=MagicMock( - integrate_package_skill=MagicMock(return_value=make_skill_result()) + integrators=IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock( + integrate_package_skill=MagicMock(return_value=make_skill_result()) + ), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), ), - instruction_integrator=MagicMock(), - command_integrator=MagicMock(), - hook_integrator=MagicMock(), force=False, managed_files=set(), diagnostics=MagicMock(), diff --git a/tests/unit/install/test_services_branches.py b/tests/unit/install/test_services_branches.py index 40005cc2d..13d34ef04 100644 --- a/tests/unit/install/test_services_branches.py +++ b/tests/unit/install/test_services_branches.py @@ -23,6 +23,7 @@ import pytest from apm_cli.install.services import ( + IntegrationOptions, IntegratorBundle, _deployed_path_entry, _integrate_local_content, @@ -214,7 +215,7 @@ def test_scratch_root_inside_itself_is_valid(self, tmp_path: Path) -> None: integrators=_to_bundle(integrators), force=False, managed_files=None, - scratch_root=scratch, + options=IntegrationOptions(scratch_root=scratch), ) assert isinstance(result, dict) @@ -245,7 +246,7 @@ def test_scratch_root_outside_raises(self, tmp_path: Path) -> None: integrators=_to_bundle(integrators), force=False, managed_files=None, - scratch_root=scratch, + options=IntegrationOptions(scratch_root=scratch), ) @@ -459,12 +460,7 @@ def test_delegates_to_integrate_package_primitives(self, tmp_path: Path) -> None tmp_path, targets=[KNOWN_TARGETS["copilot"]], diagnostics=MagicMock(), - prompt_integrator=integrators["prompt_integrator"], - agent_integrator=integrators["agent_integrator"], - skill_integrator=integrators["skill_integrator"], - instruction_integrator=integrators["instruction_integrator"], - command_integrator=integrators["command_integrator"], - hook_integrator=integrators["hook_integrator"], + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -492,12 +488,7 @@ def _capture(pkg_info, *args, **kwargs): tmp_path, targets=[KNOWN_TARGETS["copilot"]], diagnostics=MagicMock(), - prompt_integrator=integrators["prompt_integrator"], - agent_integrator=integrators["agent_integrator"], - skill_integrator=integrators["skill_integrator"], - instruction_integrator=integrators["instruction_integrator"], - command_integrator=integrators["command_integrator"], - hook_integrator=integrators["hook_integrator"], + integrators=_to_bundle(integrators), force=False, managed_files=None, ) diff --git a/tests/unit/install/test_services_phase3.py b/tests/unit/install/test_services_phase3.py index 735d2b68f..c08639693 100644 --- a/tests/unit/install/test_services_phase3.py +++ b/tests/unit/install/test_services_phase3.py @@ -23,6 +23,7 @@ import pytest from apm_cli.install.services import ( + IntegrationOptions, IntegratorBundle, _deployed_path_entry, _integrate_local_content, @@ -214,7 +215,7 @@ def test_scratch_root_inside_itself_is_valid(self, tmp_path: Path) -> None: integrators=_to_bundle(integrators), force=False, managed_files=None, - scratch_root=scratch, + options=IntegrationOptions(scratch_root=scratch), ) assert isinstance(result, dict) @@ -245,7 +246,7 @@ def test_scratch_root_outside_raises(self, tmp_path: Path) -> None: integrators=_to_bundle(integrators), force=False, managed_files=None, - scratch_root=scratch, + options=IntegrationOptions(scratch_root=scratch), ) @@ -459,12 +460,7 @@ def test_delegates_to_integrate_package_primitives(self, tmp_path: Path) -> None tmp_path, targets=[KNOWN_TARGETS["copilot"]], diagnostics=MagicMock(), - prompt_integrator=integrators["prompt_integrator"], - agent_integrator=integrators["agent_integrator"], - skill_integrator=integrators["skill_integrator"], - instruction_integrator=integrators["instruction_integrator"], - command_integrator=integrators["command_integrator"], - hook_integrator=integrators["hook_integrator"], + integrators=_to_bundle(integrators), force=False, managed_files=None, ) @@ -492,12 +488,7 @@ def _capture(pkg_info, *args, **kwargs): tmp_path, targets=[KNOWN_TARGETS["copilot"]], diagnostics=MagicMock(), - prompt_integrator=integrators["prompt_integrator"], - agent_integrator=integrators["agent_integrator"], - skill_integrator=integrators["skill_integrator"], - instruction_integrator=integrators["instruction_integrator"], - command_integrator=integrators["command_integrator"], - hook_integrator=integrators["hook_integrator"], + integrators=_to_bundle(integrators), force=False, managed_files=None, ) diff --git a/tests/unit/test_local_content_install.py b/tests/unit/test_local_content_install.py index 05972952e..08b490f7b 100644 --- a/tests/unit/test_local_content_install.py +++ b/tests/unit/test_local_content_install.py @@ -12,6 +12,7 @@ from apm_cli.commands.install import _has_local_apm_content, _integrate_local_content from apm_cli.deps.lockfile import LockFile +from apm_cli.install.services import IntegratorBundle # --------------------------------------------------------------------------- # Helpers @@ -19,15 +20,17 @@ def _make_integrators(): - """Return a dict of MagicMock integrators for _integrate_local_content.""" + """Return a dict of kwargs for _integrate_local_content using IntegratorBundle.""" return { "targets": [MagicMock()], - "prompt_integrator": MagicMock(), - "agent_integrator": MagicMock(), - "skill_integrator": MagicMock(), - "instruction_integrator": MagicMock(), - "command_integrator": MagicMock(), - "hook_integrator": MagicMock(), + "integrators": IntegratorBundle( + prompt=MagicMock(), + agent=MagicMock(), + skill=MagicMock(), + instruction=MagicMock(), + command=MagicMock(), + hook=MagicMock(), + ), "force": False, "managed_files": set(), "diagnostics": MagicMock(), From 1aafdc51ddfe2883c394006768e330245bef4257 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 01:19:53 +0200 Subject: [PATCH 03/21] refactor(install): clear PLR0913 in mcp/lsp via shared MCPRequestSpec + LSPIntegrationContext (#1078) Introduce MCPRequestSpec frozen dataclass (mcp/spec.py) grouping the six shared MCP identity fields; run_mcp_install drops 15->10 args and validate_mcp_conflicts 14->9 args by consuming it. Add LSPIntegrationContext to lsp/integration.py so run_lsp_integration drops 15->8 args. Call sites in mcp_handler.py, commands/install.py, apm_packages.py updated; unit/integration call sites updated to match. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/commands/install.py | 15 +++-- src/apm_cli/install/apm_packages.py | 10 +-- src/apm_cli/install/lsp/__init__.py | 4 +- src/apm_cli/install/lsp/integration.py | 45 ++++++++----- src/apm_cli/install/mcp/command.py | 15 +++-- src/apm_cli/install/mcp/conflicts.py | 16 +++-- src/apm_cli/install/mcp/spec.py | 18 +++++ src/apm_cli/install/mcp_handler.py | 15 +++-- tests/integration/test_coverage_phase4.py | 78 +++++++++++++--------- tests/unit/install/test_lsp_integration.py | 23 +++---- tests/unit/install/test_mcp_conflicts.py | 21 +++--- 11 files changed, 160 insertions(+), 100 deletions(-) create mode 100644 src/apm_cli/install/mcp/spec.py diff --git a/src/apm_cli/commands/install.py b/src/apm_cli/commands/install.py index 54738fafe..5ed5a6917 100644 --- a/src/apm_cli/commands/install.py +++ b/src/apm_cli/commands/install.py @@ -102,6 +102,7 @@ validate_mcp_dry_run_entry as _validate_mcp_dry_run_entry, # noqa: F401 validate_registry_url as _validate_registry_url, ) +from ..install.mcp.spec import MCPRequestSpec as _MCPRequestSpec from ..utils.console import ( # noqa: F401 _rich_echo, _rich_error, @@ -587,20 +588,22 @@ def install( # noqa: PLR0913 validated_registry_url = _validate_registry_url(registry_url) _validate_mcp_conflicts( - mcp_name=mcp_name, + spec=_MCPRequestSpec( + mcp_name=mcp_name, + transport=transport, + url=url, + mcp_version=mcp_version, + command_argv=command_argv, + registry_url=validated_registry_url, + ), packages=packages, pre_dash_packages=pre_dash_packages, - transport=transport, - url=url, env=env_pairs, headers=header_pairs, - mcp_version=mcp_version, - command_argv=command_argv, global_=global_, only=only, update=update, any_transport_flag=use_ssh or use_https or allow_protocol_fallback, - registry_url=validated_registry_url, ) # Normalize --skill: '*' means all (same as absent). Reject with --mcp. diff --git a/src/apm_cli/install/apm_packages.py b/src/apm_cli/install/apm_packages.py index eb3683d29..0ebda7bbc 100644 --- a/src/apm_cli/install/apm_packages.py +++ b/src/apm_cli/install/apm_packages.py @@ -340,7 +340,7 @@ def _install_apm_packages(ctx, outcome): # ------------------------------------------------------------------------- # LSP integration (extracted to install/lsp/integration.py) # ------------------------------------------------------------------------- - from apm_cli.install.lsp import run_lsp_integration + from apm_cli.install.lsp import LSPIntegrationContext, run_lsp_integration lsp_count = run_lsp_integration( apm_package=apm_package, @@ -348,11 +348,13 @@ def _install_apm_packages(ctx, outcome): lock_path=_lock_path, existing_lock=_existing_lock, project_root=ctx.project_root, - user_scope=(ctx.scope is InstallScope.USER), - should_install=should_install_lsp, logger=logger, diagnostics=apm_diagnostics, - target_context=(mcp_apm_config, ctx.target, ctx.scope), + ctx=LSPIntegrationContext( + user_scope=(ctx.scope is InstallScope.USER), + should_install=should_install_lsp, + target_context=(mcp_apm_config, ctx.target, ctx.scope), + ), ) # Local .apm/ content integration is now handled inside the diff --git a/src/apm_cli/install/lsp/__init__.py b/src/apm_cli/install/lsp/__init__.py index 2aa884d2c..6e9bc2ab6 100644 --- a/src/apm_cli/install/lsp/__init__.py +++ b/src/apm_cli/install/lsp/__init__.py @@ -1,5 +1,5 @@ """LSP integration for APM install pipeline.""" -from .integration import run_lsp_integration +from .integration import LSPIntegrationContext, run_lsp_integration -__all__ = ["run_lsp_integration"] +__all__ = ["LSPIntegrationContext", "run_lsp_integration"] diff --git a/src/apm_cli/install/lsp/integration.py b/src/apm_cli/install/lsp/integration.py index f93050641..011adb9af 100644 --- a/src/apm_cli/install/lsp/integration.py +++ b/src/apm_cli/install/lsp/integration.py @@ -4,6 +4,7 @@ """ import builtins +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING @@ -12,6 +13,20 @@ from apm_cli.models.apm_package import APMPackage +@dataclass(frozen=True) +class LSPIntegrationContext: + """Lower-frequency context fields for :func:`run_lsp_integration`.""" + + user_scope: bool + should_install: bool + runtime: str | None = None + exclude: str | None = None + apm_config: dict | None = None + explicit_target: str | list[str] | None = None + scope: object = None + target_context: tuple | None = None + + def run_lsp_integration( *, apm_package: "APMPackage", @@ -19,16 +34,9 @@ def run_lsp_integration( lock_path: Path, existing_lock: "LockFile | None", project_root: Path, - user_scope: bool, - should_install: bool, logger, diagnostics=None, - runtime: str | None = None, - exclude: str | None = None, - apm_config: dict | None = None, - explicit_target: str | list[str] | None = None, - scope=None, - target_context: tuple[dict | None, str | list[str] | None, object] | None = None, + ctx: LSPIntegrationContext, ) -> int: """Run LSP server integration after APM package installation. @@ -46,23 +54,26 @@ def run_lsp_integration( lock_path: Path to apm.lock.yaml. existing_lock: Previously loaded lockfile (for old LSP state). project_root: Project root directory. - user_scope: If True, write to user-scope runtime config paths. - should_install: Whether LSP integration should run (same gate as MCP). logger: Install logger instance. diagnostics: Optional DiagnosticCollector. - runtime: Optional runtime override. - exclude: Optional runtime exclusion. - apm_config: Parsed apm.yml target metadata for project-scope gating. - explicit_target: Explicit target selected by CLI or manifest. - scope: Optional InstallScope for user/project filtering. - target_context: Compact `(apm_config, explicit_target, scope)` tuple - used by the install command to keep entry-point glue small. + ctx: Lower-frequency context fields (user_scope, should_install, + runtime, exclude, apm_config, explicit_target, scope, + target_context). Returns: Number of LSP servers configured. """ from apm_cli.integration.lsp_integrator import LSPIntegrator + user_scope = ctx.user_scope + should_install = ctx.should_install + runtime = ctx.runtime + exclude = ctx.exclude + apm_config = ctx.apm_config + explicit_target = ctx.explicit_target + scope = ctx.scope + target_context = ctx.target_context + lsp_deps = apm_package.get_lsp_dependencies() # Capture old LSP servers from lockfile diff --git a/src/apm_cli/install/mcp/command.py b/src/apm_cli/install/mcp/command.py index aaf8964a9..c04670a49 100644 --- a/src/apm_cli/install/mcp/command.py +++ b/src/apm_cli/install/mcp/command.py @@ -18,6 +18,7 @@ from .args import parse_env_pairs, parse_header_pairs from .entry import build_mcp_entry from .registry import registry_env_override +from .spec import MCPRequestSpec from .warnings import warn_shell_metachars, warn_ssrf_url from .writer import add_mcp_to_apm_yml @@ -37,13 +38,9 @@ def run_mcp_install( *, - mcp_name: str, - transport: str | None, - url: str | None, + spec: MCPRequestSpec, env_pairs: Sequence[str] | None, header_pairs: Sequence[str] | None, - mcp_version: str | None, - command_argv: Sequence[str] | None, dev: bool, force: bool, runtime: str | None, @@ -51,13 +48,19 @@ def run_mcp_install( logger, apm_dir: Path, scope: str | None, - registry_url: str | None = None, ) -> None: """Execute the --mcp install path. ``registry_url`` is the validated --registry value; the caller resolved precedence vs MCP_REGISTRY_URL. ``manifest_path`` is derived from ``apm_dir`` (``apm_dir / 'apm.yml'``).""" from ...constants import APM_YML_FILENAME + mcp_name = spec.mcp_name + transport = spec.transport + url = spec.url + mcp_version = spec.mcp_version + command_argv = spec.command_argv + registry_url = spec.registry_url + manifest_path = apm_dir / APM_YML_FILENAME verbose = logger.verbose from ...models.dependency.mcp import MCPDependency diff --git a/src/apm_cli/install/mcp/conflicts.py b/src/apm_cli/install/mcp/conflicts.py index b67750058..61212a3f9 100644 --- a/src/apm_cli/install/mcp/conflicts.py +++ b/src/apm_cli/install/mcp/conflicts.py @@ -12,6 +12,8 @@ import click +from .spec import MCPRequestSpec + # Mapping for E10: which flags require --mcp. Keyed by attribute-style # name so we can read directly from the Click handler locals. MCP_REQUIRED_FLAGS: tuple[tuple[str, str], ...] = ( @@ -25,26 +27,28 @@ def validate_mcp_conflicts( *, - mcp_name: str | None, + spec: MCPRequestSpec, packages: Sequence[str], pre_dash_packages: Sequence[str], - transport: str | None, - url: str | None, env: Mapping[str, str], headers: Mapping[str, str], - mcp_version: str | None, - command_argv: Sequence[str] | None, global_: bool, only: str | None, update: bool, any_transport_flag: bool, - registry_url: str | None = None, ) -> None: """Apply conflict matrix E1-E15. Raises ``click.UsageError`` on hit. ``any_transport_flag`` should be ``use_ssh or use_https or allow_protocol_fallback`` (pre-evaluated by the caller). """ + mcp_name = spec.mcp_name + transport = spec.transport + url = spec.url + mcp_version = spec.mcp_version + command_argv = spec.command_argv + registry_url = spec.registry_url + # E10: flags require --mcp -- run first so users get the right hint. if mcp_name is None: flag_values = { diff --git a/src/apm_cli/install/mcp/spec.py b/src/apm_cli/install/mcp/spec.py new file mode 100644 index 000000000..90a85fce1 --- /dev/null +++ b/src/apm_cli/install/mcp/spec.py @@ -0,0 +1,18 @@ +"""Shared MCP request-identity dataclass.""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass + + +@dataclass(frozen=True) +class MCPRequestSpec: + """Shared MCP request-identity fields common to install + conflict validation.""" + + mcp_name: str | None + transport: str | None + url: str | None + mcp_version: str | None + command_argv: Sequence[str] | None + registry_url: str | None = None diff --git a/src/apm_cli/install/mcp_handler.py b/src/apm_cli/install/mcp_handler.py index 0dc3d88a5..943a29de4 100644 --- a/src/apm_cli/install/mcp_handler.py +++ b/src/apm_cli/install/mcp_handler.py @@ -43,6 +43,7 @@ def _handle_mcp_install( from apm_cli.install.mcp.registry import ( validate_mcp_dry_run_entry as _validate_mcp_dry_run_entry, ) + from apm_cli.install.mcp.spec import MCPRequestSpec transport = mcp_conn.transport url = mcp_conn.url @@ -116,13 +117,16 @@ def _handle_mcp_install( logger.dry_run_notice(f"would add MCP server '{mcp_name}' to {mcp_manifest_path}") return _m._run_mcp_install( - mcp_name=mcp_name, - transport=transport, - url=url, + spec=MCPRequestSpec( + mcp_name=mcp_name, + transport=transport, + url=url, + mcp_version=mcp_version, + command_argv=command_argv, + registry_url=validated_registry_url, + ), env_pairs=env_pairs, header_pairs=header_pairs, - mcp_version=mcp_version, - command_argv=command_argv, dev=dev, force=force, runtime=runtime, @@ -130,5 +134,4 @@ def _handle_mcp_install( logger=logger, apm_dir=mcp_apm_dir, scope=mcp_scope, - registry_url=validated_registry_url, ) diff --git a/tests/integration/test_coverage_phase4.py b/tests/integration/test_coverage_phase4.py index 237845fa2..91c0f5238 100644 --- a/tests/integration/test_coverage_phase4.py +++ b/tests/integration/test_coverage_phase4.py @@ -614,6 +614,7 @@ def _make_logger(self) -> MagicMock: def test_skipped_status_returns_early(self, tmp_path: Path) -> None: from apm_cli.install.mcp.command import run_mcp_install + from apm_cli.install.mcp.spec import MCPRequestSpec apm_yml = tmp_path / "apm.yml" apm_yml.write_text(_APM_YML_MINIMAL, encoding="utf-8") @@ -632,13 +633,15 @@ def test_skipped_status_returns_early(self, tmp_path: Path) -> None: patch("apm_cli.install.mcp.command.warn_shell_metachars"), ): run_mcp_install( - mcp_name="my-server", - transport=None, - url=None, + spec=MCPRequestSpec( + mcp_name="my-server", + transport=None, + url=None, + mcp_version=None, + command_argv=None, + ), env_pairs=None, header_pairs=None, - mcp_version=None, - command_argv=None, dev=False, force=False, runtime=None, @@ -653,6 +656,7 @@ def test_skipped_status_returns_early(self, tmp_path: Path) -> None: def test_added_string_entry_no_deps_available(self, tmp_path: Path) -> None: from apm_cli.install.mcp.command import run_mcp_install + from apm_cli.install.mcp.spec import MCPRequestSpec apm_yml = tmp_path / "apm.yml" apm_yml.write_text(_APM_YML_MINIMAL, encoding="utf-8") @@ -672,13 +676,15 @@ def test_added_string_entry_no_deps_available(self, tmp_path: Path) -> None: patch("apm_cli.install.mcp.command.APM_DEPS_AVAILABLE", False), ): run_mcp_install( - mcp_name="registry-server", - transport=None, - url=None, + spec=MCPRequestSpec( + mcp_name="registry-server", + transport=None, + url=None, + mcp_version=None, + command_argv=None, + ), env_pairs=None, header_pairs=None, - mcp_version=None, - command_argv=None, dev=False, force=False, runtime=None, @@ -693,6 +699,7 @@ def test_added_string_entry_no_deps_available(self, tmp_path: Path) -> None: def test_replaced_dict_entry_no_deps_available(self, tmp_path: Path) -> None: from apm_cli.install.mcp.command import run_mcp_install + from apm_cli.install.mcp.spec import MCPRequestSpec apm_yml = tmp_path / "apm.yml" apm_yml.write_text(_APM_YML_MINIMAL, encoding="utf-8") @@ -717,13 +724,15 @@ def test_replaced_dict_entry_no_deps_available(self, tmp_path: Path) -> None: patch("apm_cli.install.mcp.command.APM_DEPS_AVAILABLE", False), ): run_mcp_install( - mcp_name="my-srv", - transport="http", - url="http://localhost:8000", + spec=MCPRequestSpec( + mcp_name="my-srv", + transport="http", + url="http://localhost:8000", + mcp_version=None, + command_argv=None, + ), env_pairs=None, header_pairs=None, - mcp_version=None, - command_argv=None, dev=False, force=True, runtime=None, @@ -740,6 +749,7 @@ def test_build_entry_value_error_becomes_usage_error(self, tmp_path: Path) -> No import click from apm_cli.install.mcp.command import run_mcp_install + from apm_cli.install.mcp.spec import MCPRequestSpec apm_yml = tmp_path / "apm.yml" apm_yml.write_text(_APM_YML_MINIMAL, encoding="utf-8") @@ -755,13 +765,15 @@ def test_build_entry_value_error_becomes_usage_error(self, tmp_path: Path) -> No ): with pytest.raises(click.UsageError, match=r"bad entry"): run_mcp_install( - mcp_name="bad-srv", - transport=None, - url=None, + spec=MCPRequestSpec( + mcp_name="bad-srv", + transport=None, + url=None, + mcp_version=None, + command_argv=None, + ), env_pairs=None, header_pairs=None, - mcp_version=None, - command_argv=None, dev=False, force=False, runtime=None, @@ -775,6 +787,7 @@ def test_mcp_integrator_failure_raises_click_exception(self, tmp_path: Path) -> import click from apm_cli.install.mcp.command import run_mcp_install + from apm_cli.install.mcp.spec import MCPRequestSpec apm_yml = tmp_path / "apm.yml" apm_yml.write_text(_APM_YML_MINIMAL, encoding="utf-8") @@ -810,13 +823,15 @@ def test_mcp_integrator_failure_raises_click_exception(self, tmp_path: Path) -> mock_ctx.return_value.__exit__ = MagicMock(return_value=False) with pytest.raises(click.ClickException, match=r"MCP integration failed"): run_mcp_install( - mcp_name="fail-srv", - transport="http", - url="http://x", + spec=MCPRequestSpec( + mcp_name="fail-srv", + transport="http", + url="http://x", + mcp_version=None, + command_argv=None, + ), env_pairs=None, header_pairs=None, - mcp_version=None, - command_argv=None, dev=False, force=False, runtime=None, @@ -828,6 +843,7 @@ def test_mcp_integrator_failure_raises_click_exception(self, tmp_path: Path) -> def test_mcp_integrator_success_updates_lockfile(self, tmp_path: Path) -> None: from apm_cli.install.mcp.command import run_mcp_install + from apm_cli.install.mcp.spec import MCPRequestSpec apm_yml = tmp_path / "apm.yml" apm_yml.write_text(_APM_YML_MINIMAL, encoding="utf-8") @@ -866,13 +882,15 @@ def test_mcp_integrator_success_updates_lockfile(self, tmp_path: Path) -> None: mock_ctx.return_value.__enter__ = MagicMock(return_value=None) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) run_mcp_install( - mcp_name="ok-srv", - transport="http", - url="http://ok", + spec=MCPRequestSpec( + mcp_name="ok-srv", + transport="http", + url="http://ok", + mcp_version=None, + command_argv=None, + ), env_pairs=None, header_pairs=None, - mcp_version=None, - command_argv=None, dev=False, force=False, runtime=None, diff --git a/tests/unit/install/test_lsp_integration.py b/tests/unit/install/test_lsp_integration.py index d14931156..c2327871c 100644 --- a/tests/unit/install/test_lsp_integration.py +++ b/tests/unit/install/test_lsp_integration.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch -from apm_cli.install.lsp.integration import run_lsp_integration +from apm_cli.install.lsp.integration import LSPIntegrationContext, run_lsp_integration from apm_cli.models.dependency.lsp import LSPDependency # --------------------------------------------------------------------------- @@ -61,9 +61,8 @@ def test_no_lsp_deps_no_old_servers(self, mock_integrator, tmp_path): lock_path=tmp_path / "apm.lock.yaml", existing_lock=None, project_root=tmp_path, - user_scope=False, - should_install=True, logger=_mock_logger(), + ctx=LSPIntegrationContext(user_scope=False, should_install=True), ) assert count == 0 @@ -90,9 +89,8 @@ def test_installs_direct_deps(self, mock_integrator, tmp_path): lock_path=tmp_path / "apm.lock.yaml", existing_lock=None, project_root=tmp_path, - user_scope=False, - should_install=True, logger=_mock_logger(), + ctx=LSPIntegrationContext(user_scope=False, should_install=True), ) assert count == 1 @@ -118,9 +116,8 @@ def test_resolves_targets_for_install(self, mock_integrator, tmp_path): lock_path=tmp_path / "apm.lock.yaml", existing_lock=None, project_root=tmp_path, - user_scope=False, - should_install=True, logger=logger, + ctx=LSPIntegrationContext(user_scope=False, should_install=True), ) assert count == 1 @@ -158,9 +155,8 @@ def test_deduplicates_transitive(self, mock_integrator, tmp_path): lock_path=tmp_path / "apm.lock.yaml", existing_lock=None, project_root=tmp_path, - user_scope=False, - should_install=True, logger=_mock_logger(), + ctx=LSPIntegrationContext(user_scope=False, should_install=True), ) assert count == 2 @@ -195,9 +191,8 @@ def test_removes_stale_servers(self, mock_integrator, tmp_path): lock_path=tmp_path / "apm.lock.yaml", existing_lock=old_lock, project_root=tmp_path, - user_scope=False, - should_install=True, logger=_mock_logger(), + ctx=LSPIntegrationContext(user_scope=False, should_install=True), ) mock_integrator.remove_stale.assert_called_once() @@ -218,9 +213,8 @@ def test_removes_all_old_when_no_deps_remain(self, mock_integrator, tmp_path): lock_path=tmp_path / "apm.lock.yaml", existing_lock=old_lock, project_root=tmp_path, - user_scope=False, - should_install=True, logger=_mock_logger(), + ctx=LSPIntegrationContext(user_scope=False, should_install=True), ) mock_integrator.remove_stale.assert_called_once() @@ -251,9 +245,8 @@ def test_restores_old_lockfile_when_not_installing(self, mock_integrator, tmp_pa lock_path=tmp_path / "apm.lock.yaml", existing_lock=old_lock, project_root=tmp_path, - user_scope=False, - should_install=False, logger=_mock_logger(), + ctx=LSPIntegrationContext(user_scope=False, should_install=False), ) mock_integrator.update_lockfile.assert_called_once() diff --git a/tests/unit/install/test_mcp_conflicts.py b/tests/unit/install/test_mcp_conflicts.py index 1fe785a8f..4531d78ac 100644 --- a/tests/unit/install/test_mcp_conflicts.py +++ b/tests/unit/install/test_mcp_conflicts.py @@ -13,6 +13,7 @@ MCP_REQUIRED_FLAGS, validate_mcp_conflicts, ) +from apm_cli.install.mcp.spec import MCPRequestSpec # --------------------------------------------------------------------------- # Helper @@ -21,24 +22,28 @@ def _call(**overrides) -> None: """Call validate_mcp_conflicts with sensible defaults, allowing overrides.""" - defaults: dict = dict( + spec_keys = {"mcp_name", "transport", "url", "mcp_version", "command_argv", "registry_url"} + spec_defaults: dict = dict( mcp_name="my-server", - packages=[], - pre_dash_packages=[], transport=None, url=None, - env={}, - headers={}, mcp_version=None, command_argv=None, + registry_url=None, + ) + other_defaults: dict = dict( + packages=[], + pre_dash_packages=[], + env={}, + headers={}, global_=False, only=None, update=False, any_transport_flag=False, - registry_url=None, ) - defaults.update(overrides) - validate_mcp_conflicts(**defaults) + spec_kwargs = {k: overrides.pop(k, spec_defaults[k]) for k in spec_keys} + other_defaults.update(overrides) + validate_mcp_conflicts(spec=MCPRequestSpec(**spec_kwargs), **other_defaults) # --------------------------------------------------------------------------- From ad88dfa540518cf0e4f8a762ed8c34d8f7a12cec Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 01:37:51 +0200 Subject: [PATCH 04/21] refactor(install): extract resolve download closure into _TransitiveDownloader Move the ~285-line nested download_callback closure out of _resolve_dependencies into a stateful _TransitiveDownloader callable in a new sibling module install/phases/resolve_transitive.py. Ruff folds nested closures into the parent function, so this clears the C901/PLR0915 hot spot on _resolve_dependencies and drops resolve.py from 988 to 691 lines (under the 800 file-length target); the new module is 344 lines. The callable keeps parent_pkg in __call__ so APMDependencyResolver's _signature_accepts_parent_pkg introspection still routes the declaring-parent anchor for transitive local deps (#857). The registry branch stays inline in __call__ so the dep_ref rebind remains visible to the shared except handler; this is safe because get_unique_key() is invariant to registry_name. The download accumulators (downloaded/failures/transitive_failures) are read back onto ctx by identity, preserving post-resolution behaviour exactly. Refs #1078 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/install/phases/resolve.py | 353 ++---------------- .../install/phases/resolve_transitive.py | 343 +++++++++++++++++ .../phases/test_resolve_tui_callbacks.py | 16 +- 3 files changed, 380 insertions(+), 332 deletions(-) create mode 100644 src/apm_cli/install/phases/resolve_transitive.py diff --git a/src/apm_cli/install/phases/resolve.py b/src/apm_cli/install/phases/resolve.py index 10b507d9d..0a0f9e3ee 100644 --- a/src/apm_cli/install/phases/resolve.py +++ b/src/apm_cli/install/phases/resolve.py @@ -375,9 +375,6 @@ def _resolve_dependencies(ctx: InstallContext) -> None: builds ``ctx.dep_base_dirs``, writes ancillary state to ``ctx``, and cleans up the shared clone cache. """ - import threading as _threading - - from apm_cli.core.scope import InstallScope from apm_cli.deps.apm_resolver import APMDependencyResolver from apm_cli.install.insecure_policy import ( _check_insecure_dependencies, @@ -385,7 +382,6 @@ def _resolve_dependencies(ctx: InstallContext) -> None: _guard_transitive_insecure_dependencies, _warn_insecure_dependencies, ) - from apm_cli.install.phases.local_content import _copy_local_package # ------------------------------------------------------------------ # 3b. Dedicated registry resolver (design §3.1, §8) @@ -414,336 +410,36 @@ def _resolve_dependencies(ctx: InstallContext) -> None: ctx.registry_resolver = registry_resolver # ------------------------------------------------------------------ - # 4. Tracking variables (phase-local except where noted) + # 4. Tracking variables + transitive download callback # ------------------------------------------------------------------ - # direct_dep_keys is phase-local (only read inside download_callback) + # direct_dep_keys is phase-local (only read by the download callback). direct_dep_keys = builtins.set(dep.get_unique_key() for dep in ctx.all_apm_deps) - # These three escape to later phases via ctx - callback_downloaded: builtins.dict = {} - transitive_failures: builtins.list = [] - callback_failures: builtins.set = builtins.set() - # F7 (#1116): the resolver may dispatch ``download_callback`` calls - # across a worker pool. CPython's GIL makes individual dict/set/list - # mutations atomic, but logging emission and the read+update on - # ``callback_downloaded`` (e.g. duplicate-key races) are not. A single - # narrow lock around the result-recording sites is sufficient and - # cheap; the heavy I/O work runs OUTSIDE the lock. - callback_lock = _threading.Lock() - - # ------------------------------------------------------------------ - # 5. Download callback for transitive resolution - # ------------------------------------------------------------------ - # Capture frequently-used ctx fields as locals for the closure. - # This matches the original code's closure over function-level locals. - scope = ctx.scope + # project_root is reused below when building dep_base_dirs for transitive + # local deps (#857). project_root = ctx.project_root - # Local-path package references in apm.yml are relative to the - # manifest's location (source_root), not the deploy override. - # source_root is required on InstallContext; equals project_root - # when --root is not used. - source_root = ctx.source_root # --refresh implies re-resolution of all refs (but does NOT discard # lockfile entries for packages not in the manifest, unlike --update # which may restructure the whole graph). update_refs = ctx.update_refs or ctx.refresh if ctx.refresh and ctx.logger: ctx.logger.verbose_detail("[*] --refresh: re-resolving all refs") - logger = ctx.logger - existing_lockfile = ctx.existing_lockfile - downloader = ctx.downloader - - # Hoist drift helpers so download_callback avoids per-call sys.modules - # lookups and static analysis can see the dependency. - from apm_cli.drift import build_download_ref, detect_ref_change - - verbose = ctx.verbose # noqa: F841 - - def download_callback(dep_ref, modules_dir, parent_chain="", parent_pkg=None): - """Download a package during dependency resolution. - - Args: - dep_ref: The dependency to download. - modules_dir: Target apm_modules directory. - parent_chain: Human-readable breadcrumb (e.g. "root > mid") - showing which dependency path led to this transitive dep. - parent_pkg: APMPackage that declared *dep_ref*, or None for direct - deps from the root project. For local deps we use its - ``source_path`` as the anchor for relative paths so a - transitive ``../sibling`` resolves against the declaring - package's directory rather than the root consumer (#857). - """ - install_path = dep_ref.get_install_path(modules_dir) - # Cache short-circuit: skip the rest of the callback when the - # install path already exists. Exception: for git-source semver - # deps under ``--update`` / ``--refresh`` (``update_refs=True``), - # fall through so ``_maybe_resolve_git_semver`` re-runs - # ``git ls-remote`` and the lockfile gets rewritten with the - # latest matching tag. Matches npm/cargo/bundler: ``--update`` - # is the explicit re-resolve trigger and must not be swallowed - # by the on-disk cache (Bug 1 fix on #1496). The downstream - # ``downloader.download_package`` rmtrees and re-clones the - # install path when the resolved tag changes, so refetching is - # safe. - if install_path.exists(): - _force_semver_resolve = ( - update_refs - and not dep_ref.is_local - and getattr(dep_ref, "source", None) != "registry" - and not getattr(dep_ref, "artifactory_prefix", None) - and getattr(dep_ref, "ref_kind", None) == "semver" - ) - if not _force_semver_resolve: - return install_path - # F1 (#1116): surface a heartbeat BEFORE the network/copy work so - # users see the install advancing past silent transitive lookups. - # Under F7's parallel BFS this callback may run on a worker - # thread, so serialise the emission via ``callback_lock`` to - # keep heartbeat lines from interleaving with each other. - # Workstream B (#1116): when the shared InstallTui is painting - # the Live region, the static heartbeat line would interleave - # with the spinner -- route the heartbeat to the TUI's - # task_started instead and skip the static line. - if logger: - with callback_lock: - _display = dep_ref.get_display_name() - _tui = getattr(ctx, "tui", None) - if _tui is not None: - _tui.task_started(dep_ref.get_unique_key(), f"resolve {_display}") - if _tui is None or not _tui.is_animating(): - logger.resolving_heartbeat(_display) - try: - # ─── Registry-sourced dep (design §8) ────────────────────── - # Routed before local/git so the registry resolver owns the - # download for source=="registry" entries. Lockfile re-installs - # may arrive with registry_name=None — look it up by URL prefix - # against the configured registries map. - if dep_ref.source == "registry": - from apm_cli.deps.registry.feature_gate import ( - require_package_registry_enabled, - ) - - require_package_registry_enabled("Registry-sourced downloads") - - if registry_resolver is None: - raise RuntimeError( - f"dep {dep_ref.repo_url!r} is registry-sourced but no " - f"registries: block is configured in apm.yml and the " - f"lockfile carries no resolved_url for it." - ) - dep_ref = _apply_lockfile_registry_name( - dep_ref, - registries_map, - existing_lockfile=existing_lockfile, - ) - # Registry T5: honor lockfile on apm install (mirrors git T5 - # at lines below). When the lockfile has full replay data and - # the manifest range still covers the locked version, fetch - # from the locked URL and verify against the locked hash - # (npm install model — no /versions API call). - _locked_reg = ( - existing_lockfile.get_dependency(dep_ref.get_unique_key()) - if existing_lockfile - else None - ) - if ( - not update_refs - and _locked_reg - and _locked_reg.resolved_url - and _locked_reg.resolved_hash - and _locked_reg.version - ): - from apm_cli.drift import detect_ref_change as _detect_ref_change - - if not _detect_ref_change(dep_ref, _locked_reg, update_refs=False): - registry_resolver.download_from_lockfile( - dep_ref, - install_path, - resolved_url=_locked_reg.resolved_url, - resolved_hash=_locked_reg.resolved_hash, - version=_locked_reg.version, - ) - callback_downloaded[dep_ref.get_unique_key()] = None - return install_path - registry_resolver.download_package(dep_ref, install_path) - # Mark as already-downloaded so the parallel pre-download - # phase skips this dep. No SHA for registry deps. - callback_downloaded[dep_ref.get_unique_key()] = None - return install_path - - # Handle local packages: copy instead of git clone - if dep_ref.is_local and dep_ref.local_path: - if ( - scope is InstallScope.USER - and not Path(dep_ref.local_path).expanduser().is_absolute() - ): - # At user scope, relative local paths have no meaningful - # root (cwd is arbitrary, $HOME is not a project). Only - # absolute paths are unambiguous; reject relative refs. - # Note: callback_failures is a set (see line ~105), - # so use .add() rather than dict-style assignment. - with callback_lock: - callback_failures.add(dep_ref.get_unique_key()) - _tui = getattr(ctx, "tui", None) - if _tui is not None: - _tui.task_failed(dep_ref.get_unique_key()) - return None - # Anchor relative paths on the *declaring* package's source - # directory when available (#857). Falls back to project_root - # for direct deps and for parents that predate source_path. - # Direct deps from the root project anchor at ``source_root`` - # (which equals ``project_root`` unless ``apm install --root`` - # redirects writes -- then it stays at $PWD). Transitive - # deps from a parent local package anchor at that package's - # source_path, which is already an absolute path and not - # affected by ``--root``. - base_dir = ( - parent_pkg.source_path - if parent_pkg is not None and parent_pkg.source_path is not None - else source_root - ) - result_path = _copy_local_package( - dep_ref, - install_path, - base_dir, - project_root=project_root, - logger=logger, - ) - if result_path: - with callback_lock: - callback_downloaded[dep_ref.get_unique_key()] = None - _tui = getattr(ctx, "tui", None) - if _tui is not None: - _tui.task_completed(dep_ref.get_unique_key()) - return result_path - _tui = getattr(ctx, "tui", None) - if _tui is not None: - _tui.task_failed(dep_ref.get_unique_key()) - return None - - # --- Git-source semver range resolution (issue #1488) --- - # When the manifest carries a semver range as ``ref:`` and - # the dep is non-local, non-registry, and non-proxy, resolve - # it to a concrete tag BEFORE any git operation. The result - # is stashed on ctx so install/sources.py can plumb it into - # the lockfile, and the dep_ref's ``reference`` is replaced - # with the concrete tag so build_download_ref / clone use a - # literal git ref. - _semver_resolution = _maybe_resolve_git_semver( - dep_ref=dep_ref, - existing_lockfile=existing_lockfile, - update_refs=update_refs, - auth_resolver=ctx.auth_resolver, - ) - if _semver_resolution is not None: - with callback_lock: - ctx.git_semver_resolutions[dep_ref.get_unique_key()] = _semver_resolution - # Rewrite the dep_ref's ref to the concrete tag so the - # rest of the pipeline (drift detection, download, etc.) - # operates on a literal git ref. The original constraint - # is preserved in the resolution dataclass. - dep_ref.reference = _semver_resolution.resolved_tag - - # T5: Use locked commit for reproducibility, unless the manifest - # ref has drifted from what the lockfile recorded (spec drift). - _locked_dep = ( - existing_lockfile.get_dependency(dep_ref.get_unique_key()) - if existing_lockfile - else None - ) - _ref_changed = detect_ref_change(dep_ref, _locked_dep, update_refs=update_refs) - - # When ref drifts, signal downstream that a content-hash change - # is expected so the supply-chain check in sources.py doesn't - # treat a legitimate re-resolution as an attack. - if _ref_changed: - with callback_lock: - ctx.expected_hash_change_deps.add(dep_ref.get_unique_key()) - if logger: - _old = ( - _locked_dep.resolved_ref or _locked_dep.resolved_commit[:8] - if _locked_dep - else "?" - ) - _new = dep_ref.reference or "HEAD" - logger.verbose_detail( - f" [!] Spec drift: {dep_ref.get_unique_key()} " - f"{_old} -> {_new}, re-resolving" - ) - download_dep = build_download_ref( - dep_ref, - existing_lockfile, - update_refs=update_refs, - ref_changed=_ref_changed, - ) - - # Silent download - no progress display for transitive deps - result = downloader.download_package(download_dep, install_path) - # Capture resolved commit SHA for lockfile - resolved_sha = None - if result and hasattr(result, "resolved_reference") and result.resolved_reference: - resolved_sha = result.resolved_reference.resolved_commit - callback_downloaded_value = resolved_sha - with callback_lock: - callback_downloaded[dep_ref.get_unique_key()] = callback_downloaded_value - _tui = getattr(ctx, "tui", None) - if _tui is not None: - _tui.task_completed(dep_ref.get_unique_key()) - return install_path - except Exception as e: - dep_display = dep_ref.get_display_name() - dep_key = dep_ref.get_unique_key() - is_direct = dep_key in direct_dep_keys - - # Distinguish resolution failures (git-semver no-match) from - # download failures: the dep_ref was rewritten to a concrete - # tag BEFORE clone, so a NoMatchingTagError means we never - # got to the download step. Using "download" as the verb - # would mislead users who are debugging an unsatisfied - # constraint -- nothing was downloaded yet. - from apm_cli.deps.git_semver_resolver import NoMatchingTagError - from apm_cli.models.dependency.reference import InvalidSemverRangeError - - if isinstance(e, InvalidSemverRangeError): - if is_direct: - fail_msg = f"Invalid dependency spec for {dep_ref.repo_url}: {e}" - else: - chain_hint = f" (via {parent_chain})" if parent_chain else "" - fail_msg = ( - f"Invalid dependency spec for transitive dep " - f"{dep_ref.repo_url}{chain_hint}: {e}" - ) - elif isinstance(e, NoMatchingTagError): - if is_direct: - fail_msg = f"No matching tag for {dep_ref.repo_url}: {e}" - else: - chain_hint = f" (via {parent_chain})" if parent_chain else "" - fail_msg = ( - f"No matching tag for transitive dep {dep_ref.repo_url}{chain_hint}: {e}" - ) - # Distinguish direct vs transitive failure messages so users - # don't see a misleading "transitive dep" label for top-level deps. - elif is_direct: - fail_msg = f"Failed to download dependency {dep_ref.repo_url}: {e}" - else: - chain_hint = f" (via {parent_chain})" if parent_chain else "" - fail_msg = f"Failed to resolve transitive dep {dep_ref.repo_url}{chain_hint}: {e}" - - # Verbose: inline detail via logger (single output path). - # Deferred diagnostics below cover the non-logger case. - # F7 (#1116): single critical section for both the logger - # emission and the result-recording so concurrent failures - # don't interleave their lines. - with callback_lock: - if logger: - logger.verbose_detail(f" {fail_msg}") - # Collect for deferred diagnostics summary (always, even non-verbose) - callback_failures.add(dep_key) - transitive_failures.append((dep_display, fail_msg)) - _tui = getattr(ctx, "tui", None) - if _tui is not None: - _tui.task_failed(dep_key) - return None + # The former nested ``download_callback`` closure now lives in a stateful + # callable so this function stays within the complexity/statement budget. + # It accumulates downloaded / failures / transitive_failures which are + # folded back onto ctx after resolution (same mutable objects, by + # identity). Constructed lazily to avoid a resolve <-> resolve_transitive + # import cycle. + from apm_cli.install.phases.resolve_transitive import _TransitiveDownloader + + download_cb = _TransitiveDownloader( + ctx, + registry_resolver=registry_resolver, + apply_lockfile_registry_name=_apply_lockfile_registry_name, + registries_map=registries_map, + direct_dep_keys=direct_dep_keys, + update_refs=update_refs, + ) # ------------------------------------------------------------------ # 6. Resolver creation + dependency resolution @@ -757,7 +453,7 @@ def download_callback(dep_ref, modules_dir, parent_chain="", parent_pkg=None): resolver = APMDependencyResolver( apm_modules_dir=ctx.apm_modules_dir, - download_callback=download_callback, + download_callback=download_cb, auth_resolver=ctx.auth_resolver, ) @@ -778,6 +474,13 @@ def download_callback(dep_ref, modules_dir, parent_chain="", parent_pkg=None): ctx.dependency_graph = dependency_graph _fail_on_resolution_errors(ctx, dependency_graph) + # Read back the accumulators populated by the download callback (same + # mutable objects, by identity) so the post-resolution code and the + # ctx.callback_* assignments below operate unchanged. + callback_downloaded = download_cb.downloaded + transitive_failures = download_cb.transitive_failures + callback_failures = download_cb.failures + # Fold remote-parent local_path rejections into ``callback_failures`` so # the integrate phase skips them via the same gate used for download # failures (PR #1111 review C2). The resolver has already emitted the diff --git a/src/apm_cli/install/phases/resolve_transitive.py b/src/apm_cli/install/phases/resolve_transitive.py new file mode 100644 index 000000000..6d634d878 --- /dev/null +++ b/src/apm_cli/install/phases/resolve_transitive.py @@ -0,0 +1,343 @@ +"""Transitive download callback for the resolve phase. + +Extracts the former ``download_callback`` closure out of +``resolve._resolve_dependencies`` into a stateful callable so the enclosing +phase function stays within the complexity / statement budget. One instance +is created per resolve run; the resolver (``APMDependencyResolver``) invokes +it once per BFS edge -- possibly across a worker pool, hence the lock. + +The instance accumulates ``downloaded`` / ``failures`` / ``transitive_failures`` +which the resolve phase folds back onto the :class:`InstallContext` after +resolution completes (same mutable objects, read back by identity). +""" + +from __future__ import annotations + +import threading +from pathlib import Path +from typing import TYPE_CHECKING + +from apm_cli.drift import build_download_ref, detect_ref_change + +if TYPE_CHECKING: + from apm_cli.install.context import InstallContext + + +class _TransitiveDownloader: + """Callable that downloads one package during BFS dependency resolution. + + Replaces the former nested ``download_callback`` closure. The + ``__call__`` signature deliberately keeps ``parent_pkg`` so the resolver's + ``_signature_accepts_parent_pkg`` introspection (which inspects + ``__call__`` and skips ``self``) still routes the declaring-parent anchor + through for transitive local deps (#857). + + The registry branch is kept inline in :meth:`__call__` (rather than + extracted) so the ``dep_ref`` reassignment from + ``_apply_lockfile_registry_name`` stays visible to the shared exception + handler, preserving the exact failure-key behaviour of the original + closure. + """ + + def __init__( + self, + ctx: InstallContext, + *, + registry_resolver, + apply_lockfile_registry_name, + registries_map, + direct_dep_keys, + update_refs: bool, + ): + self.ctx = ctx + self.registry_resolver = registry_resolver + self._apply_lockfile_registry_name = apply_lockfile_registry_name + self.registries_map = registries_map + self.direct_dep_keys = direct_dep_keys + self.update_refs = update_refs + # Snapshot the same ctx fields the former closure captured as locals. + self.scope = ctx.scope + self.project_root = ctx.project_root + self.source_root = ctx.source_root + self.logger = ctx.logger + self.existing_lockfile = ctx.existing_lockfile + self.downloader = ctx.downloader + # Accumulators that escape back to ctx after resolution. Mutated in + # place; the resolve phase reads these same objects by identity. + self.downloaded: dict = {} + self.transitive_failures: list = [] + self.failures: set = set() + self.lock = threading.Lock() + + # -- TUI helpers (collapse the repeated getattr/guard blocks) ---------- + # ``ctx.tui`` is read live each call (never cached) to mirror the + # original closure, which re-fetched it at every use site. + def _tui_completed(self, dep_key) -> None: + tui = getattr(self.ctx, "tui", None) + if tui is not None: + tui.task_completed(dep_key) + + def _tui_failed(self, dep_key) -> None: + tui = getattr(self.ctx, "tui", None) + if tui is not None: + tui.task_failed(dep_key) + + def _force_semver_resolve(self, dep_ref) -> bool: + """Pure predicate: should a cached install path fall through for a + git-source semver dep under ``--update`` / ``--refresh``?""" + return ( + self.update_refs + and not dep_ref.is_local + and getattr(dep_ref, "source", None) != "registry" + and not getattr(dep_ref, "artifactory_prefix", None) + and getattr(dep_ref, "ref_kind", None) == "semver" + ) + + def _emit_heartbeat(self, dep_ref) -> None: + """Surface a heartbeat BEFORE the network/copy work so users see the + install advancing past silent transitive lookups (#1116 F1/B).""" + if not self.logger: + return + with self.lock: + _display = dep_ref.get_display_name() + _tui = getattr(self.ctx, "tui", None) + if _tui is not None: + _tui.task_started(dep_ref.get_unique_key(), f"resolve {_display}") + if _tui is None or not _tui.is_animating(): + self.logger.resolving_heartbeat(_display) + + def __call__(self, dep_ref, modules_dir, parent_chain="", parent_pkg=None): + """Download a package during dependency resolution. + + Args: + dep_ref: The dependency to download. + modules_dir: Target apm_modules directory. + parent_chain: Human-readable breadcrumb (e.g. "root > mid") + showing which dependency path led to this transitive dep. + parent_pkg: APMPackage that declared *dep_ref*, or None for direct + deps from the root project. For local deps we use its + ``source_path`` as the anchor for relative paths so a + transitive ``../sibling`` resolves against the declaring + package's directory rather than the root consumer (#857). + """ + install_path = dep_ref.get_install_path(modules_dir) + # Cache short-circuit: skip the rest when the install path already + # exists, unless this is a git-source semver dep under --update / + # --refresh (then fall through so ``_maybe_resolve_git_semver`` + # re-runs ``git ls-remote`` and the lockfile gets rewritten with the + # latest matching tag -- Bug 1 fix on #1496). + if install_path.exists() and not self._force_semver_resolve(dep_ref): + return install_path + self._emit_heartbeat(dep_ref) + try: + # Registry-sourced dep (design 8): routed before local/git so the + # registry resolver owns the download. Kept inline so the + # ``dep_ref`` reassignment below is seen by the except handler. + if dep_ref.source == "registry": + return self._download_registry(dep_ref, install_path) + # Local package: copy instead of git clone. + if dep_ref.is_local and dep_ref.local_path: + return self._download_local(dep_ref, install_path, parent_pkg) + return self._download_git(dep_ref, install_path) + except Exception as e: + self._record_failure(dep_ref, e, parent_chain) + return None + + def _download_registry(self, dep_ref, install_path): + """Download a ``source == "registry"`` dep. + + Reassigns ``dep_ref`` via ``_apply_lockfile_registry_name`` exactly + like the original closure; because ``get_unique_key()`` is invariant + to ``registry_name`` the failure key is unchanged either way, but the + rebind is preserved for parity. + """ + from apm_cli.deps.registry.feature_gate import require_package_registry_enabled + + require_package_registry_enabled("Registry-sourced downloads") + + if self.registry_resolver is None: + raise RuntimeError( + f"dep {dep_ref.repo_url!r} is registry-sourced but no " + f"registries: block is configured in apm.yml and the " + f"lockfile carries no resolved_url for it." + ) + dep_ref = self._apply_lockfile_registry_name( + dep_ref, + self.registries_map, + existing_lockfile=self.existing_lockfile, + ) + # Registry T5: honor lockfile on apm install. When the lockfile has + # full replay data and the manifest range still covers the locked + # version, fetch from the locked URL and verify against the locked + # hash (npm install model -- no /versions API call). + _locked_reg = ( + self.existing_lockfile.get_dependency(dep_ref.get_unique_key()) + if self.existing_lockfile + else None + ) + if ( + not self.update_refs + and _locked_reg + and _locked_reg.resolved_url + and _locked_reg.resolved_hash + and _locked_reg.version + ): + if not detect_ref_change(dep_ref, _locked_reg, update_refs=False): + self.registry_resolver.download_from_lockfile( + dep_ref, + install_path, + resolved_url=_locked_reg.resolved_url, + resolved_hash=_locked_reg.resolved_hash, + version=_locked_reg.version, + ) + self.downloaded[dep_ref.get_unique_key()] = None + return install_path + self.registry_resolver.download_package(dep_ref, install_path) + # Mark as already-downloaded so the parallel pre-download phase skips + # this dep. No SHA for registry deps. + self.downloaded[dep_ref.get_unique_key()] = None + return install_path + + def _download_local(self, dep_ref, install_path, parent_pkg): + """Copy a local-path dep into the modules dir (no git clone).""" + from apm_cli.core.scope import InstallScope + from apm_cli.install.phases.local_content import _copy_local_package + + if ( + self.scope is InstallScope.USER + and not Path(dep_ref.local_path).expanduser().is_absolute() + ): + # At user scope, relative local paths have no meaningful root + # (cwd is arbitrary, $HOME is not a project). Reject them. + with self.lock: + self.failures.add(dep_ref.get_unique_key()) + self._tui_failed(dep_ref.get_unique_key()) + return None + # Anchor relative paths on the declaring package's source directory + # when available (#857); fall back to source_root for direct deps. + base_dir = ( + parent_pkg.source_path + if parent_pkg is not None and parent_pkg.source_path is not None + else self.source_root + ) + result_path = _copy_local_package( + dep_ref, + install_path, + base_dir, + project_root=self.project_root, + logger=self.logger, + ) + if result_path: + with self.lock: + self.downloaded[dep_ref.get_unique_key()] = None + self._tui_completed(dep_ref.get_unique_key()) + return result_path + self._tui_failed(dep_ref.get_unique_key()) + return None + + def _download_git(self, dep_ref, install_path): + """Resolve any git-source semver range, detect spec drift, and clone.""" + from apm_cli.install.phases.resolve import _maybe_resolve_git_semver + + # Git-source semver range resolution (#1488): resolve a semver range + # ``ref:`` to a concrete tag BEFORE any git operation. The result is + # stashed on ctx so sources.py can plumb it into the lockfile; the + # dep_ref's ``reference`` is rewritten in place to the concrete tag. + _semver_resolution = _maybe_resolve_git_semver( + dep_ref=dep_ref, + existing_lockfile=self.existing_lockfile, + update_refs=self.update_refs, + auth_resolver=self.ctx.auth_resolver, + ) + if _semver_resolution is not None: + with self.lock: + self.ctx.git_semver_resolutions[dep_ref.get_unique_key()] = _semver_resolution + dep_ref.reference = _semver_resolution.resolved_tag + + # T5: use locked commit for reproducibility, unless the manifest ref + # has drifted from what the lockfile recorded (spec drift). + _locked_dep = ( + self.existing_lockfile.get_dependency(dep_ref.get_unique_key()) + if self.existing_lockfile + else None + ) + _ref_changed = detect_ref_change(dep_ref, _locked_dep, update_refs=self.update_refs) + + # When ref drifts, signal downstream that a content-hash change is + # expected so the supply-chain check in sources.py doesn't treat a + # legitimate re-resolution as an attack. + if _ref_changed: + with self.lock: + self.ctx.expected_hash_change_deps.add(dep_ref.get_unique_key()) + if self.logger: + _old = ( + _locked_dep.resolved_ref or _locked_dep.resolved_commit[:8] + if _locked_dep + else "?" + ) + _new = dep_ref.reference or "HEAD" + self.logger.verbose_detail( + f" [!] Spec drift: {dep_ref.get_unique_key()} {_old} -> {_new}, re-resolving" + ) + + download_dep = build_download_ref( + dep_ref, + self.existing_lockfile, + update_refs=self.update_refs, + ref_changed=_ref_changed, + ) + + # Silent download - no progress display for transitive deps. + result = self.downloader.download_package(download_dep, install_path) + # Capture resolved commit SHA for lockfile. + resolved_sha = None + if result and hasattr(result, "resolved_reference") and result.resolved_reference: + resolved_sha = result.resolved_reference.resolved_commit + with self.lock: + self.downloaded[dep_ref.get_unique_key()] = resolved_sha + self._tui_completed(dep_ref.get_unique_key()) + return install_path + + def _record_failure(self, dep_ref, e, parent_chain) -> None: + """Record a download/resolution failure for deferred diagnostics.""" + # Distinguish resolution failures (git-semver no-match) from download + # failures: the dep_ref was rewritten to a concrete tag BEFORE clone, + # so a NoMatchingTagError means we never got to the download step. + from apm_cli.deps.git_semver_resolver import NoMatchingTagError + from apm_cli.models.dependency.reference import InvalidSemverRangeError + + dep_display = dep_ref.get_display_name() + dep_key = dep_ref.get_unique_key() + is_direct = dep_key in self.direct_dep_keys + + if isinstance(e, InvalidSemverRangeError): + if is_direct: + fail_msg = f"Invalid dependency spec for {dep_ref.repo_url}: {e}" + else: + chain_hint = f" (via {parent_chain})" if parent_chain else "" + fail_msg = ( + f"Invalid dependency spec for transitive dep " + f"{dep_ref.repo_url}{chain_hint}: {e}" + ) + elif isinstance(e, NoMatchingTagError): + if is_direct: + fail_msg = f"No matching tag for {dep_ref.repo_url}: {e}" + else: + chain_hint = f" (via {parent_chain})" if parent_chain else "" + fail_msg = f"No matching tag for transitive dep {dep_ref.repo_url}{chain_hint}: {e}" + # Distinguish direct vs transitive failure messages so users don't see + # a misleading "transitive dep" label for top-level deps. + elif is_direct: + fail_msg = f"Failed to download dependency {dep_ref.repo_url}: {e}" + else: + chain_hint = f" (via {parent_chain})" if parent_chain else "" + fail_msg = f"Failed to resolve transitive dep {dep_ref.repo_url}{chain_hint}: {e}" + + # F7 (#1116): single critical section for both the logger emission and + # the result-recording so concurrent failures don't interleave. + with self.lock: + if self.logger: + self.logger.verbose_detail(f" {fail_msg}") + self.failures.add(dep_key) + self.transitive_failures.append((dep_display, fail_msg)) + self._tui_failed(dep_key) diff --git a/tests/unit/install/phases/test_resolve_tui_callbacks.py b/tests/unit/install/phases/test_resolve_tui_callbacks.py index f049cf6be..830f9925a 100644 --- a/tests/unit/install/phases/test_resolve_tui_callbacks.py +++ b/tests/unit/install/phases/test_resolve_tui_callbacks.py @@ -73,19 +73,21 @@ def test_task_completed_called_on_local_copy_path() -> None: def test_resolve_module_imports_tui_attr_safely() -> None: - """Resolve uses getattr(ctx, 'tui', None) -- ctx without tui is OK. + """Download callback uses getattr(ctx, 'tui', None) -- ctx without tui is OK. Pins the duck-typed access pattern so older test fixtures - constructing minimal contexts don't break. + constructing minimal contexts don't break. The callback now lives in + ``resolve_transitive`` (extracted from ``resolve.py``'s former nested + closure), so the guard follows the code there. """ - from apm_cli.install.phases import resolve as resolve_mod + from apm_cli.install.phases import resolve_transitive as transitive_mod - # The module must use getattr(ctx, "tui", None) -- not direct + # The module must use getattr(self.ctx, "tui", None) -- not direct # attribute access -- so a missing attr does not raise. - src = resolve_mod.__file__ + src = transitive_mod.__file__ with open(src) as fh: text = fh.read() - assert 'getattr(ctx, "tui", None)' in text, ( - "resolve.py must access ctx.tui via getattr(...,None) so " + assert 'getattr(self.ctx, "tui", None)' in text, ( + "resolve_transitive.py must access ctx.tui via getattr(...,None) so " "minimal/older context objects don't trigger AttributeError" ) From abd883d7342e27aaa2f47fbc18147ad6a09da148 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 01:56:21 +0200 Subject: [PATCH 05/21] refactor(install): split sources.py into base + fresh modules (#1078) Stage 2 file-length tightening for the install dependency-source strategy. sources.py 886 -> 519 lines (<800) with no behavioural change. - New sources_base.py: Materialization value object + DependencySource ABC, now with two hoisted helpers that fold the per-source lockfile bookkeeping that was copy-pasted into every acquire(): * _lockfile_node_fields() -> (depth, resolved_by, is_dev) * _skip_integration(deltas) -> Materialization(package_info=None, ...) - New sources_fresh.py: FreshDependencySource (network path) plus _format_package_type_label, moved verbatim. - sources.py keeps Local + Cached + _rebuild_cached_semver_resolution + make_dependency_source, re-exports the moved symbols (with __all__) so every existing import path keeps working, and uses the new base helpers in Local/Cached. Less code, same functionality: the three identical node-field blocks and two identical skip-integration returns collapse into shared base helpers. R0801 green. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/install/sources.py | 405 ++------------------------- src/apm_cli/install/sources_base.py | 101 +++++++ src/apm_cli/install/sources_fresh.py | 333 ++++++++++++++++++++++ 3 files changed, 453 insertions(+), 386 deletions(-) create mode 100644 src/apm_cli/install/sources_base.py create mode 100644 src/apm_cli/install/sources_fresh.py diff --git a/src/apm_cli/install/sources.py b/src/apm_cli/install/sources.py index 4f8a81d81..4a00a9df4 100644 --- a/src/apm_cli/install/sources.py +++ b/src/apm_cli/install/sources.py @@ -26,45 +26,33 @@ from __future__ import annotations -import sys -from abc import ABC, abstractmethod -from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any from apm_cli.install.registry_wiring import ( - get_registry_resolver, registry_resolution_for_cached_registry_dep, - resolver_last_registry_resolution, ) -from apm_cli.utils.console import _rich_error, _rich_success +from apm_cli.install.sources_base import DependencySource, Materialization +from apm_cli.install.sources_fresh import ( + FreshDependencySource, + _format_package_type_label, +) from apm_cli.utils.short_sha import format_short_sha if TYPE_CHECKING: from apm_cli.install.context import InstallContext - from apm_cli.models.apm_package import PackageInfo - - -def _format_package_type_label(pkg_type) -> str | None: - """Human-readable label for a detected ``PackageType``. - Centralised so every install path emits the same wording and so - new ``PackageType`` values can be added without grepping for ad-hoc - dicts. Missing ``HOOK_PACKAGE`` from this table is what made - microsoft/apm#780 silent -- keep all classifiable enum members - covered. - """ - from apm_cli.models.apm_package import PackageType - - return { - PackageType.CLAUDE_SKILL: "Skill (SKILL.md detected)", - PackageType.MARKETPLACE_PLUGIN: "Marketplace Plugin (plugin.json or agents/skills/commands)", - PackageType.HYBRID: "Hybrid (apm.yml + SKILL.md)", - PackageType.APM_PACKAGE: "APM Package (apm.yml)", - PackageType.HOOK_PACKAGE: "Hook Package (hooks/*.json only)", - PackageType.SKILL_BUNDLE: "Skill Bundle (skills//SKILL.md)", - }.get(pkg_type) +__all__ = [ + "CachedDependencySource", + "DependencySource", + "FreshDependencySource", + "LocalDependencySource", + "Materialization", + "_format_package_type_label", + "_rebuild_cached_semver_resolution", + "make_dependency_source", +] def _rebuild_cached_semver_resolution(dep_locked_chk: Any) -> Any: @@ -101,55 +89,6 @@ def _rebuild_cached_semver_resolution(dep_locked_chk: Any) -> Any: ) -@dataclass -class Materialization: - """Outcome of ``DependencySource.acquire()``. - - Carries everything the integration template needs to run the security - gate + primitive integration on a freshly-acquired package. - """ - - package_info: PackageInfo | None - install_path: Path - dep_key: str - deltas: dict[str, int] = field(default_factory=lambda: {"installed": 1}) - - -class DependencySource(ABC): - """Strategy: acquire one dependency and prepare it for integration. - - Subclasses encapsulate source-specific concerns (filesystem copy, - cache reuse, fresh download with progress + hash verification). - The post-acquire template flow is the same for every source. - """ - - INTEGRATE_ERROR_PREFIX: str = "Failed to integrate primitives" - """Per-source error wording used by the integration template when - ``integrate_package_primitives`` raises. Subclasses override to - preserve the legacy diagnostic text shown to users.""" - - def __init__( - self, - ctx: InstallContext, - dep_ref: Any, - install_path: Path, - dep_key: str, - ): - self.ctx = ctx - self.dep_ref = dep_ref - self.install_path = install_path - self.dep_key = dep_key - - @abstractmethod - def acquire(self) -> Materialization | None: - """Materialise the dependency on disk and build PackageInfo. - - Returns ``None`` to skip integration entirely (e.g. local dep at - user scope, copy/download failure). Otherwise returns a - ``Materialization`` consumed by the integration template. - """ - - class LocalDependencySource(DependencySource): """Local (``file://``) dependency: copy from a filesystem path.""" @@ -276,10 +215,7 @@ def acquire(self) -> Materialization | None: normalize_plugin_directory(install_path, plugin_json_path) # Record for lockfile - node = ctx.dependency_graph.dependency_tree.get_node(dep_key) - depth = node.depth if node else 1 - resolved_by = node.parent.dependency_ref.repo_url if node and node.parent else None - _is_dev = node.is_dev if node else False + depth, resolved_by, _is_dev = self._lockfile_node_fields() ctx.installed_packages.append( InstalledPackage( dep_ref=dep_ref, @@ -427,12 +363,7 @@ def acquire(self) -> Materialization | None: # In lockfile_only mode, skip this early return so installed_packages # is populated before we return without deploying any files. if not ctx.targets and not ctx.lockfile_only: - return Materialization( - package_info=None, - install_path=install_path, - dep_key=dep_key, - deltas=deltas, - ) + return self._skip_integration(deltas) # Load package from apm.yml. Anchor source_path on the clone location # so transitive ``local_path`` deps inside this remote package resolve @@ -475,10 +406,7 @@ def acquire(self) -> Materialization | None: cached_package_info.package_type = pkg_type # Collect for lockfile - node = ctx.dependency_graph.dependency_tree.get_node(dep_key) - depth = node.depth if node else 1 - resolved_by = node.parent.dependency_ref.repo_url if node and node.parent else None - _is_dev = node.is_dev if node else False + depth, resolved_by, _is_dev = self._lockfile_node_fields() # Determine commit SHA for the cached path. See _resolve_cached_commit # for the invariant ("recorded SHA must match disk identity") and the @@ -533,12 +461,7 @@ def acquire(self) -> Materialization | None: # Return without deploying integration files when the target set is empty. if not ctx.targets: - return Materialization( - package_info=None, - install_path=install_path, - dep_key=dep_key, - deltas=deltas, - ) + return self._skip_integration(deltas) return Materialization( package_info=cached_package_info, @@ -548,296 +471,6 @@ def acquire(self) -> Materialization | None: ) -class FreshDependencySource(DependencySource): - """Fresh dependency: needs a network download. - - Performs supply-chain hash verification (#763) and, on mismatch, - aborts the entire process via ``sys.exit(1)`` -- this matches the - legacy behaviour because content drift from the lockfile is treated - as a possible tampering event. - """ - - # Inherits the default "Failed to integrate primitives" prefix. - - def __init__( - self, - ctx: InstallContext, - dep_ref: Any, - install_path: Path, - dep_key: str, - resolved_ref: Any, - dep_locked_chk: Any, - ref_changed: bool, - progress: Any = None, - ): - super().__init__(ctx, dep_ref, install_path, dep_key) - self.resolved_ref = resolved_ref - self.dep_locked_chk = dep_locked_chk - self.ref_changed = ref_changed - self.progress = progress - - def acquire(self) -> Materialization | None: - from apm_cli.deps.installed_package import InstalledPackage - from apm_cli.drift import build_download_ref - from apm_cli.utils.content_hash import compute_package_hash as _compute_hash - from apm_cli.utils.path_security import safe_rmtree - - ctx = self.ctx - dep_ref = self.dep_ref - install_path = self.install_path - dep_key = self.dep_key - dep_locked_chk = self.dep_locked_chk - ref_changed = self.ref_changed - progress = self.progress - diagnostics = ctx.diagnostics - logger = ctx.logger - - try: - display_name = str(dep_ref) if dep_ref.is_virtual else dep_ref.repo_url - short_name = display_name.split("/")[-1] if "/" in display_name else display_name - - # Workstream B (#1116): per-dep progress is owned by the - # shared InstallTui ``ctx.tui``; legacy local Progress is - # only wired when integrate is invoked outside the install - # pipeline (no callers do this today, but the parameter is - # kept for back-compat). - task_id = None - if progress is not None: - task_id = progress.add_task( - description=f"Fetching {short_name}", - total=None, - ) - if ctx.tui is not None: - ctx.tui.task_started(dep_key, f"fetch {short_name}") - - download_ref = build_download_ref( - dep_ref, - ctx.existing_lockfile, - update_refs=ctx.update_refs, - ref_changed=ref_changed, - ) - - if dep_key in ctx.pre_download_results: - package_info = ctx.pre_download_results[dep_key] - elif dep_ref.source == "registry": - from apm_cli.deps.registry.feature_gate import ( - require_package_registry_enabled, - ) - - require_package_registry_enabled("Registry-sourced downloads") - - # Registry-sourced dep: dispatch to the dedicated-registry - # resolver instead of the GitHub downloader. This branch - # fires when (a) the BFS callback skipped due to existing - # install path on a re-install, or (b) parallel pre-download - # was skipped (registry deps aren't pre-downloaded). - _registry_resolver = get_registry_resolver(ctx) - if _registry_resolver is None: - raise RuntimeError( - f"dep {dep_ref.repo_url!r} is registry-sourced but " - f"no registry resolver was constructed (apm.yml may " - f"be missing a 'registries:' block)." - ) - # Lockfile re-install path: registry_name might be absent — - # look it up from the lockfile's resolved_url. - from apm_cli.deps.registry.auth import ( - dependency_ref_with_registry_name_from_lockfile, - ) - - _regs = getattr(ctx.apm_package, "registries", None) or {} - download_ref = dependency_ref_with_registry_name_from_lockfile( - download_ref, - _regs, - locked_dep=dep_locked_chk, - ) - # Lockfile replay (npm install model): fetch directly from the - # locked URL and verify against the locked hash when available - # and the manifest range still covers the locked version. - if ( - not ctx.update_refs - and dep_locked_chk - and dep_locked_chk.resolved_url - and dep_locked_chk.resolved_hash - and dep_locked_chk.version - and not ref_changed - ): - package_info = _registry_resolver.download_from_lockfile( - download_ref, - install_path, - resolved_url=dep_locked_chk.resolved_url, - resolved_hash=dep_locked_chk.resolved_hash, - version=dep_locked_chk.version, - ) - else: - package_info = _registry_resolver.download_package( - download_ref, - install_path, - ) - else: - package_info = ctx.downloader.download_package( - download_ref, - install_path, - progress_task_id=task_id, - progress_obj=progress, - ) - - # CRITICAL: hide progress BEFORE printing success to avoid overlap - if progress is not None and task_id is not None: - progress.update(task_id, visible=False) - progress.refresh() - if ctx.tui is not None: - ctx.tui.task_completed(dep_key) - - deltas: dict[str, int] = {"installed": 1} - - resolved = getattr(package_info, "resolved_reference", None) - if logger: - _ref = "" - _sha = "" - if resolved: - _ref = resolved.ref_name if resolved.ref_name else "" - # F3 (#1116): centralised hex/sentinel-aware short SHA helper. - _sha = format_short_sha(resolved.resolved_commit) - logger.download_complete(display_name, ref=_ref, sha=_sha) - # Only emit the per-package git auth diagnostic for git deps. - # Registry-sourced deps don't talk to git hosts; resolving - # github.com auth here for them is misleading (and can issue - # network calls via auth.AuthResolver providers). - if ctx.auth_resolver and dep_ref.source in (None, "git"): - try: - _host = dep_ref.host or "github.com" - _org = ( - dep_ref.repo_url.split("/")[0] - if dep_ref.repo_url and "/" in dep_ref.repo_url - else None - ) - _ctx = ctx.auth_resolver.resolve(_host, org=_org, port=dep_ref.port) - logger.package_auth(_ctx.source, _ctx.token_type or "none") - except Exception: - pass - else: - _ref_suffix = "" - if resolved: - _r = resolved.ref_name if resolved.ref_name else "" - _s = format_short_sha(resolved.resolved_commit) - if _r and _s: - _ref_suffix = f" #{_r} @{_s}" - elif _r: - _ref_suffix = f" #{_r}" - elif _s: - _ref_suffix = f" @{_s}" - _rich_success(f"[+] {display_name}{_ref_suffix}") - - if not dep_ref.reference: - deltas["unpinned"] = 1 - - # Lockfile bookkeeping - resolved_commit = None - if resolved: - resolved_commit = package_info.resolved_reference.resolved_commit - node = ctx.dependency_graph.dependency_tree.get_node(dep_key) - depth = node.depth if node else 1 - resolved_by = node.parent.dependency_ref.repo_url if node and node.parent else None - _is_dev = node.is_dev if node else False - # Registry-sourced deps: pull the captured resolution out of - # the resolver's per-graph map so the lockfile records - # resolved_url + resolved_hash + version (design §6.1). - _registry_resolution = ( - resolver_last_registry_resolution(ctx, dep_key) - if dep_ref.source == "registry" - else None - ) - # Git-source semver-range deps (#1488): the resolution was - # captured by the BFS download_callback in phases/resolve.py. - _git_semver_resolution = ctx.git_semver_resolutions.get(dep_key) - ctx.installed_packages.append( - InstalledPackage( - dep_ref=dep_ref, - resolved_commit=resolved_commit, - depth=depth, - resolved_by=resolved_by, - is_dev=_is_dev, - registry_config=(ctx.registry_config if not dep_ref.is_local else None), - registry_resolution=_registry_resolution, - git_semver_resolution=_git_semver_resolution, - ) - ) - if install_path.is_dir(): - ctx.package_hashes[dep_key] = _compute_hash(install_path) - - # Supply-chain protection: verify content hash on fresh - # downloads when the lockfile already records a hash. - # Skip when ``ctx.expected_hash_change_deps`` marks this dep - # (set by resolve.py's BFS callback and _resolve_download_strategy - # when branch-ref drift or the v<=0.12.2 self-heal forces a - # re-download whose hash is legitimately expected to differ from - # the lockfile record). - # Thread-safety: resolve phase completes before integrate runs, - # so the set is stable here. integrate.py's own .add() is - # idempotent (set semantics) and runs single-threaded. - _expected_hash_deps = ctx.expected_hash_change_deps - if ( - not ctx.update_refs - and dep_key not in _expected_hash_deps - and dep_locked_chk - and dep_locked_chk.content_hash - and dep_key in ctx.package_hashes - ): - _fresh_hash = ctx.package_hashes[dep_key] - if _fresh_hash != dep_locked_chk.content_hash: - safe_rmtree(install_path, ctx.apm_modules_dir) - _rich_error( - f"Content hash mismatch for " - f"{dep_key}: " - f"expected {dep_locked_chk.content_hash}, " - f"got {_fresh_hash}. " - "The downloaded content differs from the " - "lockfile record. This may indicate a " - "supply-chain attack. Use 'apm install " - "--update' to accept new content and " - "update the lockfile." - ) - sys.exit(1) - - if hasattr(package_info, "package_type") and package_info.package_type: - ctx.package_types[dep_key] = package_info.package_type.value - - if hasattr(package_info, "package_type"): - package_type = package_info.package_type - _type_label = _format_package_type_label(package_type) - if _type_label and logger: - logger.package_type_info(_type_label) - - # If no targets, skip integration but keep deltas - if not ctx.targets: - return Materialization( - package_info=None, - install_path=install_path, - dep_key=dep_key, - deltas=deltas, - ) - - return Materialization( - package_info=package_info, - install_path=package_info.install_path, - dep_key=dep_key, - deltas=deltas, - ) - - except Exception as e: - display_name = str(dep_ref) if dep_ref.is_virtual else dep_ref.repo_url - # task_id may not exist if progress.add_task failed; guard it. - try: # noqa: SIM105 - progress.remove_task(task_id) # type: ignore[name-defined] - except Exception: - pass - diagnostics.error( - f"Failed to install {display_name}: {e}", - package=dep_key, - ) - return None - - def make_dependency_source( ctx: InstallContext, dep_ref: Any, diff --git a/src/apm_cli/install/sources_base.py b/src/apm_cli/install/sources_base.py new file mode 100644 index 000000000..69e0893db --- /dev/null +++ b/src/apm_cli/install/sources_base.py @@ -0,0 +1,101 @@ +"""Base types for the install dependency-source strategy. + +Holds the pieces shared by every concrete ``DependencySource`` so the +source modules (``sources``, ``sources_fresh``) can depend on them without +importing each other: + +- :class:`Materialization` -- the value object returned by ``acquire()``. +- :class:`DependencySource` -- the strategy base class, including the + helpers that fold the per-source lockfile bookkeeping (which was + previously copy-pasted into each ``acquire()``). + +See ``apm_cli.install.sources`` for the module-level overview of the +strategy flow. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from apm_cli.install.context import InstallContext + from apm_cli.models.apm_package import PackageInfo + + +@dataclass +class Materialization: + """Outcome of ``DependencySource.acquire()``. + + Carries everything the integration template needs to run the security + gate + primitive integration on a freshly-acquired package. + """ + + package_info: PackageInfo | None + install_path: Path + dep_key: str + deltas: dict[str, int] = field(default_factory=lambda: {"installed": 1}) + + +class DependencySource(ABC): + """Strategy: acquire one dependency and prepare it for integration. + + Subclasses encapsulate source-specific concerns (filesystem copy, + cache reuse, fresh download with progress + hash verification). + The post-acquire template flow is the same for every source. + """ + + INTEGRATE_ERROR_PREFIX: str = "Failed to integrate primitives" + """Per-source error wording used by the integration template when + ``integrate_package_primitives`` raises. Subclasses override to + preserve the legacy diagnostic text shown to users.""" + + def __init__( + self, + ctx: InstallContext, + dep_ref: Any, + install_path: Path, + dep_key: str, + ): + self.ctx = ctx + self.dep_ref = dep_ref + self.install_path = install_path + self.dep_key = dep_key + + @abstractmethod + def acquire(self) -> Materialization | None: + """Materialise the dependency on disk and build PackageInfo. + + Returns ``None`` to skip integration entirely (e.g. local dep at + user scope, copy/download failure). Otherwise returns a + ``Materialization`` consumed by the integration template. + """ + + def _lockfile_node_fields(self) -> tuple[int, str | None, bool]: + """Return ``(depth, resolved_by, is_dev)`` for this dep's lockfile entry. + + Shared by every source: looks up the dependency-tree node and reads + the three fields the ``InstalledPackage`` record needs, defaulting + gracefully when the node is absent (depth 1, no parent, not dev). + """ + node = self.ctx.dependency_graph.dependency_tree.get_node(self.dep_key) + depth = node.depth if node else 1 + resolved_by = node.parent.dependency_ref.repo_url if node and node.parent else None + is_dev = node.is_dev if node else False + return depth, resolved_by, is_dev + + def _skip_integration(self, deltas: dict[str, int]) -> Materialization: + """Return a ``Materialization`` that signals 'skip integration'. + + Used when the target set is empty: the package is recorded in + lockfile-bound state but no files are deployed. ``package_info=None`` + is the agreed signal to the template to skip the integration pass. + """ + return Materialization( + package_info=None, + install_path=self.install_path, + dep_key=self.dep_key, + deltas=deltas, + ) diff --git a/src/apm_cli/install/sources_fresh.py b/src/apm_cli/install/sources_fresh.py new file mode 100644 index 000000000..75c487e24 --- /dev/null +++ b/src/apm_cli/install/sources_fresh.py @@ -0,0 +1,333 @@ +"""Fresh-download dependency source for the install pipeline. + +Split out of ``apm_cli.install.sources`` to keep that module under the +file-length budget. ``FreshDependencySource`` is the network path: it +downloads a dependency that is not already cached, runs supply-chain hash +verification against the lockfile, and records the install for lockfile +write-back. + +The public class is re-exported from ``apm_cli.install.sources`` so existing +``from apm_cli.install.sources import FreshDependencySource`` imports keep +working. +""" + +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any + +from apm_cli.install.registry_wiring import ( + get_registry_resolver, + resolver_last_registry_resolution, +) +from apm_cli.install.sources_base import DependencySource, Materialization +from apm_cli.utils.console import _rich_error, _rich_success +from apm_cli.utils.short_sha import format_short_sha + +if TYPE_CHECKING: + from pathlib import Path + + from apm_cli.install.context import InstallContext + + +def _format_package_type_label(pkg_type) -> str | None: + """Human-readable label for a detected ``PackageType``. + + Centralised so every install path emits the same wording and so + new ``PackageType`` values can be added without grepping for ad-hoc + dicts. Missing ``HOOK_PACKAGE`` from this table is what made + microsoft/apm#780 silent -- keep all classifiable enum members + covered. + """ + from apm_cli.models.apm_package import PackageType + + return { + PackageType.CLAUDE_SKILL: "Skill (SKILL.md detected)", + PackageType.MARKETPLACE_PLUGIN: "Marketplace Plugin (plugin.json or agents/skills/commands)", + PackageType.HYBRID: "Hybrid (apm.yml + SKILL.md)", + PackageType.APM_PACKAGE: "APM Package (apm.yml)", + PackageType.HOOK_PACKAGE: "Hook Package (hooks/*.json only)", + PackageType.SKILL_BUNDLE: "Skill Bundle (skills//SKILL.md)", + }.get(pkg_type) + + +class FreshDependencySource(DependencySource): + """Fresh dependency: needs a network download. + + Performs supply-chain hash verification (#763) and, on mismatch, + aborts the entire process via ``sys.exit(1)`` -- this matches the + legacy behaviour because content drift from the lockfile is treated + as a possible tampering event. + """ + + # Inherits the default "Failed to integrate primitives" prefix. + + def __init__( + self, + ctx: InstallContext, + dep_ref: Any, + install_path: Path, + dep_key: str, + resolved_ref: Any, + dep_locked_chk: Any, + ref_changed: bool, + progress: Any = None, + ): + super().__init__(ctx, dep_ref, install_path, dep_key) + self.resolved_ref = resolved_ref + self.dep_locked_chk = dep_locked_chk + self.ref_changed = ref_changed + self.progress = progress + + def acquire(self) -> Materialization | None: + from apm_cli.deps.installed_package import InstalledPackage + from apm_cli.drift import build_download_ref + from apm_cli.utils.content_hash import compute_package_hash as _compute_hash + from apm_cli.utils.path_security import safe_rmtree + + ctx = self.ctx + dep_ref = self.dep_ref + install_path = self.install_path + dep_key = self.dep_key + dep_locked_chk = self.dep_locked_chk + ref_changed = self.ref_changed + progress = self.progress + diagnostics = ctx.diagnostics + logger = ctx.logger + + try: + display_name = str(dep_ref) if dep_ref.is_virtual else dep_ref.repo_url + short_name = display_name.split("/")[-1] if "/" in display_name else display_name + + # Workstream B (#1116): per-dep progress is owned by the + # shared InstallTui ``ctx.tui``; legacy local Progress is + # only wired when integrate is invoked outside the install + # pipeline (no callers do this today, but the parameter is + # kept for back-compat). + task_id = None + if progress is not None: + task_id = progress.add_task( + description=f"Fetching {short_name}", + total=None, + ) + if ctx.tui is not None: + ctx.tui.task_started(dep_key, f"fetch {short_name}") + + download_ref = build_download_ref( + dep_ref, + ctx.existing_lockfile, + update_refs=ctx.update_refs, + ref_changed=ref_changed, + ) + + if dep_key in ctx.pre_download_results: + package_info = ctx.pre_download_results[dep_key] + elif dep_ref.source == "registry": + from apm_cli.deps.registry.feature_gate import ( + require_package_registry_enabled, + ) + + require_package_registry_enabled("Registry-sourced downloads") + + # Registry-sourced dep: dispatch to the dedicated-registry + # resolver instead of the GitHub downloader. This branch + # fires when (a) the BFS callback skipped due to existing + # install path on a re-install, or (b) parallel pre-download + # was skipped (registry deps aren't pre-downloaded). + _registry_resolver = get_registry_resolver(ctx) + if _registry_resolver is None: + raise RuntimeError( + f"dep {dep_ref.repo_url!r} is registry-sourced but " + f"no registry resolver was constructed (apm.yml may " + f"be missing a 'registries:' block)." + ) + # Lockfile re-install path: registry_name might be absent -- + # look it up from the lockfile's resolved_url. + from apm_cli.deps.registry.auth import ( + dependency_ref_with_registry_name_from_lockfile, + ) + + _regs = getattr(ctx.apm_package, "registries", None) or {} + download_ref = dependency_ref_with_registry_name_from_lockfile( + download_ref, + _regs, + locked_dep=dep_locked_chk, + ) + # Lockfile replay (npm install model): fetch directly from the + # locked URL and verify against the locked hash when available + # and the manifest range still covers the locked version. + if ( + not ctx.update_refs + and dep_locked_chk + and dep_locked_chk.resolved_url + and dep_locked_chk.resolved_hash + and dep_locked_chk.version + and not ref_changed + ): + package_info = _registry_resolver.download_from_lockfile( + download_ref, + install_path, + resolved_url=dep_locked_chk.resolved_url, + resolved_hash=dep_locked_chk.resolved_hash, + version=dep_locked_chk.version, + ) + else: + package_info = _registry_resolver.download_package( + download_ref, + install_path, + ) + else: + package_info = ctx.downloader.download_package( + download_ref, + install_path, + progress_task_id=task_id, + progress_obj=progress, + ) + + # CRITICAL: hide progress BEFORE printing success to avoid overlap + if progress is not None and task_id is not None: + progress.update(task_id, visible=False) + progress.refresh() + if ctx.tui is not None: + ctx.tui.task_completed(dep_key) + + deltas: dict[str, int] = {"installed": 1} + + resolved = getattr(package_info, "resolved_reference", None) + if logger: + _ref = "" + _sha = "" + if resolved: + _ref = resolved.ref_name if resolved.ref_name else "" + # F3 (#1116): centralised hex/sentinel-aware short SHA helper. + _sha = format_short_sha(resolved.resolved_commit) + logger.download_complete(display_name, ref=_ref, sha=_sha) + # Only emit the per-package git auth diagnostic for git deps. + # Registry-sourced deps don't talk to git hosts; resolving + # github.com auth here for them is misleading (and can issue + # network calls via auth.AuthResolver providers). + if ctx.auth_resolver and dep_ref.source in (None, "git"): + try: + _host = dep_ref.host or "github.com" + _org = ( + dep_ref.repo_url.split("/")[0] + if dep_ref.repo_url and "/" in dep_ref.repo_url + else None + ) + _ctx = ctx.auth_resolver.resolve(_host, org=_org, port=dep_ref.port) + logger.package_auth(_ctx.source, _ctx.token_type or "none") + except Exception: + pass + else: + _ref_suffix = "" + if resolved: + _r = resolved.ref_name if resolved.ref_name else "" + _s = format_short_sha(resolved.resolved_commit) + if _r and _s: + _ref_suffix = f" #{_r} @{_s}" + elif _r: + _ref_suffix = f" #{_r}" + elif _s: + _ref_suffix = f" @{_s}" + _rich_success(f"[+] {display_name}{_ref_suffix}") + + if not dep_ref.reference: + deltas["unpinned"] = 1 + + # Lockfile bookkeeping + resolved_commit = None + if resolved: + resolved_commit = package_info.resolved_reference.resolved_commit + depth, resolved_by, _is_dev = self._lockfile_node_fields() + # Registry-sourced deps: pull the captured resolution out of + # the resolver's per-graph map so the lockfile records + # resolved_url + resolved_hash + version (design 6.1). + _registry_resolution = ( + resolver_last_registry_resolution(ctx, dep_key) + if dep_ref.source == "registry" + else None + ) + # Git-source semver-range deps (#1488): the resolution was + # captured by the BFS download_callback in phases/resolve.py. + _git_semver_resolution = ctx.git_semver_resolutions.get(dep_key) + ctx.installed_packages.append( + InstalledPackage( + dep_ref=dep_ref, + resolved_commit=resolved_commit, + depth=depth, + resolved_by=resolved_by, + is_dev=_is_dev, + registry_config=(ctx.registry_config if not dep_ref.is_local else None), + registry_resolution=_registry_resolution, + git_semver_resolution=_git_semver_resolution, + ) + ) + if install_path.is_dir(): + ctx.package_hashes[dep_key] = _compute_hash(install_path) + + # Supply-chain protection: verify content hash on fresh + # downloads when the lockfile already records a hash. + # Skip when ``ctx.expected_hash_change_deps`` marks this dep + # (set by resolve.py's BFS callback and _resolve_download_strategy + # when branch-ref drift or the v<=0.12.2 self-heal forces a + # re-download whose hash is legitimately expected to differ from + # the lockfile record). + # Thread-safety: resolve phase completes before integrate runs, + # so the set is stable here. integrate.py's own .add() is + # idempotent (set semantics) and runs single-threaded. + _expected_hash_deps = ctx.expected_hash_change_deps + if ( + not ctx.update_refs + and dep_key not in _expected_hash_deps + and dep_locked_chk + and dep_locked_chk.content_hash + and dep_key in ctx.package_hashes + ): + _fresh_hash = ctx.package_hashes[dep_key] + if _fresh_hash != dep_locked_chk.content_hash: + safe_rmtree(install_path, ctx.apm_modules_dir) + _rich_error( + f"Content hash mismatch for " + f"{dep_key}: " + f"expected {dep_locked_chk.content_hash}, " + f"got {_fresh_hash}. " + "The downloaded content differs from the " + "lockfile record. This may indicate a " + "supply-chain attack. Use 'apm install " + "--update' to accept new content and " + "update the lockfile." + ) + sys.exit(1) + + if hasattr(package_info, "package_type") and package_info.package_type: + ctx.package_types[dep_key] = package_info.package_type.value + + if hasattr(package_info, "package_type"): + package_type = package_info.package_type + _type_label = _format_package_type_label(package_type) + if _type_label and logger: + logger.package_type_info(_type_label) + + # If no targets, skip integration but keep deltas + if not ctx.targets: + return self._skip_integration(deltas) + + return Materialization( + package_info=package_info, + install_path=package_info.install_path, + dep_key=dep_key, + deltas=deltas, + ) + + except Exception as e: + display_name = str(dep_ref) if dep_ref.is_virtual else dep_ref.repo_url + # task_id may not exist if progress.add_task failed; guard it. + try: # noqa: SIM105 + progress.remove_task(task_id) # type: ignore[name-defined] + except Exception: + pass + diagnostics.error( + f"Failed to install {display_name}: {e}", + package=dep_key, + ) + return None From 548e9e91eaeaf43e9c81fb06f2c6eea4506072bd Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 01:56:37 +0200 Subject: [PATCH 06/21] refactor(install): decompose targets.run() to cut complexity (#1078) Stage 2 complexity tightening for the install targets phase. targets.py run() C901 45 -> ~8, PLR0915 143 -> ~35, PLR0912 50 -> ~7, all well under the Stage 2 thresholds (35/120/40). File 576 -> 585. Extracted five focused module-level helpers from run(): - _gate_cowork_target / _gate_copilot_app_target (kept separate: cowork has Linux-vs-other messaging + a project-scope gate copilot-app lacks) - _resolve_project_scope_targets (v2 resolution block) - _handle_user_scope_targets (user-scope logging + dir creation) - _build_integrators (integrators dict) run() is now a thin orchestrator. No behavioural change: call sites, argument order, log messages, exit codes, and exception types preserved. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/install/phases/targets.py | 597 +++++++++++++------------- 1 file changed, 303 insertions(+), 294 deletions(-) diff --git a/src/apm_cli/install/phases/targets.py b/src/apm_cli/install/phases/targets.py index 4956041b9..cca70b2e5 100644 --- a/src/apm_cli/install/phases/targets.py +++ b/src/apm_cli/install/phases/targets.py @@ -134,92 +134,29 @@ def _create_target_dirs( return created -def run(ctx: InstallContext) -> None: - """Execute the targets phase. - - On return ``ctx.targets`` and ``ctx.integrators`` are populated. +def _gate_cowork_target(ctx, explicit, targets, is_user) -> None: + """Gate the copilot-cowork experimental target. + + Fix 2: explicit --target copilot-cowork with flag OFF must hint at the + experimental enable command. + Fix 3: explicit --target copilot-cowork with flag ON but unresolvable + OneDrive must error. + Only fires when the user explicitly asked for cowork. Auto-detect + silently omits cowork when unavailable. + + Amendment 5: project-scope gate -- ``--target copilot-cowork`` without + ``--global`` is an error (cowork is user-scope only). Aborts before any + filesystem activity. """ - - import click as _click - - from apm_cli.core.scope import InstallScope - from apm_cli.core.target_detection import ( - detect_target, - format_provenance, - ) - from apm_cli.core.target_detection import ( - resolve_targets as _resolve_targets_v2, - ) - from apm_cli.integration import AgentIntegrator, PromptIntegrator - from apm_cli.integration.command_integrator import CommandIntegrator - from apm_cli.integration.copilot_cowork_paths import CoworkResolutionError - from apm_cli.integration.hook_integrator import HookIntegrator - from apm_cli.integration.instruction_integrator import InstructionIntegrator - from apm_cli.integration.skill_integrator import SkillIntegrator - from apm_cli.integration.targets import ( - KNOWN_TARGETS, - ) - from apm_cli.integration.targets import ( - resolve_targets as _resolve_targets_legacy, - ) - - # Get config target from apm.yml if available. - try: - config_target = _package_target_value(ctx.apm_package) - except _click.UsageError as exc: - _raise_target_usage_error(ctx, exc) - - # Resolve effective explicit target: CLI --target wins, then apm.yml - _explicit = ctx.target_override or config_target or None - - # ------------------------------------------------------------------ - # Deprecation warning for legacy '--target agents' alias (cli-review §1) - # Driven by the raw-token flag set in parse_target_field() so that - # multi-token inputs like "--target copilot,agents" still surface the - # warning even after alias resolution collapses "agents" away. - # ------------------------------------------------------------------ - from apm_cli.core.target_detection import agents_alias_was_detected - - if agents_alias_was_detected(): - if ctx.logger: - ctx.logger.warning( - "'--target agents' is deprecated -- it maps to 'copilot' (.github/), " - "not '.agents/'. Use '--target copilot' or '--target agent-skills' " - "(.agents/skills/). Removal in v1.0." - ) - - _is_user = ctx.scope is InstallScope.USER - - # Determine active targets using the legacy resolver first. - # This preserves backward compatibility (cowork, user-scope, etc.) - # while v2 adds provenance and stricter error checking. - try: - _targets = _resolve_targets_legacy( - ctx.project_root, - user_scope=_is_user, - explicit_target=_explicit, - ) - except CoworkResolutionError as exc: - if ctx.logger: - ctx.logger.error(str(exc), symbol="cross") - raise SystemExit(1) from exc - - # ------------------------------------------------------------------ - # Fix 2: explicit --target copilot-cowork with flag OFF must error. - # Fix 3: explicit --target copilot-cowork with flag ON but unresolvable - # OneDrive must error. - # Only fire when the user explicitly asked for cowork. Auto-detect - # silently omits cowork when unavailable. - # ------------------------------------------------------------------ _user_asked_cowork = False - if _explicit: - if isinstance(_explicit, list): - _user_asked_cowork = "copilot-cowork" in _explicit + if explicit: + if isinstance(explicit, list): + _user_asked_cowork = "copilot-cowork" in explicit else: - _user_asked_cowork = _explicit == "copilot-cowork" + _user_asked_cowork = explicit == "copilot-cowork" if _user_asked_cowork: - _cowork_resolved = any(t.name == "copilot-cowork" for t in _targets) + _cowork_resolved = any(t.name == "copilot-cowork" for t in targets) if not _cowork_resolved: from apm_cli.core.experimental import is_enabled as _is_flag_on @@ -249,13 +186,9 @@ def run(ctx: InstallContext) -> None: ctx.logger.error(_cowork_msg, symbol="cross") raise SystemExit(1) - # ------------------------------------------------------------------ - # Amendment 5: project-scope gate for cowork target. - # `--target copilot-cowork` without `--global` is an error -- cowork is - # user-scope only. Abort before any filesystem activity. - # ------------------------------------------------------------------ - if not _is_user: - _cowork_in_set = any(t.name == "copilot-cowork" for t in _targets) + # Project-scope gate: cowork is user-scope only. + if not is_user: + _cowork_in_set = any(t.name == "copilot-cowork" for t in targets) if _cowork_in_set: if ctx.logger: ctx.logger.error( @@ -264,216 +197,302 @@ def run(ctx: InstallContext) -> None: ) raise SystemExit(1) - # ------------------------------------------------------------------ - # GitHub Copilot App target gating (mirrors cowork rules above): - # explicit --target copilot-app with flag OFF must hint at the - # experimental enable command; with flag ON but no ~/.copilot/data.db - # must error with an actionable install instruction; without --global - # must error because copilot-app is user-scope only. - # ------------------------------------------------------------------ + +def _gate_copilot_app_target(ctx, explicit, targets) -> None: + """Gate the copilot-app experimental target. + + GitHub Copilot App target gating (mirrors cowork rules): + explicit --target copilot-app with flag OFF must hint at the + experimental enable command; with flag ON but no ~/.copilot/data.db + must error with an actionable install instruction. + + Note: copilot-app intentionally has no project-scope gate. The DB at + ~/.copilot/data.db is a single user-scoped resource, but the *intent* to + deploy can legitimately come from a project's apm.yml (a team-shared + scheduled prompt belongs in the project that owns the prompt, not in + every developer's user-scope manifest). The experimental flag + (machine-level opt-in) is the consent envelope; the package-namespaced + row id (apm------) prevents collisions across + projects sharing the same package. Rows always arrive enabled=0; users + grant the second consent in the App's Workflows tab before anything runs + on a schedule. + """ _user_asked_copilot_app = False - if _explicit: - if isinstance(_explicit, list): - _user_asked_copilot_app = "copilot-app" in _explicit + if explicit: + if isinstance(explicit, list): + _user_asked_copilot_app = "copilot-app" in explicit else: - _user_asked_copilot_app = _explicit == "copilot-app" + _user_asked_copilot_app = explicit == "copilot-app" - if _user_asked_copilot_app: - _copilot_app_resolved = any(t.name == "copilot-app" for t in _targets) - if not _copilot_app_resolved: - from apm_cli.core.experimental import is_enabled as _is_flag_on + if not _user_asked_copilot_app: + return - if not _is_flag_on("copilot_app"): - if ctx.logger: - ctx.logger.progress( - "The 'copilot-app' target requires an experimental flag. " - "Run: apm experimental enable copilot-app", - symbol="info", - ) - else: - _app_msg = ( - "GitHub Copilot desktop App not detected.\n" - "Expected ~/.copilot/data.db but the file is missing.\n" - "Install the app, or omit '--target copilot-app'." + _copilot_app_resolved = any(t.name == "copilot-app" for t in targets) + if not _copilot_app_resolved: + from apm_cli.core.experimental import is_enabled as _is_flag_on + + if not _is_flag_on("copilot_app"): + if ctx.logger: + ctx.logger.progress( + "The 'copilot-app' target requires an experimental flag. " + "Run: apm experimental enable copilot-app", + symbol="info", ) - if ctx.logger: - ctx.logger.error(_app_msg, symbol="cross") - raise SystemExit(1) + else: + _app_msg = ( + "GitHub Copilot desktop App not detected.\n" + "Expected ~/.copilot/data.db but the file is missing.\n" + "Install the app, or omit '--target copilot-app'." + ) + if ctx.logger: + ctx.logger.error(_app_msg, symbol="cross") + raise SystemExit(1) - # NOTE: copilot-app intentionally has no project-scope gate. The DB - # at ~/.copilot/data.db is a single user-scoped resource, but the - # *intent* to deploy can legitimately come from a project's apm.yml - # (a team-shared scheduled prompt belongs in the project that owns - # the prompt, not in every developer's user-scope manifest). The - # experimental flag (machine-level opt-in) is the consent envelope; - # the package-namespaced row id (apm------) - # prevents collisions across projects sharing the same package. - # Rows always arrive enabled=0; users grant the second consent in - # the App's Workflows tab before anything runs on a schedule. - # - # PR A (project-scoping): the integrator now auto-registers a row - # in the App's ``projects`` table for the current repository and - # stamps every workflow with that project_id, so workflows show up - # in the correct project's Workflows tab. On the *first* install - # into a repo, the App's webview does not always live-refresh on - # the externally-inserted ``projects`` row (see github/github-app - # #5483); the integrator emits a one-time restart hint so the user - # is not left wondering why the new project is missing from the UI. - # When the App is running, the integrator prefers the live - # WebSocket-IPC surface so the broadcast fires natively and no - # restart is needed; the SQLite path is the fallback for the - # App-closed case (still the common case during install). - # ------------------------------------------------------------------ - # v2 resolution (#1154): signal-based provenance and strict errors. - # Runs AFTER the legacy resolver and cowork gates so existing - # behavior is preserved. The v2 resolver validates signals and - # emits provenance; its target list REPLACES the legacy list for - # project-scope installs (three-guard collapse). - # ------------------------------------------------------------------ - if not _is_user: - # Build flag from CLI --target override, filtering to canonical - # targets only. Non-canonical targets (copilot-cowork) are handled - # exclusively by the legacy resolver + gates above. - from apm_cli.core.apm_yml import CANONICAL_TARGETS as _CANONICAL - - _v2_flag: str | list[str] | None = None - if ctx.target_override: - raw_override = ctx.target_override - if isinstance(raw_override, str): - parts = [t.strip() for t in raw_override.split(",") if t.strip()] - else: - parts = list(raw_override) - # Keep only canonical targets for v2 - parts = [p for p in parts if p in _CANONICAL] - if len(parts) == 1: - _v2_flag = parts[0] - elif len(parts) > 1: - _v2_flag = parts - # If no canonical targets remain, skip v2 entirely - # (all targets were non-canonical like copilot-cowork) - - # Read targets from apm.yml (supports both target: and targets:) - _v2_yaml: list[str] | None = None - if _v2_flag is None and not ctx.target_override: - try: - _v2_yaml = _read_yaml_targets(ctx) - except _click.UsageError as exc: - # ConflictingTargetsError (both target: and targets: in - # apm.yml) is a user error -- surface with exit code 2. - _raise_target_usage_error(ctx, exc) - - # Skip v2 entirely when all override targets were non-canonical - # (e.g. copilot-cowork only). Those are fully handled by the - # legacy resolver + cowork gates. - _skip_v2 = _v2_flag is None and _v2_yaml is None and ctx.target_override is not None - - if not _skip_v2: - # Resolve: raises click.UsageError on no-harness, ambiguous, - # unknown target, or schema conflict. When the legacy resolver - # already found targets and the v2 auto-detect disagrees (e.g. - # because the legacy fallback-to-copilot is disabled in v2), - # the v2 error takes precedence -- EXCEPT when the legacy - # targets include non-canonical entries (e.g. copilot-cowork) - # that v2 does not handle. - import click as _click +def _resolve_project_scope_targets(ctx, targets, explicit) -> list: + """Resolve project-scope targets using the v2 resolution algorithm (#1154). - try: - _resolved = _resolve_targets_v2( - ctx.project_root, - flag=_v2_flag, - yaml_targets=_v2_yaml, - ) - except _click.UsageError as exc: - # v2 target-resolution errors (NoHarnessError, - # AmbiguousHarnessError, etc.) are intentionally - # STRICTER than the legacy resolver. They always - # take precedence -- the whole point of the overhaul - # is to stop silently falling back to copilot. - # - # The ONLY exception: if ALL legacy targets are - # non-canonical (e.g. copilot-cowork) and v2 was - # invoked without any canonical flag/yaml, the error - # is a false positive because v2 does not handle - # non-canonical targets. That case is already - # guarded by ``_skip_v2`` above, so it never reaches - # this except block. The renderer already emits a - # leading "[x]"; pass an empty symbol so logger.error - # doesn't double-prefix. - if ctx.logger: - ctx.logger.error(str(exc), symbol="") - raise SystemExit(2) from exc - - # Emit provenance BEFORE any mutation. Route via _rich_info so - # the line picks up consistent symbol + color treatment and so - # automated tests can rely on the canonical "[i] Targets: ..." - # rendering (convergence item 1). - from apm_cli.utils.console import _rich_info - - _provenance_msg = format_provenance(_resolved) - _rich_info(_provenance_msg, symbol="info") - - # Map resolved v2 target names to TargetProfile objects, - # materializing deploy directories (three-guard collapse: - # auto_create unconditionally post-resolution). - _v2_targets = [] - for _tname in _resolved.targets: - _profile = KNOWN_TARGETS.get(_tname) - if _profile is None: - continue - _target_dir = ctx.project_root / _profile.root_dir - if not _target_dir.exists(): - try: - _target_dir.mkdir(parents=True, exist_ok=True) - except PermissionError: - if ctx.logger: - ctx.logger.error( - f"Cannot create {_profile.root_dir}/ -- permission denied. " - f"Check directory permissions or use a different --target." - ) - raise SystemExit(1) from None + Runs AFTER the legacy resolver and cowork gates so existing behaviour is + preserved. The v2 resolver validates signals and emits provenance; its + target list REPLACES the legacy list for project-scope installs + (three-guard collapse). + + Returns the updated targets list (v2 targets + legacy-only non-canonical + targets that v2 does not handle, e.g. copilot-cowork). Returns the + original *targets* list unchanged when v2 is skipped entirely (all + override targets were non-canonical). + """ + import click as _click + + from apm_cli.core.apm_yml import CANONICAL_TARGETS as _CANONICAL + from apm_cli.core.target_detection import format_provenance + from apm_cli.core.target_detection import resolve_targets as _resolve_targets_v2 + from apm_cli.integration.targets import KNOWN_TARGETS + from apm_cli.utils.console import _rich_info + + # Build flag from CLI --target override, filtering to canonical + # targets only. Non-canonical targets (copilot-cowork) are handled + # exclusively by the legacy resolver + gates above. + _v2_flag: str | list[str] | None = None + if ctx.target_override: + raw_override = ctx.target_override + if isinstance(raw_override, str): + parts = [t.strip() for t in raw_override.split(",") if t.strip()] + else: + parts = list(raw_override) + # Keep only canonical targets for v2 + parts = [p for p in parts if p in _CANONICAL] + if len(parts) == 1: + _v2_flag = parts[0] + elif len(parts) > 1: + _v2_flag = parts + # If no canonical targets remain, skip v2 entirely + # (all targets were non-canonical like copilot-cowork) + + # Read targets from apm.yml (supports both target: and targets:) + _v2_yaml: list[str] | None = None + if _v2_flag is None and not ctx.target_override: + try: + _v2_yaml = _read_yaml_targets(ctx) + except _click.UsageError as exc: + # ConflictingTargetsError (both target: and targets: in + # apm.yml) is a user error -- surface with exit code 2. + _raise_target_usage_error(ctx, exc) + + # Skip v2 entirely when all override targets were non-canonical + # (e.g. copilot-cowork only). Those are fully handled by the + # legacy resolver + cowork gates. + _skip_v2 = _v2_flag is None and _v2_yaml is None and ctx.target_override is not None + + if not _skip_v2: + # Resolve: raises click.UsageError on no-harness, ambiguous, + # unknown target, or schema conflict. When the legacy resolver + # already found targets and the v2 auto-detect disagrees (e.g. + # because the legacy fallback-to-copilot is disabled in v2), + # the v2 error takes precedence -- EXCEPT when the legacy + # targets include non-canonical entries (e.g. copilot-cowork) + # that v2 does not handle. + try: + _resolved = _resolve_targets_v2( + ctx.project_root, + flag=_v2_flag, + yaml_targets=_v2_yaml, + ) + except _click.UsageError as exc: + # v2 target-resolution errors (NoHarnessError, + # AmbiguousHarnessError, etc.) are intentionally + # STRICTER than the legacy resolver. They always + # take precedence -- the whole point of the overhaul + # is to stop silently falling back to copilot. + # + # The ONLY exception: if ALL legacy targets are + # non-canonical (e.g. copilot-cowork) and v2 was + # invoked without any canonical flag/yaml, the error + # is a false positive because v2 does not handle + # non-canonical targets. That case is already + # guarded by ``_skip_v2`` above, so it never reaches + # this except block. The renderer already emits a + # leading "[x]"; pass an empty symbol so logger.error + # doesn't double-prefix. + if ctx.logger: + ctx.logger.error(str(exc), symbol="") + raise SystemExit(2) from exc + + # Emit provenance BEFORE any mutation. Route via _rich_info so + # the line picks up consistent symbol + color treatment and so + # automated tests can rely on the canonical "[i] Targets: ..." + # rendering (convergence item 1). + _provenance_msg = format_provenance(_resolved) + _rich_info(_provenance_msg, symbol="info") + + # Map resolved v2 target names to TargetProfile objects, + # materializing deploy directories (three-guard collapse: + # auto_create unconditionally post-resolution). + _v2_targets = [] + for _tname in _resolved.targets: + _profile = KNOWN_TARGETS.get(_tname) + if _profile is None: + continue + _target_dir = ctx.project_root / _profile.root_dir + if not _target_dir.exists(): + try: + _target_dir.mkdir(parents=True, exist_ok=True) + except PermissionError: if ctx.logger: - ctx.logger.verbose_detail(f"Created {_profile.root_dir}/ ({_tname} target)") - # NOTE: do NOT set resolved_deploy_root on static targets. - # That field is reserved for dynamic-root targets (cowork) - # and is treated as the final deploy destination by - # skill_integrator and base_integrator. Static targets must - # follow the standard primitive-mapping path so that - # ``deploy_root`` (e.g. .agents) and ``subdir`` (e.g. skills) - # are honored. - _v2_targets.append(_profile) - - # Replace legacy targets with v2 targets for project-scope. - # Keep any legacy-only targets (e.g. copilot-cowork) that v2 - # doesn't handle. - _v2_names = {t.name for t in _v2_targets} - _legacy_only = [ - t for t in _targets if t.name not in _v2_names and t.name not in _CANONICAL - ] - _targets = _v2_targets + _legacy_only + ctx.logger.error( + f"Cannot create {_profile.root_dir}/ -- permission denied. " + f"Check directory permissions or use a different --target." + ) + raise SystemExit(1) from None + if ctx.logger: + ctx.logger.verbose_detail(f"Created {_profile.root_dir}/ ({_tname} target)") + # NOTE: do NOT set resolved_deploy_root on static targets. + # That field is reserved for dynamic-root targets (cowork) + # and is treated as the final deploy destination by + # skill_integrator and base_integrator. Static targets must + # follow the standard primitive-mapping path so that + # ``deploy_root`` (e.g. .agents) and ``subdir`` (e.g. skills) + # are honored. + _v2_targets.append(_profile) + + # Replace legacy targets with v2 targets for project-scope. + # Keep any legacy-only targets (e.g. copilot-cowork) that v2 + # doesn't handle. + _v2_names = {t.name for t in _v2_targets} + _legacy_only = [t for t in targets if t.name not in _v2_names and t.name not in _CANONICAL] + return _v2_targets + _legacy_only + + return targets + + +def _handle_user_scope_targets(ctx, targets, explicit) -> None: + """Log active user-scope targets and materialise their directories. + + Called from ``run()`` when ``ctx.scope`` is ``InstallScope.USER`` + (i.e. ``--global`` was passed). + """ + if ctx.logger: + if targets: - else: - # User-scope: legacy target directory creation and logging. + def _fmt_target(t): + if t.resolved_deploy_root is not None: + return f"{t.name} ({t.resolved_deploy_root})" + return f"{t.name} (~/{t.root_dir}/)" + + _target_names = ", ".join(_fmt_target(t) for t in targets) + ctx.logger.verbose_detail(f"Active global targets: {_target_names}") + from apm_cli.deps.lockfile import get_lockfile_path + + ctx.logger.verbose_detail(f"Lockfile: {get_lockfile_path(ctx.apm_dir)}") + else: + ctx.logger.warning( + "No global targets resolved -- nothing will be " + "deployed. Check 'target:' in apm.yml or use --target." + ) + + _create_target_dirs(targets, ctx.project_root, explicit, ctx.logger) + + +def _build_integrators() -> dict: + """Construct and return the integrators dict used by downstream phases.""" + from apm_cli.integration import AgentIntegrator, PromptIntegrator + from apm_cli.integration.command_integrator import CommandIntegrator + from apm_cli.integration.hook_integrator import HookIntegrator + from apm_cli.integration.instruction_integrator import InstructionIntegrator + from apm_cli.integration.skill_integrator import SkillIntegrator + + return { + "prompt": PromptIntegrator(), + "agent": AgentIntegrator(), + "skill": SkillIntegrator(), + "command": CommandIntegrator(), + "hook": HookIntegrator(), + "instruction": InstructionIntegrator(), + } + + +def run(ctx: InstallContext) -> None: + """Execute the targets phase. + + On return ``ctx.targets`` and ``ctx.integrators`` are populated. + """ + + import click as _click + + from apm_cli.core.scope import InstallScope + from apm_cli.core.target_detection import agents_alias_was_detected, detect_target + from apm_cli.integration.copilot_cowork_paths import CoworkResolutionError + from apm_cli.integration.targets import resolve_targets as _resolve_targets_legacy + + # Get config target from apm.yml if available. + try: + config_target = _package_target_value(ctx.apm_package) + except _click.UsageError as exc: + _raise_target_usage_error(ctx, exc) + + # Resolve effective explicit target: CLI --target wins, then apm.yml + _explicit = ctx.target_override or config_target or None + + # ------------------------------------------------------------------ + # Deprecation warning for legacy '--target agents' alias (cli-review §1) + # Driven by the raw-token flag set in parse_target_field() so that + # multi-token inputs like "--target copilot,agents" still surface the + # warning even after alias resolution collapses "agents" away. + # ------------------------------------------------------------------ + if agents_alias_was_detected(): if ctx.logger: - if _targets: + ctx.logger.warning( + "'--target agents' is deprecated -- it maps to 'copilot' (.github/), " + "not '.agents/'. Use '--target copilot' or '--target agent-skills' " + "(.agents/skills/). Removal in v1.0." + ) - def _fmt_target(t): - if t.resolved_deploy_root is not None: - return f"{t.name} ({t.resolved_deploy_root})" - return f"{t.name} (~/{t.root_dir}/)" + _is_user = ctx.scope is InstallScope.USER - _target_names = ", ".join(_fmt_target(t) for t in _targets) - ctx.logger.verbose_detail(f"Active global targets: {_target_names}") - from apm_cli.deps.lockfile import get_lockfile_path + # Determine active targets using the legacy resolver first. + # This preserves backward compatibility (cowork, user-scope, etc.) + # while v2 adds provenance and stricter error checking. + try: + _targets = _resolve_targets_legacy( + ctx.project_root, + user_scope=_is_user, + explicit_target=_explicit, + ) + except CoworkResolutionError as exc: + if ctx.logger: + ctx.logger.error(str(exc), symbol="cross") + raise SystemExit(1) from exc - ctx.logger.verbose_detail(f"Lockfile: {get_lockfile_path(ctx.apm_dir)}") - else: - ctx.logger.warning( - "No global targets resolved -- nothing will be " - "deployed. Check 'target:' in apm.yml or use --target." - ) + # Experimental target gates. + _gate_cowork_target(ctx, _explicit, _targets, _is_user) + _gate_copilot_app_target(ctx, _explicit, _targets) - _create_target_dirs(_targets, ctx.project_root, _explicit, ctx.logger) + # v2 resolution (#1154): project-scope uses v2 targets; user-scope uses + # legacy target list with directory creation and logging. + if not _is_user: + _targets = _resolve_project_scope_targets(ctx, _targets, _explicit) + else: + _handle_user_scope_targets(ctx, _targets, _explicit) # Legacy detect_target call -- return values are not consumed by any # downstream code but the call is preserved for behaviour parity with @@ -495,18 +514,8 @@ def _fmt_target(t): _targets = apply_legacy_skill_paths(_targets) - # ------------------------------------------------------------------ - # Initialize integrators - # ------------------------------------------------------------------ ctx.targets = _targets - ctx.integrators = { - "prompt": PromptIntegrator(), - "agent": AgentIntegrator(), - "skill": SkillIntegrator(), - "command": CommandIntegrator(), - "hook": HookIntegrator(), - "instruction": InstructionIntegrator(), - } + ctx.integrators = _build_integrators() def run_targets_phase(ctx) -> None: From e2d112f10bc0204e8bac4e3346b74e95a8be875d Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 02:09:21 +0200 Subject: [PATCH 07/21] refactor(install): split pipeline.py preflight + extract orchestration helpers Move `_preflight_auth_check` into a new sibling `pipeline_preflight.py` (re-exported for import compatibility) and extract three cohesive helpers from `run_install_pipeline` -- `_read_early_lockfile_state`, `_resolve_managed_files`, and `_run_skill_path_migration`. This drops `run_install_pipeline` under the Stage 2 complexity thresholds (statements/branches/mccabe) and pipeline.py under 800 lines, with no behaviour change. `_run_phase` and `run_install_pipeline` remain module-level (test monkeypatch surface preserved). Refs #1078 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/install/pipeline.py | 528 +++++++--------------- src/apm_cli/install/pipeline_preflight.py | 240 ++++++++++ 2 files changed, 406 insertions(+), 362 deletions(-) create mode 100644 src/apm_cli/install/pipeline_preflight.py diff --git a/src/apm_cli/install/pipeline.py b/src/apm_cli/install/pipeline.py index b15d098cf..197a15a31 100644 --- a/src/apm_cli/install/pipeline.py +++ b/src/apm_cli/install/pipeline.py @@ -51,6 +51,7 @@ from ..utils.diagnostics import DiagnosticCollector from ..utils.path_security import PathTraversalError from .errors import AuthenticationError, DirectDependencyError +from .pipeline_preflight import _preflight_auth_check if TYPE_CHECKING: from ..core.auth import AuthResolver @@ -92,230 +93,6 @@ def _run_phase(name: str, phase, ctx): logger.verbose_detail(f"Phase: {name} -> {elapsed:.3f}s") -def _preflight_auth_check(ctx, auth_resolver, verbose: bool) -> None: - """Verify auth for every distinct (host, org) before write phases. - - Called only when ``update_refs`` is set, so we know the pipeline is - about to overwrite ``apm.yml``, ``apm.lock.yaml``, and - ``apm_modules/``. A single ``git ls-remote`` per cluster catches - stale tokens before any file is touched. - - For ADO clusters, a stale ``ADO_APM_PAT`` automatically falls back - to an ``az cli`` AAD bearer via :meth:`AuthResolver.execute_with_bearer_fallback` - -- matching the protocol used by the actual clone path. Without this, - ``apm install -g`` (which skipped preflight) would succeed but - ``apm install -g --update`` would fail on the same machine with the - same creds. See #1212. - - For generic hosts, the probe uses the same transport the real clone - would use, mirroring :meth:`TransportSelector.select`: SSH only when - the dep carries an explicit ``ssh://`` scheme; otherwise HTTPS (token - embedded when available, plain HTTPS for anonymous public deps). - SSH failures are detected via :func:`is_ssh_auth_failure_signal`; - HTTPS failures via :func:`is_ado_auth_failure_signal`. - - Raises :class:`AuthenticationError` (with ``build_error_context`` - payload) on the first auth failure that survives the fallback. - """ - import os - import subprocess as _sp - - from ..utils.github_host import ( - is_ado_auth_failure_signal, - is_azure_devops_hostname, - is_github_hostname, - is_ssh_auth_failure_signal, - ) - - logger = getattr(ctx, "logger", None) - - def _trace(line: str) -> None: - """Emit a verbose tracing line; best-effort, never raises.""" - if not verbose or logger is None: - return - with contextlib.suppress(Exception): - logger.verbose_detail(line) - - seen: builtins.set = builtins.set() - for dep in ctx.deps_to_install: - host = dep.host - if not host or is_github_hostname(host): - continue # github.com uses API probe with unauth fallback - org = dep.repo_url.split("/")[0] if dep.repo_url and "/" in dep.repo_url else None - key = (host, org) - if key in seen: - continue - seen.add(key) - - dep_ctx = auth_resolver.resolve_for_dep(dep) - _auth_scheme = getattr(dep_ctx, "auth_scheme", "basic") or "basic" - - from ..deps.github_downloader import GitHubPackageDownloader - - _dl = GitHubPackageDownloader(auth_resolver=auth_resolver) - _dl.github_host = host - is_generic = not is_github_hostname(host) and not is_azure_devops_hostname(host) - - # For generic hosts, mirror TransportSelector.select() when picking - # the probe transport: SSH only when the dep carries an explicit - # ssh:// scheme. Shorthand deps (no explicit scheme) default to - # HTTPS regardless of token presence -- TransportSelector's default - # is plain HTTPS without a token and authenticated HTTPS with one. - # Forcing SSH on tokenless generic hosts would break anonymous - # access to public Gitea/Forgejo deps that have neither an HTTPS - # token nor a configured SSH key. - _explicit_scheme = (getattr(dep, "explicit_scheme", None) or "").lower() - _use_ssh = is_generic and _explicit_scheme == "ssh" - - probe_url = _dl._build_repo_url( - dep.repo_url, - use_ssh=_use_ssh, - dep_ref=dep, - token=dep_ctx.token, - auth_scheme=_auth_scheme, - ) - _ctx_env = getattr(dep_ctx, "git_env", {}) or {} - probe_env = {**os.environ, **_dl.git_env, **_ctx_env} - # GIT_CONFIG_GLOBAL / GIT_CONFIG_NOSYSTEM carve-out: GitAuthEnvBuilder - # forces an empty global gitconfig for ALL hosts to prevent a user's - # ~/.gitconfig insteadOf rewrites or credential helpers from leaking - # tokens during a clone. But for preflight probes (a single ls-remote - # against the same host the dep targets), the redirection surface is - # nil and killing the user's global config kills Git Credential - # Manager along with it -- the helper most Windows ADO users rely on - # for Entra-cached credentials. For ADO specifically that matters - # because bearer acquisition can fail for reasons unrelated to login - # state (sandbox, proxy, microsoft/apm#1430-style PATH quirks), and - # GCM is the only remaining channel that can save us. Generic hosts - # have the same logic; widening the carve-out to ADO keeps the - # actual clone path isolated (it builds its own clean env) while - # giving the preflight probe the best chance to succeed. - if is_generic or is_azure_devops_hostname(host): - for _key in ("GIT_CONFIG_GLOBAL", "GIT_CONFIG_NOSYSTEM", "GIT_ASKPASS"): - probe_env.pop(_key, None) - - host_display = host if not org else f"{host}/{org}" - - def _run_ls_remote(url, env): - # auth-delegated: invoked via _primary_op/_bearer_op below, both - # routed through auth_resolver.execute_with_bearer_fallback. - try: - return _sp.run( - ["git", "ls-remote", "--heads", "--exit-code", url], - capture_output=True, - text=True, - encoding="utf-8", - timeout=30, - env=env, - ) - except _sp.TimeoutExpired: - return None # network timeout sentinel; treated as non-auth - - def _primary_op(url=probe_url, env=probe_env): - return _run_ls_remote(url, env) - - def _bearer_op( - bearer, dep=dep, dep_ctx=dep_ctx, host=host, host_display=host_display, _dl=_dl - ): - # SECURITY: build a CLEAN env via _build_git_env(scheme="bearer") - # rather than {**probe_env, **build_ado_bearer_git_env(bearer)}. - # probe_env carries GIT_TOKEN= from dep_ctx.git_env; - # leaving it set during the bearer attempt would leak the - # rejected PAT into the child-process env table even though the - # GIT_CONFIG_VALUE_0 header carries the bearer. _build_git_env - # explicitly skips GIT_TOKEN for scheme="bearer". - bearer_env = auth_resolver._build_git_env(bearer, scheme="bearer", host_kind="ado") - bearer_url = _dl._build_repo_url( - dep.repo_url, - use_ssh=False, - dep_ref=dep, - token=None, - auth_scheme="bearer", - ) - _trace(f"Preflight: {host_display} -- retrying with az cli bearer") - return _run_ls_remote(bearer_url, bearer_env) - - def _is_auth_failure(outcome): - if outcome is None: - return False # timeout: not an auth failure - if outcome.returncode == 0: - return False - return is_ado_auth_failure_signal(outcome.stderr or "") - - ado_eligible = ( - dep.is_azure_devops() - and _auth_scheme == "basic" - and getattr(dep_ctx, "source", None) == "ADO_APM_PAT" - ) - - if ado_eligible: - fallback_result = auth_resolver.execute_with_bearer_fallback( - dep, - _primary_op, - _bearer_op, - _is_auth_failure, - ) - result = fallback_result.outcome - # bearer_also_failed is True only when the bearer leg actually - # ran AND its outcome still matched the auth-failure signature. - # Early returns from execute_with_bearer_fallback (az - # unavailable, JWT acquisition failed) leave bearer_attempted - # False so the diagnostic does not falsely claim an attempt. - bearer_also_failed = ( - fallback_result.bearer_attempted - and result is not None - and result.returncode != 0 - and is_ado_auth_failure_signal(result.stderr or "") - ) - else: - result = _primary_op() - bearer_also_failed = False - - if result is None: - continue # timeout fallthrough -- handled by the real phase - - if result.returncode != 0: - stderr_text = result.stderr or "" - if _use_ssh: - # Generic SSH transport: check SSH-specific failure signals. - if not is_ssh_auth_failure_signal(stderr_text): - continue # non-auth SSH failure (network, unknown host key) -- defer - _trace(f"Preflight: {host_display} -- SSH auth rejected") - raise AuthenticationError( - f"SSH authentication failed for {host}", - diagnostic_context=( - f" SSH authentication was rejected by {host_display}.\n" - f" Ensure your SSH key is loaded in ssh-agent " - f"(ssh-add -l) and that the\n" - f" public key is authorised on the server.\n\n" - f" git output: {stderr_text.strip()}\n\n" - f" No files were modified.\n" - f" apm.yml, apm.lock.yaml, and apm_modules/ are unchanged." - ), - ) - else: - if not is_ado_auth_failure_signal(stderr_text): - continue # non-auth git failure (network, ref-not-found) -- defer - _trace(f"Preflight: {host_display} -- auth rejected") - _diag = auth_resolver.build_error_context( - host, - "install --update", - org=org, - dep_url=dep.repo_url, - bearer_also_failed=bearer_also_failed, - ) - raise AuthenticationError( - f"Authentication failed for {host}", - diagnostic_context=( - _diag - + "\n\n No files were modified." - + "\n apm.yml, apm.lock.yaml, and apm_modules/ are unchanged." - ), - ) - else: - _trace(f"Preflight: {host_display} -- accepted") - - def _write_empty_lockfile_only(apm_dir: Path) -> None: """Materialise an empty ``apm.lock.yaml`` for a depless ``apm lock`` run. @@ -354,6 +131,163 @@ def _is_no_work_install( return True +def _read_early_lockfile_state(lockfile_cls, get_path, apm_dir): + """Read prior local-deployed files + orphan-dep flag from the lockfile. + + Returns ``(early_lockfile, old_local_deployed, has_orphan_deps)``. The + orphan flag lets the cleanup phase run even when the user removed every + dependency from apm.yml. + """ + from apm_cli.deps.lockfile import _SELF_KEY + + old_local_deployed: builtins.list = [] + early_lockfile = lockfile_cls.read(get_path(apm_dir)) if apm_dir else None + if early_lockfile: + old_local_deployed = builtins.list(early_lockfile.local_deployed_files) + has_orphan_deps = bool( + early_lockfile and any(k != _SELF_KEY for k in early_lockfile.dependencies) + ) + return early_lockfile, old_local_deployed, has_orphan_deps + + +def _resolve_managed_files(apm_dir, diagnostics): + """Seed managed-files set + resolve registry-proxy config. + + Reads the existing lockfile to seed ``managed_files`` for collision + detection, enforces the PROXY_REGISTRY_ONLY conflict + missing-hash + rules, and normalises path separators. Returns + ``(managed_files, registry_config)``. + """ + from ..deps.lockfile import LockFile, get_lockfile_path + from ..deps.registry_proxy import RegistryConfig + + # Resolve registry proxy configuration once for this install session. + registry_config = RegistryConfig.from_env() + + # Build managed_files from existing lockfile for collision detection + managed_files = builtins.set() + existing_lockfile = LockFile.read(get_lockfile_path(apm_dir)) if apm_dir else None + if existing_lockfile: + for dep in existing_lockfile.dependencies.values(): + managed_files.update(dep.deployed_files) + + # Conflict: registry-only mode requires all locked deps to route + # through the configured proxy. Deps locked to direct VCS sources + # (github.com, GHE Cloud, GHES) are incompatible. + if registry_config and registry_config.enforce_only: + conflicts = registry_config.validate_lockfile_deps( + builtins.list(existing_lockfile.dependencies.values()) + ) + if conflicts: + _rich_error( + "PROXY_REGISTRY_ONLY is set but the lockfile contains " + "dependencies locked to direct VCS sources:" + ) + for dep in conflicts[:10]: + host = dep.host or "github.com" + name = dep.repo_url + if dep.virtual_path: + name = f"{name}/{dep.virtual_path}" + _rich_error(f" - {name} (host: {host})") + _rich_error( + "Re-run with 'apm install --update' to re-resolve " + "through the registry, or unset PROXY_REGISTRY_ONLY." + ) + sys.exit(1) + + # Supply chain warning: registry-proxy entries without a + # content_hash cannot be verified on re-install. + if registry_config and registry_config.enforce_only: + missing = registry_config.find_missing_hashes( + builtins.list(existing_lockfile.dependencies.values()) + ) + if missing: + diagnostics.warn( + "The following registry-proxy dependencies have no " + "content_hash in the lockfile. Run 'apm install " + "--update' to populate hashes for tamper detection.", + package="lockfile", + ) + for dep in missing[:10]: + name = dep.repo_url + if dep.virtual_path: + name = f"{name}/{dep.virtual_path}" + diagnostics.warn( + f" - {name} (host: {dep.host})", + package="lockfile", + ) + + # Normalize path separators once for O(1) lookups in check_collision + from ..integration.base_integrator import BaseIntegrator + + managed_files = BaseIntegrator.normalize_managed_files(managed_files) + return managed_files, registry_config + + +def _run_skill_path_migration(ctx) -> None: + """Auto-migrate legacy skill deployments (#737); no-op when disabled. + + Mirrors the previous inline block: skips entirely unless a prior + lockfile exists and skill-path migration is in effect, then either + reports collisions (error) or executes the migration and summarises + the result. + """ + if ctx.legacy_skill_paths or not ctx.existing_lockfile or ctx.dry_run or ctx.lockfile_only: + return + from apm_cli.utils.console import _rich_info, _rich_warning + + from .skill_path_migration import ( + COLLISION_HEADER_TEMPLATE, + COLLISION_HINT, + MIGRATION_SUMMARY_TEMPLATE, + check_collisions, + detect_legacy_skill_deployments, + execute_migration, + ) + + _migration_plans = detect_legacy_skill_deployments(ctx.existing_lockfile, ctx.project_root) + if _migration_plans: + _collisions = check_collisions(_migration_plans, ctx.project_root) + if _collisions: + # H2: collision is an error, not a warning. + _rich_error( + COLLISION_HEADER_TEMPLATE.format(count=len(_collisions)), + symbol="error", + ) + for _c in _collisions: + _rich_error(f" {_c}", symbol="error") + # H5: actionable next-step hint. + _rich_info(COLLISION_HINT, symbol="info") + # H2: surface via DiagnosticCollector. + if ctx.diagnostics: + for _c in _collisions: + ctx.diagnostics.error( + f"Skill migration collision: {_c}", + package="skill-path-migration", + ) + else: + _migration_result = execute_migration( + _migration_plans, ctx.existing_lockfile, ctx.project_root + ) + _total = len(_migration_result.deleted) + len(_migration_result.skipped_no_file) + if _total > 0: + # H3: suppress info when quiet. + if not (ctx.logger and getattr(ctx.logger, "_quiet", False)): + _rich_info( + MIGRATION_SUMMARY_TEMPLATE.format(count=_total), + symbol="info", + ) + # H4: enumerate deleted paths when verbose. + if ctx.verbose and _migration_result.deleted: + for _dp in _migration_result.deleted: + _rich_info(f" removed {_dp}", symbol="info") + if _migration_result.failed: + _rich_warning( + f" {len(_migration_result.failed)} file(s) could not be deleted (will retry next install)", + symbol="warning", + ) + + def run_install_pipeline( # noqa: PLR0913, RUF100 apm_package: APMPackage, update_refs: bool = False, @@ -444,22 +378,12 @@ def run_install_pipeline( # noqa: PLR0913, RUF100 _root_has_local_primitives = _project_has_root_primitives(source_root) - # Read old local deployed files from the existing lockfile so the - # post-deps-local phase can run stale cleanup even when no current - # local content exists (e.g. .apm/ was deleted but old files remain). - _old_local_deployed: builtins.list = [] - _early_lockfile = LockFile.read(get_lockfile_path(apm_dir)) if apm_dir else None - if _early_lockfile: - _old_local_deployed = builtins.list(_early_lockfile.local_deployed_files) - - # Detect orphan APM dependencies in the previous lockfile so we don't - # short-circuit cleanup when the user removed every dep from apm.yml. - # Without this check, deleting all deps would leave their deployed files - # behind because the cleanup phase never runs. - from apm_cli.deps.lockfile import _SELF_KEY - - _has_orphan_deps = bool( - _early_lockfile and any(k != _SELF_KEY for k in _early_lockfile.dependencies) + # Read old local deployed files + detect orphan deps from the existing + # lockfile so the post-deps-local cleanup phase can run even when no + # current local content exists (e.g. .apm/ deleted but old files remain) + # or the user removed every dep from apm.yml. + _early_lockfile, _old_local_deployed, _has_orphan_deps = _read_early_lockfile_state( + LockFile, get_lockfile_path, apm_dir ) if _is_no_work_install( @@ -641,71 +565,10 @@ def run_install_pipeline( # noqa: PLR0913, RUF100 # Collect installed packages for lockfile generation from ..deps.installed_package import InstalledPackage - from ..deps.lockfile import LockFile, get_lockfile_path - from ..deps.registry_proxy import RegistryConfig installed_packages: builtins.list[InstalledPackage] = [] - # Resolve registry proxy configuration once for this install session. - registry_config = RegistryConfig.from_env() - - # Build managed_files from existing lockfile for collision detection - managed_files = builtins.set() - existing_lockfile = LockFile.read(get_lockfile_path(apm_dir)) if apm_dir else None - if existing_lockfile: - for dep in existing_lockfile.dependencies.values(): - managed_files.update(dep.deployed_files) - - # Conflict: registry-only mode requires all locked deps to route - # through the configured proxy. Deps locked to direct VCS sources - # (github.com, GHE Cloud, GHES) are incompatible. - if registry_config and registry_config.enforce_only: - conflicts = registry_config.validate_lockfile_deps( - builtins.list(existing_lockfile.dependencies.values()) - ) - if conflicts: - _rich_error( - "PROXY_REGISTRY_ONLY is set but the lockfile contains " - "dependencies locked to direct VCS sources:" - ) - for dep in conflicts[:10]: - host = dep.host or "github.com" - name = dep.repo_url - if dep.virtual_path: - name = f"{name}/{dep.virtual_path}" - _rich_error(f" - {name} (host: {host})") - _rich_error( - "Re-run with 'apm install --update' to re-resolve " - "through the registry, or unset PROXY_REGISTRY_ONLY." - ) - sys.exit(1) - - # Supply chain warning: registry-proxy entries without a - # content_hash cannot be verified on re-install. - if registry_config and registry_config.enforce_only: - missing = registry_config.find_missing_hashes( - builtins.list(existing_lockfile.dependencies.values()) - ) - if missing: - diagnostics.warn( - "The following registry-proxy dependencies have no " - "content_hash in the lockfile. Run 'apm install " - "--update' to populate hashes for tamper detection.", - package="lockfile", - ) - for dep in missing[:10]: - name = dep.repo_url - if dep.virtual_path: - name = f"{name}/{dep.virtual_path}" - diagnostics.warn( - f" - {name} (host: {dep.host})", - package="lockfile", - ) - - # Normalize path separators once for O(1) lookups in check_collision - from ..integration.base_integrator import BaseIntegrator - - managed_files = BaseIntegrator.normalize_managed_files(managed_files) + managed_files, registry_config = _resolve_managed_files(apm_dir, diagnostics) # -------------------------------------------------------------- # Phase 4 (#171): Parallel package pre-download @@ -761,66 +624,7 @@ def run_install_pipeline( # noqa: PLR0913, RUF100 # Phase: Skill path auto-migration (#737). # Skipped in lockfile_only mode. # ------------------------------------------------------------------ - if ( - not ctx.legacy_skill_paths - and ctx.existing_lockfile - and not ctx.dry_run - and not lockfile_only - ): - from apm_cli.utils.console import _rich_info, _rich_warning - - from .skill_path_migration import ( - COLLISION_HEADER_TEMPLATE, - COLLISION_HINT, - MIGRATION_SUMMARY_TEMPLATE, - check_collisions, - detect_legacy_skill_deployments, - execute_migration, - ) - - _migration_plans = detect_legacy_skill_deployments( - ctx.existing_lockfile, ctx.project_root - ) - if _migration_plans: - _collisions = check_collisions(_migration_plans, ctx.project_root) - if _collisions: - # H2: collision is an error, not a warning. - _rich_error( - COLLISION_HEADER_TEMPLATE.format(count=len(_collisions)), - symbol="error", - ) - for _c in _collisions: - _rich_error(f" {_c}", symbol="error") - # H5: actionable next-step hint. - _rich_info(COLLISION_HINT, symbol="info") - # H2: surface via DiagnosticCollector. - if ctx.diagnostics: - for _c in _collisions: - ctx.diagnostics.error( - f"Skill migration collision: {_c}", - package="skill-path-migration", - ) - else: - _migration_result = execute_migration( - _migration_plans, ctx.existing_lockfile, ctx.project_root - ) - _total = len(_migration_result.deleted) + len(_migration_result.skipped_no_file) - if _total > 0: - # H3: suppress info when quiet. - if not (ctx.logger and getattr(ctx.logger, "_quiet", False)): - _rich_info( - MIGRATION_SUMMARY_TEMPLATE.format(count=_total), - symbol="info", - ) - # H4: enumerate deleted paths when verbose. - if ctx.verbose and _migration_result.deleted: - for _dp in _migration_result.deleted: - _rich_info(f" removed {_dp}", symbol="info") - if _migration_result.failed: - _rich_warning( - f" {len(_migration_result.failed)} file(s) could not be deleted (will retry next install)", - symbol="warning", - ) + _run_skill_path_migration(ctx) # Generate apm.lock for reproducible installs (T4: lockfile generation) from .phases.lockfile import LockfileBuilder diff --git a/src/apm_cli/install/pipeline_preflight.py b/src/apm_cli/install/pipeline_preflight.py new file mode 100644 index 000000000..983d45721 --- /dev/null +++ b/src/apm_cli/install/pipeline_preflight.py @@ -0,0 +1,240 @@ +"""Auth pre-flight probe for ``apm install --update`` (split from pipeline.py). + +Verifies credentials for every distinct (host, org) cluster before the +write phases touch ``apm.yml`` / ``apm.lock.yaml`` / ``apm_modules/``. +Kept in its own module so ``pipeline.py`` stays within the file-length +budget; re-exported from ``apm_cli.install.pipeline`` so existing +``from apm_cli.install.pipeline import _preflight_auth_check`` imports and +the install-pipeline call site keep working unchanged. +""" + +from __future__ import annotations + +import builtins +import contextlib + +from .errors import AuthenticationError + + +def _preflight_auth_check(ctx, auth_resolver, verbose: bool) -> None: + """Verify auth for every distinct (host, org) before write phases. + + Called only when ``update_refs`` is set, so we know the pipeline is + about to overwrite ``apm.yml``, ``apm.lock.yaml``, and + ``apm_modules/``. A single ``git ls-remote`` per cluster catches + stale tokens before any file is touched. + + For ADO clusters, a stale ``ADO_APM_PAT`` automatically falls back + to an ``az cli`` AAD bearer via :meth:`AuthResolver.execute_with_bearer_fallback` + -- matching the protocol used by the actual clone path. Without this, + ``apm install -g`` (which skipped preflight) would succeed but + ``apm install -g --update`` would fail on the same machine with the + same creds. See #1212. + + For generic hosts, the probe uses the same transport the real clone + would use, mirroring :meth:`TransportSelector.select`: SSH only when + the dep carries an explicit ``ssh://`` scheme; otherwise HTTPS (token + embedded when available, plain HTTPS for anonymous public deps). + SSH failures are detected via :func:`is_ssh_auth_failure_signal`; + HTTPS failures via :func:`is_ado_auth_failure_signal`. + + Raises :class:`AuthenticationError` (with ``build_error_context`` + payload) on the first auth failure that survives the fallback. + """ + import os + import subprocess as _sp + + from ..utils.github_host import ( + is_ado_auth_failure_signal, + is_azure_devops_hostname, + is_github_hostname, + is_ssh_auth_failure_signal, + ) + + logger = getattr(ctx, "logger", None) + + def _trace(line: str) -> None: + """Emit a verbose tracing line; best-effort, never raises.""" + if not verbose or logger is None: + return + with contextlib.suppress(Exception): + logger.verbose_detail(line) + + seen: builtins.set = builtins.set() + for dep in ctx.deps_to_install: + host = dep.host + if not host or is_github_hostname(host): + continue # github.com uses API probe with unauth fallback + org = dep.repo_url.split("/")[0] if dep.repo_url and "/" in dep.repo_url else None + key = (host, org) + if key in seen: + continue + seen.add(key) + + dep_ctx = auth_resolver.resolve_for_dep(dep) + _auth_scheme = getattr(dep_ctx, "auth_scheme", "basic") or "basic" + + from ..deps.github_downloader import GitHubPackageDownloader + + _dl = GitHubPackageDownloader(auth_resolver=auth_resolver) + _dl.github_host = host + is_generic = not is_github_hostname(host) and not is_azure_devops_hostname(host) + + # For generic hosts, mirror TransportSelector.select() when picking + # the probe transport: SSH only when the dep carries an explicit + # ssh:// scheme. Shorthand deps (no explicit scheme) default to + # HTTPS regardless of token presence -- TransportSelector's default + # is plain HTTPS without a token and authenticated HTTPS with one. + # Forcing SSH on tokenless generic hosts would break anonymous + # access to public Gitea/Forgejo deps that have neither an HTTPS + # token nor a configured SSH key. + _explicit_scheme = (getattr(dep, "explicit_scheme", None) or "").lower() + _use_ssh = is_generic and _explicit_scheme == "ssh" + + probe_url = _dl._build_repo_url( + dep.repo_url, + use_ssh=_use_ssh, + dep_ref=dep, + token=dep_ctx.token, + auth_scheme=_auth_scheme, + ) + _ctx_env = getattr(dep_ctx, "git_env", {}) or {} + probe_env = {**os.environ, **_dl.git_env, **_ctx_env} + # GIT_CONFIG_GLOBAL / GIT_CONFIG_NOSYSTEM carve-out: GitAuthEnvBuilder + # forces an empty global gitconfig for ALL hosts to prevent a user's + # ~/.gitconfig insteadOf rewrites or credential helpers from leaking + # tokens during a clone. But for preflight probes (a single ls-remote + # against the same host the dep targets), the redirection surface is + # nil and killing the user's global config kills Git Credential + # Manager along with it -- the helper most Windows ADO users rely on + # for Entra-cached credentials. For ADO specifically that matters + # because bearer acquisition can fail for reasons unrelated to login + # state (sandbox, proxy, microsoft/apm#1430-style PATH quirks), and + # GCM is the only remaining channel that can save us. Generic hosts + # have the same logic; widening the carve-out to ADO keeps the + # actual clone path isolated (it builds its own clean env) while + # giving the preflight probe the best chance to succeed. + if is_generic or is_azure_devops_hostname(host): + for _key in ("GIT_CONFIG_GLOBAL", "GIT_CONFIG_NOSYSTEM", "GIT_ASKPASS"): + probe_env.pop(_key, None) + + host_display = host if not org else f"{host}/{org}" + + def _run_ls_remote(url, env): + # auth-delegated: invoked via _primary_op/_bearer_op below, both + # routed through auth_resolver.execute_with_bearer_fallback. + try: + return _sp.run( + ["git", "ls-remote", "--heads", "--exit-code", url], + capture_output=True, + text=True, + encoding="utf-8", + timeout=30, + env=env, + ) + except _sp.TimeoutExpired: + return None # network timeout sentinel; treated as non-auth + + def _primary_op(url=probe_url, env=probe_env): + return _run_ls_remote(url, env) + + def _bearer_op( + bearer, dep=dep, dep_ctx=dep_ctx, host=host, host_display=host_display, _dl=_dl + ): + # SECURITY: build a CLEAN env via _build_git_env(scheme="bearer") + # rather than {**probe_env, **build_ado_bearer_git_env(bearer)}. + # probe_env carries GIT_TOKEN= from dep_ctx.git_env; + # leaving it set during the bearer attempt would leak the + # rejected PAT into the child-process env table even though the + # GIT_CONFIG_VALUE_0 header carries the bearer. _build_git_env + # explicitly skips GIT_TOKEN for scheme="bearer". + bearer_env = auth_resolver._build_git_env(bearer, scheme="bearer", host_kind="ado") + bearer_url = _dl._build_repo_url( + dep.repo_url, + use_ssh=False, + dep_ref=dep, + token=None, + auth_scheme="bearer", + ) + _trace(f"Preflight: {host_display} -- retrying with az cli bearer") + return _run_ls_remote(bearer_url, bearer_env) + + def _is_auth_failure(outcome): + if outcome is None: + return False # timeout: not an auth failure + if outcome.returncode == 0: + return False + return is_ado_auth_failure_signal(outcome.stderr or "") + + ado_eligible = ( + dep.is_azure_devops() + and _auth_scheme == "basic" + and getattr(dep_ctx, "source", None) == "ADO_APM_PAT" + ) + + if ado_eligible: + fallback_result = auth_resolver.execute_with_bearer_fallback( + dep, + _primary_op, + _bearer_op, + _is_auth_failure, + ) + result = fallback_result.outcome + # bearer_also_failed is True only when the bearer leg actually + # ran AND its outcome still matched the auth-failure signature. + # Early returns from execute_with_bearer_fallback (az + # unavailable, JWT acquisition failed) leave bearer_attempted + # False so the diagnostic does not falsely claim an attempt. + bearer_also_failed = ( + fallback_result.bearer_attempted + and result is not None + and result.returncode != 0 + and is_ado_auth_failure_signal(result.stderr or "") + ) + else: + result = _primary_op() + bearer_also_failed = False + + if result is None: + continue # timeout fallthrough -- handled by the real phase + + if result.returncode != 0: + stderr_text = result.stderr or "" + if _use_ssh: + # Generic SSH transport: check SSH-specific failure signals. + if not is_ssh_auth_failure_signal(stderr_text): + continue # non-auth SSH failure (network, unknown host key) -- defer + _trace(f"Preflight: {host_display} -- SSH auth rejected") + raise AuthenticationError( + f"SSH authentication failed for {host}", + diagnostic_context=( + f" SSH authentication was rejected by {host_display}.\n" + f" Ensure your SSH key is loaded in ssh-agent " + f"(ssh-add -l) and that the\n" + f" public key is authorised on the server.\n\n" + f" git output: {stderr_text.strip()}\n\n" + f" No files were modified.\n" + f" apm.yml, apm.lock.yaml, and apm_modules/ are unchanged." + ), + ) + else: + if not is_ado_auth_failure_signal(stderr_text): + continue # non-auth git failure (network, ref-not-found) -- defer + _trace(f"Preflight: {host_display} -- auth rejected") + _diag = auth_resolver.build_error_context( + host, + "install --update", + org=org, + dep_url=dep.repo_url, + bearer_also_failed=bearer_also_failed, + ) + raise AuthenticationError( + f"Authentication failed for {host}", + diagnostic_context=( + _diag + + "\n\n No files were modified." + + "\n apm.yml, apm.lock.yaml, and apm_modules/ are unchanged." + ), + ) + else: + _trace(f"Preflight: {host_display} -- accepted") From f80b2079e47996f6589eedfffe45b8ce7c826943 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 02:43:54 +0200 Subject: [PATCH 08/21] refactor(integration): split hook_integrator.py into transforms + merge modules Extract the pure stateless hook transforms into `hook_transforms.py` and the source-marker / dependency-scan / merge helpers into `hook_merge.py`, and decompose `_integrate_merged_hooks` into a thin orchestrator over focused helpers. A shared `_copy_hook_scripts` removes a duplicated script-copy loop. hook_integrator.py drops 1738 -> 769 lines and `_integrate_merged_hooks` falls under the Stage 2 complexity thresholds (mccabe/branches/statements), with no behaviour change. Module-level names tests patch/import (`HookIntegrator`, `_rich_warning` via wrapper, the re-exported transforms) are preserved. Refs #1078 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/integration/hook_integrator.py | 1295 +++----------------- src/apm_cli/integration/hook_merge.py | 711 +++++++++++ src/apm_cli/integration/hook_transforms.py | 514 ++++++++ 3 files changed, 1388 insertions(+), 1132 deletions(-) create mode 100644 src/apm_cli/integration/hook_merge.py create mode 100644 src/apm_cli/integration/hook_transforms.py diff --git a/src/apm_cli/integration/hook_integrator.py b/src/apm_cli/integration/hook_integrator.py index 48838199a..a785e59a0 100644 --- a/src/apm_cli/integration/hook_integrator.py +++ b/src/apm_cli/integration/hook_integrator.py @@ -45,22 +45,46 @@ import json import logging -import re import shutil from dataclasses import dataclass from pathlib import Path -import yaml - from apm_cli.integration.base_integrator import BaseIntegrator, IntegrationResult from apm_cli.utils.console import _rich_warning -from apm_cli.utils.path_security import ( - PathTraversalError, - ensure_path_within, - validate_path_segments, -) +from apm_cli.utils.path_security import ensure_path_within from apm_cli.utils.paths import portable_relpath +from .hook_merge import ( + _clean_apm_entries_from_json, + _dependency_hook_sources, + _get_hook_source_marker, + _get_package_name, + _is_root_local_package, + _load_merged_config_and_sidecar, + _merge_hook_file_entries, + _sync_claude_hooks_settings, + _warn_empty_hook_file, + _write_merged_config, +) +from .hook_transforms import ( + _APM_HOOKS_SIDECAR, + _HOOK_EVENT_MAP, + _emit_hook_event_diagnostics, + _filter_hook_files_for_target, + _rewrite_command_for_target, + _rewrite_hooks_data, +) + +# --------------------------------------------------------------------------- +# Re-exports: symbols imported by external callers / tests from this module. +# The ``X as X`` form marks them as intentional public re-exports (PEP 484). +# --------------------------------------------------------------------------- +from .hook_transforms import _HOOK_EVENT_EXPECTED_CASING as _HOOK_EVENT_EXPECTED_CASING +from .hook_transforms import _copilot_keys_to_gemini as _copilot_keys_to_gemini +from .hook_transforms import _detect_event_casing as _detect_event_casing +from .hook_transforms import _reinject_apm_source_from_sidecar as _reinject_apm_source_from_sidecar +from .hook_transforms import _to_gemini_hook_entries as _to_gemini_hook_entries + _log = logging.getLogger(__name__) @@ -94,143 +118,6 @@ class _MergeHookConfig: schema_strict: bool = False # True = strip _apm_source before writing to disk -# Per-target hook event name mapping. Packages are authored with -# Copilot (camelCase) or Claude (PascalCase) names; targets that use -# different conventions get their events renamed during merge. -_HOOK_EVENT_MAP: dict[str, dict[str, str]] = { - "claude": { - # Copilot camelCase -> Claude PascalCase - "preToolUse": "PreToolUse", - "postToolUse": "PostToolUse", - }, - "gemini": { - # Copilot / Claude -> Gemini - "PreToolUse": "BeforeTool", - "preToolUse": "BeforeTool", - "PostToolUse": "AfterTool", - "postToolUse": "AfterTool", - "Stop": "SessionEnd", - }, -} - -# Expected hook event naming convention per target. -# Used to warn when a package author deploys events whose casing does not -# match the target's convention AND no explicit rename mapping exists. -_HOOK_EVENT_EXPECTED_CASING: dict[str, str] = { - "copilot": "camelCase", - "vscode": "PascalCase", - "claude": "PascalCase", - "cursor": "PascalCase", - "codex": "PascalCase", - "gemini": "PascalCase", - "windsurf": "PascalCase", -} - - -def _detect_event_casing(name: str) -> str | None: - """Return 'camelCase', 'PascalCase', or None for an event name string.""" - if not name or not name[0].isalpha(): - return None - if name[0].islower() and any(c.isupper() for c in name[1:]): - return "camelCase" - if name[0].isupper(): - return "PascalCase" - return None - - -def _sanitize_event_name(name: str) -> str: - """Return event name with non-printable-ASCII characters stripped, for safe logging.""" - return "".join(c for c in name if 0x20 <= ord(c) <= 0x7E) - - -def _emit_hook_event_diagnostics( - event_names: list[str], - target_key: str, - event_map: dict[str, str], -) -> None: - """Log hook events per-target and warn on unmapped casing mismatches. - - This is informational only -- it never blocks deployment. - """ - if not event_names: - return - event_label = "hook event" if len(event_names) == 1 else "hook events" - _log.info( - "target %s: detected %s: %s", - target_key, - event_label, - ", ".join(sorted(_sanitize_event_name(n) for n in event_names)), - ) - expected_casing = _HOOK_EVENT_EXPECTED_CASING.get(target_key) - if not expected_casing: - return - # Warn for events whose detected casing does not match the target convention - # and that are not covered by an explicit rename in event_map. - mismatched = [ - n - for n in event_names - if _detect_event_casing(n) not in (None, expected_casing) and n not in event_map - ] - if mismatched: - example = "preToolUse" if expected_casing == "camelCase" else "PreToolUse" - safe_mismatched = sorted(_sanitize_event_name(n) for n in mismatched) - _rich_warning( - f"Hook events for target '{target_key}' may not be recognized: " - f"{', '.join(safe_mismatched)}. " - f"Target expects {expected_casing} (e.g. {example}). " - f"Rename events to match the {expected_casing} convention, then reinstall." - ) - _log.warning( - "target %s: hook event casing mismatch (no mapping): %s", - target_key, - ", ".join(safe_mismatched), - ) - - -def _to_gemini_hook_entries(entries: list) -> list: - """Transform hook entries into Gemini CLI format. - - Gemini requires ``{"hooks": [...]}`` nesting, uses ``command`` (not - ``bash``), and ``timeout`` in milliseconds (not ``timeoutSec`` in - seconds). Entries already in Claude/Gemini nested format are left - unchanged. - """ - result = [] - for entry in entries: - if not isinstance(entry, dict): - result.append(entry) - continue - # Already nested (Claude / Gemini format) -- just fix inner keys - if "hooks" in entry and isinstance(entry["hooks"], list): - for hook in entry["hooks"]: - _copilot_keys_to_gemini(hook) - result.append(entry) - continue - # Flat Copilot entry -- wrap in nested format - inner = dict(entry) - _copilot_keys_to_gemini(inner) - # Pull _apm_source to outer level (set later, but keep if present) - apm_source = inner.pop("_apm_source", None) - outer: dict = {"hooks": [inner]} - if apm_source: - outer["_apm_source"] = apm_source - result.append(outer) - return result - - -def _copilot_keys_to_gemini(hook: dict) -> None: - """Rename Copilot hook keys to Gemini equivalents in-place.""" - # bash / powershell -> command - if "command" not in hook: - for key in ("bash", "powershell", "windows"): - if key in hook: - hook["command"] = hook.pop(key) - break - # timeoutSec (seconds) -> timeout (milliseconds) - if "timeoutSec" in hook: - hook["timeout"] = hook.pop("timeoutSec") * 1000 - - _MERGE_HOOK_TARGETS: dict[str, _MergeHookConfig] = { "claude": _MergeHookConfig( config_filename="settings.json", @@ -260,100 +147,6 @@ def _copilot_keys_to_gemini(hook: dict) -> None: ), } -_APM_HOOKS_SIDECAR = "apm-hooks.json" - - -def _reinject_apm_source_from_sidecar(hooks: dict, sidecar_data: dict) -> None: - """Restore _apm_source markers from sidecar into in-memory hook entries. - - Schema-strict targets (e.g. Claude) do not persist ``_apm_source`` in - their settings file. Instead, ownership metadata is stored in a - sidecar file. This helper re-injects those markers so the rest of - the integration logic can work with them as normal. - - Each sidecar entry is consumed at most once to prevent falsely claiming - user-owned hooks that happen to have identical content to an APM hook. - - Args: - hooks: The ``"hooks"`` dict loaded from the target config file - (mutated in-place). - sidecar_data: The dict loaded from the sidecar file. - """ - for event_name, sidecar_entries in sidecar_data.items(): - if event_name not in hooks or not isinstance(sidecar_entries, list): - continue - # Build a dict keyed by normalised content -> list of sources. - # Each source is popped on first match so identical content shared - # between APM and the user is only claimed once. - import json - from collections import deque - - pool: dict[str, deque[str]] = {} - for sc_entry in sidecar_entries: - if isinstance(sc_entry, dict) and "_apm_source" in sc_entry: - cmp = {k: v for k, v in sorted(sc_entry.items()) if k != "_apm_source"} - cmp_key = json.dumps(cmp, sort_keys=True) - pool.setdefault(cmp_key, deque()).append(sc_entry["_apm_source"]) - - for disk_entry in hooks[event_name]: - if not isinstance(disk_entry, dict) or "_apm_source" in disk_entry: - continue - disk_cmp = {k: v for k, v in sorted(disk_entry.items()) if k != "_apm_source"} - disk_key = json.dumps(disk_cmp, sort_keys=True) - sources = pool.get(disk_key) - if sources: - disk_entry["_apm_source"] = sources.popleft() - if not sources: - del pool[disk_key] - - -# Mapping from hook-file stem suffix to the set of target keys that -# should receive the file. Files whose stem does not match any -# suffix are treated as universal and deployed to every target. -_HOOK_FILE_TARGET_SUFFIXES: dict[str, set[str]] = { - "copilot-hooks": {"copilot", "vscode"}, - "cursor-hooks": {"cursor"}, - "claude-hooks": {"claude"}, - "codex-hooks": {"codex"}, - "gemini-hooks": {"gemini"}, - "windsurf-hooks": {"windsurf"}, -} - - -def _filter_hook_files_for_target( - hook_files: list[Path], - target_key: str, -) -> list[Path]: - """Return only hook files intended for *target_key*. - - Routing is based on the file stem (case-insensitive): - - Stems ending with a known ``--hooks`` suffix are - restricted to matching targets. - - All other stems (e.g. ``hooks``, ``my-custom-hooks``) are - universal and pass through for every target. - - Args: - hook_files: All discovered hook JSON files. - target_key: Lowercase target name (e.g. ``"claude"``, ``"cursor"``). - - Returns: - Filtered list preserving original order. - """ - result: list[Path] = [] - for hf in hook_files: - stem_lower = hf.stem.lower() - matched_suffix: str | None = None - for suffix, allowed_targets in _HOOK_FILE_TARGET_SUFFIXES.items(): - if stem_lower == suffix or stem_lower.endswith(f"-{suffix}"): - matched_suffix = suffix - if target_key in allowed_targets: - result.append(hf) - break - if matched_suffix is None: - # Universal file -- deploy to all targets - result.append(hf) - return result - class HookIntegrator(BaseIntegrator): """Handles integration of APM package hooks into target locations. @@ -365,29 +158,71 @@ class HookIntegrator(BaseIntegrator): - Cursor: Merged into .cursor/hooks.json hooks key + .cursor/hooks// """ - # Superset of all known script-path keys across supported hook specs. - # Every call site in _rewrite_hooks_data() iterates over this tuple, - # so a single addition here propagates everywhere. - # - # "command": Claude Code (primary), VS Code (default/cross-platform), Cursor - # "bash": GitHub Copilot Agent cloud/CLI - # "powershell": GitHub Copilot Agent cloud/CLI - # "windows": VS Code (OS-specific override) - # "linux": VS Code (OS-specific override) - # "osx": VS Code (OS-specific override) - # - # Refs: - # GH Copilot Agent: https://docs.github.com/en/copilot/concepts/agents/coding-agent/about-hooks - # VS Code: https://code.visualstudio.com/docs/copilot/customization/hooks - # Claude Code: https://code.claude.com/docs/en/hooks - HOOK_COMMAND_KEYS: tuple[str, ...] = ( - "command", - "bash", - "powershell", - "windows", - "linux", - "osx", - ) + # --------------------------------------------------------------------------- + # Static wrappers -- keep callable via HookIntegrator.xxx() for back-compat + # --------------------------------------------------------------------------- + + @staticmethod + def _is_root_local_package(package_info, project_root: Path | None) -> bool: + """Return True when package_info represents the project's own .apm content.""" + return _is_root_local_package(package_info, project_root) + + @staticmethod + def _dependency_hook_sources(project_root: Path) -> set[str]: + """Return source markers that correspond to installed dependency dirs.""" + return _dependency_hook_sources(project_root) + + @staticmethod + def _clean_apm_entries_from_json(json_path: Path, stats: dict[str, int]) -> None: + """Remove APM-tagged entries from a hooks JSON file.""" + _clean_apm_entries_from_json(json_path, stats) + + def _rewrite_command_for_target( + self, + command: str, + package_path: Path, + package_name: str, + target: str, + hook_file_dir: Path | None = None, + root_dir: str | None = None, + deploy_root: Path | None = None, + ) -> tuple[str, list[tuple[Path, str]]]: + """Rewrite a hook command to use installed script paths.""" + return _rewrite_command_for_target( + command, + package_path, + package_name, + target, + hook_file_dir=hook_file_dir, + root_dir=root_dir, + deploy_root=deploy_root, + _warn=_rich_warning, + ) + + def _rewrite_hooks_data( + self, + data: dict, + package_path: Path, + package_name: str, + target: str, + hook_file_dir: Path | None = None, + root_dir: str | None = None, + deploy_root: Path | None = None, + ) -> tuple[dict, list[tuple[Path, str]]]: + """Rewrite all command paths in a hooks JSON structure.""" + return _rewrite_hooks_data( + data, + package_path, + package_name, + target, + hook_file_dir=hook_file_dir, + root_dir=root_dir, + deploy_root=deploy_root, + ) + + # --------------------------------------------------------------------------- + # Hook file discovery + # --------------------------------------------------------------------------- def find_hook_files(self, package_path: Path) -> list[Path]: """Find all hook JSON files in a package. @@ -477,511 +312,45 @@ def _parse_hook_json(self, hook_file: Path) -> dict | None: except (json.JSONDecodeError, OSError): return None - def _rewrite_command_for_target( - self, - command: str, - package_path: Path, - package_name: str, - target: str, - hook_file_dir: Path | None = None, - root_dir: str | None = None, - deploy_root: Path | None = None, - ) -> tuple[str, list[tuple[Path, str]]]: - """Rewrite a hook command to use installed script paths. - - Handles: - - ${CLAUDE_PLUGIN_ROOT}/path references (resolved from package root) - - ./path relative references (resolved from hook file's parent directory) - - Windows backslash variants of both (.\\ and ${CLAUDE_PLUGIN_ROOT}\\) - - Args: - command: Original command string - package_path: Root path of the source package - package_name: Name used for the scripts subdirectory - target: "vscode" or "claude" - hook_file_dir: Directory containing the hook JSON file (for ./path resolution) - root_dir: Override root directory (e.g. ".copilot" for user scope) - deploy_root: Absolute root of the deployment directory. When provided, - rewritten script paths are resolved to absolute paths under this - root so the target (e.g. Claude Code) can execute them regardless - of the working directory. When *None*, rewritten paths stay - relative (backward-compatible behaviour). - - Returns: - Tuple of (rewritten_command, list of (source_file, relative_target_path)) - """ - scripts_to_copy = [] - new_command = command - - if target == "vscode": - base_root = root_dir or ".github" - scripts_base = f"{base_root}/hooks/scripts/{package_name}" - elif target == "cursor": - base_root = root_dir or ".cursor" - scripts_base = f"{base_root}/hooks/{package_name}" - elif target == "codex": - base_root = root_dir or ".codex" - scripts_base = f"{base_root}/hooks/{package_name}" - elif target == "windsurf": - base_root = root_dir or ".windsurf" - scripts_base = f"{base_root}/hooks/{package_name}" - else: - base_root = root_dir or ".claude" - scripts_base = f"{base_root}/hooks/{package_name}" - - # Handle plugin root variable references (always relative to package root) - # Match both forward-slash and backslash separators (Windows hook JSON - # may use backslashes: ${CLAUDE_PLUGIN_ROOT}\scripts\scan.ps1) - plugin_root_pattern = ( - r"\$\{(?:CLAUDE_PLUGIN_ROOT|CURSOR_PLUGIN_ROOT|PLUGIN_ROOT)\}([\\/][^\s\"']+)" - ) - for match in re.finditer(plugin_root_pattern, command): - full_var = match.group(0) - # Normalize backslashes to forward slashes before Path construction - # (on Unix, Path treats backslashes as literal filename chars) - rel_path = match.group(1).replace("\\", "/").lstrip("/") - - try: - source_file = ensure_path_within(package_path / rel_path, package_path) - except PathTraversalError: - continue - if source_file.exists() and source_file.is_file(): - target_rel = f"{scripts_base}/{rel_path}" - scripts_to_copy.append((source_file, target_rel)) - resolved_cmd = ( - str((deploy_root / target_rel).resolve()) - if deploy_root is not None - else target_rel - ) - new_command = new_command.replace(full_var, resolved_cmd) - else: - # File absent: always warn so a misconfigured hook is never - # silently deployed. For user-scope (deploy_root set) also - # rewrite the unexpanded variable to an absolute source path - # so the target surfaces a clear "file not found". For - # project-scope (deploy_root is None) leave the variable in - # place -- rewriting to an absolute path would re-introduce - # the #1394 portability regression in committed configs. - _rich_warning(f"Hook script not found: {source_file}") - if deploy_root is not None: - new_command = new_command.replace(full_var, str(source_file)) - - # Handle relative ./path and .\path references (safe to run after - # ${CLAUDE_PLUGIN_ROOT} substitution since replacements produce paths - # like ".github/..." not "./" or ".\") - # Match both forward-slash and backslash separators (Windows hook JSON - # may use backslashes: .\scripts\scan.ps1) - # Resolve from hook file's directory if available, else fall back to package root - resolve_base = hook_file_dir if hook_file_dir else package_path - rel_pattern = r"(\.[\\/][^\s\"']+)" - for match in re.finditer(rel_pattern, new_command): - rel_ref = match.group(1) - # Normalize to forward slashes for path resolution - rel_path = rel_ref[2:].replace("\\", "/") - - try: - source_file = ensure_path_within(resolve_base / rel_path, package_path) - except PathTraversalError: - continue - if source_file.exists() and source_file.is_file(): - target_rel = f"{scripts_base}/{rel_path}" - scripts_to_copy.append((source_file, target_rel)) - resolved_cmd = ( - str((deploy_root / target_rel).resolve()) - if deploy_root is not None - else target_rel - ) - new_command = new_command.replace(rel_ref, resolved_cmd) - else: - # File absent: always warn (see ${PLUGIN_ROOT} branch above - # for the project-scope vs user-scope rationale). - _rich_warning(f"Hook script not found: {source_file}") - if deploy_root is not None: - new_command = new_command.replace(rel_ref, str(source_file)) - - return new_command, scripts_to_copy + # --------------------------------------------------------------------------- + # Script copy helper + # --------------------------------------------------------------------------- - def _rewrite_hooks_data( + def _copy_hook_scripts( self, - data: dict, - package_path: Path, - package_name: str, - target: str, - hook_file_dir: Path | None = None, - root_dir: str | None = None, - deploy_root: Path | None = None, - ) -> tuple[dict, list[tuple[Path, str]]]: - """Rewrite all command paths in a hooks JSON structure. - - Creates a deep copy and rewrites command paths for the target platform. - - Args: - data: Parsed hook JSON data - package_path: Root path of the source package - package_name: Name for scripts subdirectory - target: "vscode" or "claude" - hook_file_dir: Directory containing the hook JSON file (for ./path resolution) - root_dir: Override root directory (e.g. ".copilot" for user scope) - deploy_root: Absolute root of the deployment directory. When provided, - all rewritten script paths are resolved to absolute paths so the - target can locate scripts regardless of the working directory. - When *None*, paths remain relative (backward-compatible behaviour). + scripts: list[tuple[Path, str]], + project_root: Path, + target_paths: list[Path], + managed_files: set | None, + force: bool, + diagnostics, + ) -> tuple[int, int]: + """Copy referenced hook scripts to their target locations. Returns: - Tuple of (rewritten_data_copy, list of (source_file, target_rel_path)) + (scripts_copied, scripts_adopted) counts. """ - import copy - - rewritten = copy.deepcopy(data) - all_scripts: list[tuple[Path, str]] = [] - - hooks = rewritten.get("hooks", {}) - for event_name, matchers in hooks.items(): - if not isinstance(matchers, list): + scripts_copied = 0 + scripts_adopted = 0 + for source_file, target_rel in scripts: + target_script = project_root / target_rel + ensure_path_within(target_script, project_root) + if self.try_adopt_identical(target_script, source_file, target_paths): + scripts_adopted += 1 continue - for matcher in matchers: - if not isinstance(matcher, dict): - continue - # Rewrite script paths in the matcher dict itself - # (GitHub Copilot flat format: bash/powershell/windows keys at this level) - for key in self.HOOK_COMMAND_KEYS: - if key in matcher: - new_cmd, scripts = self._rewrite_command_for_target( - matcher[key], - package_path, - package_name, - target, - hook_file_dir=hook_file_dir, - root_dir=root_dir, - deploy_root=deploy_root, - ) - if scripts: - _log.debug( - "Hook %s/%s: rewrote '%s' key (%d script(s))", - package_name, - event_name, - key, - len(scripts), - ) - matcher[key] = new_cmd - all_scripts.extend(scripts) - - # Rewrite script paths in nested hooks array - # (Claude format: matcher groups with inner hooks array) - for hook in matcher.get("hooks", []): - if not isinstance(hook, dict): - continue - for key in self.HOOK_COMMAND_KEYS: - if key in hook: - new_cmd, scripts = self._rewrite_command_for_target( - hook[key], - package_path, - package_name, - target, - hook_file_dir=hook_file_dir, - root_dir=root_dir, - deploy_root=deploy_root, - ) - if scripts: - _log.debug( - "Hook %s/%s: rewrote '%s' key (%d script(s))", - package_name, - event_name, - key, - len(scripts), - ) - hook[key] = new_cmd - all_scripts.extend(scripts) - - # De-duplicate by target path to avoid redundant copies when - # multiple keys (e.g. command + bash) reference the same script. - seen_targets: dict[str, Path] = {} - for source, target_rel in all_scripts: - if target_rel not in seen_targets: - seen_targets[target_rel] = source - unique_scripts = [(src, tgt) for tgt, src in seen_targets.items()] - - return rewritten, unique_scripts - - @staticmethod - def _is_root_local_package(package_info, project_root: Path | None) -> bool: - """Return True when *package_info* represents the project's own .apm content.""" - if project_root is None: - return False - try: - return Path(package_info.install_path).resolve() == Path(project_root).resolve() - except (OSError, RuntimeError): - return False - - @staticmethod - def _safe_source_name(value: str | None, fallback: str = "_local") -> str: - """Return a stable source marker that is also safe for hook script paths.""" - if not isinstance(value, str) or not value: - return fallback - safe = re.sub(r"[^A-Za-z0-9._-]+", "-", value.strip()) - # Collapse any run of 2+ dots to a single dot before stripping edges. - # Embedded sequences like "foo..bar" would otherwise pass through the - # earlier guard and reach downstream Path joins as a parent-dir hop. - safe = re.sub(r"\.{2,}", ".", safe).strip(".-_") - if not safe or safe in {".", ".."}: - return fallback - return safe - - @staticmethod - def _get_root_local_package_name(package_info, project_root: Path) -> str: - """Get the stable source marker for root .apm content.""" - apm_yml = Path(project_root) / "apm.yml" - if apm_yml.exists(): - try: - from apm_cli.utils.yaml_io import load_yaml - - data = load_yaml(apm_yml) - if isinstance(data, dict): - manifest_name = HookIntegrator._safe_source_name(data.get("name")) - if manifest_name != "_local": - return manifest_name - except (OSError, ValueError, yaml.YAMLError) as exc: - _log.debug( - "Hook integrator: apm.yml manifest unreadable for %s (%s: %s), " - "falling back to install_path basename", - project_root, - exc.__class__.__name__, - exc, - ) - - package = getattr(package_info, "package", None) - package_name = HookIntegrator._safe_source_name(getattr(package, "name", None)) - if package_name != "_local": - return package_name - return "_local" - - def _get_package_name(self, package_info, project_root: Path | None = None) -> str: - """Get a short package name for use in file/directory naming. - - Args: - package_info: PackageInfo object - project_root: When provided and the package is the project root, - reads ``apm.yml`` ``name`` for a stable source marker instead - of falling back to ``install_path.name`` (which drifts on - directory renames and worktrees). See #1329. - - Returns: - str: Package name used as hook source marker and script namespace - """ - if self._is_root_local_package(package_info, project_root): - return HookIntegrator._get_root_local_package_name(package_info, Path(project_root)) - return package_info.install_path.name - - @staticmethod - def _get_hook_source_marker( - package_info, - project_root: Path, - package_name: str, - ) -> str: - """Get the marker stored in merged hook JSON for ownership cleanup.""" - if HookIntegrator._is_root_local_package(package_info, project_root): - if package_name == "_local": - return "_local" - return f"_local/{package_name}" - return package_name - - @staticmethod - def _hook_entry_content_key(entry: dict) -> str: - """Build a stable comparison key excluding APM ownership metadata.""" - comparable = {k: v for k, v in sorted(entry.items()) if k != "_apm_source"} - return json.dumps(comparable, sort_keys=True, separators=(",", ":")) - - @staticmethod - def _dependency_hook_sources(project_root: Path) -> set[str]: - """Return source markers that correspond to installed dependency dirs.""" - apm_modules = project_root / "apm_modules" - if not apm_modules.is_dir(): - return set() - - lockfile_paths, lockfile_readable = HookIntegrator._lockfile_dependency_paths(project_root) - if lockfile_readable: - sources: set[str] = set() - for rel_path in lockfile_paths: - package_path = HookIntegrator._safe_dependency_path(apm_modules, rel_path) - if package_path is None: - continue - HookIntegrator._add_dependency_source(sources, package_path) - return sources - - return HookIntegrator._bounded_dependency_hook_sources(apm_modules) - - @staticmethod - def _lockfile_dependency_paths(project_root: Path) -> tuple[list[str], bool]: - """Return installed dependency paths from a readable lockfile, if present.""" - try: - from apm_cli.deps.lockfile import LEGACY_LOCKFILE_NAME, LockFile, get_lockfile_path - - lockfile_path = get_lockfile_path(project_root) - if not lockfile_path.exists(): - legacy_path = project_root / LEGACY_LOCKFILE_NAME - if legacy_path.exists(): - lockfile_path = legacy_path - if not lockfile_path.exists(): - return [], False - lockfile = LockFile.read(lockfile_path) - if lockfile is None: - return [], False - return lockfile.get_installed_paths(project_root / "apm_modules"), True - except (AttributeError, OSError, TypeError, ValueError, KeyError): - return [], False - - @staticmethod - def _safe_dependency_path(apm_modules: Path, rel_path: str) -> Path | None: - """Return a lockfile dependency path without escaping apm_modules.""" - try: - validate_path_segments( - rel_path, - context="lockfile dependency path", - reject_empty=True, - ) - package_path = apm_modules / Path(rel_path) - ensure_path_within(package_path, apm_modules) - if HookIntegrator._has_symlink_component(apm_modules, package_path): - return None - return package_path - except (OSError, PathTraversalError, RuntimeError, TypeError): - return None - - @staticmethod - def _has_symlink_component(apm_modules: Path, package_path: Path) -> bool: - """Return True when any component below apm_modules is a symlink.""" - try: - relative = package_path.relative_to(apm_modules) - current = apm_modules - for part in relative.parts: - current = current / part - if current.is_symlink(): - return True - return False - except (OSError, ValueError): - return True - - @staticmethod - def _is_dependency_package_dir(path: Path) -> bool: - """Return True when *path* looks like an installed package root.""" - try: - hooks = path / "hooks" - apm_hooks = path / ".apm" / "hooks" - apm_yml = path / "apm.yml" - skill_md = path / "SKILL.md" - return ( - (hooks.is_dir() and not hooks.is_symlink()) - or (apm_hooks.is_dir() and not apm_hooks.is_symlink()) - or (apm_yml.is_file() and not apm_yml.is_symlink()) - or (skill_md.is_file() and not skill_md.is_symlink()) - ) - except OSError: - return False - - @staticmethod - def _add_dependency_source(sources: set[str], package_path: Path) -> bool: - """Add package_path.name to sources when package_path is a package root.""" - try: - if ( - not package_path.is_dir() - or package_path.is_symlink() - or not HookIntegrator._is_dependency_package_dir(package_path) + if self.check_collision( + target_script, target_rel, managed_files, force, diagnostics=diagnostics ): - return False - except OSError: - return False - sources.add(package_path.name) - return True - - @staticmethod - def _child_dependency_dirs(path: Path) -> list[Path]: - """Return direct non-hidden child dirs without following symlink roots.""" - try: - if path.is_symlink() or not path.is_dir(): - return [] - return sorted( - [ - child - for child in path.iterdir() - if not child.is_symlink() and child.is_dir() and not child.name.startswith(".") - ], - key=lambda child: child.name, - ) - except OSError: - return [] - - @staticmethod - def _collect_known_subdirectory_sources(sources: set[str], repo_root: Path) -> None: - """Collect dependency sources from known virtual subdirectory layouts.""" - for namespace in ("collections", "skills"): - for package_path in HookIntegrator._child_dependency_dirs(repo_root / namespace): - HookIntegrator._add_dependency_source(sources, package_path) - - apm_dir = repo_root / ".apm" - try: - if apm_dir.is_symlink() or not apm_dir.is_dir(): - return - except OSError: - return - for primitive in ("agents", "commands", "hooks", "instructions", "prompts", "skills"): - for package_path in HookIntegrator._child_dependency_dirs(apm_dir / primitive): - HookIntegrator._add_dependency_source(sources, package_path) - - @staticmethod - def _collect_remote_dependency_sources(sources: set[str], namespace: Path) -> None: - """Collect fallback sources from explicit remote install layouts.""" - if HookIntegrator._add_dependency_source(sources, namespace): - return - - for repo_or_project in HookIntegrator._child_dependency_dirs(namespace): - if HookIntegrator._add_dependency_source(sources, repo_or_project): - continue - - HookIntegrator._collect_known_subdirectory_sources(sources, repo_or_project) - - for ado_repo in HookIntegrator._child_dependency_dirs(repo_or_project): - if HookIntegrator._add_dependency_source(sources, ado_repo): - continue - HookIntegrator._collect_known_subdirectory_sources(sources, ado_repo) - - @staticmethod - def _collect_local_dependency_sources(sources: set[str], local_namespace: Path) -> None: - """Collect apm_modules/_local/ package roots only.""" - for local_package in HookIntegrator._child_dependency_dirs(local_namespace): - HookIntegrator._add_dependency_source(sources, local_package) - - @staticmethod - def _bounded_dependency_hook_sources(apm_modules: Path) -> set[str]: - """Fallback source scan limited to known apm_modules package layouts.""" - sources: set[str] = set() - - for package_root in HookIntegrator._child_dependency_dirs(apm_modules): - if package_root.name == "_local": - HookIntegrator._collect_local_dependency_sources(sources, package_root) continue + target_script.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(source_file, target_script) + scripts_copied += 1 + target_paths.append(target_script) + return scripts_copied, scripts_adopted - HookIntegrator._collect_remote_dependency_sources(sources, package_root) - return sources - - @staticmethod - def _should_remove_prior_merged_entry( - entry, - *, - source_marker: str, - fresh_content_keys: set[str], - heal_stale_root_source: bool, - dependency_sources: set[str], - remove_current_source: bool, - ) -> bool: - """Return True when an existing merged-hook entry should be replaced.""" - if not isinstance(entry, dict): - return False - source = entry.get("_apm_source") - if remove_current_source and source == source_marker: - return True - if not heal_stale_root_source or not source or source in dependency_sources: - return False - return HookIntegrator._hook_entry_content_key(entry) in fresh_content_keys + # --------------------------------------------------------------------------- + # Copilot (individual-file) integration + # --------------------------------------------------------------------------- def integrate_package_hooks( self, @@ -1022,7 +391,7 @@ def integrate_package_hooks( hooks_dir = project_root / root_dir / "hooks" hooks_dir.mkdir(parents=True, exist_ok=True) - package_name = self._get_package_name(package_info, project_root) + package_name = _get_package_name(package_info, project_root) hooks_integrated = 0 scripts_copied = 0 scripts_adopted = 0 @@ -1034,7 +403,7 @@ def integrate_package_hooks( continue # Rewrite script paths for VSCode target - rewritten, scripts = self._rewrite_hooks_data( + rewritten, scripts = _rewrite_hooks_data( data, package_info.install_path, package_name, @@ -1064,21 +433,11 @@ def integrate_package_hooks( hooks_integrated += 1 target_paths.append(target_path) - # Copy referenced scripts (individual file tracking) - for source_file, target_rel in scripts: - target_script = project_root / target_rel - ensure_path_within(target_script, project_root) - if self.try_adopt_identical(target_script, source_file, target_paths): - scripts_adopted += 1 - continue - if self.check_collision( - target_script, target_rel, managed_files, force, diagnostics=diagnostics - ): - continue - target_script.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(source_file, target_script) - scripts_copied += 1 - target_paths.append(target_script) + sc, sa = self._copy_hook_scripts( + scripts, project_root, target_paths, managed_files, force, diagnostics + ) + scripts_copied += sc + scripts_adopted += sa return HookIntegrationResult( files_integrated=hooks_integrated, @@ -1125,17 +484,7 @@ def _integrate_merged_hooks( if config.require_dir and not target_dir.exists(): return _empty - # Absolutize hook commands only for user-scope deploys. Claude - # Code (and the Codex/Cursor/Gemini equivalents) reads - # ``~/.claude/settings.json`` without a fixed cwd and does not - # expand ``${CLAUDE_PLUGIN_ROOT}`` in that file (see #1310 / #1354), - # so user-scope deploys must write absolute paths. Project-scope - # ``/.claude/settings.json`` is typically checked in and runs - # with cwd at the repo root, where repo-relative paths resolve - # correctly -- baking absolute machine paths into checked-in config - # breaks portability across clones, contributors, and CI (#1394). - # ``user_scope`` is threaded from the caller's ``InstallScope`` so - # the gate is explicit rather than inferred from deploy-root shape. + # Absolutize hook commands only for user-scope deploys. _deploy_root_for_rewrite = project_root if user_scope else None hook_files = self.find_hook_files(package_info.install_path) @@ -1143,65 +492,27 @@ def _integrate_merged_hooks( if not hook_files: return _empty - package_name = self._get_package_name(package_info, project_root) - source_marker = self._get_hook_source_marker(package_info, project_root, package_name) - heal_stale_root_source = self._is_root_local_package(package_info, project_root) - dependency_sources = ( - self._dependency_hook_sources(project_root) if heal_stale_root_source else set() - ) + package_name = _get_package_name(package_info, project_root) + source_marker = _get_hook_source_marker(package_info, project_root, package_name) + heal_stale = _is_root_local_package(package_info, project_root) + dep_sources = _dependency_hook_sources(project_root) if heal_stale else set() + hooks_integrated = 0 scripts_copied = 0 scripts_adopted = 0 target_paths: list[Path] = [] - # Events whose prior-owned entries have already been cleared on - # this install run. Packages can contribute to the same event - # from multiple hook files -- we must only strip once so earlier - # files' fresh entries aren't wiped by later iterations. cleared_events: set = set() - # Read existing JSON config json_path = target_dir / config.config_filename - json_config: dict = {} - if json_path.exists(): - try: - with open(json_path, encoding="utf-8") as f: - json_config = json.load(f) - except (json.JSONDecodeError, OSError): - json_config = {} - - # Load sidecar ownership metadata (schema-strict targets) sidecar_path = target_dir / _APM_HOOKS_SIDECAR - sidecar_data: dict = {} - if config.schema_strict and sidecar_path.exists(): - try: - with open(sidecar_path, encoding="utf-8") as f: - _raw = json.load(f) - if isinstance(_raw, dict): - sidecar_data = _raw - else: - _log.warning( - "Sidecar file %s contains non-dict JSON; treating as empty.", - sidecar_path, - ) - sidecar_data = {} - except (json.JSONDecodeError, OSError) as exc: - _log.warning("Failed to read sidecar %s: %s; treating as empty.", sidecar_path, exc) - sidecar_data = {} - - # Re-inject _apm_source from sidecar into matching in-memory entries - if sidecar_data and "hooks" in json_config: - _reinject_apm_source_from_sidecar(json_config["hooks"], sidecar_data) - - if "hooks" not in json_config: - json_config["hooks"] = {} + json_config = _load_merged_config_and_sidecar(json_path, sidecar_path, config.schema_strict) for hook_file in hook_files: data = self._parse_hook_json(hook_file) if data is None: continue - # Rewrite script paths for the target - rewritten, scripts = self._rewrite_hooks_data( + rewritten, scripts = _rewrite_hooks_data( data, package_info.install_path, package_name, @@ -1211,208 +522,34 @@ def _integrate_merged_hooks( deploy_root=_deploy_root_for_rewrite, ) - # Merge hooks into config (additive) hooks = rewritten.get("hooks", {}) event_map = _HOOK_EVENT_MAP.get(config.target_key, {}) - _emit_hook_event_diagnostics(list(hooks.keys()), config.target_key, event_map) - # Build reverse map: normalised name -> set of source aliases - reverse_map: dict[str, set[str]] = {} - for source_name, norm_name in event_map.items(): - reverse_map.setdefault(norm_name, set()).add(source_name) - - entries_appended_for_file = False - for raw_event_name, entries in hooks.items(): - if not isinstance(entries, list) or not entries: - continue - event_name = event_map.get(raw_event_name, raw_event_name) - if event_name not in json_config["hooks"]: - json_config["hooks"][event_name] = [] - - # Transform flat Copilot entries to Gemini nested format - if config.target_key == "gemini": - entries = _to_gemini_hook_entries(entries) - - # Mark each entry with APM source for sync/cleanup - for entry in entries: - if isinstance(entry, dict): - entry["_apm_source"] = source_marker - fresh_content_keys = { - self._hook_entry_content_key(entry) - for entry in entries - if isinstance(entry, dict) - } + appended = _merge_hook_file_entries( + json_config, + hooks, + config.target_key, + event_map, + source_marker, + cleared_events, + heal_stale_root_source=heal_stale, + dependency_sources=dep_sources, + ) - # Idempotent upsert: drop any prior entries owned by this - # package before appending fresh ones. Without this, every - # `apm install` re-run duplicates the package's hooks - # because `.extend()` is unconditional. See microsoft/apm#708. - # Only strip once per event per install run -- a package - # with multiple hook files targeting the same event - # contributes each file's entries in turn, and stripping - # on every iteration would erase earlier files' work. - remove_current_source = event_name not in cleared_events - if remove_current_source or heal_stale_root_source: - # Clear from the normalised event - prior_entries = json_config["hooks"][event_name] - kept_entries = [ - e - for e in prior_entries - if not self._should_remove_prior_merged_entry( - e, - source_marker=source_marker, - fresh_content_keys=fresh_content_keys, - heal_stale_root_source=heal_stale_root_source, - dependency_sources=dependency_sources, - remove_current_source=remove_current_source, - ) - ] - if heal_stale_root_source: - kept_ids = {id(e) for e in kept_entries} - healed = sum( - 1 - for e in prior_entries - if isinstance(e, dict) - and e.get("_apm_source") - and e.get("_apm_source") != source_marker - and e.get("_apm_source") not in dependency_sources - and id(e) not in kept_ids - ) - if healed: - _log.debug( - "Hook integrator: healed %d stale same-content " - "merged hook entries for source %s in event %s", - healed, - source_marker, - event_name, - ) - json_config["hooks"][event_name] = kept_entries - # Also clear from any alias events that map to - # this normalised name (handles migration from - # corrupted installs with mixed-case event keys). - for alias in reverse_map.get(event_name, set()): - if alias != event_name and alias in json_config["hooks"]: - json_config["hooks"][alias] = [ - e - for e in json_config["hooks"][alias] - if not self._should_remove_prior_merged_entry( - e, - source_marker=source_marker, - fresh_content_keys=fresh_content_keys, - heal_stale_root_source=heal_stale_root_source, - dependency_sources=dependency_sources, - remove_current_source=remove_current_source, - ) - ] - # Remove the alias key entirely if now empty - if not json_config["hooks"][alias]: - del json_config["hooks"][alias] - cleared_events.add(event_name) - json_config["hooks"][event_name].extend(entries) - - # Deduplicate same-package entries by content. - # Safety net for edge cases where multiple source files - # produce semantically identical entries. - import json as _json - - seen_keys: set[str] = set() - deduped: list = [] - for entry in json_config["hooks"][event_name]: - if not isinstance(entry, dict): - deduped.append(entry) - continue - cmp = {k: v for k, v in sorted(entry.items()) if k != "_apm_source"} - source = entry.get("_apm_source") - dedup_key = _json.dumps({"s": source, "c": cmp}, sort_keys=True) - if dedup_key not in seen_keys: - seen_keys.add(dedup_key) - deduped.append(entry) - json_config["hooks"][event_name] = deduped - entries_appended_for_file = True - - if entries_appended_for_file: + if appended: hooks_integrated += 1 else: - # Diagnostic for the fail-closed silent-skip path introduced - # by the integrated-counter fix (microsoft/apm#1499): a hook - # file that parsed cleanly but contributed zero entries (all - # events empty / non-list) used to bump the counter and lie - # to the user. Now we skip it -- emit a user-visible warning - # (the original #1499 symptom was that authors saw nothing - # bad AND nothing good, so a structured-logger-only message - # would re-introduce the silent-failure UX) and a parallel - # _log.warning for operators consuming structured logs. - rel_hook = hook_file.name - _rich_warning( - f"Hook file {rel_hook} contributed no entries to " - f"{config.target_key} settings; skipped." - ) - _log.warning( - "Hook file %s contributed no entries to %s settings " - "(all events empty or non-list); skipping.", - hook_file, - config.target_key, - ) + _warn_empty_hook_file(hook_file, config.target_key) - # Copy referenced scripts - for source_file, target_rel in scripts: - target_script = project_root / target_rel - ensure_path_within(target_script, project_root) - if self.try_adopt_identical(target_script, source_file, target_paths): - scripts_adopted += 1 - continue - if self.check_collision( - target_script, - target_rel, - managed_files, - force, - diagnostics=diagnostics, - ): - continue - target_script.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(source_file, target_script) - scripts_copied += 1 - target_paths.append(target_script) - - # Write JSON config back - # Don't track the config file in target_paths -- it's a shared - # file cleaned via _apm_source markers, not file-level deletion - json_path.parent.mkdir(parents=True, exist_ok=True) + copied, adopted = self._copy_hook_scripts( + scripts, project_root, target_paths, managed_files, force, diagnostics + ) + scripts_copied += copied + scripts_adopted += adopted - if config.schema_strict: - # Build sidecar from entries that have _apm_source - sidecar_out: dict = {} - for event_name, entries_list in json_config.get("hooks", {}).items(): - if not isinstance(entries_list, list): - continue - owned = [e for e in entries_list if isinstance(e, dict) and "_apm_source" in e] - if owned: - sidecar_out[event_name] = [dict(e) for e in owned] - - # Strip _apm_source from entries before writing to disk - for entries_list in json_config.get("hooks", {}).values(): - if isinstance(entries_list, list): - for entry in entries_list: - if isinstance(entry, dict): - entry.pop("_apm_source", None) - - # Write sidecar - sidecar_path = target_dir / _APM_HOOKS_SIDECAR - if sidecar_out: - try: - with open(sidecar_path, "w", encoding="utf-8") as f: - json.dump(sidecar_out, f, indent=2) - f.write("\n") - except OSError as exc: - _log.warning("Failed to write sidecar %s: %s", sidecar_path, exc) - elif sidecar_path.exists(): - sidecar_path.unlink() - - # Write the (now schema-clean) config - with open(json_path, "w", encoding="utf-8") as f: - json.dump(json_config, f, indent=2) - f.write("\n") + json_path.parent.mkdir(parents=True, exist_ok=True) + _write_merged_config(json_path, sidecar_path, json_config, config.schema_strict) return HookIntegrationResult( files_integrated=hooks_integrated, @@ -1625,114 +762,8 @@ def sync_integration( if config is not None: json_path = project_root / t.root_dir / config.config_filename if t.name == "claude": - # Claude uses settings.json with special structure - if json_path.exists(): - try: - with open(json_path, encoding="utf-8") as f: - settings = json.load(f) - - # Load sidecar to restore _apm_source markers - sidecar_path = json_path.parent / _APM_HOOKS_SIDECAR - sidecar_data: dict = {} - if sidecar_path.exists(): - try: - with open(sidecar_path, encoding="utf-8") as sf: - _raw = json.load(sf) - if isinstance(_raw, dict): - sidecar_data = _raw - else: - _log.warning( - "Sidecar file %s contains non-dict JSON; treating as empty.", - sidecar_path, - ) - sidecar_data = {} - except (json.JSONDecodeError, OSError) as exc: - _log.warning( - "Failed to read sidecar %s: %s; treating as empty.", - sidecar_path, - exc, - ) - sidecar_data = {} - - # Re-inject _apm_source from sidecar - if sidecar_data and "hooks" in settings: - _reinject_apm_source_from_sidecar(settings["hooks"], sidecar_data) - - if "hooks" in settings: - modified = False - for event_name in list(settings["hooks"].keys()): - matchers = settings["hooks"][event_name] - if isinstance(matchers, list): - filtered = [ - m - for m in matchers - if not (isinstance(m, dict) and "_apm_source" in m) - ] - if len(filtered) != len(matchers): - modified = True - settings["hooks"][event_name] = filtered - if not filtered: - del settings["hooks"][event_name] - - if not settings["hooks"]: - del settings["hooks"] - - if modified: - with open(json_path, "w", encoding="utf-8") as f: - json.dump(settings, f, indent=2) - f.write("\n") - stats["files_removed"] += 1 - - # Clean up sidecar - if sidecar_path.exists(): - sidecar_path.unlink() - - # Remove stale sidecar when no hooks section remains - if sidecar_path.exists() and "hooks" not in settings: - sidecar_path.unlink() - except (json.JSONDecodeError, OSError): - stats["errors"] += 1 + _sync_claude_hooks_settings(json_path, stats) else: - self._clean_apm_entries_from_json(json_path, stats) + _clean_apm_entries_from_json(json_path, stats) return stats - - @staticmethod - def _clean_apm_entries_from_json(json_path: Path, stats: dict[str, int]) -> None: - """Remove APM-tagged entries from a hooks JSON file. - - Filters out entries with ``_apm_source`` markers and cleans up - empty event arrays and the ``hooks`` key itself. - """ - if not json_path.exists(): - return - try: - with open(json_path, encoding="utf-8") as f: - data = json.load(f) - - if "hooks" not in data: - return - - modified = False - for event_name in list(data["hooks"].keys()): - entries = data["hooks"][event_name] - if isinstance(entries, list): - filtered = [ - e for e in entries if not (isinstance(e, dict) and "_apm_source" in e) - ] - if len(filtered) != len(entries): - modified = True - data["hooks"][event_name] = filtered - if not filtered: - del data["hooks"][event_name] - - if not data["hooks"]: - del data["hooks"] - - if modified: - with open(json_path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2) - f.write("\n") - stats["files_removed"] += 1 - except (json.JSONDecodeError, OSError): - stats["errors"] += 1 diff --git a/src/apm_cli/integration/hook_merge.py b/src/apm_cli/integration/hook_merge.py new file mode 100644 index 000000000..1cdd34926 --- /dev/null +++ b/src/apm_cli/integration/hook_merge.py @@ -0,0 +1,711 @@ +"""Source-marker utilities, dependency scanning, and merge-operation helpers. + +This module holds pure helper functions extracted from ``HookIntegrator`` +to reduce the complexity of the merge path. Nothing here depends on +the ``HookIntegrator`` class itself. +""" + +import json +import logging +import re +from pathlib import Path + +import yaml + +from apm_cli.utils.console import _rich_warning +from apm_cli.utils.path_security import ( + PathTraversalError, + ensure_path_within, + validate_path_segments, +) + +from .hook_transforms import ( + _APM_HOOKS_SIDECAR, + _reinject_apm_source_from_sidecar, + _to_gemini_hook_entries, +) + +_log = logging.getLogger("apm_cli.integration.hook_integrator") + +# --------------------------------------------------------------------------- +# Package source-marker utilities +# --------------------------------------------------------------------------- + + +def _is_root_local_package(package_info, project_root: Path | None) -> bool: + """Return True when *package_info* represents the project's own .apm content.""" + if project_root is None: + return False + try: + return Path(package_info.install_path).resolve() == Path(project_root).resolve() + except (OSError, RuntimeError): + return False + + +def _safe_source_name(value: str | None, fallback: str = "_local") -> str: + """Return a stable source marker that is also safe for hook script paths.""" + if not isinstance(value, str) or not value: + return fallback + safe = re.sub(r"[^A-Za-z0-9._-]+", "-", value.strip()) + # Collapse any run of 2+ dots to a single dot before stripping edges. + # Embedded sequences like "foo..bar" would otherwise pass through the + # earlier guard and reach downstream Path joins as a parent-dir hop. + safe = re.sub(r"\.{2,}", ".", safe).strip(".-_") + if not safe or safe in {".", ".."}: + return fallback + return safe + + +def _get_root_local_package_name(package_info, project_root: Path) -> str: + """Get the stable source marker for root .apm content.""" + apm_yml = Path(project_root) / "apm.yml" + if apm_yml.exists(): + try: + from apm_cli.utils.yaml_io import load_yaml + + data = load_yaml(apm_yml) + if isinstance(data, dict): + manifest_name = _safe_source_name(data.get("name")) + if manifest_name != "_local": + return manifest_name + except (OSError, ValueError, yaml.YAMLError) as exc: + _log.debug( + "Hook integrator: apm.yml manifest unreadable for %s (%s: %s), " + "falling back to install_path basename", + project_root, + exc.__class__.__name__, + exc, + ) + + package = getattr(package_info, "package", None) + package_name = _safe_source_name(getattr(package, "name", None)) + if package_name != "_local": + return package_name + return "_local" + + +def _get_package_name(package_info, project_root: Path | None = None) -> str: + """Get a short package name for use in file/directory naming. + + Args: + package_info: PackageInfo object + project_root: When provided and the package is the project root, + reads ``apm.yml`` ``name`` for a stable source marker instead + of falling back to ``install_path.name`` (which drifts on + directory renames and worktrees). See #1329. + + Returns: + str: Package name used as hook source marker and script namespace + """ + if _is_root_local_package(package_info, project_root): + return _get_root_local_package_name(package_info, Path(project_root)) + return package_info.install_path.name + + +def _get_hook_source_marker( + package_info, + project_root: Path, + package_name: str, +) -> str: + """Get the marker stored in merged hook JSON for ownership cleanup.""" + if _is_root_local_package(package_info, project_root): + if package_name == "_local": + return "_local" + return f"_local/{package_name}" + return package_name + + +def _hook_entry_content_key(entry: dict) -> str: + """Build a stable comparison key excluding APM ownership metadata.""" + comparable = {k: v for k, v in sorted(entry.items()) if k != "_apm_source"} + return json.dumps(comparable, sort_keys=True, separators=(",", ":")) + + +# --------------------------------------------------------------------------- +# Dependency source scanning +# --------------------------------------------------------------------------- + + +def _dependency_hook_sources(project_root: Path) -> set[str]: + """Return source markers that correspond to installed dependency dirs.""" + apm_modules = project_root / "apm_modules" + if not apm_modules.is_dir(): + return set() + + lockfile_paths, lockfile_readable = _lockfile_dependency_paths(project_root) + if lockfile_readable: + sources: set[str] = set() + for rel_path in lockfile_paths: + package_path = _safe_dependency_path(apm_modules, rel_path) + if package_path is None: + continue + _add_dependency_source(sources, package_path) + return sources + + return _bounded_dependency_hook_sources(apm_modules) + + +def _lockfile_dependency_paths(project_root: Path) -> tuple[list[str], bool]: + """Return installed dependency paths from a readable lockfile, if present.""" + try: + from apm_cli.deps.lockfile import LEGACY_LOCKFILE_NAME, LockFile, get_lockfile_path + + lockfile_path = get_lockfile_path(project_root) + if not lockfile_path.exists(): + legacy_path = project_root / LEGACY_LOCKFILE_NAME + if legacy_path.exists(): + lockfile_path = legacy_path + if not lockfile_path.exists(): + return [], False + lockfile = LockFile.read(lockfile_path) + if lockfile is None: + return [], False + return lockfile.get_installed_paths(project_root / "apm_modules"), True + except (AttributeError, OSError, TypeError, ValueError, KeyError): + return [], False + + +def _safe_dependency_path(apm_modules: Path, rel_path: str) -> Path | None: + """Return a lockfile dependency path without escaping apm_modules.""" + try: + validate_path_segments( + rel_path, + context="lockfile dependency path", + reject_empty=True, + ) + package_path = apm_modules / Path(rel_path) + ensure_path_within(package_path, apm_modules) + if _has_symlink_component(apm_modules, package_path): + return None + return package_path + except (OSError, PathTraversalError, RuntimeError, TypeError): + return None + + +def _has_symlink_component(apm_modules: Path, package_path: Path) -> bool: + """Return True when any component below apm_modules is a symlink.""" + try: + relative = package_path.relative_to(apm_modules) + current = apm_modules + for part in relative.parts: + current = current / part + if current.is_symlink(): + return True + return False + except (OSError, ValueError): + return True + + +def _is_dependency_package_dir(path: Path) -> bool: + """Return True when *path* looks like an installed package root.""" + try: + hooks = path / "hooks" + apm_hooks = path / ".apm" / "hooks" + apm_yml = path / "apm.yml" + skill_md = path / "SKILL.md" + return ( + (hooks.is_dir() and not hooks.is_symlink()) + or (apm_hooks.is_dir() and not apm_hooks.is_symlink()) + or (apm_yml.is_file() and not apm_yml.is_symlink()) + or (skill_md.is_file() and not skill_md.is_symlink()) + ) + except OSError: + return False + + +def _add_dependency_source(sources: set[str], package_path: Path) -> bool: + """Add package_path.name to sources when package_path is a package root.""" + try: + if ( + not package_path.is_dir() + or package_path.is_symlink() + or not _is_dependency_package_dir(package_path) + ): + return False + except OSError: + return False + sources.add(package_path.name) + return True + + +def _child_dependency_dirs(path: Path) -> list[Path]: + """Return direct non-hidden child dirs without following symlink roots.""" + try: + if path.is_symlink() or not path.is_dir(): + return [] + return sorted( + [ + child + for child in path.iterdir() + if not child.is_symlink() and child.is_dir() and not child.name.startswith(".") + ], + key=lambda child: child.name, + ) + except OSError: + return [] + + +def _collect_known_subdirectory_sources(sources: set[str], repo_root: Path) -> None: + """Collect dependency sources from known virtual subdirectory layouts.""" + for namespace in ("collections", "skills"): + for package_path in _child_dependency_dirs(repo_root / namespace): + _add_dependency_source(sources, package_path) + + apm_dir = repo_root / ".apm" + try: + if apm_dir.is_symlink() or not apm_dir.is_dir(): + return + except OSError: + return + for primitive in ("agents", "commands", "hooks", "instructions", "prompts", "skills"): + for package_path in _child_dependency_dirs(apm_dir / primitive): + _add_dependency_source(sources, package_path) + + +def _collect_remote_dependency_sources(sources: set[str], namespace: Path) -> None: + """Collect fallback sources from explicit remote install layouts.""" + if _add_dependency_source(sources, namespace): + return + + for repo_or_project in _child_dependency_dirs(namespace): + if _add_dependency_source(sources, repo_or_project): + continue + + _collect_known_subdirectory_sources(sources, repo_or_project) + + for ado_repo in _child_dependency_dirs(repo_or_project): + if _add_dependency_source(sources, ado_repo): + continue + _collect_known_subdirectory_sources(sources, ado_repo) + + +def _collect_local_dependency_sources(sources: set[str], local_namespace: Path) -> None: + """Collect apm_modules/_local/ package roots only.""" + for local_package in _child_dependency_dirs(local_namespace): + _add_dependency_source(sources, local_package) + + +def _bounded_dependency_hook_sources(apm_modules: Path) -> set[str]: + """Fallback source scan limited to known apm_modules package layouts.""" + sources: set[str] = set() + + for package_root in _child_dependency_dirs(apm_modules): + if package_root.name == "_local": + _collect_local_dependency_sources(sources, package_root) + continue + + _collect_remote_dependency_sources(sources, package_root) + return sources + + +# --------------------------------------------------------------------------- +# Merge-entry filtering +# --------------------------------------------------------------------------- + + +def _should_remove_prior_merged_entry( + entry, + *, + source_marker: str, + fresh_content_keys: set[str], + heal_stale_root_source: bool, + dependency_sources: set[str], + remove_current_source: bool, +) -> bool: + """Return True when an existing merged-hook entry should be replaced.""" + if not isinstance(entry, dict): + return False + source = entry.get("_apm_source") + if remove_current_source and source == source_marker: + return True + if not heal_stale_root_source or not source or source in dependency_sources: + return False + return _hook_entry_content_key(entry) in fresh_content_keys + + +# --------------------------------------------------------------------------- +# Merge operation helpers +# --------------------------------------------------------------------------- + + +def _load_merged_config_and_sidecar( + json_path: Path, + sidecar_path: Path, + schema_strict: bool, +) -> dict: + """Load target config JSON and optionally re-inject sidecar _apm_source markers. + + Returns a json_config dict that always has a ``"hooks"`` key. + """ + json_config: dict = {} + if json_path.exists(): + try: + with open(json_path, encoding="utf-8") as f: + json_config = json.load(f) + except (json.JSONDecodeError, OSError): + json_config = {} + + if schema_strict and sidecar_path.exists(): + sidecar_data: dict = {} + try: + with open(sidecar_path, encoding="utf-8") as f: + _raw = json.load(f) + if isinstance(_raw, dict): + sidecar_data = _raw + else: + _log.warning( + "Sidecar file %s contains non-dict JSON; treating as empty.", + sidecar_path, + ) + except (json.JSONDecodeError, OSError) as exc: + _log.warning("Failed to read sidecar %s: %s; treating as empty.", sidecar_path, exc) + + if sidecar_data and "hooks" in json_config: + _reinject_apm_source_from_sidecar(json_config["hooks"], sidecar_data) + + if "hooks" not in json_config: + json_config["hooks"] = {} + + return json_config + + +def _deduplicate_event_entries(entries: list) -> list: + """Deduplicate hook entries by (source, content) key. + + Safety net for edge cases where multiple source files produce + semantically identical entries. + """ + seen_keys: set[str] = set() + deduped: list = [] + for entry in entries: + if not isinstance(entry, dict): + deduped.append(entry) + continue + cmp = {k: v for k, v in sorted(entry.items()) if k != "_apm_source"} + source = entry.get("_apm_source") + dedup_key = json.dumps({"s": source, "c": cmp}, sort_keys=True) + if dedup_key not in seen_keys: + seen_keys.add(dedup_key) + deduped.append(entry) + return deduped + + +def _merge_hook_file_entries( + json_config: dict, + hooks: dict, + target_key: str, + event_map: dict, + source_marker: str, + cleared_events: set, + *, + heal_stale_root_source: bool, + dependency_sources: set, +) -> bool: + """Merge hook entries from one hook file into json_config["hooks"]. + + Applies Gemini transforms, stamps _apm_source, performs idempotent + upsert (stripping prior same-package entries), and deduplicates. + + Returns True when at least one event received new entries. + """ + # Build reverse map: normalised name -> set of source aliases. + # Used to clean up alias event keys left by mixed-case past installs. + reverse_map: dict[str, set[str]] = {} + for source_name, norm_name in event_map.items(): + reverse_map.setdefault(norm_name, set()).add(source_name) + + entries_appended = False + for raw_event_name, entries in hooks.items(): + if not isinstance(entries, list) or not entries: + continue + event_name = event_map.get(raw_event_name, raw_event_name) + if event_name not in json_config["hooks"]: + json_config["hooks"][event_name] = [] + + # Transform flat Copilot entries to Gemini nested format + if target_key == "gemini": + entries = _to_gemini_hook_entries(entries) + + # Mark each entry with APM source for sync/cleanup + for entry in entries: + if isinstance(entry, dict): + entry["_apm_source"] = source_marker + fresh_content_keys = { + _hook_entry_content_key(entry) for entry in entries if isinstance(entry, dict) + } + + # Idempotent upsert: drop prior entries owned by this package + # before appending fresh ones. Only strip once per event per + # install run -- a package with multiple hook files targeting the + # same event contributes each file's entries in turn, and stripping + # on every iteration would erase earlier files' work. + remove_current_source = event_name not in cleared_events + if remove_current_source or heal_stale_root_source: + _upsert_event_entries( + json_config, + event_name, + source_marker, + fresh_content_keys, + heal_stale_root_source, + dependency_sources, + reverse_map, + remove_current_source, + ) + cleared_events.add(event_name) + + json_config["hooks"][event_name].extend(entries) + json_config["hooks"][event_name] = _deduplicate_event_entries( + json_config["hooks"][event_name] + ) + entries_appended = True + + return entries_appended + + +def _upsert_event_entries( + json_config: dict, + event_name: str, + source_marker: str, + fresh_content_keys: set[str], + heal_stale_root_source: bool, + dependency_sources: set, + reverse_map: dict, + remove_current_source: bool, +) -> None: + """Remove stale same-package entries before fresh ones are appended. + + Mutates json_config["hooks"] in-place. + """ + prior_entries = json_config["hooks"][event_name] + kept_entries = [ + e + for e in prior_entries + if not _should_remove_prior_merged_entry( + e, + source_marker=source_marker, + fresh_content_keys=fresh_content_keys, + heal_stale_root_source=heal_stale_root_source, + dependency_sources=dependency_sources, + remove_current_source=remove_current_source, + ) + ] + if heal_stale_root_source: + kept_ids = {id(e) for e in kept_entries} + healed = sum( + 1 + for e in prior_entries + if isinstance(e, dict) + and e.get("_apm_source") + and e.get("_apm_source") != source_marker + and e.get("_apm_source") not in dependency_sources + and id(e) not in kept_ids + ) + if healed: + _log.debug( + "Hook integrator: healed %d stale same-content " + "merged hook entries for source %s in event %s", + healed, + source_marker, + event_name, + ) + json_config["hooks"][event_name] = kept_entries + + # Also clear from any alias events that map to this normalised name + # (handles migration from corrupted installs with mixed-case event keys). + for alias in reverse_map.get(event_name, set()): + if alias != event_name and alias in json_config["hooks"]: + json_config["hooks"][alias] = [ + e + for e in json_config["hooks"][alias] + if not _should_remove_prior_merged_entry( + e, + source_marker=source_marker, + fresh_content_keys=fresh_content_keys, + heal_stale_root_source=heal_stale_root_source, + dependency_sources=dependency_sources, + remove_current_source=remove_current_source, + ) + ] + # Remove the alias key entirely if now empty + if not json_config["hooks"][alias]: + del json_config["hooks"][alias] + + +def _warn_empty_hook_file(hook_file: Path, target_key: str) -> None: + """Emit user-visible and structured-log warnings for an empty hook file. + + A hook file that parsed cleanly but contributed zero entries (all + events empty / non-list) used to bump the counter and lie to the user. + Now we skip it -- emit a warning so the author notices. + """ + rel_hook = hook_file.name + _rich_warning(f"Hook file {rel_hook} contributed no entries to {target_key} settings; skipped.") + _log.warning( + "Hook file %s contributed no entries to %s settings " + "(all events empty or non-list); skipping.", + hook_file, + target_key, + ) + + +def _write_merged_config( + json_path: Path, + sidecar_path: Path, + json_config: dict, + schema_strict: bool, +) -> None: + """Write the merged config (and optionally the sidecar) to disk. + + For schema-strict targets (e.g. Claude): + - Builds a sidecar from entries that carry ``_apm_source``. + - Strips ``_apm_source`` from the config before writing so the + target's schema validator does not reject the file. + - Writes the sidecar alongside the config, or removes it when empty. + """ + if schema_strict: + # Build sidecar from entries that have _apm_source + sidecar_out: dict = {} + for ev_name, entries_list in json_config.get("hooks", {}).items(): + if not isinstance(entries_list, list): + continue + owned = [e for e in entries_list if isinstance(e, dict) and "_apm_source" in e] + if owned: + sidecar_out[ev_name] = [dict(e) for e in owned] + + # Strip _apm_source from entries before writing to disk + for entries_list in json_config.get("hooks", {}).values(): + if isinstance(entries_list, list): + for entry in entries_list: + if isinstance(entry, dict): + entry.pop("_apm_source", None) + + # Write or remove sidecar + if sidecar_out: + try: + with open(sidecar_path, "w", encoding="utf-8") as f: + json.dump(sidecar_out, f, indent=2) + f.write("\n") + except OSError as exc: + _log.warning("Failed to write sidecar %s: %s", sidecar_path, exc) + elif sidecar_path.exists(): + sidecar_path.unlink() + + # Write the (now schema-clean) config + with open(json_path, "w", encoding="utf-8") as f: + json.dump(json_config, f, indent=2) + f.write("\n") + + +# --------------------------------------------------------------------------- +# Sync helpers +# --------------------------------------------------------------------------- + + +def _sync_claude_hooks_settings(json_path: Path, stats: dict[str, int]) -> None: + """Remove APM-managed hook entries from a Claude settings.json file. + + Loads the sidecar to restore _apm_source markers, filters out all + entries tagged with ``_apm_source``, writes the cleaned config back, + and removes the sidecar when no hooks remain. + """ + if not json_path.exists(): + return + try: + with open(json_path, encoding="utf-8") as f: + settings = json.load(f) + + # Load sidecar to restore _apm_source markers + sidecar_path = json_path.parent / _APM_HOOKS_SIDECAR + sidecar_data: dict = {} + if sidecar_path.exists(): + try: + with open(sidecar_path, encoding="utf-8") as sf: + _raw = json.load(sf) + if isinstance(_raw, dict): + sidecar_data = _raw + else: + _log.warning( + "Sidecar file %s contains non-dict JSON; treating as empty.", + sidecar_path, + ) + except (json.JSONDecodeError, OSError) as exc: + _log.warning( + "Failed to read sidecar %s: %s; treating as empty.", + sidecar_path, + exc, + ) + + # Re-inject _apm_source from sidecar + if sidecar_data and "hooks" in settings: + _reinject_apm_source_from_sidecar(settings["hooks"], sidecar_data) + + if "hooks" in settings: + modified = False + for event_name in list(settings["hooks"].keys()): + matchers = settings["hooks"][event_name] + if isinstance(matchers, list): + filtered = [ + m for m in matchers if not (isinstance(m, dict) and "_apm_source" in m) + ] + if len(filtered) != len(matchers): + modified = True + settings["hooks"][event_name] = filtered + if not filtered: + del settings["hooks"][event_name] + + if not settings["hooks"]: + del settings["hooks"] + + if modified: + with open(json_path, "w", encoding="utf-8") as f: + json.dump(settings, f, indent=2) + f.write("\n") + stats["files_removed"] += 1 + + # Clean up sidecar + if sidecar_path.exists(): + sidecar_path.unlink() + + # Remove stale sidecar when no hooks section remains + if sidecar_path.exists() and "hooks" not in settings: + sidecar_path.unlink() + except (json.JSONDecodeError, OSError): + stats["errors"] += 1 + + +def _clean_apm_entries_from_json(json_path: Path, stats: dict[str, int]) -> None: + """Remove APM-tagged entries from a hooks JSON file. + + Filters out entries with ``_apm_source`` markers and cleans up + empty event arrays and the ``hooks`` key itself. + """ + if not json_path.exists(): + return + try: + with open(json_path, encoding="utf-8") as f: + data = json.load(f) + + if "hooks" not in data: + return + + modified = False + for event_name in list(data["hooks"].keys()): + entries = data["hooks"][event_name] + if isinstance(entries, list): + filtered = [e for e in entries if not (isinstance(e, dict) and "_apm_source" in e)] + if len(filtered) != len(entries): + modified = True + data["hooks"][event_name] = filtered + if not filtered: + del data["hooks"][event_name] + + if not data["hooks"]: + del data["hooks"] + + if modified: + with open(json_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + f.write("\n") + stats["files_removed"] += 1 + except (json.JSONDecodeError, OSError): + stats["errors"] += 1 diff --git a/src/apm_cli/integration/hook_transforms.py b/src/apm_cli/integration/hook_transforms.py new file mode 100644 index 000000000..a1e37f1cd --- /dev/null +++ b/src/apm_cli/integration/hook_transforms.py @@ -0,0 +1,514 @@ +"""Pure transform and rewrite helpers for APM hook integration. + +This module holds stateless transform utilities that are shared by the +hook integration, merge, and sync layers. Nothing here depends on +``HookIntegrator`` or ``_MergeHookConfig``. +""" + +import copy +import json +import logging +import re +from collections import deque +from pathlib import Path + +from apm_cli.utils.console import _rich_warning +from apm_cli.utils.path_security import PathTraversalError, ensure_path_within + +_log = logging.getLogger("apm_cli.integration.hook_integrator") + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +# Superset of all known script-path keys across supported hook specs. +# "command": Claude Code (primary), VS Code (default/cross-platform), Cursor +# "bash": GitHub Copilot Agent cloud/CLI +# "powershell": GitHub Copilot Agent cloud/CLI +# "windows": VS Code (OS-specific override) +# "linux": VS Code (OS-specific override) +# "osx": VS Code (OS-specific override) +_HOOK_COMMAND_KEYS: tuple[str, ...] = ( + "command", + "bash", + "powershell", + "windows", + "linux", + "osx", +) + +# Per-target hook event name mapping. Packages are authored with +# Copilot (camelCase) or Claude (PascalCase) names; targets that use +# different conventions get their events renamed during merge. +_HOOK_EVENT_MAP: dict[str, dict[str, str]] = { + "claude": { + # Copilot camelCase -> Claude PascalCase + "preToolUse": "PreToolUse", + "postToolUse": "PostToolUse", + }, + "gemini": { + # Copilot / Claude -> Gemini + "PreToolUse": "BeforeTool", + "preToolUse": "BeforeTool", + "PostToolUse": "AfterTool", + "postToolUse": "AfterTool", + "Stop": "SessionEnd", + }, +} + +# Expected hook event naming convention per target. +# Used to warn when a package author deploys events whose casing does not +# match the target's convention AND no explicit rename mapping exists. +_HOOK_EVENT_EXPECTED_CASING: dict[str, str] = { + "copilot": "camelCase", + "vscode": "PascalCase", + "claude": "PascalCase", + "cursor": "PascalCase", + "codex": "PascalCase", + "gemini": "PascalCase", + "windsurf": "PascalCase", +} + +# Mapping from hook-file stem suffix to the set of target keys that +# should receive the file. Files whose stem does not match any +# suffix are treated as universal and deployed to every target. +_HOOK_FILE_TARGET_SUFFIXES: dict[str, set[str]] = { + "copilot-hooks": {"copilot", "vscode"}, + "cursor-hooks": {"cursor"}, + "claude-hooks": {"claude"}, + "codex-hooks": {"codex"}, + "gemini-hooks": {"gemini"}, + "windsurf-hooks": {"windsurf"}, +} + +# Filename used to persist _apm_source markers for schema-strict targets. +_APM_HOOKS_SIDECAR = "apm-hooks.json" + +# --------------------------------------------------------------------------- +# Event name utilities +# --------------------------------------------------------------------------- + + +def _detect_event_casing(name: str) -> str | None: + """Return 'camelCase', 'PascalCase', or None for an event name string.""" + if not name or not name[0].isalpha(): + return None + if name[0].islower() and any(c.isupper() for c in name[1:]): + return "camelCase" + if name[0].isupper(): + return "PascalCase" + return None + + +def _sanitize_event_name(name: str) -> str: + """Return event name with non-printable-ASCII characters stripped, for safe logging.""" + return "".join(c for c in name if 0x20 <= ord(c) <= 0x7E) + + +def _emit_hook_event_diagnostics( + event_names: list[str], + target_key: str, + event_map: dict[str, str], +) -> None: + """Log hook events per-target and warn on unmapped casing mismatches. + + This is informational only -- it never blocks deployment. + """ + if not event_names: + return + event_label = "hook event" if len(event_names) == 1 else "hook events" + _log.info( + "target %s: detected %s: %s", + target_key, + event_label, + ", ".join(sorted(_sanitize_event_name(n) for n in event_names)), + ) + expected_casing = _HOOK_EVENT_EXPECTED_CASING.get(target_key) + if not expected_casing: + return + # Warn for events whose detected casing does not match the target convention + # and that are not covered by an explicit rename in event_map. + mismatched = [ + n + for n in event_names + if _detect_event_casing(n) not in (None, expected_casing) and n not in event_map + ] + if mismatched: + example = "preToolUse" if expected_casing == "camelCase" else "PreToolUse" + safe_mismatched = sorted(_sanitize_event_name(n) for n in mismatched) + _rich_warning( + f"Hook events for target '{target_key}' may not be recognized: " + f"{', '.join(safe_mismatched)}. " + f"Target expects {expected_casing} (e.g. {example}). " + f"Rename events to match the {expected_casing} convention, then reinstall." + ) + _log.warning( + "target %s: hook event casing mismatch (no mapping): %s", + target_key, + ", ".join(safe_mismatched), + ) + + +# --------------------------------------------------------------------------- +# Gemini transforms +# --------------------------------------------------------------------------- + + +def _to_gemini_hook_entries(entries: list) -> list: + """Transform hook entries into Gemini CLI format. + + Gemini requires ``{"hooks": [...]}`` nesting, uses ``command`` (not + ``bash``), and ``timeout`` in milliseconds (not ``timeoutSec`` in + seconds). Entries already in Claude/Gemini nested format are left + unchanged. + """ + result = [] + for entry in entries: + if not isinstance(entry, dict): + result.append(entry) + continue + # Already nested (Claude / Gemini format) -- just fix inner keys + if "hooks" in entry and isinstance(entry["hooks"], list): + for hook in entry["hooks"]: + _copilot_keys_to_gemini(hook) + result.append(entry) + continue + # Flat Copilot entry -- wrap in nested format + inner = dict(entry) + _copilot_keys_to_gemini(inner) + # Pull _apm_source to outer level (set later, but keep if present) + apm_source = inner.pop("_apm_source", None) + outer: dict = {"hooks": [inner]} + if apm_source: + outer["_apm_source"] = apm_source + result.append(outer) + return result + + +def _copilot_keys_to_gemini(hook: dict) -> None: + """Rename Copilot hook keys to Gemini equivalents in-place.""" + # bash / powershell -> command + if "command" not in hook: + for key in ("bash", "powershell", "windows"): + if key in hook: + hook["command"] = hook.pop(key) + break + # timeoutSec (seconds) -> timeout (milliseconds) + if "timeoutSec" in hook: + hook["timeout"] = hook.pop("timeoutSec") * 1000 + + +# --------------------------------------------------------------------------- +# Sidecar re-injection +# --------------------------------------------------------------------------- + + +def _reinject_apm_source_from_sidecar(hooks: dict, sidecar_data: dict) -> None: + """Restore _apm_source markers from sidecar into in-memory hook entries. + + Schema-strict targets (e.g. Claude) do not persist ``_apm_source`` in + their settings file. Instead, ownership metadata is stored in a + sidecar file. This helper re-injects those markers so the rest of + the integration logic can work with them as normal. + + Each sidecar entry is consumed at most once to prevent falsely claiming + user-owned hooks that happen to have identical content to an APM hook. + + Args: + hooks: The ``"hooks"`` dict loaded from the target config file + (mutated in-place). + sidecar_data: The dict loaded from the sidecar file. + """ + for event_name, sidecar_entries in sidecar_data.items(): + if event_name not in hooks or not isinstance(sidecar_entries, list): + continue + # Build a dict keyed by normalised content -> list of sources. + # Each source is popped on first match so identical content shared + # between APM and the user is only claimed once. + pool: dict[str, deque[str]] = {} + for sc_entry in sidecar_entries: + if isinstance(sc_entry, dict) and "_apm_source" in sc_entry: + cmp = {k: v for k, v in sorted(sc_entry.items()) if k != "_apm_source"} + cmp_key = json.dumps(cmp, sort_keys=True) + pool.setdefault(cmp_key, deque()).append(sc_entry["_apm_source"]) + + for disk_entry in hooks[event_name]: + if not isinstance(disk_entry, dict) or "_apm_source" in disk_entry: + continue + disk_cmp = {k: v for k, v in sorted(disk_entry.items()) if k != "_apm_source"} + disk_key = json.dumps(disk_cmp, sort_keys=True) + sources = pool.get(disk_key) + if sources: + disk_entry["_apm_source"] = sources.popleft() + if not sources: + del pool[disk_key] + + +# --------------------------------------------------------------------------- +# Hook file routing +# --------------------------------------------------------------------------- + + +def _filter_hook_files_for_target( + hook_files: list[Path], + target_key: str, +) -> list[Path]: + """Return only hook files intended for *target_key*. + + Routing is based on the file stem (case-insensitive): + - Stems ending with a known ``--hooks`` suffix are + restricted to matching targets. + - All other stems (e.g. ``hooks``, ``my-custom-hooks``) are + universal and pass through for every target. + + Args: + hook_files: All discovered hook JSON files. + target_key: Lowercase target name (e.g. ``"claude"``, ``"cursor"``). + + Returns: + Filtered list preserving original order. + """ + result: list[Path] = [] + for hf in hook_files: + stem_lower = hf.stem.lower() + matched_suffix: str | None = None + for suffix, allowed_targets in _HOOK_FILE_TARGET_SUFFIXES.items(): + if stem_lower == suffix or stem_lower.endswith(f"-{suffix}"): + matched_suffix = suffix + if target_key in allowed_targets: + result.append(hf) + break + if matched_suffix is None: + # Universal file -- deploy to all targets + result.append(hf) + return result + + +# --------------------------------------------------------------------------- +# Command path rewriting +# --------------------------------------------------------------------------- + + +def _rewrite_command_for_target( + command: str, + package_path: Path, + package_name: str, + target: str, + hook_file_dir: Path | None = None, + root_dir: str | None = None, + deploy_root: Path | None = None, + _warn=None, +) -> tuple[str, list[tuple[Path, str]]]: + """Rewrite a hook command to use installed script paths. + + Handles: + - ${CLAUDE_PLUGIN_ROOT}/path references (resolved from package root) + - ./path relative references (resolved from hook file's parent directory) + - Windows backslash variants of both (.\\ and ${CLAUDE_PLUGIN_ROOT}\\) + + Args: + command: Original command string + package_path: Root path of the source package + package_name: Name used for the scripts subdirectory + target: "vscode" or "claude" + hook_file_dir: Directory containing the hook JSON file (for ./path resolution) + root_dir: Override root directory (e.g. ".copilot" for user scope) + deploy_root: Absolute root of the deployment directory. When provided, + rewritten script paths are resolved to absolute paths under this + root so the target (e.g. Claude Code) can execute them regardless + of the working directory. When *None*, rewritten paths stay + relative (backward-compatible behaviour). + _warn: Warning callable (defaults to _rich_warning); override in tests + or when the caller needs to intercept warnings. + + Returns: + Tuple of (rewritten_command, list of (source_file, relative_target_path)) + """ + if _warn is None: + _warn = _rich_warning + scripts_to_copy = [] + new_command = command + + if target == "vscode": + base_root = root_dir or ".github" + scripts_base = f"{base_root}/hooks/scripts/{package_name}" + elif target == "cursor": + base_root = root_dir or ".cursor" + scripts_base = f"{base_root}/hooks/{package_name}" + elif target == "codex": + base_root = root_dir or ".codex" + scripts_base = f"{base_root}/hooks/{package_name}" + elif target == "windsurf": + base_root = root_dir or ".windsurf" + scripts_base = f"{base_root}/hooks/{package_name}" + else: + base_root = root_dir or ".claude" + scripts_base = f"{base_root}/hooks/{package_name}" + + # Handle plugin root variable references (always relative to package root) + # Match both forward-slash and backslash separators (Windows hook JSON + # may use backslashes: ${CLAUDE_PLUGIN_ROOT}\scripts\scan.ps1) + plugin_root_pattern = ( + r"\$\{(?:CLAUDE_PLUGIN_ROOT|CURSOR_PLUGIN_ROOT|PLUGIN_ROOT)\}([\\/][^\s\"']+)" + ) + for match in re.finditer(plugin_root_pattern, command): + full_var = match.group(0) + # Normalize backslashes to forward slashes before Path construction + # (on Unix, Path treats backslashes as literal filename chars) + rel_path = match.group(1).replace("\\", "/").lstrip("/") + + try: + source_file = ensure_path_within(package_path / rel_path, package_path) + except PathTraversalError: + continue + if source_file.exists() and source_file.is_file(): + target_rel = f"{scripts_base}/{rel_path}" + scripts_to_copy.append((source_file, target_rel)) + resolved_cmd = ( + str((deploy_root / target_rel).resolve()) if deploy_root is not None else target_rel + ) + new_command = new_command.replace(full_var, resolved_cmd) + else: + # File absent: always warn so a misconfigured hook is never + # silently deployed. For user-scope (deploy_root set) also + # rewrite the unexpanded variable to an absolute source path + # so the target surfaces a clear "file not found". For + # project-scope (deploy_root is None) leave the variable in + # place -- rewriting to an absolute path would re-introduce + # the #1394 portability regression in committed configs. + _warn(f"Hook script not found: {source_file}") + if deploy_root is not None: + new_command = new_command.replace(full_var, str(source_file)) + + # Handle relative ./path and .\path references (safe to run after + # ${CLAUDE_PLUGIN_ROOT} substitution since replacements produce paths + # like ".github/..." not "./" or ".\") + # Match both forward-slash and backslash separators (Windows hook JSON + # may use backslashes: .\scripts\scan.ps1) + # Resolve from hook file's directory if available, else fall back to package root + resolve_base = hook_file_dir if hook_file_dir else package_path + rel_pattern = r"(\.[\\/][^\s\"']+)" + for match in re.finditer(rel_pattern, new_command): + rel_ref = match.group(1) + # Normalize to forward slashes for path resolution + rel_path = rel_ref[2:].replace("\\", "/") + + try: + source_file = ensure_path_within(resolve_base / rel_path, package_path) + except PathTraversalError: + continue + if source_file.exists() and source_file.is_file(): + target_rel = f"{scripts_base}/{rel_path}" + scripts_to_copy.append((source_file, target_rel)) + resolved_cmd = ( + str((deploy_root / target_rel).resolve()) if deploy_root is not None else target_rel + ) + new_command = new_command.replace(rel_ref, resolved_cmd) + else: + # File absent: always warn (see ${PLUGIN_ROOT} branch above + # for the project-scope vs user-scope rationale). + _warn(f"Hook script not found: {source_file}") + if deploy_root is not None: + new_command = new_command.replace(rel_ref, str(source_file)) + + return new_command, scripts_to_copy + + +def _rewrite_hooks_data( + data: dict, + package_path: Path, + package_name: str, + target: str, + hook_file_dir: Path | None = None, + root_dir: str | None = None, + deploy_root: Path | None = None, +) -> tuple[dict, list[tuple[Path, str]]]: + """Rewrite all command paths in a hooks JSON structure. + + Creates a deep copy and rewrites command paths for the target platform. + + Args: + data: Parsed hook JSON data + package_path: Root path of the source package + package_name: Name for scripts subdirectory + target: "vscode" or "claude" + hook_file_dir: Directory containing the hook JSON file (for ./path resolution) + root_dir: Override root directory (e.g. ".copilot" for user scope) + deploy_root: Absolute root of the deployment directory. When provided, + all rewritten script paths are resolved to absolute paths so the + target can locate scripts regardless of the working directory. + When *None*, paths remain relative (backward-compatible behaviour). + + Returns: + Tuple of (rewritten_data_copy, list of (source_file, target_rel_path)) + """ + rewritten = copy.deepcopy(data) + all_scripts: list[tuple[Path, str]] = [] + + hooks = rewritten.get("hooks", {}) + for event_name, matchers in hooks.items(): + if not isinstance(matchers, list): + continue + for matcher in matchers: + if not isinstance(matcher, dict): + continue + # Rewrite script paths in the matcher dict itself + # (GitHub Copilot flat format: bash/powershell/windows keys at this level) + for key in _HOOK_COMMAND_KEYS: + if key in matcher: + new_cmd, scripts = _rewrite_command_for_target( + matcher[key], + package_path, + package_name, + target, + hook_file_dir=hook_file_dir, + root_dir=root_dir, + deploy_root=deploy_root, + ) + if scripts: + _log.debug( + "Hook %s/%s: rewrote '%s' key (%d script(s))", + package_name, + event_name, + key, + len(scripts), + ) + matcher[key] = new_cmd + all_scripts.extend(scripts) + + # Rewrite script paths in nested hooks array + # (Claude format: matcher groups with inner hooks array) + for hook in matcher.get("hooks", []): + if not isinstance(hook, dict): + continue + for key in _HOOK_COMMAND_KEYS: + if key in hook: + new_cmd, scripts = _rewrite_command_for_target( + hook[key], + package_path, + package_name, + target, + hook_file_dir=hook_file_dir, + root_dir=root_dir, + deploy_root=deploy_root, + ) + if scripts: + _log.debug( + "Hook %s/%s: rewrote '%s' key (%d script(s))", + package_name, + event_name, + key, + len(scripts), + ) + hook[key] = new_cmd + all_scripts.extend(scripts) + + # De-duplicate by target path to avoid redundant copies when + # multiple keys (e.g. command + bash) reference the same script. + seen_targets: dict[str, Path] = {} + for source, target_rel in all_scripts: + if target_rel not in seen_targets: + seen_targets[target_rel] = source + unique_scripts = [(src, tgt) for tgt, src in seen_targets.items()] + + return rewritten, unique_scripts From 6c6786fb4f0a38634712636e80c61490c1bb58e1 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 03:19:40 +0200 Subject: [PATCH 09/21] refactor(integration): split skill_integrator.py under Stage 2 thresholds Decompose skill_integrator.py (1827 lines) into cohesive siblings: skill_naming, skill_orchestrate, skill_plugin, skill_sync, skill_deploy. skill_integrator now layers over skill_deploy; get_effective_type and should_install_skill remain module-level to preserve monkeypatch targets. validate_skill_name return count reduced under the Stage 2 max-returns gate. Part of #1078 (Strangler Stage 2). Behaviour-preserving; full unit, acceptance and integration suites green. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/integration/skill_deploy.py | 768 ++++++++ src/apm_cli/integration/skill_integrator.py | 1842 +++--------------- src/apm_cli/integration/skill_naming.py | 49 + src/apm_cli/integration/skill_orchestrate.py | 135 ++ src/apm_cli/integration/skill_plugin.py | 186 ++ src/apm_cli/integration/skill_sync.py | 236 +++ 6 files changed, 1600 insertions(+), 1616 deletions(-) create mode 100644 src/apm_cli/integration/skill_deploy.py create mode 100644 src/apm_cli/integration/skill_naming.py create mode 100644 src/apm_cli/integration/skill_orchestrate.py create mode 100644 src/apm_cli/integration/skill_plugin.py create mode 100644 src/apm_cli/integration/skill_sync.py diff --git a/src/apm_cli/integration/skill_deploy.py b/src/apm_cli/integration/skill_deploy.py new file mode 100644 index 000000000..2f196e51a --- /dev/null +++ b/src/apm_cli/integration/skill_deploy.py @@ -0,0 +1,768 @@ +"""Deployment helpers for skill integration.""" + +from __future__ import annotations + +import filecmp +import logging +import shutil +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .skill_naming import normalize_skill_name +from .skill_orchestrate import PackageSkillContext as PackageSkillContext +from .skill_orchestrate import integrate_package_skill as integrate_package_skill +from .skill_plugin import _bin_deploy_denied as _bin_deploy_denied +from .skill_plugin import _copy_plugin_file as _copy_plugin_file +from .skill_plugin import _deploy_bin_files as _deploy_bin_files +from .skill_plugin import _deploy_plugin_bin as _deploy_plugin_bin +from .skill_plugin import _deploy_plugin_manifest as _deploy_plugin_manifest +from .skill_sync import _clean_orphaned_skills as _clean_orphaned_skills +from .skill_sync import _get_lockfile_owned_agent_skills as _get_lockfile_owned_agent_skills +from .skill_sync import _sync_skills_legacy as _sync_skills_legacy +from .skill_sync import _sync_skills_managed_files as _sync_skills_managed_files +from .skill_sync import sync_integration as sync_integration + +_log = logging.getLogger("apm_cli.integration.skill_integrator") + + +@dataclass +class CopySkillContext: + """Dependencies used by the standalone skill-copy helper.""" + + should_install_fn: Callable[[Any], bool] + validate_name_fn: Callable[[str], tuple[bool, str]] + normalize_name_fn: Callable[[str], str] + rewriter_factory: Callable[[], Any] + + +@dataclass +class NativeSkillTargetContext: + """Shared state for deploying one native skill to one target.""" + + package_path: Path + skill_name: str + project_root: Path + current_key: str | None + lockfile_native_owners: dict[str, str] + owned_by: dict[str, str] + sub_skills_dir: Path + seen_skill_dirs: set[Path] + diagnostics: Any + managed_files: set[str] | None + force: bool + logger: Any + link_rewriter: Any + + +@dataclass +class SkillBundleTargetContext: + """Shared state for deploying one skill bundle to one target.""" + + skills_dir: Path + parent_name: str + owned_by: dict[str, str] + diagnostics: Any + managed_files: set[str] | None + force: bool + project_root: Path + logger: Any + name_filter: set[str] | None + link_rewriter: Any + seen_skill_dirs: set[Path] + + +def _validate_skill_name(name: str) -> tuple[bool, str]: + """Resolve the public validator lazily to avoid circular imports.""" + from .skill_integrator import validate_skill_name + + return validate_skill_name(name) + + +def find_instruction_files(package_path: Path) -> list[Path]: + """Find all instruction files in a package.""" + instruction_files: list[Path] = [] + apm_instructions = package_path / ".apm" / "instructions" + if apm_instructions.exists(): + instruction_files.extend(apm_instructions.glob("*.instructions.md")) + return instruction_files + + +def find_agent_files(package_path: Path) -> list[Path]: + """Find all agent files in a package.""" + agent_files: list[Path] = [] + apm_agents = package_path / ".apm" / "agents" + if apm_agents.exists(): + agent_files.extend(apm_agents.glob("*.agent.md")) + return agent_files + + +def find_prompt_files(package_path: Path) -> list[Path]: + """Find all prompt files in a package.""" + prompt_files: list[Path] = [] + if package_path.exists(): + prompt_files.extend(package_path.glob("*.prompt.md")) + apm_prompts = package_path / ".apm" / "prompts" + if apm_prompts.exists(): + prompt_files.extend(apm_prompts.glob("*.prompt.md")) + return prompt_files + + +def find_context_files(package_path: Path) -> list[Path]: + """Find all context and memory files in a package.""" + context_files: list[Path] = [] + apm_context = package_path / ".apm" / "context" + if apm_context.exists(): + context_files.extend(apm_context.glob("*.context.md")) + apm_memory = package_path / ".apm" / "memory" + if apm_memory.exists(): + context_files.extend(apm_memory.glob("*.memory.md")) + return context_files + + +def _copy_skill_to_target( + package_info: Any, + source_path: Path, + target_base: Path, + targets: Any, + context: CopySkillContext, +) -> list[Path]: + """Copy a skill directory to all active target skills directories.""" + if not context.should_install_fn(package_info): + return [] + + source_skill_md = source_path / "SKILL.md" + if not source_skill_md.exists(): + return [] + + raw_skill_name = source_path.name + is_valid, _ = context.validate_name_fn(raw_skill_name) + skill_name = raw_skill_name if is_valid else context.normalize_name_fn(raw_skill_name) + + deployed: list[Path] = [] + seen_skill_dirs: set[Path] = set() + if targets is None: + from apm_cli.integration.targets import active_targets + + targets = active_targets(target_base) + + for target in targets: + if not target.supports("skills"): + continue + skills_mapping = target.primitives["skills"] + effective_root = skills_mapping.deploy_root or target.root_dir + target_root_dir = target_base / target.root_dir + if not target.auto_create and not target_root_dir.is_dir(): + continue + + skill_dir = target_base / effective_root / "skills" / skill_name + from apm_cli.utils.path_security import ( + PathTraversalError, + ensure_path_within, + validate_path_segments, + ) + + validate_path_segments(skill_name, context="skill name") + if skill_dir.is_symlink(): + raise PathTraversalError( + f"Skill destination {skill_dir} is a symlink -- refusing to deploy" + ) + + resolved_project = target_base.resolve() + resolved_skill_dir = skill_dir.resolve() + if not resolved_skill_dir.is_relative_to(resolved_project): + raise PathTraversalError( + f"Skill directory '{skill_dir}' resolves to '{resolved_skill_dir}' " + f"which is outside the project root '{resolved_project}'" + ) + ensure_path_within(skill_dir, target_base / effective_root / "skills") + + resolved = skill_dir.resolve() + if resolved in seen_skill_dirs: + _log.debug("%s -- already deployed, skipping for %s", skill_dir, target.name) + continue + seen_skill_dirs.add(resolved) + + skill_dir.parent.mkdir(parents=True, exist_ok=True) + if skill_dir.exists(): + shutil.rmtree(skill_dir) + rewriter = context.rewriter_factory() + rewriter._copy_source_skill_tree(source_path, skill_dir) + rewriter.init_link_resolver(package_info, target_base) + rewriter._resolve_markdown_links_in_skill_bundle(source_path, skill_dir) + deployed.append(skill_dir) + + return deployed + + +def is_skill_dir_identical_to_source(dir_a: Path, dir_b: Path) -> bool: + """Check if two directory trees have identical file contents.""" + dcmp = filecmp.dircmp(str(dir_a), str(dir_b)) + return _dircmp_equal(dcmp) + + +def _dircmp_equal(dcmp: Any) -> bool: + """Recursively check if dircmp shows identical contents.""" + if dcmp.left_only or dcmp.right_only or dcmp.funny_files: + return False + _, mismatches, errors = filecmp.cmpfiles( + dcmp.left, dcmp.right, dcmp.common_files, shallow=False + ) + if mismatches or errors: + return False + return all(_dircmp_equal(sub_dcmp) for sub_dcmp in dcmp.subdirs.values()) + + +def _resolve_markdown_links_in_skill_bundle( + link_rewriter: Any, + source_root: Path, + target_root: Path, +) -> int: + """Read copied skill markdown from source and write resolved target content.""" + links_resolved = 0 + for target_file in target_root.rglob("*.md"): + if not target_file.is_file() or target_file.is_symlink(): + continue + source_file = source_root / target_file.relative_to(target_root) + if not source_file.is_file() or source_file.is_symlink(): + continue + content = source_file.read_text(encoding="utf-8") + resolved, count = link_rewriter.resolve_links( + content, + source_file, + target_file, + preserved_source_root=source_root, + ) + if count: + target_file.write_text(resolved, encoding="utf-8") + links_resolved += count + return links_resolved + + +def _emit_unmanaged_skill_skip( + sub_name: str, + rel_path: str, + parent_name: str, + diagnostics: Any, + logger: Any, +) -> None: + """Emit the existing unmanaged-skill skip warning.""" + message = ( + f"Skipping skill '{sub_name}' -- local skill exists (not managed by APM). " + "Use 'apm install --force' to overwrite." + ) + if diagnostics is not None: + diagnostics.skip(rel_path, package=parent_name) + elif logger: + logger.warning(message) + else: + try: + from apm_cli.utils.console import _rich_warning + + _rich_warning(message) + except ImportError: + pass + + +def _emit_sub_skill_overwrite( + sub_name: str, + rel_path: str, + parent_name: str, + diagnostics: Any, + logger: Any, +) -> None: + """Emit the existing sub-skill overwrite warning.""" + if diagnostics is not None: + diagnostics.overwrite( + path=rel_path, + package=parent_name, + detail=f"Skill '{sub_name}' replaced -- previously from another package", + ) + elif logger: + logger.warning( + f"Sub-skill '{sub_name}' from '{parent_name}' overwrites existing skill at {rel_path}" + ) + else: + try: + from apm_cli.utils.console import _rich_warning + + _rich_warning( + f"Sub-skill '{sub_name}' from '{parent_name}' overwrites existing skill at {rel_path}" + ) + except ImportError: + pass + + +def _target_rel_prefix(target_skills_root: Path, project_root: Path | None) -> str: + """Return a project-relative target prefix when possible.""" + if project_root is None: + return target_skills_root.name + try: + return target_skills_root.relative_to(project_root).as_posix() + except ValueError: + return target_skills_root.name + + +def _promote_sub_skills( + sub_skills_dir: Path, + target_skills_root: Path, + parent_name: str, + *, + warn: bool = True, + owned_by: dict[str, str] | None = None, + diagnostics: Any = None, + managed_files: set[str] | None = None, + force: bool = False, + project_root: Path | None = None, + logger: Any = None, + name_filter: set[str] | None = None, + link_rewriter: Any = None, +) -> tuple[int, list[Path]]: + """Promote sub-skills from a package skill directory.""" + promoted = 0 + deployed: list[Path] = [] + if not sub_skills_dir.is_dir(): + return promoted, deployed + + rel_prefix = _target_rel_prefix(target_skills_root, project_root) + for sub_skill_path in sub_skills_dir.iterdir(): + if not sub_skill_path.is_dir() or not (sub_skill_path / "SKILL.md").exists(): + continue + raw_sub_name = sub_skill_path.name + if name_filter is not None and raw_sub_name not in name_filter: + continue + is_valid, _ = _validate_skill_name(raw_sub_name) + sub_name = raw_sub_name if is_valid else normalize_skill_name(raw_sub_name) + target = target_skills_root / sub_name + rel_path = f"{rel_prefix}/{sub_name}" + if target.exists(): + if is_skill_dir_identical_to_source(sub_skill_path, target): + promoted += 1 + deployed.append(target) + continue + + is_managed = managed_files is not None and rel_path.replace("\\", "/") in managed_files + prev_owner = (owned_by or {}).get(sub_name) + is_self_overwrite = prev_owner is not None and prev_owner == parent_name + if managed_files is not None and not is_managed and not is_self_overwrite and not force: + _emit_unmanaged_skill_skip(sub_name, rel_path, parent_name, diagnostics, logger) + continue + if warn and not is_self_overwrite: + _emit_sub_skill_overwrite(sub_name, rel_path, parent_name, diagnostics, logger) + shutil.rmtree(target) + target.mkdir(parents=True, exist_ok=True) + if link_rewriter is not None: + link_rewriter._copy_promoted_skill_tree(sub_skill_path, target) + link_rewriter._resolve_markdown_links_in_skill_bundle(sub_skill_path, target) + else: + from apm_cli.security.gate import ignore_non_content + + shutil.copytree(sub_skill_path, target, dirs_exist_ok=True, ignore=ignore_non_content) + promoted += 1 + deployed.append(target) + return promoted, deployed + + +def _build_ownership_maps(project_root: Path) -> tuple[dict[str, str], dict[str, str]]: + """Read the lockfile once and build sub-skill and native-skill owner maps.""" + from apm_cli.deps.lockfile import LockFile, get_lockfile_path + + owned_by: dict[str, str] = {} + native_owners: dict[str, str] = {} + lockfile = LockFile.read(get_lockfile_path(project_root)) + if not lockfile: + return owned_by, native_owners + for dep in lockfile.get_package_dependencies(): + short_owner = (dep.virtual_path or dep.repo_url).rsplit("/", 1)[-1] + unique_key = dep.get_unique_key() + for deployed_path in dep.deployed_files: + normalized = deployed_path.rstrip("/").replace("\\", "/") + skill_name = normalized.rsplit("/", 1)[-1] + owned_by[skill_name] = short_owner + if "/skills/" in normalized: + native_owners[skill_name] = unique_key + return owned_by, native_owners + + +def _target_skills_root(target: Any, project_root: Path) -> Path: + """Return the skills root for a target.""" + if target.resolved_deploy_root is not None: + return target.resolved_deploy_root + skills_mapping = target.primitives["skills"] + effective_root = skills_mapping.deploy_root or target.root_dir + return project_root / effective_root / "skills" + + +def _promote_sub_skills_standalone( + link_rewriter: Any, + package_info: Any, + project_root: Path, + *, + diagnostics: Any = None, + managed_files: set[str] | None = None, + force: bool = False, + logger: Any = None, + targets: Any = None, +) -> tuple[int, list[Path]]: + """Promote sub-skills from a package that is not itself a skill.""" + link_rewriter.init_link_resolver(package_info, project_root) + package_path = package_info.install_path + sub_skills_dir = package_path / ".apm" / "skills" + if not sub_skills_dir.is_dir(): + return 0, [] + if targets is None: + from apm_cli.integration.targets import active_targets + + targets = active_targets(project_root) + + parent_name = package_path.name + owned_by, _ = link_rewriter._build_ownership_maps(project_root) + count = 0 + all_deployed: list[Path] = [] + seen_skill_dirs: set[Path] = set() + for idx, target in enumerate(targets): + if not target.supports("skills"): + continue + is_primary = idx == 0 + target_skills_root = _target_skills_root(target, project_root) + resolved_root = target_skills_root.resolve() + if resolved_root in seen_skill_dirs: + if logger: + logger.progress( + f"{target_skills_root} -- already deployed, skipping for {target.name}", + symbol="info", + ) + continue + seen_skill_dirs.add(resolved_root) + target_skills_root.mkdir(parents=True, exist_ok=True) + n, deployed = _promote_sub_skills( + sub_skills_dir, + target_skills_root, + parent_name, + warn=is_primary, + owned_by=owned_by if is_primary else None, + diagnostics=diagnostics if is_primary else None, + managed_files=managed_files if is_primary else None, + force=force, + project_root=project_root, + link_rewriter=link_rewriter, + ) + if is_primary: + count = n + all_deployed.extend(deployed) + return count, all_deployed + + +def _warn_normalized_skill_name( + raw_skill_name: str, + skill_name: str, + error_msg: str, + diagnostics: Any, + logger: Any, +) -> None: + """Emit the existing normalised skill-name warning.""" + message = f"Skill name '{raw_skill_name}' normalized to '{skill_name}' ({error_msg})" + if diagnostics is not None: + diagnostics.warn(message, package=raw_skill_name) + elif logger: + logger.warning(message) + else: + try: + from apm_cli.utils.console import _rich_warning + + _rich_warning(message) + except ImportError: + pass + + +def _native_collision_warning(ctx: NativeSkillTargetContext, target_skill_dir: Path) -> None: + """Emit the existing native-skill collision warning when needed.""" + prev_owner = ctx.lockfile_native_owners.get( + ctx.skill_name + ) or ctx.link_rewriter._native_skill_session_owners.get(ctx.skill_name) + is_self_overwrite = prev_owner is not None and prev_owner == ctx.current_key + if prev_owner is None or is_self_overwrite: + return + try: + rel_prefix = target_skill_dir.parent.relative_to(ctx.project_root).as_posix() + except ValueError: + rel_prefix = "skills" + rel_path = f"{rel_prefix}/{ctx.skill_name}" + detail = ( + f"Skill '{ctx.skill_name}' from '{ctx.current_key}' replaced " + f"'{prev_owner}' -- remove one package to avoid this" + ) + if ctx.diagnostics is not None: + ctx.diagnostics.overwrite( + path=rel_path, + package=ctx.current_key or ctx.skill_name, + detail=detail, + ) + elif ctx.logger: + ctx.logger.warning(detail) + else: + from apm_cli.utils.console import _rich_warning + + _rich_warning(detail) + + +def _integrate_native_skill_to_target( + target: Any, + *, + is_primary: bool, + context: NativeSkillTargetContext, +) -> dict[str, Any]: + """Integrate one native skill target and return aggregate updates.""" + if not target.supports("skills"): + return {"target_paths": []} + + skills_mapping = target.primitives["skills"] + if target.resolved_deploy_root is not None: + target_skill_dir = target.resolved_deploy_root / context.skill_name + target_skills_root = target.resolved_deploy_root + else: + effective_root = skills_mapping.deploy_root or target.root_dir + target_skills_root = context.project_root / effective_root / "skills" + target_skill_dir = target_skills_root / context.skill_name + + from apm_cli.utils.path_security import ( + PathTraversalError, + ensure_path_within, + validate_path_segments, + ) + + validate_path_segments(context.skill_name, context="skill name") + if target_skill_dir.is_symlink(): + raise PathTraversalError( + f"Skill destination {target_skill_dir} is a symlink -- refusing to deploy" + ) + if target.resolved_deploy_root is None: + ensure_path_within(target_skill_dir, target_skills_root) + + resolved = target_skill_dir.resolve() + if resolved in context.seen_skill_dirs: + if context.logger: + context.logger.progress( + f"{target_skill_dir} -- already deployed, skipping for {target.name}", + symbol="info", + ) + return {"target_paths": []} + context.seen_skill_dirs.add(resolved) + + result: dict[str, Any] = {"target_paths": []} + if is_primary: + result.update( + skill_created=not target_skill_dir.exists(), + skill_updated=target_skill_dir.exists(), + primary_skill_md=target_skill_dir / "SKILL.md", + ) + + if target_skill_dir.exists(): + if is_primary: + _native_collision_warning(context, target_skill_dir) + shutil.rmtree(target_skill_dir) + + target_skill_dir.parent.mkdir(parents=True, exist_ok=True) + context.link_rewriter._copy_native_skill_tree(context.package_path, target_skill_dir) + context.link_rewriter._resolve_markdown_links_in_skill_bundle( + context.package_path, target_skill_dir + ) + result["target_paths"].append(target_skill_dir) + + if is_primary: + result["files_copied"] = sum(1 for path in target_skill_dir.rglob("*") if path.is_file()) + + _, sub_deployed = _promote_sub_skills( + context.sub_skills_dir, + target_skills_root, + context.skill_name, + warn=is_primary, + owned_by=context.owned_by if is_primary else None, + diagnostics=context.diagnostics if is_primary else None, + managed_files=context.managed_files if is_primary else None, + force=context.force, + project_root=context.project_root, + logger=context.logger if is_primary else None, + link_rewriter=context.link_rewriter, + ) + result["target_paths"].extend(sub_deployed) + return result + + +def _integrate_native_skill( + link_rewriter: Any, + package_info: Any, + project_root: Path, + source_skill_md: Path, + *, + diagnostics: Any = None, + managed_files: set[str] | None = None, + force: bool = False, + logger: Any = None, + targets: Any = None, +) -> dict[str, Any]: + """Copy a native skill to all active targets and return result fields.""" + link_rewriter.init_link_resolver(package_info, project_root) + package_path = package_info.install_path + raw_skill_name = package_path.name + is_valid, error_msg = _validate_skill_name(raw_skill_name) + if is_valid: + skill_name = raw_skill_name + else: + skill_name = normalize_skill_name(raw_skill_name) + _warn_normalized_skill_name(raw_skill_name, skill_name, error_msg, diagnostics, logger) + + if targets is None: + from apm_cli.integration.targets import active_targets + + targets = active_targets(project_root) + + dep_ref = package_info.dependency_ref + current_key = dep_ref.get_unique_key() if dep_ref is not None else None + owned_by, lockfile_native_owners = link_rewriter._build_ownership_maps(project_root) + context = NativeSkillTargetContext( + package_path=package_path, + skill_name=skill_name, + project_root=project_root, + current_key=current_key, + lockfile_native_owners=lockfile_native_owners, + owned_by=owned_by, + sub_skills_dir=package_path / ".apm" / "skills", + seen_skill_dirs=set(), + diagnostics=diagnostics, + managed_files=managed_files, + force=force, + logger=logger, + link_rewriter=link_rewriter, + ) + + result: dict[str, Any] = { + "skill_created": False, + "skill_updated": False, + "files_copied": 0, + "target_paths": [], + "primary_skill_md": None, + "sub_skills_promoted": 0, + } + for idx, target in enumerate(targets): + target_result = _integrate_native_skill_to_target( + target, + is_primary=idx == 0, + context=context, + ) + result["target_paths"].extend(target_result.get("target_paths", [])) + for key in ("skill_created", "skill_updated", "files_copied", "primary_skill_md"): + if key in target_result: + result[key] = target_result[key] + + if current_key is not None: + link_rewriter._native_skill_session_owners[skill_name] = current_key + + primary_root = project_root / ".github" / "skills" + result["sub_skills_promoted"] = sum( + 1 + for path in result["target_paths"] + if path.parent == primary_root and path.name != skill_name + ) + return result + + +def _integrate_skill_bundle_target( + target: Any, + *, + is_primary: bool, + context: SkillBundleTargetContext, +) -> dict[str, Any]: + """Integrate one skill-bundle target and return aggregate updates.""" + if not target.supports("skills"): + return {"deployed": [], "promoted": 0, "created": False} + + skills_mapping = target.primitives["skills"] + effective_root = skills_mapping.deploy_root or target.root_dir + target_skills_root = context.project_root / effective_root / "skills" + resolved_root = target_skills_root.resolve() + if resolved_root in context.seen_skill_dirs: + if context.logger: + context.logger.progress( + f"{target_skills_root} -- already deployed, skipping for {target.name}", + symbol="info", + ) + return {"deployed": [], "promoted": 0, "created": False} + context.seen_skill_dirs.add(resolved_root) + target_skills_root.mkdir(parents=True, exist_ok=True) + + promoted, deployed = _promote_sub_skills( + context.skills_dir, + target_skills_root, + context.parent_name, + warn=is_primary, + owned_by=context.owned_by if is_primary else None, + diagnostics=context.diagnostics if is_primary else None, + managed_files=context.managed_files if is_primary else None, + force=context.force, + project_root=context.project_root, + logger=context.logger if is_primary else None, + name_filter=context.name_filter, + link_rewriter=context.link_rewriter, + ) + return {"deployed": deployed, "promoted": promoted, "created": is_primary and promoted > 0} + + +def _integrate_skill_bundle( + link_rewriter: Any, + package_info: Any, + project_root: Path, + skills_dir: Path, + *, + diagnostics: Any = None, + managed_files: set[str] | None = None, + force: bool = False, + logger: Any = None, + targets: Any = None, + skill_subset: Any = None, +) -> dict[str, Any]: + """Promote every skill in a skill bundle's top-level skills directory.""" + link_rewriter.init_link_resolver(package_info, project_root) + if targets is None: + from apm_cli.integration.targets import active_targets + + targets = active_targets(project_root) + + owned_by, _ = link_rewriter._build_ownership_maps(project_root) + context = SkillBundleTargetContext( + skills_dir=skills_dir, + parent_name=package_info.install_path.name, + owned_by=owned_by, + diagnostics=diagnostics, + managed_files=managed_files, + force=force, + project_root=project_root, + logger=logger, + name_filter=set(skill_subset) if skill_subset else None, + link_rewriter=link_rewriter, + seen_skill_dirs=set(), + ) + total_promoted = 0 + all_deployed: list[Path] = [] + any_created = False + for idx, target in enumerate(targets): + target_result = _integrate_skill_bundle_target( + target, + is_primary=idx == 0, + context=context, + ) + if idx == 0: + total_promoted = target_result["promoted"] + any_created = target_result["created"] + all_deployed.extend(target_result["deployed"]) + return { + "skill_created": any_created, + "skill_updated": False, + "skill_skipped": False, + "skill_path": None, + "references_copied": 0, + "links_resolved": 0, + "sub_skills_promoted": total_promoted, + "target_paths": all_deployed, + } diff --git a/src/apm_cli/integration/skill_integrator.py b/src/apm_cli/integration/skill_integrator.py index 3264bdd37..e9ddce324 100644 --- a/src/apm_cli/integration/skill_integrator.py +++ b/src/apm_cli/integration/skill_integrator.py @@ -1,20 +1,20 @@ -"""Skill integration functionality for APM packages (Claude Code & Cursor support).""" +"""Skill integration functionality for APM packages.""" -import filecmp -import hashlib import re import shutil from dataclasses import dataclass, replace from pathlib import Path +from typing import Any from apm_cli.integration.base_integrator import BaseIntegrator +from . import skill_deploy as _skill_deploy +from .skill_naming import _skill_name_char_error +from .skill_naming import normalize_skill_name as normalize_skill_name +from .skill_naming import should_compile_instructions as should_compile_instructions +from .skill_naming import to_hyphen_case as to_hyphen_case + -# DEPRECATED -- use IntegrationResult directly for new code. -# Kept for backward compatibility. The fields map as follows: -# skill_created -> IntegrationResult.skill_created -# sub_skills_promoted -> IntegrationResult.sub_skills_promoted -# skill_path, references_copied -> not mapped (skill-internal) @dataclass class SkillIntegrationResult: """Result of skill integration operation.""" @@ -23,174 +23,39 @@ class SkillIntegrationResult: skill_updated: bool skill_skipped: bool skill_path: Path | None - references_copied: int # Now tracks total files copied to subdirectories - links_resolved: int = 0 # Kept for backwards compatibility - sub_skills_promoted: int = 0 # Number of sub-skills promoted to top-level - bin_deployed: int = 0 # Number of marketplace_plugin bin/ executables deployed - # Why a plugin's bin/ was NOT deployed despite shipping one, so the install - # layer can surface an actionable hint: "project_scope" | "no_claude_target". + references_copied: int + links_resolved: int = 0 + sub_skills_promoted: int = 0 + bin_deployed: int = 0 bin_skipped_reason: str | None = None - target_paths: list[Path] = None # All deployed directories (for deployed_files manifest) + target_paths: list[Path] | None = None - def __post_init__(self): + def __post_init__(self) -> None: if self.target_paths is None: self.target_paths = [] -def to_hyphen_case(name: str) -> str: - """Convert a package name to hyphen-case for Claude Skills spec. - - Args: - name: Package name (e.g., "owner/repo" or "MyPackage") - - Returns: - str: Hyphen-case name, max 64 chars (e.g., "owner-repo" or "my-package") - """ - # Extract just the repo name if it's owner/repo format - if "/" in name: - name = name.split("/")[-1] - - # Replace underscores and spaces with hyphens - result = name.replace("_", "-").replace(" ", "-") - - # Insert hyphens before uppercase letters (camelCase to hyphen-case) - result = re.sub(r"([a-z])([A-Z])", r"\1-\2", result) - - # Convert to lowercase and remove any invalid characters - result = re.sub(r"[^a-z0-9-]", "", result.lower()) - - # Remove consecutive hyphens - result = re.sub(r"-+", "-", result) - - # Remove leading/trailing hyphens - result = result.strip("-") - - # Truncate to 64 chars (Claude Skills spec limit) - return result[:64] - - def validate_skill_name(name: str) -> tuple[bool, str]: - """Validate skill name per agentskills.io spec. - - Skill names must: - - Be 1-64 characters long - - Contain only lowercase alphanumeric characters and hyphens (a-z, 0-9, -) - - Not contain consecutive hyphens (--) - - Not start or end with a hyphen - - Args: - name: Skill name to validate - - Returns: - tuple[bool, str]: (is_valid, error_message) - - is_valid: True if name is valid, False otherwise - - error_message: Empty string if valid, descriptive error otherwise - """ - # Check length - if len(name) < 1: + """Validate skill name per agentskills.io spec.""" + if not name: return (False, "Skill name cannot be empty") - if len(name) > 64: return (False, f"Skill name must be 1-64 characters (got {len(name)})") - - # Check for consecutive hyphens if "--" in name: return (False, "Skill name cannot contain consecutive hyphens (--)") - - # Check for leading/trailing hyphens if name.startswith("-"): return (False, "Skill name cannot start with a hyphen") - if name.endswith("-"): return (False, "Skill name cannot end with a hyphen") - - # Check for valid characters (lowercase alphanumeric + hyphens only) - # Pattern: must start and end with alphanumeric, with alphanumeric or hyphens in between - pattern = r"^[a-z0-9]([a-z0-9-]*[a-z0-9])?$" - if not re.match(pattern, name): - # Determine specific error - if any(c.isupper() for c in name): - return (False, "Skill name must be lowercase (no uppercase letters)") - - if "_" in name: - return (False, "Skill name cannot contain underscores (use hyphens instead)") - - if " " in name: - return (False, "Skill name cannot contain spaces (use hyphens instead)") - - # Check for other invalid characters - invalid_chars = set(re.findall(r"[^a-z0-9-]", name)) - if invalid_chars: - return ( - False, - f"Skill name contains invalid characters: {', '.join(sorted(invalid_chars))}", - ) - - return (False, "Skill name must be lowercase alphanumeric with hyphens only") - + if not re.match(r"^[a-z0-9]([a-z0-9-]*[a-z0-9])?$", name): + return (False, _skill_name_char_error(name)) return (True, "") -def normalize_skill_name(name: str) -> str: - """Convert any package name to a valid skill name per agentskills.io spec. - - Normalization steps: - 1. Extract repo name if owner/repo format - 2. Convert to lowercase - 3. Replace underscores and spaces with hyphens - 4. Convert camelCase to hyphen-case - 5. Remove invalid characters - 6. Remove consecutive hyphens - 7. Strip leading/trailing hyphens - 8. Truncate to 64 characters - - Args: - name: Package name to normalize (e.g., "owner/MyRepo_Name") - - Returns: - str: Valid skill name (e.g., "my-repo-name") - """ - # Use to_hyphen_case which already handles most normalization - return to_hyphen_case(name) - - -# ============================================================================= -# Package Type Routing Functions (T4) -# ============================================================================= -# These functions determine behavior based on: -# 1. Explicit `type` field in apm.yml (highest priority) -# 2. Presence of SKILL.md at package root (makes it a skill) -# 3. Default to INSTRUCTIONS for instruction-only packages -# -# Per skill-strategy.md Decision 2: "Skills are explicit, not implicit" -# - Packages with SKILL.md OR explicit type: skill/hybrid -> become skills -# - Packages with only instructions -> compile to AGENTS.md, NOT skills - - -def get_effective_type(package_info) -> "PackageContentType": - """Get effective package content type based on package structure. - - Determines type by: - 1. Package has SKILL.md (PackageType.CLAUDE_SKILL or HYBRID) -> SKILL - 2. Package is a SKILL_BUNDLE or MARKETPLACE_PLUGIN (has skills/) -> SKILL - 3. Otherwise -> INSTRUCTIONS (compile to AGENTS.md only) - - Args: - package_info: PackageInfo object containing package metadata - - Returns: - PackageContentType: The effective type - """ +def get_effective_type(package_info: Any) -> "PackageContentType": + """Get effective package content type based on package structure.""" from apm_cli.models.apm_package import PackageContentType, PackageType - # Check if package has SKILL.md (via package_type field) - # PackageType.CLAUDE_SKILL = has root SKILL.md only - # PackageType.HYBRID = has both apm.yml AND root SKILL.md - # PackageType.SKILL_BUNDLE = has skills//SKILL.md (nested bundle) - # PackageType.MARKETPLACE_PLUGIN = has plugin manifest (plugin.json or - # .claude-plugin/); may or may not include skills/. The integrator - # path gates on actual skills/ presence, so plugins without skills - # are inert in the SKILL branch. if package_info.package_type in ( PackageType.CLAUDE_SKILL, PackageType.HYBRID, @@ -198,362 +63,97 @@ def get_effective_type(package_info) -> "PackageContentType": PackageType.MARKETPLACE_PLUGIN, ): return PackageContentType.SKILL - - # Default to INSTRUCTIONS for packages without SKILL.md return PackageContentType.INSTRUCTIONS -def should_install_skill(package_info) -> bool: - """Determine if package should be installed as a native skill. - - This controls whether a package gets installed to .github/skills/ (or .claude/skills/). - - Per skill-strategy.md Decision 2 - "Skills are explicit, not implicit": - - Returns True for: - - SKILL: Package has SKILL.md or declares type: skill - - HYBRID: Package declares type: hybrid in apm.yml - - Returns False for: - - INSTRUCTIONS: Compile to AGENTS.md only, no skill created - - PROMPTS: Commands/prompts only, no skill created - - Packages without SKILL.md and no explicit type field - - Args: - package_info: PackageInfo object containing package metadata - - Returns: - bool: True if package should be installed as a native skill - """ +def should_install_skill(package_info: Any) -> bool: + """Determine if package should be installed as a native skill.""" from apm_cli.models.apm_package import PackageContentType effective_type = get_effective_type(package_info) - - # SKILL and HYBRID should install as skills - # INSTRUCTIONS and PROMPTS should NOT install as skills return effective_type in (PackageContentType.SKILL, PackageContentType.HYBRID) -def should_compile_instructions(package_info) -> bool: - """Determine if package should compile to AGENTS.md/CLAUDE.md. - - This controls whether a package's instructions are included in compiled output. - - Per skill-strategy.md Decision 2: - - Returns True for: - - INSTRUCTIONS: Compile to AGENTS.md only (default for packages without SKILL.md) - - HYBRID: Package declares type: hybrid in apm.yml - - Returns False for: - - SKILL: Install as native skill only, no AGENTS.md compilation - - PROMPTS: Commands/prompts only, no instructions compiled - - Args: - package_info: PackageInfo object containing package metadata - - Returns: - bool: True if package's instructions should be compiled to AGENTS.md/CLAUDE.md - """ - from apm_cli.models.apm_package import PackageContentType - - effective_type = get_effective_type(package_info) - - # INSTRUCTIONS and HYBRID should compile to AGENTS.md - # SKILL and PROMPTS should NOT compile to AGENTS.md - return effective_type in (PackageContentType.INSTRUCTIONS, PackageContentType.HYBRID) - - def copy_skill_to_target( - package_info, + package_info: Any, source_path: Path, target_base: Path, - targets=None, + targets: Any = None, ) -> list[Path]: - """Copy skill directory to all active target skills/ directories. - - This is a standalone function for direct skill copy operations. - It handles: - - Package type routing via should_install_skill() - - Skill name validation/normalization - - Directory structure preservation - - Deployment to every active target that supports skills - - When *targets* is provided, only those targets are used. - Otherwise falls back to ``active_targets()``. - - Source SKILL.md gets no metadata injection; outbound package links are rewritten. - - Copies: - - SKILL.md (required) - - scripts/ (optional) - - references/ (optional) - - assets/ (optional) - - Any other subdirectories the package contains - - Args: - package_info: PackageInfo object with package metadata - source_path: Path to skill in apm_modules/ - target_base: Usually project root - targets: Optional explicit list of TargetProfile objects. - - Returns: - List of all deployed skill directory paths (empty if skipped). - """ - # Check if package type allows skill installation (T4 routing) - if not should_install_skill(package_info): - return [] - - # Check for SKILL.md existence - source_skill_md = source_path / "SKILL.md" - if not source_skill_md.exists(): - # No SKILL.md means this package is handled by compilation, not skill copy - return [] - - # Get and validate skill name from folder - raw_skill_name = source_path.name - - is_valid, _ = validate_skill_name(raw_skill_name) - if is_valid: # noqa: SIM108 - skill_name = raw_skill_name - else: - skill_name = normalize_skill_name(raw_skill_name) - - deployed: list[Path] = [] - seen_skill_dirs: set[Path] = set() - - # Deploy to all active targets that support skills. - # When no targets are provided, fall back to project-scope detection. - # Callers responsible for user-scope should pass resolved targets - # from resolve_targets(). - if targets is None: - from apm_cli.integration.targets import active_targets - - targets = active_targets(target_base) - for target in targets: - if not target.supports("skills"): - continue - skills_mapping = target.primitives["skills"] - effective_root = skills_mapping.deploy_root or target.root_dir - - # Skip if target dir does not exist and auto_create is disabled - target_root_dir = target_base / target.root_dir - if not target.auto_create and not target_root_dir.is_dir(): - continue - - skill_dir = target_base / effective_root / "skills" / skill_name - - # Security: reject traversal in skill name and validate containment. - # The containment check resolves the *base* (which may sit behind a - # symlink) but verifies the *unresolved* caller-controlled segment - # (skill_name) has no traversal parts. This prevents a symlink at - # target_base / effective_root from silently redirecting writes - # outside the project root. - from apm_cli.utils.path_security import ( - PathTraversalError, - ensure_path_within, - validate_path_segments, - ) - - validate_path_segments(skill_name, context="skill name") - if skill_dir.is_symlink(): - raise PathTraversalError( - f"Skill destination {skill_dir} is a symlink -- refusing to deploy" - ) - - # Verify the resolved skill directory is within the project root. - # This catches the case where an ancestor directory (e.g. - # effective_root) is a symlink pointing outside the project. - resolved_project = target_base.resolve() - resolved_skill_dir = skill_dir.resolve() - if not resolved_skill_dir.is_relative_to(resolved_project): - raise PathTraversalError( - f"Skill directory '{skill_dir}' resolves to '{resolved_skill_dir}' " - f"which is outside the project root '{resolved_project}'" - ) - ensure_path_within(skill_dir, target_base / effective_root / "skills") - - # Dedup: skip if same resolved path already deployed. - resolved = skill_dir.resolve() - if resolved in seen_skill_dirs: - import logging - - logging.getLogger(__name__).debug( - "%s -- already deployed, skipping for %s", skill_dir, target.name - ) - continue - seen_skill_dirs.add(resolved) - - skill_dir.parent.mkdir(parents=True, exist_ok=True) - if skill_dir.exists(): - shutil.rmtree(skill_dir) - from apm_cli.security.gate import ignore_non_content - - shutil.copytree(source_path, skill_dir, ignore=ignore_non_content) - rewriter = SkillIntegrator() - rewriter.init_link_resolver(package_info, target_base) - rewriter._resolve_markdown_links_in_skill_bundle(source_path, skill_dir) - deployed.append(skill_dir) - - return deployed + """Copy a skill directory to all active target skills directories.""" + context = _skill_deploy.CopySkillContext( + should_install_fn=should_install_skill, + validate_name_fn=validate_skill_name, + normalize_name_fn=normalize_skill_name, + rewriter_factory=SkillIntegrator, + ) + return _skill_deploy._copy_skill_to_target( + package_info, source_path, target_base, targets, context + ) class SkillIntegrator(BaseIntegrator): - """Handles integration of native SKILL.md files for Claude Code, Cursor, and VS Code. - - Claude Skills Spec: - - SKILL.md files provide structured context for Claude Code - - YAML frontmatter with name, description, and metadata - - Markdown body with instructions and agent definitions - - references/ subdirectory for prompt files - """ + """Handles integration of native skill files for supported targets.""" def __init__(self) -> None: - # In-memory map of skill_name -> dep.get_unique_key() updated as each native - # skill is deployed in the current install run. Complements the lockfile-based - # map so that same-manifest collisions are detected before the lockfile is written. self._native_skill_session_owners: dict[str, str] = {} def find_instruction_files(self, package_path: Path) -> list[Path]: - """Find all instruction files in a package. - - Searches in: - - .apm/instructions/ subdirectory - - Args: - package_path: Path to the package directory - - Returns: - List[Path]: List of absolute paths to instruction files - """ - instruction_files = [] - - # Search in .apm/instructions/ - apm_instructions = package_path / ".apm" / "instructions" - if apm_instructions.exists(): - instruction_files.extend(apm_instructions.glob("*.instructions.md")) - - return instruction_files + """Find all instruction files in a package.""" + return _skill_deploy.find_instruction_files(package_path) def find_agent_files(self, package_path: Path) -> list[Path]: - """Find all agent files in a package. - - Searches in: - - .apm/agents/ subdirectory - - Args: - package_path: Path to the package directory - - Returns: - List[Path]: List of absolute paths to agent files - """ - agent_files = [] - - # Search in .apm/agents/ - apm_agents = package_path / ".apm" / "agents" - if apm_agents.exists(): - agent_files.extend(apm_agents.glob("*.agent.md")) - - return agent_files + """Find all agent files in a package.""" + return _skill_deploy.find_agent_files(package_path) def find_prompt_files(self, package_path: Path) -> list[Path]: - """Find all prompt files in a package. - - Searches in: - - Package root directory - - .apm/prompts/ subdirectory - - Args: - package_path: Path to the package directory - - Returns: - List[Path]: List of absolute paths to prompt files - """ - prompt_files = [] - - # Search in package root - if package_path.exists(): - prompt_files.extend(package_path.glob("*.prompt.md")) - - # Search in .apm/prompts/ - apm_prompts = package_path / ".apm" / "prompts" - if apm_prompts.exists(): - prompt_files.extend(apm_prompts.glob("*.prompt.md")) - - return prompt_files + """Find all prompt files in a package.""" + return _skill_deploy.find_prompt_files(package_path) def find_context_files(self, package_path: Path) -> list[Path]: - """Find all context/memory files in a package. + """Find all context and memory files in a package.""" + return _skill_deploy.find_context_files(package_path) - Searches in: - - .apm/context/ subdirectory - - .apm/memory/ subdirectory - - Args: - package_path: Path to the package directory + @staticmethod + def is_skill_dir_identical_to_source(dir_a: Path, dir_b: Path) -> bool: + """Check if two directory trees have identical file contents.""" + return _skill_deploy.is_skill_dir_identical_to_source(dir_a, dir_b) - Returns: - List[Path]: List of absolute paths to context files - """ - context_files = [] + @staticmethod + def _dircmp_equal(dcmp: Any) -> bool: + """Recursively check if dircmp shows identical contents.""" + return _skill_deploy._dircmp_equal(dcmp) - # Search in .apm/context/ - apm_context = package_path / ".apm" / "context" - if apm_context.exists(): - context_files.extend(apm_context.glob("*.context.md")) + def _resolve_markdown_links_in_skill_bundle(self, source_root: Path, target_root: Path) -> int: + """Read copied skill markdown from source and write resolved target content.""" + return _skill_deploy._resolve_markdown_links_in_skill_bundle(self, source_root, target_root) - # Search in .apm/memory/ - apm_memory = package_path / ".apm" / "memory" - if apm_memory.exists(): - context_files.extend(apm_memory.glob("*.memory.md")) + @staticmethod + def _copy_source_skill_tree(source_path: Path, skill_dir: Path) -> None: + """Copy a standalone skill tree while excluding non-content files.""" + from apm_cli.security.gate import ignore_non_content - return context_files + shutil.copytree(source_path, skill_dir, ignore=ignore_non_content) @staticmethod - def is_skill_dir_identical_to_source(dir_a: Path, dir_b: Path) -> bool: - """Check if two directory trees have identical file contents.""" - dcmp = filecmp.dircmp(str(dir_a), str(dir_b)) - return SkillIntegrator._dircmp_equal(dcmp) + def _copy_native_skill_tree(package_path: Path, target_skill_dir: Path) -> None: + """Copy a native skill tree while excluding non-content files and .apm.""" + from apm_cli.security.gate import ignore_non_content + + def ignore_non_content_and_apm(directory: str, contents: list[str]) -> list[str]: + ignored = set(ignore_non_content(directory, contents)) + if ".apm" in contents: + ignored.add(".apm") + return list(ignored) + + shutil.copytree(package_path, target_skill_dir, ignore=ignore_non_content_and_apm) @staticmethod - def _dircmp_equal(dcmp) -> bool: - """Recursively check if dircmp shows identical contents.""" - if dcmp.left_only or dcmp.right_only or dcmp.funny_files: - return False - _, mismatches, errors = filecmp.cmpfiles( - dcmp.left, dcmp.right, dcmp.common_files, shallow=False - ) - if mismatches or errors: - return False - for sub_dcmp in dcmp.subdirs.values(): # noqa: SIM110 - if not SkillIntegrator._dircmp_equal(sub_dcmp): - return False - return True + def _copy_promoted_skill_tree(sub_skill_path: Path, target: Path) -> None: + """Copy a promoted sub-skill tree while excluding non-content files.""" + from apm_cli.security.gate import ignore_non_content - def _resolve_markdown_links_in_skill_bundle( - self, - source_root: Path, - target_root: Path, - ) -> int: - """Read copied skill markdown from source and write resolved target content.""" - links_resolved = 0 - for target_file in target_root.rglob("*.md"): - if not target_file.is_file() or target_file.is_symlink(): - continue - source_file = source_root / target_file.relative_to(target_root) - if not source_file.is_file() or source_file.is_symlink(): - continue - content = source_file.read_text(encoding="utf-8") - resolved, count = self.resolve_links( - content, - source_file, - target_file, - preserved_source_root=source_root, - ) - if count: - target_file.write_text(resolved, encoding="utf-8") - links_resolved += count - return links_resolved + shutil.copytree(sub_skill_path, target, dirs_exist_ok=True, ignore=ignore_non_content) @staticmethod def _promote_sub_skills( @@ -563,769 +163,157 @@ def _promote_sub_skills( *, warn: bool = True, owned_by: dict[str, str] | None = None, - diagnostics=None, - managed_files=None, + diagnostics: Any = None, + managed_files: set[str] | None = None, force: bool = False, project_root: Path | None = None, - logger=None, - name_filter: "set | None" = None, - link_rewriter: "SkillIntegrator | None" = None, + logger: Any = None, + name_filter: set[str] | None = None, + link_rewriter: Any = None, ) -> tuple[int, list[Path]]: - """Promote sub-skills from .apm/skills/ to top-level skill entries. - - Args: - sub_skills_dir: Path to the .apm/skills/ directory in the source package. - target_skills_root: Root skills directory (e.g. .github/skills/ or .claude/skills/). - parent_name: Name of the parent skill (used in warning messages). - warn: Whether to emit a warning on name collisions. - owned_by: Map of skill_name -> owner_package_name from the lockfile. - When provided, warnings are suppressed for self-overwrites. - diagnostics: Optional DiagnosticCollector for deferred warning output. - project_root: Project root for computing relative diagnostic paths. - - Returns: - tuple[int, list[Path]]: (count of promoted sub-skills, list of deployed dir paths) - """ - promoted = 0 - deployed = [] - if not sub_skills_dir.is_dir(): - return promoted, deployed - - # Compute project-relative prefix for consistent path reporting - if project_root is not None: - try: - rel_prefix = target_skills_root.relative_to(project_root).as_posix() - except ValueError: - # Dynamic-root targets (cowork): use synthetic prefix - # when the skills root lives outside the project tree. - rel_prefix = target_skills_root.name - else: - rel_prefix = target_skills_root.name - - for sub_skill_path in sub_skills_dir.iterdir(): - if not sub_skill_path.is_dir(): - continue - if not (sub_skill_path / "SKILL.md").exists(): - continue - raw_sub_name = sub_skill_path.name - # --skill filter: skip skills not in the requested subset - if name_filter is not None and raw_sub_name not in name_filter: - continue - is_valid, _ = validate_skill_name(raw_sub_name) - sub_name = raw_sub_name if is_valid else normalize_skill_name(raw_sub_name) - target = target_skills_root / sub_name - rel_path = f"{rel_prefix}/{sub_name}" - if target.exists(): - # Content-identical → skip entirely (no copy, no warning) - if SkillIntegrator.is_skill_dir_identical_to_source(sub_skill_path, target): - promoted += 1 - deployed.append(target) - continue - - # Check if this is a user-authored skill (not managed by APM) - is_managed = ( - managed_files is not None and rel_path.replace("\\", "/") in managed_files - ) - prev_owner = (owned_by or {}).get(sub_name) - is_self_overwrite = prev_owner is not None and prev_owner == parent_name - - if managed_files is not None and not is_managed and not is_self_overwrite: - # User-authored skill — respect force flag - if not force: - if diagnostics is not None: - diagnostics.skip(rel_path, package=parent_name) - elif logger: - logger.warning( - f"Skipping skill '{sub_name}' -- local skill exists (not managed by APM). " - f"Use 'apm install --force' to overwrite." - ) - else: - try: - from apm_cli.utils.console import _rich_warning - - _rich_warning( - f"Skipping skill '{sub_name}' -- local skill exists (not managed by APM). " - f"Use 'apm install --force' to overwrite." - ) - except ImportError: - pass - continue # SKIP — protect user content - - if warn and not is_self_overwrite: - if diagnostics is not None: - diagnostics.overwrite( - path=rel_path, - package=parent_name, - detail=f"Skill '{sub_name}' replaced -- previously from another package", - ) - elif logger: - logger.warning( - f"Sub-skill '{sub_name}' from '{parent_name}' overwrites existing skill at {rel_path}" - ) - else: - try: - from apm_cli.utils.console import _rich_warning - - _rich_warning( - f"Sub-skill '{sub_name}' from '{parent_name}' overwrites existing skill at {rel_path}" - ) - except ImportError: - pass - shutil.rmtree(target) - target.mkdir(parents=True, exist_ok=True) - from apm_cli.security.gate import ignore_non_content - - shutil.copytree(sub_skill_path, target, dirs_exist_ok=True, ignore=ignore_non_content) - if link_rewriter is not None: - link_rewriter._resolve_markdown_links_in_skill_bundle(sub_skill_path, target) - promoted += 1 - deployed.append(target) - return promoted, deployed + """Promote sub-skills from a package skill directory.""" + return _skill_deploy._promote_sub_skills( + sub_skills_dir, + target_skills_root, + parent_name, + warn=warn, + owned_by=owned_by, + diagnostics=diagnostics, + managed_files=managed_files, + force=force, + project_root=project_root, + logger=logger, + name_filter=name_filter, + link_rewriter=link_rewriter, + ) @staticmethod def _build_ownership_maps(project_root: Path) -> tuple[dict[str, str], dict[str, str]]: - """Read the lockfile once and build two ownership maps. - - Returns a tuple of: - - owned_by: skill_name -> last-segment owner name, for sub-skill self-overwrite detection. - - native_owners: skill_name -> dep.get_unique_key(), for native-skill cross-package - collision detection. Only paths under a ``/skills/`` prefix are included to avoid - false attribution from non-skill deployed_files entries (prompts, hooks, commands, etc.). - """ - from apm_cli.deps.lockfile import LockFile, get_lockfile_path - - owned_by: dict[str, str] = {} - native_owners: dict[str, str] = {} - lockfile = LockFile.read(get_lockfile_path(project_root)) - if not lockfile: - return owned_by, native_owners - for dep in lockfile.get_package_dependencies(): - short_owner = (dep.virtual_path or dep.repo_url).rsplit("/", 1)[-1] - unique_key = dep.get_unique_key() - for deployed_path in dep.deployed_files: - normalized = deployed_path.rstrip("/").replace("\\", "/") - skill_name = normalized.rsplit("/", 1)[-1] - # Both maps cover all paths for sub-skill self-overwrite tracking. - owned_by[skill_name] = short_owner - # Native-owner map is scoped to skill paths only to avoid false - # attribution from prompts/hooks/commands that share a leaf name. - if "/skills/" in normalized: - native_owners[skill_name] = unique_key - return owned_by, native_owners + """Read the lockfile once and build sub-skill and native-skill owner maps.""" + return _skill_deploy._build_ownership_maps(project_root) @staticmethod def _build_skill_ownership_map(project_root: Path) -> dict[str, str]: - """Build a map of skill_name -> owner_package_name from the lockfile. - - Used to distinguish self-overwrites (no warning) from cross-package - conflicts (warning) when promoting sub-skills. - """ + """Build a map of skill name to owner package name from the lockfile.""" owned_by, _ = SkillIntegrator._build_ownership_maps(project_root) return owned_by @staticmethod def _build_native_skill_owner_map(project_root: Path) -> dict[str, str]: - """Build a map of skill_name -> dep.get_unique_key() from the lockfile. - - Scoped to ``/skills/`` paths only -- see ``_build_ownership_maps`` for details. - """ + """Build a map of skill name to dependency key from the lockfile.""" _, native_owners = SkillIntegrator._build_ownership_maps(project_root) return native_owners def _promote_sub_skills_standalone( self, - package_info, + package_info: Any, project_root: Path, - diagnostics=None, - managed_files=None, + diagnostics: Any = None, + managed_files: set[str] | None = None, force: bool = False, - logger=None, - targets=None, + logger: Any = None, + targets: Any = None, ) -> tuple[int, list[Path]]: - """Promote sub-skills from a package that is NOT itself a skill. - - Packages typed as INSTRUCTIONS may still ship sub-skills under - ``.apm/skills/``. This method promotes them to all active targets - that support skills, without creating a top-level skill entry for - the parent package. - - Args: - package_info: PackageInfo object with package metadata. - project_root: Root directory of the project. - targets: Optional explicit list of TargetProfile objects. - - Returns: - tuple[int, list[Path]]: (count of promoted sub-skills, list of deployed dirs) - """ - self.init_link_resolver(package_info, project_root) - package_path = package_info.install_path - sub_skills_dir = package_path / ".apm" / "skills" - if not sub_skills_dir.is_dir(): - return 0, [] - - if targets is None: - from apm_cli.integration.targets import active_targets - - targets = active_targets(project_root) - - parent_name = package_path.name - owned_by = self._build_skill_ownership_map(project_root) - count = 0 - all_deployed: list[Path] = [] - seen_skill_dirs: set[Path] = set() - - for idx, target in enumerate(targets): - if not target.supports("skills"): - continue - - is_primary = idx == 0 # first active target owns diagnostics - skills_mapping = target.primitives["skills"] - # Dynamic-root targets (cowork): use resolved_deploy_root. - if target.resolved_deploy_root is not None: - target_skills_root = target.resolved_deploy_root - else: - effective_root = skills_mapping.deploy_root or target.root_dir - target_skills_root = project_root / effective_root / "skills" - - # Dedup: skip if same resolved skills root already processed. - resolved_root = target_skills_root.resolve() - if resolved_root in seen_skill_dirs: - if logger: - logger.progress( - f"{target_skills_root} -- already deployed, skipping for {target.name}", - symbol="info", - ) - continue - seen_skill_dirs.add(resolved_root) - - target_skills_root.mkdir(parents=True, exist_ok=True) - - n, deployed = self._promote_sub_skills( - sub_skills_dir, - target_skills_root, - parent_name, - warn=is_primary, - owned_by=owned_by if is_primary else None, - diagnostics=diagnostics if is_primary else None, - managed_files=managed_files if is_primary else None, - force=force, - project_root=project_root, - link_rewriter=self, - ) - if is_primary: - count = n - all_deployed.extend(deployed) - - return count, all_deployed + """Promote sub-skills from a package that is not itself a skill.""" + return _skill_deploy._promote_sub_skills_standalone( + self, + package_info, + project_root, + diagnostics=diagnostics, + managed_files=managed_files, + force=force, + logger=logger, + targets=targets, + ) def _integrate_native_skill( self, - package_info, + package_info: Any, project_root: Path, source_skill_md: Path, - diagnostics=None, - managed_files=None, + diagnostics: Any = None, + managed_files: set[str] | None = None, force: bool = False, - logger=None, - targets=None, + logger: Any = None, + targets: Any = None, ) -> SkillIntegrationResult: - """Copy a native Skill (with existing SKILL.md) to all active targets. - - For packages that already have a SKILL.md at their root (like those from - awesome-claude-skills), we copy the entire skill folder to every active - target that supports skills (driven by ``active_targets()``). - - The skill folder name is the source folder name (e.g., ``mcp-builder``), - validated and normalized per the agentskills.io spec. - - Source SKILL.md gets no metadata injection; outbound package links are rewritten. - Orphan detection uses apm.lock via directory name matching instead. - - Copies: - - SKILL.md (required) - - scripts/ (optional) - - references/ (optional) - - assets/ (optional) - - Any other subdirectories the package contains - - Args: - package_info: PackageInfo object with package metadata - project_root: Root directory of the project - source_skill_md: Path to the source SKILL.md file - - Returns: - SkillIntegrationResult: Results of the integration operation - """ - self.init_link_resolver(package_info, project_root) - package_path = package_info.install_path - - # Use the source folder name as the skill name - # e.g., apm_modules/ComposioHQ/awesome-claude-skills/mcp-builder -> mcp-builder - raw_skill_name = package_path.name - - # Validate skill name per agentskills.io spec - is_valid, error_msg = validate_skill_name(raw_skill_name) - if is_valid: - skill_name = raw_skill_name - else: - # Normalize the name if validation fails - skill_name = normalize_skill_name(raw_skill_name) - if diagnostics is not None: - diagnostics.warn( - f"Skill name '{raw_skill_name}' normalized to '{skill_name}' ({error_msg})", - package=raw_skill_name, - ) - elif logger: - logger.warning( - f"Skill name '{raw_skill_name}' normalized to '{skill_name}' ({error_msg})" - ) - else: - try: - from apm_cli.utils.console import _rich_warning - - _rich_warning( - f"Skill name '{raw_skill_name}' normalized to '{skill_name}' ({error_msg})" - ) - except ImportError: - pass # CLI not available in tests - - # Deploy to all active targets that support skills. - # When *targets* is provided (from --target), use it directly. - # Otherwise auto-detect with copilot as the fallback. - if targets is None: - from apm_cli.integration.targets import active_targets - - targets = active_targets(project_root) - skill_created = False - skill_updated = False - files_copied = 0 - all_target_paths: list[Path] = [] - primary_skill_md: Path | None = None - - # Read lockfile once and derive both maps in a single pass. - owned_by, lockfile_native_owners = self._build_ownership_maps(project_root) - sub_skills_dir = package_path / ".apm" / "skills" - - # Full unique key of the package currently being installed. - dep_ref = package_info.dependency_ref - current_key: str | None = dep_ref.get_unique_key() if dep_ref is not None else None - - seen_skill_dirs: set[Path] = set() - - for idx, target in enumerate(targets): - if not target.supports("skills"): - continue - - is_primary = idx == 0 # first active target owns diagnostics - skills_mapping = target.primitives["skills"] - # Dynamic-root targets (cowork): use resolved_deploy_root. - if target.resolved_deploy_root is not None: - target_skill_dir = target.resolved_deploy_root / skill_name - else: - effective_root = skills_mapping.deploy_root or target.root_dir - target_skill_dir = project_root / effective_root / "skills" / skill_name - - # Security: validate name + containment + symlink rejection. - from apm_cli.utils.path_security import ( - PathTraversalError, - ensure_path_within, - validate_path_segments, - ) - - validate_path_segments(skill_name, context="skill name") - if target_skill_dir.is_symlink(): - raise PathTraversalError( - f"Skill destination {target_skill_dir} is a symlink -- refusing to deploy" - ) - if target.resolved_deploy_root is None: - ensure_path_within(target_skill_dir, project_root / effective_root / "skills") - - # Dedup: skip if same resolved path already deployed. - resolved = target_skill_dir.resolve() - if resolved in seen_skill_dirs: - if logger: - logger.progress( - f"{target_skill_dir} -- already deployed, skipping for {target.name}", - symbol="info", - ) - continue - seen_skill_dirs.add(resolved) - - if is_primary: - skill_created = not target_skill_dir.exists() - skill_updated = not skill_created - primary_skill_md = target_skill_dir / "SKILL.md" - - if target_skill_dir.exists(): - if is_primary: - # Check both the lockfile (previous runs) and the in-memory session - # map (current run) so that same-manifest collisions are caught even - # before the lockfile has been written for this run. - prev_owner = lockfile_native_owners.get( - skill_name - ) or self._native_skill_session_owners.get(skill_name) - is_self_overwrite = prev_owner is not None and prev_owner == current_key - if prev_owner is not None and not is_self_overwrite: - try: - rel_prefix = target_skill_dir.parent.relative_to( - project_root - ).as_posix() - except ValueError: - # Dynamic-root targets (cowork): directory is - # outside the project tree. - rel_prefix = "skills" - rel_path = f"{rel_prefix}/{skill_name}" - # Issue 1: package= should identify the package causing the - # collision (current_key), not the skill name, so render_summary() - # groups diagnostics by the package responsible. - # Issue 2: message must tell the user what to do ("So What?" test). - detail = ( - f"Skill '{skill_name}' from '{current_key}' replaced " - f"'{prev_owner}' -- remove one package to avoid this" - ) - if diagnostics is not None: - diagnostics.overwrite( - path=rel_path, - package=current_key or skill_name, - detail=detail, - ) - elif logger: - logger.warning(detail) - else: - # Reached when called without diagnostics or logger (e.g. uninstall sync). - from apm_cli.utils.console import _rich_warning - - _rich_warning(detail) - shutil.rmtree(target_skill_dir) - - target_skill_dir.parent.mkdir(parents=True, exist_ok=True) - from apm_cli.security.gate import ignore_non_content - - _apm_filter = shutil.ignore_patterns(".apm") - - def _ignore_non_content_and_apm(directory, contents): - return list( - set(ignore_non_content(directory, contents)) - | set(_apm_filter(directory, contents)) # noqa: B023 - ) - - shutil.copytree(package_path, target_skill_dir, ignore=_ignore_non_content_and_apm) - self._resolve_markdown_links_in_skill_bundle(package_path, target_skill_dir) - all_target_paths.append(target_skill_dir) - - if is_primary: - files_copied = sum(1 for _ in target_skill_dir.rglob("*") if _.is_file()) - - # Promote sub-skills for this target - if target.resolved_deploy_root is not None: - target_skills_root = target.resolved_deploy_root - else: - target_skills_root = project_root / effective_root / "skills" - _, sub_deployed = self._promote_sub_skills( - sub_skills_dir, - target_skills_root, - skill_name, - warn=is_primary, - owned_by=owned_by if is_primary else None, - diagnostics=diagnostics if is_primary else None, - managed_files=managed_files if is_primary else None, - force=force, - project_root=project_root, - logger=logger if is_primary else None, - link_rewriter=self, - ) - all_target_paths.extend(sub_deployed) - - # Record ownership in the session map so subsequent packages installed in - # the same run can detect a collision even before the lockfile is written. - if current_key is not None: - self._native_skill_session_owners[skill_name] = current_key - - # Count unique sub-skills from primary target only - primary_root = project_root / ".github" / "skills" - sub_skills_count = sum( - 1 for p in all_target_paths if p.parent == primary_root and p.name != skill_name + """Copy a native skill to all active targets.""" + fields = _skill_deploy._integrate_native_skill( + self, + package_info, + project_root, + source_skill_md, + diagnostics=diagnostics, + managed_files=managed_files, + force=force, + logger=logger, + targets=targets, ) - return SkillIntegrationResult( - skill_created=skill_created, - skill_updated=skill_updated, + skill_created=fields["skill_created"], + skill_updated=fields["skill_updated"], skill_skipped=False, - skill_path=primary_skill_md, - references_copied=files_copied, + skill_path=fields["primary_skill_md"], + references_copied=fields["files_copied"], links_resolved=0, - sub_skills_promoted=sub_skills_count, - target_paths=all_target_paths, + sub_skills_promoted=fields["sub_skills_promoted"], + target_paths=fields["target_paths"], ) def _integrate_skill_bundle( self, - package_info, + package_info: Any, project_root: Path, skills_dir: Path, - diagnostics=None, - managed_files=None, + diagnostics: Any = None, + managed_files: set[str] | None = None, force: bool = False, - logger=None, - targets=None, - skill_subset=None, + logger: Any = None, + targets: Any = None, + skill_subset: Any = None, ) -> SkillIntegrationResult: - """Promote every skill in a SKILL_BUNDLE's top-level skills/ directory. - - Reuses the same promotion logic as _promote_sub_skills but sources - from package_root/skills/ instead of .apm/skills/. Each nested - skill directory becomes a top-level skill in every target. - - Args: - package_info: PackageInfo with package metadata. - project_root: Root directory of the project. - skills_dir: The package's skills/ directory. - diagnostics: Optional DiagnosticCollector. - managed_files: Set of managed file paths. - force: Whether to overwrite locally-authored files. - logger: Optional InstallLogger. - targets: Optional explicit list of TargetProfile objects. - skill_subset: Optional tuple of skill names to install (None = all). - - Returns: - SkillIntegrationResult with all promoted skills. - """ - self.init_link_resolver(package_info, project_root) - if targets is None: - from apm_cli.integration.targets import active_targets - - targets = active_targets(project_root) - - parent_name = package_info.install_path.name - owned_by, lockfile_native_owners = self._build_ownership_maps(project_root) # noqa: RUF059 - - total_promoted = 0 - all_deployed: list[Path] = [] - any_created = False - seen_skill_dirs: set[Path] = set() - - # Convert skill_subset tuple to a set for O(1) lookup - _name_filter = set(skill_subset) if skill_subset else None - - for idx, target in enumerate(targets): - if not target.supports("skills"): - continue - - is_primary = idx == 0 - skills_mapping = target.primitives["skills"] - effective_root = skills_mapping.deploy_root or target.root_dir - target_skills_root = project_root / effective_root / "skills" - - # Dedup: skip if same resolved skills root already processed. - resolved_root = target_skills_root.resolve() - if resolved_root in seen_skill_dirs: - if logger: - logger.progress( - f"{target_skills_root} -- already deployed, skipping for {target.name}", - symbol="info", - ) - continue - seen_skill_dirs.add(resolved_root) - - target_skills_root.mkdir(parents=True, exist_ok=True) - - n, deployed = self._promote_sub_skills( - skills_dir, - target_skills_root, - parent_name, - warn=is_primary, - owned_by=owned_by if is_primary else None, - diagnostics=diagnostics if is_primary else None, - managed_files=managed_files if is_primary else None, - force=force, - project_root=project_root, - logger=logger if is_primary else None, - name_filter=_name_filter, - link_rewriter=self, - ) - if is_primary: - total_promoted = n - if n > 0: - any_created = True - all_deployed.extend(deployed) - - return SkillIntegrationResult( - skill_created=any_created, - skill_updated=False, - skill_skipped=False, - skill_path=None, - references_copied=0, - links_resolved=0, - sub_skills_promoted=total_promoted, - target_paths=all_deployed, + """Promote every skill in a skill bundle's top-level skills directory.""" + fields = _skill_deploy._integrate_skill_bundle( + self, + package_info, + project_root, + skills_dir, + diagnostics=diagnostics, + managed_files=managed_files, + force=force, + logger=logger, + targets=targets, + skill_subset=skill_subset, ) + return SkillIntegrationResult(**fields) def integrate_package_skill( self, - package_info, + package_info: Any, project_root: Path, - diagnostics=None, - managed_files=None, + diagnostics: Any = None, + managed_files: set[str] | None = None, force: bool = False, - logger=None, - targets=None, - skill_subset=None, - scope=None, - policy=None, + logger: Any = None, + targets: Any = None, + skill_subset: Any = None, + scope: Any = None, + policy: Any = None, ) -> SkillIntegrationResult: - """Integrate a package's skill into all active target directories. - - Copies native skills (packages with SKILL.md at root) to every active - target that supports skills (e.g. .github/skills/, .claude/skills/, - .opencode/skills/). Also promotes any sub-skills from .apm/skills/. - - When *targets* is provided (e.g. from ``--target cursor``), only those - targets are considered. Otherwise falls back to ``active_targets()``. - - Packages without SKILL.md at root are not installed as skills -- only their - sub-skills (if any) are promoted. - - Args: - package_info: PackageInfo object with package metadata - project_root: Root directory of the project - targets: Optional explicit list of TargetProfile objects. - - Returns: - SkillIntegrationResult: Results of the integration operation - """ - # Check if package type allows skill installation (T4 routing) - # SKILL and HYBRID -> install as skill - # INSTRUCTIONS and PROMPTS -> skip skill installation - if not should_install_skill(package_info): - # Even non-skill packages may ship sub-skills under .apm/skills/. - # Promote them so Copilot can discover them independently. - sub_skills_count, sub_deployed = self._promote_sub_skills_standalone( - package_info, - project_root, - diagnostics=diagnostics, - managed_files=managed_files, - force=force, - logger=logger, - targets=targets, - ) - return SkillIntegrationResult( - skill_created=False, - skill_updated=False, - skill_skipped=True, - skill_path=None, - references_copied=0, - links_resolved=0, - sub_skills_promoted=sub_skills_count, - target_paths=sub_deployed, - ) - - # Skip virtual FILE packages - they're individual files, not full packages - # Multiple virtual files from the same repo would collide on skill name - # BUT: subdirectory packages (like Claude Skills) SHOULD generate skills - if package_info.dependency_ref and package_info.dependency_ref.is_virtual: - # Allow subdirectory packages through - they are complete skill packages - if not package_info.dependency_ref.is_virtual_subdirectory(): - return SkillIntegrationResult( - skill_created=False, - skill_updated=False, - skill_skipped=True, - skill_path=None, - references_copied=0, - links_resolved=0, - ) - - package_path = package_info.install_path - - # MARKETPLACE_PLUGIN: deploy bin/ executables + plugin manifest BEFORE - # skill routing. bin/ deployment is orthogonal to whether the plugin - # also ships a root SKILL.md or a skills/ bundle, so it must run for - # every plugin -- not only the no-skill fallback. See issue #1544. - bin_paths: list[Path] = [] - bin_skip_reason: str | None = None - from apm_cli.models.apm_package import PackageType as _PackageType - - if package_info.package_type == _PackageType.MARKETPLACE_PLUGIN: - bin_paths, bin_skip_reason = self._deploy_plugin_bin( - package_info, - project_root, - targets, - scope=scope, - policy=policy, - force=force, - logger=logger, - ) - - # Check if this is a native Skill (already has SKILL.md at root) - source_skill_md = package_path / "SKILL.md" - if source_skill_md.exists(): - if skill_subset: - from apm_cli.utils.console import _rich_warning - - _rich_warning( - f"--skill filter ignored for '{package_info.install_path.name}': " - "package is a single CLAUDE_SKILL, not a SKILL_BUNDLE." - ) - return self._merge_bin_paths( - self._integrate_native_skill( - package_info, - project_root, - source_skill_md, - diagnostics=diagnostics, - managed_files=managed_files, - force=force, - logger=logger, - targets=targets, - ), - bin_paths, - bin_skip_reason, - ) - - # SKILL_BUNDLE: promote skills from root-level skills/ directory. - root_skills_dir = package_path / "skills" - if root_skills_dir.is_dir() and any( - (d / "SKILL.md").exists() for d in root_skills_dir.iterdir() if d.is_dir() - ): - return self._merge_bin_paths( - self._integrate_skill_bundle( - package_info, - project_root, - root_skills_dir, - diagnostics=diagnostics, - managed_files=managed_files, - force=force, - logger=logger, - targets=targets, - skill_subset=skill_subset, - ), - bin_paths, - bin_skip_reason, - ) - - # No SKILL.md at root -- not a skill package. - # Still promote any sub-skills shipped under .apm/skills/. - sub_skills_count, sub_deployed = self._promote_sub_skills_standalone( - package_info, - project_root, + """Integrate a package's skill into all active target directories.""" + context = _skill_deploy.PackageSkillContext( diagnostics=diagnostics, managed_files=managed_files, force=force, logger=logger, targets=targets, + skill_subset=skill_subset, + scope=scope, + policy=policy, + should_install_fn=should_install_skill, + result_cls=SkillIntegrationResult, ) - return self._merge_bin_paths( - SkillIntegrationResult( - skill_created=False, - skill_updated=False, - skill_skipped=True, - skill_path=None, - references_copied=0, - links_resolved=0, - sub_skills_promoted=sub_skills_count, - target_paths=sub_deployed, - ), - bin_paths, - bin_skip_reason, - ) + return _skill_deploy.integrate_package_skill(self, package_info, project_root, context) @staticmethod def _merge_bin_paths( @@ -1333,20 +321,10 @@ def _merge_bin_paths( bin_paths: list[Path], skip_reason: str | None = None, ) -> SkillIntegrationResult: - """Fold deployed plugin bin/manifest paths into a skill result. - - Pure: returns a NEW result via ``dataclasses.replace`` rather than - mutating the argument, so callers never observe surprise in-place - edits. ``skill_created`` is intentionally left untouched -- deploying - executables is not the same as creating a skill, so reporting and - sync semantics stay honest. When bins were deployed the result is no - longer "skipped" (work happened) and the paths are tracked for the - lockfile / uninstall manifest. ``skip_reason`` records why a plugin - that ships bin/ was NOT deployed, for the install layer to surface. - """ + """Fold deployed plugin bin and manifest paths into a skill result.""" if not bin_paths and skip_reason is None: return result - updates: dict = {} + updates: dict[str, Any] = {} if bin_paths: updates["bin_deployed"] = len(bin_paths) updates["skill_skipped"] = False @@ -1357,119 +335,30 @@ def _merge_bin_paths( def _deploy_plugin_bin( self, - package_info, + package_info: Any, project_root: Path, - targets, - scope=None, - policy=None, + targets: Any, + scope: Any = None, + policy: Any = None, force: bool = False, - logger=None, + logger: Any = None, ) -> tuple[list[Path], str | None]: - """Deploy bin/ executables and plugin manifest for a MARKETPLACE_PLUGIN. - - Only activates when ALL of: - - The package has a bin/ directory - - At least one Claude target that supports skills is active - - scope is InstallScope.USER (bin/ deploy is user-scope only, v1) - - policy does not deny the package - - This realizes Claude Code's "skills-directory plugin" contract: a folder - under a skills directory containing ``.claude-plugin/plugin.json`` is - loaded as ``@skills-dir`` and its root ``bin/`` is added to the - Bash tool PATH. The contract is Claude-specific by design; other - harnesses have no equivalent, so only Claude targets are considered. - - Each binary is made executable (user-only +x, stripping group/other - execute bits) on POSIX systems. The deployed root is user-scoped - (~/.claude/skills/), so tighter-than-0o755 permissions are correct. - - Returns ``(deployed_paths, skip_reason)``. ``skip_reason`` is non-None - ONLY when the package ships a bin/ but it could not be deployed for an - actionable reason ("project_scope", "no_claude_target"), so the install - layer can surface a hint. Policy-deny and "no bin/ at all" return - ``None`` -- they are intentional, not traps. - """ - from apm_cli.core.scope import InstallScope - from apm_cli.utils.path_security import validate_path_segments - - bin_dir = package_info.install_path / "bin" - if not bin_dir.is_dir(): - return [], None - - # The package ships executables -- from here a non-deploy is a - # reportable skip, not a silent no-op. - if scope is not InstallScope.USER: - if logger and scope is InstallScope.PROJECT: - logger.progress( - "bin/ deploy is user-scope only; skipping for project-scope install", - symbol="info", - ) - return [], "project_scope" - - if self._bin_deploy_denied(package_info, policy, logger): - return [], None - - if targets is None: - from apm_cli.integration.targets import active_targets - - targets = active_targets(project_root) - - # Claude-specific contract: only Claude targets that support skills. - claude_targets = [t for t in targets if t.name == "claude" and t.supports("skills")] - if not claude_targets: - if logger: - logger.progress( - "bin/ present but no active Claude skills target; skipping bin deploy for " - f"{package_info.get_canonical_dependency_string()}", - symbol="warning", - ) - return [], "no_claude_target" - - skill_name = package_info.install_path.name - validate_path_segments(skill_name, context="plugin skill name") - deployed: list[Path] = [] - - for target in claude_targets: - effective_root = target.primitives["skills"].deploy_root or target.root_dir - target_root_dir = project_root / target.root_dir - if not target.auto_create and not target_root_dir.is_dir(): - continue - - skill_base = project_root / effective_root / "skills" / skill_name - rel_prefix = f"{effective_root}/skills/{skill_name}" - deployed.extend(self._deploy_bin_files(bin_dir, skill_base, rel_prefix, force, logger)) - manifest = self._deploy_plugin_manifest( - package_info.install_path, skill_base, rel_prefix, force, logger - ) - if manifest is not None: - deployed.append(manifest) - - return deployed, None + """Deploy bin executables and plugin manifest for a marketplace plugin.""" + return _skill_deploy._deploy_plugin_bin( + self, + package_info, + project_root, + targets, + scope=scope, + policy=policy, + force=force, + logger=logger, + ) @staticmethod - def _bin_deploy_denied(package_info, policy, logger) -> bool: - """Return True when policy opts the package out of bin/ deployment.""" - if policy is None: - return False - bd_policy = policy.bin_deploy - if bd_policy is None: - return False - canonical = package_info.get_canonical_dependency_string() - if bd_policy.deny_all: - if logger: - logger.progress( - f"bin_deploy.deny_all: skipping bin deploy for {canonical}", - symbol="info", - ) - return True - if canonical in bd_policy.deny: - if logger: - logger.progress( - f"bin_deploy.deny: skipping bin deploy for {canonical}", - symbol="info", - ) - return True - return False + def _bin_deploy_denied(package_info: Any, policy: Any, logger: Any) -> bool: + """Return True when policy opts the package out of bin deployment.""" + return _skill_deploy._bin_deploy_denied(package_info, policy, logger) def _deploy_bin_files( self, @@ -1477,31 +366,10 @@ def _deploy_bin_files( skill_base: Path, rel_prefix: str, force: bool, - logger, + logger: Any, ) -> list[Path]: - """Copy bin/ executables into ``skill_base/bin`` (chmod +x on POSIX).""" - from apm_cli.utils.path_security import ensure_path_within - - dest_bin = skill_base / "bin" - dest_bin.mkdir(parents=True, exist_ok=True) - deployed: list[Path] = [] - for src_file in bin_dir.iterdir(): - # Reject symlinks -- a malicious package could point a symlink - # at an arbitrary file outside the sandbox. - if src_file.is_symlink() or not src_file.is_file(): - continue - dest_file = dest_bin / src_file.name - ensure_path_within(dest_file, dest_bin) - self._copy_plugin_file( - src_file, - dest_file, - force=force, - make_executable=True, - logger=logger, - rel_label=f"{rel_prefix}/bin/{src_file.name}", - ) - deployed.append(dest_file) - return deployed + """Copy bin executables into a deployed skill directory.""" + return _skill_deploy._deploy_bin_files(bin_dir, skill_base, rel_prefix, force, logger) def _deploy_plugin_manifest( self, @@ -1509,23 +377,12 @@ def _deploy_plugin_manifest( skill_base: Path, rel_prefix: str, force: bool, - logger, + logger: Any, ) -> Path | None: - """Copy ``.claude-plugin/plugin.json`` next to the deployed bin/.""" - plugin_manifest = package_path / ".claude-plugin" / "plugin.json" - if plugin_manifest.is_symlink() or not plugin_manifest.is_file(): - return None - dest_manifest = skill_base / ".claude-plugin" / "plugin.json" - dest_manifest.parent.mkdir(parents=True, exist_ok=True) - self._copy_plugin_file( - plugin_manifest, - dest_manifest, - force=force, - make_executable=False, - logger=logger, - rel_label=f"{rel_prefix}/.claude-plugin/plugin.json", + """Copy .claude-plugin/plugin.json next to the deployed bin directory.""" + return _skill_deploy._deploy_plugin_manifest( + package_path, skill_base, rel_prefix, force, logger ) - return dest_manifest @staticmethod def _copy_plugin_file( @@ -1534,294 +391,47 @@ def _copy_plugin_file( *, force: bool, make_executable: bool, - logger, + logger: Any, rel_label: str, ) -> None: - """Hash-gated copy of one plugin file, optionally marking it executable. - - Skips the copy when an identical file already exists (unless *force*), - keeping repeated installs quiet and idempotent. - - When *make_executable* is True, only the owner (user) execute bit is - set; group and other execute bits are explicitly cleared. Deployed - files live under ~/.claude/skills/ which is user-scoped, so there is - no reason to grant group/other execute access regardless of what the - source package shipped. - """ - import os - import stat - - skip_copy = False - if dest_file.exists() and not force: - src_hash = hashlib.sha256(src_file.read_bytes()).hexdigest() - dst_hash = hashlib.sha256(dest_file.read_bytes()).hexdigest() - skip_copy = src_hash == dst_hash - - if not skip_copy: - shutil.copy2(src_file, dest_file) - - if make_executable and os.name == "posix": - current = dest_file.stat().st_mode - # User-only execute: set S_IXUSR, clear group and other execute bits. - # Runs for both fresh copies and idempotent re-installs so that files - # previously deployed by older APM versions are hardened in-place. - dest_file.chmod((current & ~(stat.S_IXGRP | stat.S_IXOTH)) | stat.S_IXUSR) - - if not skip_copy and logger: - logger.progress(f"deployed {src_file.name} -> {rel_label}", symbol="check") + """Hash-gated copy of one plugin file, optionally marking it executable.""" + _skill_deploy._copy_plugin_file( + src_file, + dest_file, + force=force, + make_executable=make_executable, + logger=logger, + rel_label=rel_label, + ) def sync_integration( self, - apm_package, + apm_package: Any, project_root: Path, - managed_files: set = None, # noqa: RUF013 - targets=None, + managed_files: set[str] | None = None, + targets: Any = None, ) -> dict[str, int]: - """Sync skill directories with currently installed packages. - - Derives skill prefixes dynamically from *targets* (or - ``KNOWN_TARGETS``) so user-scope paths like ``.copilot/skills/`` - and ``.config/opencode/skills/`` are handled correctly. - - When *managed_files* is provided, only removes skill directories - whose paths appear in the set. Otherwise falls back to - npm-style orphan detection (derives expected names from installed - dependencies). - - Args: - apm_package: APMPackage with current dependencies - project_root: Root directory of the project - managed_files: Set of relative paths known to be APM-managed - targets: Optional list of (scope-resolved) TargetProfile objects. - When ``None``, uses ``KNOWN_TARGETS``. - - Returns: - Dict with cleanup statistics - """ - from apm_cli.integration.targets import KNOWN_TARGETS - - source = targets if targets is not None else list(KNOWN_TARGETS.values()) - - stats = {"files_removed": 0, "errors": 0} - - # Build the set of valid skill prefixes from targets - skill_prefixes: list[str] = [] - for t in source: - if not t.supports("skills"): - continue - # Dynamic-root targets (cowork) use cowork:// URI prefix. - if t.user_root_resolver is not None: - from apm_cli.integration.copilot_cowork_paths import COWORK_LOCKFILE_PREFIX - - if COWORK_LOCKFILE_PREFIX not in skill_prefixes: - skill_prefixes.append(COWORK_LOCKFILE_PREFIX) - continue - sm = t.primitives["skills"] - effective_root = sm.deploy_root or t.root_dir - skill_prefixes.append(f"{effective_root}/skills/") - skill_prefix_tuple = tuple(skill_prefixes) - - if managed_files is not None: - # Manifest-based removal -- only remove tracked skill directories - project_root_resolved = project_root.resolve() - - # Lazy-resolve cowork root at most once per invocation - # (mirrors the pattern in cleanup.py and sync_remove_files). - _cowork_root_resolved: bool = False - _cowork_root_cached: Path | None = None - _cowork_skipped: int = 0 - - for rel_path in managed_files: - if not rel_path.startswith(skill_prefix_tuple): - continue - if ".." in rel_path: - continue - - # ── Cowork:// paths ────────────────────────────────── - from apm_cli.integration.copilot_cowork_paths import COWORK_URI_SCHEME - - if rel_path.startswith(COWORK_URI_SCHEME): - try: - if not _cowork_root_resolved: - from apm_cli.integration.copilot_cowork_paths import ( - resolve_copilot_cowork_skills_dir, - ) - - _cowork_root_cached = resolve_copilot_cowork_skills_dir() - _cowork_root_resolved = True - if _cowork_root_cached is None: - _cowork_skipped += 1 - continue - from apm_cli.integration.copilot_cowork_paths import from_lockfile_path - - target = from_lockfile_path(rel_path, _cowork_root_cached) - except Exception: - stats["errors"] += 1 - continue - else: - target = project_root / rel_path - if not str(target.resolve()).startswith(str(project_root_resolved)): - continue - - if not target.exists(): - continue - - try: - if target.is_dir(): - shutil.rmtree(target) - else: - target.unlink() - stats["files_removed"] += 1 - except Exception: - stats["errors"] += 1 - - # One-time warning when cowork entries were skipped - # because the OneDrive path is unavailable. - if _cowork_skipped > 0: - from apm_cli.utils.console import _rich_warning - - _rich_warning( - f"Cowork: skipping {_cowork_skipped} skill " - f"{'entry' if _cowork_skipped == 1 else 'entries'}" - " -- OneDrive path not detected.\n" - "Run: apm config set copilot-cowork-skills-dir " - "(or set APM_COPILOT_COWORK_SKILLS_DIR)\n" - "to clean up these entries on the next install/uninstall.", - symbol="warning", - ) - - return stats - - # Legacy fallback: npm-style orphan detection - # Build set of expected skill directory names from installed packages - installed_skill_names = set() - for dep in apm_package.get_apm_dependencies(): - raw_name = dep.repo_url.split("/")[-1] - if dep.is_virtual and dep.virtual_path: - raw_name = dep.virtual_path.split("/")[-1] - is_valid, _ = validate_skill_name(raw_name) - skill_name = raw_name if is_valid else normalize_skill_name(raw_name) - installed_skill_names.add(skill_name) - - # Also include promoted sub-skills from installed packages - install_path = dep.get_install_path(project_root / "apm_modules") - sub_skills_dir = install_path / ".apm" / "skills" - if sub_skills_dir.is_dir(): - for sub_skill_path in sub_skills_dir.iterdir(): - if sub_skill_path.is_dir() and (sub_skill_path / "SKILL.md").exists(): - raw_sub = sub_skill_path.name - is_valid, _ = validate_skill_name(raw_sub) - installed_skill_names.add( - raw_sub if is_valid else normalize_skill_name(raw_sub) - ) - - # Clean all target skill directories dynamically - seen_cleanup_dirs: set[Path] = set() - for t in source: - if not t.supports("skills"): - continue - sm = t.primitives["skills"] - effective_root = sm.deploy_root or t.root_dir - - # Special guard for cross-tool deploy_root (.agents/) - # Only clean if the owning target dir exists - if sm.deploy_root: - if not (project_root / t.root_dir).is_dir(): - continue - - skills_dir = project_root / effective_root / "skills" - - # Dedup: skip if same resolved skills dir already cleaned. - resolved_skills = skills_dir.resolve() - if resolved_skills in seen_cleanup_dirs: - import logging - - logging.getLogger(__name__).debug( - "%s -- already processed, skipping cleanup for %s", skills_dir, t.name - ) - continue - seen_cleanup_dirs.add(resolved_skills) - - if skills_dir.exists(): - result = self._clean_orphaned_skills( - skills_dir, installed_skill_names, project_root=project_root - ) - stats["files_removed"] += result["files_removed"] - stats["errors"] += result["errors"] - - return stats + """Sync skill directories with currently installed packages.""" + return _skill_deploy.sync_integration( + self, apm_package, project_root, managed_files, targets + ) def _clean_orphaned_skills( self, skills_dir: Path, - installed_skill_names: set, + installed_skill_names: set[str], *, project_root: Path | None = None, ) -> dict[str, int]: - """Clean orphaned skills from a skills directory. - - Uses npm-style approach: any skill directory not matching an installed - package name is considered orphaned and removed. - - For the cross-client ``.agents/skills/`` directory, only removes skill - directories that appear in the lockfile's ``deployed_files`` to avoid - deleting foreign skills placed by other tools (Codex CLI, manual). - - Args: - skills_dir: Path to skills directory (.github/skills/, .claude/skills/, etc.) - installed_skill_names: Set of expected skill directory names - project_root: Project root for lockfile-based ownership check. - - Returns: - Dict with cleanup statistics - """ - files_removed = 0 - errors = 0 - - # For .agents/skills/: only delete skills that APM owns (appear in lockfile). - is_agents_dir = skills_dir.parent.name == ".agents" - lockfile_owned_skills: set[str] | None = None - if is_agents_dir and project_root is not None: - lockfile_owned_skills = self._get_lockfile_owned_agent_skills(project_root) - - for skill_subdir in skills_dir.iterdir(): - if skill_subdir.is_dir(): - if skill_subdir.name not in installed_skill_names: - # Ownership check: skip foreign skills in .agents/skills/. - if lockfile_owned_skills is not None: - if skill_subdir.name not in lockfile_owned_skills: - continue - try: - shutil.rmtree(skill_subdir) - files_removed += 1 - except Exception: - errors += 1 - - return {"files_removed": files_removed, "errors": errors} + """Clean orphaned skills from a skills directory.""" + return _skill_deploy._clean_orphaned_skills( + skills_dir, + installed_skill_names, + project_root=project_root, + get_lockfile_owned_fn=self._get_lockfile_owned_agent_skills, + ) @staticmethod def _get_lockfile_owned_agent_skills(project_root: Path) -> set[str]: - """Return the set of skill names under ``.agents/skills/`` in the lockfile. - - Used by ``_clean_orphaned_skills`` to avoid deleting foreign skills - in the cross-client ``.agents/`` directory. - """ - owned: set[str] = set() - try: - from apm_cli.deps.lockfile import LockFile, get_lockfile_path - - lockfile = LockFile.read(get_lockfile_path(project_root)) - if lockfile and lockfile.dependencies: - for dep in lockfile.dependencies.values(): - for f in dep.deployed_files: - if f.startswith(".agents/skills/"): - parts = f[len(".agents/skills/") :].split("/") - if parts and parts[0]: - owned.add(parts[0]) - except (FileNotFoundError, OSError, KeyError, ValueError, TypeError, AttributeError) as exc: - import logging - - logging.getLogger(__name__).debug( - "Could not read lockfile for ownership check: %s", exc - ) - return owned + """Return skill names under .agents/skills in the lockfile.""" + return _skill_deploy._get_lockfile_owned_agent_skills(project_root) diff --git a/src/apm_cli/integration/skill_naming.py b/src/apm_cli/integration/skill_naming.py new file mode 100644 index 000000000..72df4ada0 --- /dev/null +++ b/src/apm_cli/integration/skill_naming.py @@ -0,0 +1,49 @@ +"""Skill naming helpers for APM skill integration.""" + +import logging +import re +from typing import Any + +_log = logging.getLogger("apm_cli.integration.skill_integrator") + + +def to_hyphen_case(name: str) -> str: + """Convert a package name to hyphen-case for Claude Skills spec.""" + if "/" in name: + name = name.split("/")[-1] + + result = name.replace("_", "-").replace(" ", "-") + result = re.sub(r"([a-z])([A-Z])", r"\1-\2", result) + result = re.sub(r"[^a-z0-9-]", "", result.lower()) + result = re.sub(r"-+", "-", result) + result = result.strip("-") + return result[:64] + + +def _skill_name_char_error(name: str) -> str: + """Return the precise skill-name character validation error for *name*.""" + if any(c.isupper() for c in name): + return "Skill name must be lowercase (no uppercase letters)" + if "_" in name: + return "Skill name cannot contain underscores (use hyphens instead)" + if " " in name: + return "Skill name cannot contain spaces (use hyphens instead)" + invalid_chars = set(re.findall(r"[^a-z0-9-]", name)) + if invalid_chars: + return f"Skill name contains invalid characters: {', '.join(sorted(invalid_chars))}" + return "Skill name must be lowercase alphanumeric with hyphens only" + + +def normalize_skill_name(name: str) -> str: + """Convert any package name to a valid skill name per agentskills.io spec.""" + return to_hyphen_case(name) + + +def should_compile_instructions(package_info: Any) -> bool: + """Determine if package should compile to AGENTS.md/CLAUDE.md.""" + from apm_cli.models.apm_package import PackageContentType + + from .skill_integrator import get_effective_type + + effective_type = get_effective_type(package_info) + return effective_type in (PackageContentType.INSTRUCTIONS, PackageContentType.HYBRID) diff --git a/src/apm_cli/integration/skill_orchestrate.py b/src/apm_cli/integration/skill_orchestrate.py new file mode 100644 index 000000000..9f756dea8 --- /dev/null +++ b/src/apm_cli/integration/skill_orchestrate.py @@ -0,0 +1,135 @@ +"""Package-level skill integration orchestration.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +@dataclass +class PackageSkillContext: + """Options and callbacks for package skill integration.""" + + diagnostics: Any = None + managed_files: set[str] | None = None + force: bool = False + logger: Any = None + targets: Any = None + skill_subset: Any = None + scope: Any = None + policy: Any = None + should_install_fn: Callable[[Any], bool] | None = None + result_cls: Any = None + + +def _skipped_result( + result_cls: Any, sub_skills_count: int = 0, target_paths: list[Path] | None = None +) -> Any: + """Build the standard skipped skill integration result.""" + return result_cls( + skill_created=False, + skill_updated=False, + skill_skipped=True, + skill_path=None, + references_copied=0, + links_resolved=0, + sub_skills_promoted=sub_skills_count, + target_paths=target_paths or [], + ) + + +def integrate_package_skill( + link_rewriter: Any, + package_info: Any, + project_root: Path, + context: PackageSkillContext, +) -> Any: + """Integrate a package's skill into all active target directories.""" + if context.should_install_fn is None or context.result_cls is None: + raise ValueError("PackageSkillContext requires should_install_fn and result_cls") + + if not context.should_install_fn(package_info): + sub_count, sub_deployed = link_rewriter._promote_sub_skills_standalone( + package_info, + project_root, + diagnostics=context.diagnostics, + managed_files=context.managed_files, + force=context.force, + logger=context.logger, + targets=context.targets, + ) + return _skipped_result(context.result_cls, sub_count, sub_deployed) + + if package_info.dependency_ref and package_info.dependency_ref.is_virtual: + if not package_info.dependency_ref.is_virtual_subdirectory(): + return _skipped_result(context.result_cls) + + package_path = package_info.install_path + bin_paths: list[Path] = [] + bin_skip_reason: str | None = None + from apm_cli.models.apm_package import PackageType as _PackageType + + if package_info.package_type == _PackageType.MARKETPLACE_PLUGIN: + bin_paths, bin_skip_reason = link_rewriter._deploy_plugin_bin( + package_info, + project_root, + context.targets, + scope=context.scope, + policy=context.policy, + force=context.force, + logger=context.logger, + ) + + source_skill_md = package_path / "SKILL.md" + if source_skill_md.exists(): + if context.skill_subset: + from apm_cli.utils.console import _rich_warning + + _rich_warning( + f"--skill filter ignored for '{package_info.install_path.name}': " + "package is a single CLAUDE_SKILL, not a SKILL_BUNDLE." + ) + result = link_rewriter._integrate_native_skill( + package_info, + project_root, + source_skill_md, + diagnostics=context.diagnostics, + managed_files=context.managed_files, + force=context.force, + logger=context.logger, + targets=context.targets, + ) + return link_rewriter._merge_bin_paths(result, bin_paths, bin_skip_reason) + + root_skills_dir = package_path / "skills" + if root_skills_dir.is_dir() and any( + (directory / "SKILL.md").exists() + for directory in root_skills_dir.iterdir() + if directory.is_dir() + ): + result = link_rewriter._integrate_skill_bundle( + package_info, + project_root, + root_skills_dir, + diagnostics=context.diagnostics, + managed_files=context.managed_files, + force=context.force, + logger=context.logger, + targets=context.targets, + skill_subset=context.skill_subset, + ) + return link_rewriter._merge_bin_paths(result, bin_paths, bin_skip_reason) + + sub_count, sub_deployed = link_rewriter._promote_sub_skills_standalone( + package_info, + project_root, + diagnostics=context.diagnostics, + managed_files=context.managed_files, + force=context.force, + logger=context.logger, + targets=context.targets, + ) + result = _skipped_result(context.result_cls, sub_count, sub_deployed) + return link_rewriter._merge_bin_paths(result, bin_paths, bin_skip_reason) diff --git a/src/apm_cli/integration/skill_plugin.py b/src/apm_cli/integration/skill_plugin.py new file mode 100644 index 000000000..8514ede84 --- /dev/null +++ b/src/apm_cli/integration/skill_plugin.py @@ -0,0 +1,186 @@ +"""Plugin bin deployment helpers for skill integration.""" + +from __future__ import annotations + +import hashlib +import os +import shutil +import stat +from pathlib import Path +from typing import Any + + +def _bin_deploy_denied(package_info: Any, policy: Any, logger: Any) -> bool: + """Return True when policy opts the package out of bin deployment.""" + if policy is None: + return False + bd_policy = policy.bin_deploy + if bd_policy is None: + return False + canonical = package_info.get_canonical_dependency_string() + if bd_policy.deny_all: + if logger: + logger.progress( + f"bin_deploy.deny_all: skipping bin deploy for {canonical}", + symbol="info", + ) + return True + if canonical in bd_policy.deny: + if logger: + logger.progress( + f"bin_deploy.deny: skipping bin deploy for {canonical}", + symbol="info", + ) + return True + return False + + +def _deploy_plugin_bin( + link_rewriter: Any, + package_info: Any, + project_root: Path, + targets: Any, + *, + scope: Any = None, + policy: Any = None, + force: bool = False, + logger: Any = None, +) -> tuple[list[Path], str | None]: + """Deploy bin executables and plugin manifest for a marketplace plugin.""" + from apm_cli.core.scope import InstallScope + from apm_cli.utils.path_security import validate_path_segments + + bin_dir = package_info.install_path / "bin" + if not bin_dir.is_dir(): + return [], None + + if scope is not InstallScope.USER: + if logger and scope is InstallScope.PROJECT: + logger.progress( + "bin/ deploy is user-scope only; skipping for project-scope install", + symbol="info", + ) + return [], "project_scope" + + if link_rewriter._bin_deploy_denied(package_info, policy, logger): + return [], None + + if targets is None: + from apm_cli.integration.targets import active_targets + + targets = active_targets(project_root) + + claude_targets = [ + target for target in targets if target.name == "claude" and target.supports("skills") + ] + if not claude_targets: + if logger: + logger.progress( + "bin/ present but no active Claude skills target; skipping bin deploy for " + f"{package_info.get_canonical_dependency_string()}", + symbol="warning", + ) + return [], "no_claude_target" + + skill_name = package_info.install_path.name + validate_path_segments(skill_name, context="plugin skill name") + deployed: list[Path] = [] + for target in claude_targets: + effective_root = target.primitives["skills"].deploy_root or target.root_dir + target_root_dir = project_root / target.root_dir + if not target.auto_create and not target_root_dir.is_dir(): + continue + + skill_base = project_root / effective_root / "skills" / skill_name + rel_prefix = f"{effective_root}/skills/{skill_name}" + deployed.extend( + link_rewriter._deploy_bin_files(bin_dir, skill_base, rel_prefix, force, logger) + ) + manifest = link_rewriter._deploy_plugin_manifest( + package_info.install_path, skill_base, rel_prefix, force, logger + ) + if manifest is not None: + deployed.append(manifest) + + return deployed, None + + +def _deploy_bin_files( + bin_dir: Path, + skill_base: Path, + rel_prefix: str, + force: bool, + logger: Any, +) -> list[Path]: + """Copy bin executables into a deployed skill directory.""" + from apm_cli.utils.path_security import ensure_path_within + + dest_bin = skill_base / "bin" + dest_bin.mkdir(parents=True, exist_ok=True) + deployed: list[Path] = [] + for src_file in bin_dir.iterdir(): + if src_file.is_symlink() or not src_file.is_file(): + continue + dest_file = dest_bin / src_file.name + ensure_path_within(dest_file, dest_bin) + _copy_plugin_file( + src_file, + dest_file, + force=force, + make_executable=True, + logger=logger, + rel_label=f"{rel_prefix}/bin/{src_file.name}", + ) + deployed.append(dest_file) + return deployed + + +def _deploy_plugin_manifest( + package_path: Path, + skill_base: Path, + rel_prefix: str, + force: bool, + logger: Any, +) -> Path | None: + """Copy .claude-plugin/plugin.json next to the deployed bin directory.""" + plugin_manifest = package_path / ".claude-plugin" / "plugin.json" + if plugin_manifest.is_symlink() or not plugin_manifest.is_file(): + return None + dest_manifest = skill_base / ".claude-plugin" / "plugin.json" + dest_manifest.parent.mkdir(parents=True, exist_ok=True) + _copy_plugin_file( + plugin_manifest, + dest_manifest, + force=force, + make_executable=False, + logger=logger, + rel_label=f"{rel_prefix}/.claude-plugin/plugin.json", + ) + return dest_manifest + + +def _copy_plugin_file( + src_file: Path, + dest_file: Path, + *, + force: bool, + make_executable: bool, + logger: Any, + rel_label: str, +) -> None: + """Hash-gated copy of one plugin file, optionally marking it executable.""" + skip_copy = False + if dest_file.exists() and not force: + src_hash = hashlib.sha256(src_file.read_bytes()).hexdigest() + dst_hash = hashlib.sha256(dest_file.read_bytes()).hexdigest() + skip_copy = src_hash == dst_hash + + if not skip_copy: + shutil.copy2(src_file, dest_file) + + if make_executable and os.name == "posix": + current = dest_file.stat().st_mode + dest_file.chmod((current & ~(stat.S_IXGRP | stat.S_IXOTH)) | stat.S_IXUSR) + + if not skip_copy and logger: + logger.progress(f"deployed {src_file.name} -> {rel_label}", symbol="check") diff --git a/src/apm_cli/integration/skill_sync.py b/src/apm_cli/integration/skill_sync.py new file mode 100644 index 000000000..99349dcaa --- /dev/null +++ b/src/apm_cli/integration/skill_sync.py @@ -0,0 +1,236 @@ +"""Skill sync and cleanup helpers.""" + +from __future__ import annotations + +import logging +import shutil +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from .skill_naming import normalize_skill_name + +_log = logging.getLogger("apm_cli.integration.skill_integrator") + + +def _validate_skill_name(name: str) -> tuple[bool, str]: + """Resolve the public validator lazily to avoid circular imports.""" + from .skill_integrator import validate_skill_name + + return validate_skill_name(name) + + +def _build_skill_prefixes(source: Any) -> tuple[str, ...]: + """Build the valid lockfile prefixes for skill targets.""" + skill_prefixes: list[str] = [] + for target in source: + if not target.supports("skills"): + continue + if target.user_root_resolver is not None: + from apm_cli.integration.copilot_cowork_paths import COWORK_LOCKFILE_PREFIX + + if COWORK_LOCKFILE_PREFIX not in skill_prefixes: + skill_prefixes.append(COWORK_LOCKFILE_PREFIX) + continue + skills_mapping = target.primitives["skills"] + effective_root = skills_mapping.deploy_root or target.root_dir + skill_prefixes.append(f"{effective_root}/skills/") + return tuple(skill_prefixes) + + +def _resolve_managed_skill_target( + rel_path: str, + project_root: Path, + project_root_resolved: Path, + cowork_state: dict[str, Any], +) -> Path | None: + """Resolve a managed skill lockfile path to a filesystem path.""" + from apm_cli.integration.copilot_cowork_paths import COWORK_URI_SCHEME + + if rel_path.startswith(COWORK_URI_SCHEME): + if not cowork_state["resolved"]: + from apm_cli.integration.copilot_cowork_paths import resolve_copilot_cowork_skills_dir + + cowork_state["root"] = resolve_copilot_cowork_skills_dir() + cowork_state["resolved"] = True + if cowork_state["root"] is None: + cowork_state["skipped"] += 1 + return None + from apm_cli.integration.copilot_cowork_paths import from_lockfile_path + + return from_lockfile_path(rel_path, cowork_state["root"]) + + target = project_root / rel_path + if not str(target.resolve()).startswith(str(project_root_resolved)): + return None + return target + + +def _sync_skills_managed_files( + managed_files: set[str], + project_root: Path, + skill_prefix_tuple: tuple[str, ...], + stats: dict[str, int], + source: Any, +) -> None: + """Remove managed skill paths from the deployment manifest.""" + project_root_resolved = project_root.resolve() + cowork_state: dict[str, Any] = {"resolved": False, "root": None, "skipped": 0} + for rel_path in managed_files: + if not rel_path.startswith(skill_prefix_tuple) or ".." in rel_path: + continue + try: + target = _resolve_managed_skill_target( + rel_path, project_root, project_root_resolved, cowork_state + ) + except Exception: + stats["errors"] += 1 + continue + if target is None or not target.exists(): + continue + try: + if target.is_dir(): + shutil.rmtree(target) + else: + target.unlink() + stats["files_removed"] += 1 + except Exception: + stats["errors"] += 1 + + if cowork_state["skipped"] > 0: + from apm_cli.utils.console import _rich_warning + + _rich_warning( + f"Cowork: skipping {cowork_state['skipped']} skill " + f"{'entry' if cowork_state['skipped'] == 1 else 'entries'}" + " -- OneDrive path not detected.\n" + "Run: apm config set copilot-cowork-skills-dir " + "(or set APM_COPILOT_COWORK_SKILLS_DIR)\n" + "to clean up these entries on the next install/uninstall.", + symbol="warning", + ) + + +def _installed_skill_names(apm_package: Any, project_root: Path) -> set[str]: + """Build expected skill directory names from installed packages.""" + installed_skill_names: set[str] = set() + for dep in apm_package.get_apm_dependencies(): + raw_name = dep.repo_url.split("/")[-1] + if dep.is_virtual and dep.virtual_path: + raw_name = dep.virtual_path.split("/")[-1] + is_valid, _ = _validate_skill_name(raw_name) + skill_name = raw_name if is_valid else normalize_skill_name(raw_name) + installed_skill_names.add(skill_name) + + install_path = dep.get_install_path(project_root / "apm_modules") + sub_skills_dir = install_path / ".apm" / "skills" + if sub_skills_dir.is_dir(): + for sub_skill_path in sub_skills_dir.iterdir(): + if sub_skill_path.is_dir() and (sub_skill_path / "SKILL.md").exists(): + raw_sub = sub_skill_path.name + is_valid, _ = _validate_skill_name(raw_sub) + installed_skill_names.add( + raw_sub if is_valid else normalize_skill_name(raw_sub) + ) + return installed_skill_names + + +def _sync_skills_legacy( + apm_package: Any, + project_root: Path, + source: Any, + stats: dict[str, int], + clean_fn: Callable[..., dict[str, int]], +) -> None: + """Run legacy npm-style orphan detection for skills.""" + installed_skill_names = _installed_skill_names(apm_package, project_root) + seen_cleanup_dirs: set[Path] = set() + for target in source: + if not target.supports("skills"): + continue + skills_mapping = target.primitives["skills"] + effective_root = skills_mapping.deploy_root or target.root_dir + if skills_mapping.deploy_root and not (project_root / target.root_dir).is_dir(): + continue + + skills_dir = project_root / effective_root / "skills" + resolved_skills = skills_dir.resolve() + if resolved_skills in seen_cleanup_dirs: + _log.debug("%s -- already processed, skipping cleanup for %s", skills_dir, target.name) + continue + seen_cleanup_dirs.add(resolved_skills) + + if skills_dir.exists(): + result = clean_fn(skills_dir, installed_skill_names, project_root=project_root) + stats["files_removed"] += result["files_removed"] + stats["errors"] += result["errors"] + + +def sync_integration( + link_rewriter: Any, + apm_package: Any, + project_root: Path, + managed_files: set[str] | None = None, + targets: Any = None, +) -> dict[str, int]: + """Sync skill directories with currently installed packages.""" + from apm_cli.integration.targets import KNOWN_TARGETS + + source = targets if targets is not None else list(KNOWN_TARGETS.values()) + stats = {"files_removed": 0, "errors": 0} + skill_prefix_tuple = _build_skill_prefixes(source) + if managed_files is not None: + _sync_skills_managed_files(managed_files, project_root, skill_prefix_tuple, stats, source) + return stats + _sync_skills_legacy( + apm_package, project_root, source, stats, link_rewriter._clean_orphaned_skills + ) + return stats + + +def _clean_orphaned_skills( + skills_dir: Path, + installed_skill_names: set[str], + *, + project_root: Path | None = None, + get_lockfile_owned_fn: Callable[[Path], set[str]] | None = None, +) -> dict[str, int]: + """Clean orphaned skills from a skills directory.""" + files_removed = 0 + errors = 0 + is_agents_dir = skills_dir.parent.name == ".agents" + lockfile_owned_skills: set[str] | None = None + if is_agents_dir and project_root is not None: + owner_fn = get_lockfile_owned_fn or _get_lockfile_owned_agent_skills + lockfile_owned_skills = owner_fn(project_root) + + for skill_subdir in skills_dir.iterdir(): + if not skill_subdir.is_dir() or skill_subdir.name in installed_skill_names: + continue + if lockfile_owned_skills is not None and skill_subdir.name not in lockfile_owned_skills: + continue + try: + shutil.rmtree(skill_subdir) + files_removed += 1 + except Exception: + errors += 1 + return {"files_removed": files_removed, "errors": errors} + + +def _get_lockfile_owned_agent_skills(project_root: Path) -> set[str]: + """Return skill names under .agents/skills in the lockfile.""" + owned: set[str] = set() + try: + from apm_cli.deps.lockfile import LockFile, get_lockfile_path + + lockfile = LockFile.read(get_lockfile_path(project_root)) + if lockfile and lockfile.dependencies: + for dep in lockfile.dependencies.values(): + for deployed_file in dep.deployed_files: + if deployed_file.startswith(".agents/skills/"): + parts = deployed_file[len(".agents/skills/") :].split("/") + if parts and parts[0]: + owned.add(parts[0]) + except (FileNotFoundError, OSError, KeyError, ValueError, TypeError, AttributeError) as exc: + _log.debug("Could not read lockfile for ownership check: %s", exc) + return owned From 73947a05682c76d6a8dc11c5a380e10d857f6ae9 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 03:29:19 +0200 Subject: [PATCH 10/21] refactor(integration): split base_integrator I/O primitive, fix return gate Move the TOCTOU-safe _read_bytes_no_follow + _SymlinkRaceError into a cohesive base_integrator_io sibling (re-exported, so the original import/patch path is preserved). Collapse validate_deploy_path's tail to a single return (PLR0911 9->8) and reuse the existing partition_bucket_key helper in partition_managed_files instead of inlining the alias lookup twice. base_integrator.py 812->769 lines (under the 800 gate). Part of #1078 (Strangler Stage 2). Behaviour-preserving; full suites green. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/integration/base_integrator.py | 55 ++----------------- src/apm_cli/integration/base_integrator_io.py | 52 ++++++++++++++++++ 2 files changed, 58 insertions(+), 49 deletions(-) create mode 100644 src/apm_cli/integration/base_integrator_io.py diff --git a/src/apm_cli/integration/base_integrator.py b/src/apm_cli/integration/base_integrator.py index e5f890297..6a1feabf3 100644 --- a/src/apm_cli/integration/base_integrator.py +++ b/src/apm_cli/integration/base_integrator.py @@ -1,7 +1,5 @@ """Base integrator with shared collision detection and sync logic.""" -import errno -import os import re from dataclasses import dataclass from pathlib import Path @@ -10,10 +8,9 @@ from apm_cli.primitives.discovery import discover_primitives from apm_cli.utils.console import _rich_warning - -class _SymlinkRaceError(OSError): - """Raised by ``_read_bytes_no_follow`` when the path becomes a symlink - between the pre-check and the open(). Caught locally; never bubbles.""" +# Re-exported so the original ``base_integrator`` import/patch path keeps +# working after the TOCTOU-safe read primitive moved to a sibling module. +from .base_integrator_io import _read_bytes_no_follow, _SymlinkRaceError @dataclass @@ -46,42 +43,6 @@ class IntegrationResult: files_adopted: int = 0 -def _read_bytes_no_follow(path: Path) -> bytes: - """Read *path* with ``O_NOFOLLOW`` semantics where supported. - - On POSIX, opens the file with ``os.O_NOFOLLOW`` so the kernel - rejects the open atomically if the final path component is a - symlink. This closes the TOCTOU race between - ``Path.is_symlink()`` and ``Path.read_bytes()`` exploited by a - co-tenant who can swap files for symlinks. - - On Windows (no ``O_NOFOLLOW``), falls back to a plain read; the - caller's upfront ``is_symlink()`` check plus ``ensure_path_within`` - at the integrator call sites provide the containment guarantee. - """ - flags = os.O_RDONLY | getattr(os, "O_BINARY", 0) - nofollow = getattr(os, "O_NOFOLLOW", 0) - flags |= nofollow - try: - fd = os.open(str(path), flags) - except OSError as exc: - # ELOOP is the canonical errno for "O_NOFOLLOW refused to open - # a symlink"; some Linux kernels return EMLINK or ELOOP-equivalent. - if nofollow and exc.errno in (errno.ELOOP, getattr(errno, "EMLINK", -1)): - raise _SymlinkRaceError(exc.errno, f"Refused to follow symlink at {path}") from exc - raise - try: - chunks: list[bytes] = [] - while True: - chunk = os.read(fd, 65536) - if not chunk: - break - chunks.append(chunk) - return b"".join(chunks) - finally: - os.close(fd) - - class BaseIntegrator: """Shared infrastructure for file-level integrators. @@ -358,11 +319,9 @@ def validate_deploy_path( return False target = project_root / rel_path try: - if not target.resolve().is_relative_to(project_root.resolve()): - return False + return target.resolve().is_relative_to(project_root.resolve()) except (ValueError, OSError): return False - return True # Backward-compat aliases mapping raw ``{prim}_{target}`` keys to # the bucket names that existing callers expect. Shared between @@ -436,8 +395,7 @@ def partition_managed_files( COPILOT_APP_LOCKFILE_PREFIX, ) - raw_key = f"{prim_name}_{target.name}" - bucket_key = BaseIntegrator._BUCKET_ALIASES.get(raw_key, raw_key) + bucket_key = BaseIntegrator.partition_bucket_key(prim_name, target.name) if bucket_key not in buckets: buckets[bucket_key] = set() prefix_map[COPILOT_APP_LOCKFILE_PREFIX] = bucket_key @@ -453,8 +411,7 @@ def partition_managed_files( elif prim_name == "hooks": hook_prefixes.append(prefix) else: - raw_key = f"{prim_name}_{target.name}" - bucket_key = BaseIntegrator._BUCKET_ALIASES.get(raw_key, raw_key) + bucket_key = BaseIntegrator.partition_bucket_key(prim_name, target.name) if bucket_key not in buckets: buckets[bucket_key] = set() prefix_map[prefix] = bucket_key diff --git a/src/apm_cli/integration/base_integrator_io.py b/src/apm_cli/integration/base_integrator_io.py new file mode 100644 index 000000000..42940eec9 --- /dev/null +++ b/src/apm_cli/integration/base_integrator_io.py @@ -0,0 +1,52 @@ +"""TOCTOU-safe byte reads for integrator content-identity checks. + +Extracted from ``base_integrator`` so the security-sensitive symlink-race +primitive lives in one cohesive module. ``base_integrator`` re-exports both +names, so ``apm_cli.integration.base_integrator._read_bytes_no_follow`` and +``_SymlinkRaceError`` remain importable / patchable from their original path. +""" + +import errno +import os +from pathlib import Path + + +class _SymlinkRaceError(OSError): + """Raised by ``_read_bytes_no_follow`` when the path becomes a symlink + between the pre-check and the open(). Caught locally; never bubbles.""" + + +def _read_bytes_no_follow(path: Path) -> bytes: + """Read *path* with ``O_NOFOLLOW`` semantics where supported. + + On POSIX, opens the file with ``os.O_NOFOLLOW`` so the kernel + rejects the open atomically if the final path component is a + symlink. This closes the TOCTOU race between + ``Path.is_symlink()`` and ``Path.read_bytes()`` exploited by a + co-tenant who can swap files for symlinks. + + On Windows (no ``O_NOFOLLOW``), falls back to a plain read; the + caller's upfront ``is_symlink()`` check plus ``ensure_path_within`` + at the integrator call sites provide the containment guarantee. + """ + flags = os.O_RDONLY | getattr(os, "O_BINARY", 0) + nofollow = getattr(os, "O_NOFOLLOW", 0) + flags |= nofollow + try: + fd = os.open(str(path), flags) + except OSError as exc: + # ELOOP is the canonical errno for "O_NOFOLLOW refused to open + # a symlink"; some Linux kernels return EMLINK or ELOOP-equivalent. + if nofollow and exc.errno in (errno.ELOOP, getattr(errno, "EMLINK", -1)): + raise _SymlinkRaceError(exc.errno, f"Refused to follow symlink at {path}") from exc + raise + try: + chunks: list[bytes] = [] + while True: + chunk = os.read(fd, 65536) + if not chunk: + break + chunks.append(chunk) + return b"".join(chunks) + finally: + os.close(fd) From 3e269eabfacd803a75a1310c629599e7a8cadc3d Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 03:38:31 +0200 Subject: [PATCH 11/21] refactor(integration): extract target-profile dataclasses to sibling Move RULE_FORMATS, PrimitiveMapping and TargetProfile out of targets.py into a focused target_profile module; targets.py re-exports all three so the 178x-imported public surface (and KNOWN_TARGETS, which stays put) is unchanged. No behaviour change; the dataclasses are self-contained. targets.py 1036->631 lines (under the 800 gate). Part of #1078 (Strangler Stage 2). Full unit, acceptance and integration suites green. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/integration/target_profile.py | 417 ++++++++++++++++++++++ src/apm_cli/integration/targets.py | 409 +-------------------- 2 files changed, 419 insertions(+), 407 deletions(-) create mode 100644 src/apm_cli/integration/target_profile.py diff --git a/src/apm_cli/integration/target_profile.py b/src/apm_cli/integration/target_profile.py new file mode 100644 index 000000000..59bc38693 --- /dev/null +++ b/src/apm_cli/integration/target_profile.py @@ -0,0 +1,417 @@ +"""Target-profile dataclasses for multi-tool integration. + +``PrimitiveMapping`` and ``TargetProfile`` describe where each APM primitive +type is deployed in a target tool. They are extracted from ``targets`` so the +data registry (``KNOWN_TARGETS``) and the resolver functions live in a focused +module; ``targets`` re-exports all three public names, so the original import +path ``apm_cli.integration.targets`` is preserved. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass + +RULE_FORMATS: frozenset[str] = frozenset({"cursor_rules", "claude_rules", "windsurf_rules"}) +"""Canonical set of format-transforming rule ``format_id``s. + +Single home for "which instruction formats transform their source on +deploy". A mapping with one of these ``format_id``s MUST set +``output_compare=True`` (enforced by :meth:`PrimitiveMapping.__post_init__`), +and :meth:`InstructionIntegrator._render_instruction` dispatches on this same +set. Adding a new rule format means: add it here, set ``output_compare=True`` +on the mapping, and add a ``_convert_to_*`` branch in ``_render_instruction``. +""" + + +@dataclass(frozen=True) +class PrimitiveMapping: + """Where a single primitive type is deployed in a target tool.""" + + subdir: str + """Subdirectory under the target root (e.g. ``"rules"``, ``"agents"``).""" + + extension: str + """File extension or suffix for deployed files + (e.g. ``".mdc"``, ``".agent.md"``).""" + + format_id: str + """Opaque tag used by integrators to select the right + content transformer (e.g. ``"cursor_rules"``).""" + + deploy_root: str | None = None + """Override *root_dir* for this primitive only. + + When set, integrators use ``deploy_root`` instead of + ``target.root_dir`` to compute the deploy directory. + For example, Codex skills deploy to ``.agents/`` (cross-tool + directory) rather than ``.codex/``. Default ``None`` preserves + existing behavior for all other targets. + """ + + output_compare: bool = False + """Whether this primitive's deployed file is a format-transform of its + source, so the integrator must adopt/collision-check against the + rendered *output* rather than the source bytes. + + This is the single source of truth for the rule-dir formats + (``cursor_rules``, ``claude_rules``, ``windsurf_rules``). When ``True``: + + * The deployed file is never byte-identical to its source, so a + source-based adopt always misses (apm#1662). The integrator instead + compares against the rendered output and (re)writes when stale. + * The target is APM-owned per-file (``target_name`` derives 1:1 from a + source instruction), so ``managed_files`` is NOT consulted -- any + existing file at the target path is APM's, not user-authored. + * The deployed filename is renamed from ``.instructions.md`` to + ``{extension}``. + + Adding a future format-transformed rule type requires two coordinated + edits: set ``output_compare=True`` here (add the ``format_id`` to + ``RULE_FORMATS``) *and* add the matching ``_convert_to_*`` branch to + :meth:`InstructionIntegrator._render_instruction`, which dispatches on the + ``format_id`` to perform the transform. + """ + + def __post_init__(self) -> None: + """Keep ``output_compare`` and :data:`RULE_FORMATS` in lockstep. + + A rule ``format_id`` that transforms its source MUST compare against + the rendered output; otherwise the integrator would fall through to a + verbatim copy and silently deploy untransformed content (apm#1662). + The converse is also enforced so the canonical set stays the one home + for "which formats transform". + """ + is_rule = self.format_id in RULE_FORMATS + if is_rule and not self.output_compare: + raise ValueError( + f"PrimitiveMapping(format_id={self.format_id!r}) is a rule " + f"format ({sorted(RULE_FORMATS)}) and must set " + "output_compare=True; otherwise its source is deployed " + "untransformed." + ) + if self.output_compare and not is_rule: + raise ValueError( + f"PrimitiveMapping(format_id={self.format_id!r}) sets " + "output_compare=True but is not a known rule format " + f"({sorted(RULE_FORMATS)}); add it to RULE_FORMATS and a " + "_render_instruction branch, or unset output_compare." + ) + + +@dataclass(frozen=True) +class TargetProfile: + """Capabilities and layout of a single target tool.""" + + name: str + """Short unique identifier (``"copilot"``, ``"claude"``, ``"cursor"``).""" + + root_dir: str + """Top-level directory in the workspace (e.g. ``".github"``).""" + + primitives: dict[str, PrimitiveMapping] + """Mapping from APM primitive name -> deployment spec. + + Only primitives listed here are deployed to this target. + """ + + auto_create: bool = True + """Create *root_dir* if it does not exist (used during fallback or + explicit ``--target`` selection).""" + + detect_by_dir: bool = True + """If ``True``, only deploy when *root_dir* already exists.""" + + # -- user-scope metadata -------------------------------------------------- + + user_supported: bool | str = False + """Whether this target supports user-scope (``~/``) deployment. + + * ``True`` -- fully supported (all primitives work at user scope). + * ``"partial"`` -- some primitives work, others do not. + * ``False`` -- not supported at user scope. + """ + + user_root_dir: str | None = None + """Override for *root_dir* at user scope. + + When ``None`` the normal *root_dir* is used at both project and user + scope. Set this when the tool reads from a different directory at + user level (e.g. Copilot CLI uses ``~/.copilot/`` instead of + ``~/.github/``). + """ + + unsupported_user_primitives: tuple[str, ...] = () + """Primitives that are **not** available at user scope even when the + target itself is partially supported.""" + + user_primitive_overrides: dict[str, PrimitiveMapping] | None = None + """Primitive mapping overrides applied at user scope only. + + When set, these entries replace the corresponding entries in + ``primitives`` after ``unsupported_user_primitives`` filtering in + ``for_scope(user_scope=True)``. + + Use this when a primitive must be deployed to a *different* location + or via a *different* transform at user scope. The canonical example + is the Copilot target: at project scope each ``*.instructions.md`` + file deploys individually to ``.github/instructions/``; at user scope + they are all concatenated into the single file that Copilot CLI reads + (``~/.copilot/copilot-instructions.md``). + """ + + user_root_resolver: Callable[[], Path | None] | None = None # noqa: F821 + """Optional callable that resolves the deploy root at runtime. + + When set, ``for_scope(user_scope=True)`` calls this resolver instead of + using a static ``user_root_dir``. If the resolver returns ``None`` + the target is unavailable in the current environment (same semantics + as ``user_supported=False``). + + The callable must be hashable by reference (plain function or + staticmethod) so ``frozen=True`` is preserved. + """ + + resolved_deploy_root: Path | None = None # noqa: F821 + """Absolute deploy root populated by ``for_scope()`` when + ``user_root_resolver`` returns a concrete ``Path``. + + Downstream code uses ``deploy_path()`` to route filesystem I/O + through this root instead of ``project_root / root_dir``. + """ + + requires_flag: str | None = None + """When set, the target is only returned by ``active_targets`` / + ``active_targets_user_scope`` / ``resolve_targets`` when the named + experimental flag is enabled. The target entry is always visible + in ``KNOWN_TARGETS`` for tooling introspection. + """ + + scope_invariant_resolver: bool = False + """When True, ``user_root_resolver`` runs in BOTH project and user + scope (the resolved deploy root does not depend on install intent). + + Set this for targets whose deploy root is a user-machine resource + that exists regardless of who triggered the install -- e.g. + ``copilot-app`` (the GitHub Copilot desktop App's SQLite DB at + ``~/.copilot/data.db`` is the same path whether a team-shared + workflow comes in via project ``apm.yml`` or user-scope ``--global``). + + Contrast with cowork, where the OneDrive deploy root only makes + sense at user scope; project-scope cowork is intentionally rejected. + """ + + generated_files: tuple[str, ...] = () + """Additional generated files associated with this target. + + These are compile-time outputs that live at the target root but are not + deployed via primitive integrators, e.g. Copilot's root + ``copilot-instructions.md`` file. + """ + + # -- subsystem-specific metadata (single source of truth) ----------------- + # + # The four fields below centralize per-target knowledge that previously + # lived in scattered module-local dicts and ``if/elif`` chains + # (see ``bundle/lockfile_enrichment.py``, ``core/conflict_detector.py``, + # ``commands/compile/cli.py``, ``install/services.py``). Adding a new + # target now requires only a single ``KNOWN_TARGETS`` entry. + + pack_prefixes: tuple[str, ...] = () + """Path prefixes that identify this target's deployed files when packing. + + When empty, ``bundle.lockfile_enrichment`` derives ``(f"{root_dir}/",)`` + from :attr:`root_dir`. Override only when the target deploys to multiple + top-level directories (e.g. Codex deploys both ``.codex/`` and + ``.agents/``). + """ + + compile_family: str | None = None + """Compiler family this target belongs to for ``apm compile`` routing. + + Recognised values: + + * ``"vscode"`` -- emits ``.github/copilot-instructions.md`` *and* AGENTS.md. + * ``"claude"`` -- emits ``CLAUDE.md`` and ``.claude/rules/`` files. + * ``"gemini"`` -- emits ``GEMINI.md``. + * ``"agents"`` -- emits AGENTS.md only (cursor, opencode, codex, windsurf). + * ``None`` -- target has no compile output (agent-skills, copilot-cowork). + + Used by :func:`apm_cli.commands.compile.cli._resolve_compile_target` to + derive multi-target routing from the registry instead of hard-coded sets. + """ + + hooks_config_display: str | None = None + """Human-readable path shown in the install log for hooks integration. + + e.g. ``".claude/settings.json"`` for Claude (hooks merge into a settings + file rather than landing in their own subdir). When ``None``, the + install log falls back to the generic ``"{root}/{subdir}/"`` formula. + """ + + @property + def prefix(self) -> str: + """Return the path prefix for this target (e.g. ``".github/"``). + + Used by ``validate_deploy_path`` and ``partition_managed_files``. + """ + return f"{self.root_dir}/" + + @property + def effective_pack_prefixes(self) -> tuple[str, ...]: + """Return the path prefixes used by pack-time file filtering. + + Falls back to ``(self.prefix,)`` when :attr:`pack_prefixes` is empty, + so most targets need not override the field explicitly. + """ + return self.pack_prefixes if self.pack_prefixes else (self.prefix,) + + def supports(self, primitive: str) -> bool: + """Return ``True`` if this target accepts *primitive*.""" + return primitive in self.primitives + + def effective_root(self, user_scope: bool = False) -> str: + """Return the root directory for the given scope. + + At user scope, returns *user_root_dir* when set, otherwise + falls back to the standard *root_dir*. + """ + if user_scope and self.user_root_dir: + return self.user_root_dir + return self.root_dir + + def supports_at_user_scope(self, primitive: str) -> bool: + """Return ``True`` if *primitive* can be deployed at user scope.""" + if not self.user_supported: + return False + if primitive in self.unsupported_user_primitives: + return False + return primitive in self.primitives + + def deploy_path(self, project_root: Path, *parts: str) -> Path: # noqa: F821 + """Return the filesystem path for deployment. + + When ``resolved_deploy_root`` is set (dynamic-root targets like + cowork), the path is rooted there. Otherwise falls back to the + standard ``project_root / root_dir`` pattern. + + Args: + project_root: Workspace or home directory root. + *parts: Additional path segments (e.g. ``"skills"``, ``"my-skill"``). + """ + if self.resolved_deploy_root is not None: + return ( + self.resolved_deploy_root.joinpath(*parts) if parts else self.resolved_deploy_root + ) + base = project_root / self.root_dir + return base.joinpath(*parts) if parts else base + + def for_scope(self, user_scope: bool = False) -> TargetProfile | None: + """Return a scope-resolved copy of this profile. + + When *user_scope* is ``False``, returns ``self`` unchanged. + + When *user_scope* is ``True``: + - If ``user_root_resolver`` is set, calls it. Returns ``None`` + when the resolver returns ``None`` (target unavailable). + Otherwise returns a copy with ``resolved_deploy_root`` set and + primitives filtered for user scope. + - Returns ``None`` if this target does not support user scope. + - Otherwise returns a frozen copy with ``root_dir`` set to + ``user_root_dir`` (or left unchanged when ``user_root_dir`` + is ``None``) and ``primitives`` filtered to exclude entries + listed in ``unsupported_user_primitives``. + + This is the **single place** where scope resolution happens. + All downstream code reads ``target.root_dir`` directly. + """ + if not user_scope: + # Most targets have no project-scope resolver work to do. + # The scope_invariant_resolver opt-in lets a target whose + # deploy root is a user-machine resource (e.g. copilot-app's + # ~/.copilot/data.db) populate resolved_deploy_root even when + # the install intent is project-scope. Downstream lockfile + # enrichment then routes via the dynamic-root URI path. + if self.scope_invariant_resolver and self.user_root_resolver is not None: + resolved_root = self.user_root_resolver() + if resolved_root is None: + return None + from dataclasses import replace + + return replace(self, resolved_deploy_root=resolved_root) + return self + + from dataclasses import replace + + # --- dynamic-root resolver path (cowork) --- + if self.user_root_resolver is not None: + resolved_root = self.user_root_resolver() + if resolved_root is None: + return None + if self.unsupported_user_primitives: + filtered = { + k: v + for k, v in self.primitives.items() + if k not in self.unsupported_user_primitives + } + else: + filtered = self.primitives + if self.user_primitive_overrides: + merged = dict(filtered) + merged.update(self.user_primitive_overrides) + filtered = merged + return replace( + self, + primitives=filtered, + resolved_deploy_root=resolved_root, + ) + + if not self.user_supported: + return None + + new_root = self.user_root_dir or self.root_dir + + # Claude Code honors CLAUDE_CONFIG_DIR (default ~/.claude); mirror + # that at user scope so `apm install -g` lands where Claude reads. + if self.name == "claude": + import os + from pathlib import Path + + env = os.environ.get("CLAUDE_CONFIG_DIR", "").strip() + if env: + # ``resolve`` collapses ``..`` so traversal segments cannot + # leak into ``root_dir`` and escape ``project_root / root_dir``. + abs_path = Path(env).expanduser().resolve(strict=False) + home = Path.home().resolve(strict=False) + try: + # Keep ``root_dir`` home-relative so cleanup prefix matching holds. + new_root = abs_path.relative_to(home).as_posix() + except ValueError: + # Fallback: when CLAUDE_CONFIG_DIR points outside $HOME we + # store an absolute path. ``pathlib.Path / `` is + # ```` so deploy + cleanup write to the right + # place. Caveat: the lockfile path translator + # (``install/services._deployed_path_entry``) calls + # ``relative_to(project_root)`` and raises ``RuntimeError`` + # for out-of-tree paths that are not dynamic-root targets. + # Today this is unreachable because user-scope CLAUDE + # installs do not flow through that translator, but any + # future refactor that lockfiles user-scope deploys must + # treat absolute ``root_dir`` as a dynamic-root case. + new_root = str(abs_path) + + if self.unsupported_user_primitives: + filtered = { + k: v + for k, v in self.primitives.items() + if k not in self.unsupported_user_primitives + } + else: + filtered = self.primitives + + if self.user_primitive_overrides: + merged = dict(filtered) + merged.update(self.user_primitive_overrides) + filtered = merged + + return replace(self, root_dir=new_root, primitives=filtered) diff --git a/src/apm_cli/integration/targets.py b/src/apm_cli/integration/targets.py index 0949f175d..89908471d 100644 --- a/src/apm_cli/integration/targets.py +++ b/src/apm_cli/integration/targets.py @@ -18,413 +18,8 @@ from __future__ import annotations -from collections.abc import Callable -from dataclasses import dataclass - -RULE_FORMATS: frozenset[str] = frozenset({"cursor_rules", "claude_rules", "windsurf_rules"}) -"""Canonical set of format-transforming rule ``format_id``s. - -Single home for "which instruction formats transform their source on -deploy". A mapping with one of these ``format_id``s MUST set -``output_compare=True`` (enforced by :meth:`PrimitiveMapping.__post_init__`), -and :meth:`InstructionIntegrator._render_instruction` dispatches on this same -set. Adding a new rule format means: add it here, set ``output_compare=True`` -on the mapping, and add a ``_convert_to_*`` branch in ``_render_instruction``. -""" - - -@dataclass(frozen=True) -class PrimitiveMapping: - """Where a single primitive type is deployed in a target tool.""" - - subdir: str - """Subdirectory under the target root (e.g. ``"rules"``, ``"agents"``).""" - - extension: str - """File extension or suffix for deployed files - (e.g. ``".mdc"``, ``".agent.md"``).""" - - format_id: str - """Opaque tag used by integrators to select the right - content transformer (e.g. ``"cursor_rules"``).""" - - deploy_root: str | None = None - """Override *root_dir* for this primitive only. - - When set, integrators use ``deploy_root`` instead of - ``target.root_dir`` to compute the deploy directory. - For example, Codex skills deploy to ``.agents/`` (cross-tool - directory) rather than ``.codex/``. Default ``None`` preserves - existing behavior for all other targets. - """ - - output_compare: bool = False - """Whether this primitive's deployed file is a format-transform of its - source, so the integrator must adopt/collision-check against the - rendered *output* rather than the source bytes. - - This is the single source of truth for the rule-dir formats - (``cursor_rules``, ``claude_rules``, ``windsurf_rules``). When ``True``: - - * The deployed file is never byte-identical to its source, so a - source-based adopt always misses (apm#1662). The integrator instead - compares against the rendered output and (re)writes when stale. - * The target is APM-owned per-file (``target_name`` derives 1:1 from a - source instruction), so ``managed_files`` is NOT consulted -- any - existing file at the target path is APM's, not user-authored. - * The deployed filename is renamed from ``.instructions.md`` to - ``{extension}``. - - Adding a future format-transformed rule type requires two coordinated - edits: set ``output_compare=True`` here (add the ``format_id`` to - ``RULE_FORMATS``) *and* add the matching ``_convert_to_*`` branch to - :meth:`InstructionIntegrator._render_instruction`, which dispatches on the - ``format_id`` to perform the transform. - """ - - def __post_init__(self) -> None: - """Keep ``output_compare`` and :data:`RULE_FORMATS` in lockstep. - - A rule ``format_id`` that transforms its source MUST compare against - the rendered output; otherwise the integrator would fall through to a - verbatim copy and silently deploy untransformed content (apm#1662). - The converse is also enforced so the canonical set stays the one home - for "which formats transform". - """ - is_rule = self.format_id in RULE_FORMATS - if is_rule and not self.output_compare: - raise ValueError( - f"PrimitiveMapping(format_id={self.format_id!r}) is a rule " - f"format ({sorted(RULE_FORMATS)}) and must set " - "output_compare=True; otherwise its source is deployed " - "untransformed." - ) - if self.output_compare and not is_rule: - raise ValueError( - f"PrimitiveMapping(format_id={self.format_id!r}) sets " - "output_compare=True but is not a known rule format " - f"({sorted(RULE_FORMATS)}); add it to RULE_FORMATS and a " - "_render_instruction branch, or unset output_compare." - ) - - -@dataclass(frozen=True) -class TargetProfile: - """Capabilities and layout of a single target tool.""" - - name: str - """Short unique identifier (``"copilot"``, ``"claude"``, ``"cursor"``).""" - - root_dir: str - """Top-level directory in the workspace (e.g. ``".github"``).""" - - primitives: dict[str, PrimitiveMapping] - """Mapping from APM primitive name -> deployment spec. - - Only primitives listed here are deployed to this target. - """ - - auto_create: bool = True - """Create *root_dir* if it does not exist (used during fallback or - explicit ``--target`` selection).""" - - detect_by_dir: bool = True - """If ``True``, only deploy when *root_dir* already exists.""" - - # -- user-scope metadata -------------------------------------------------- - - user_supported: bool | str = False - """Whether this target supports user-scope (``~/``) deployment. - - * ``True`` -- fully supported (all primitives work at user scope). - * ``"partial"`` -- some primitives work, others do not. - * ``False`` -- not supported at user scope. - """ - - user_root_dir: str | None = None - """Override for *root_dir* at user scope. - - When ``None`` the normal *root_dir* is used at both project and user - scope. Set this when the tool reads from a different directory at - user level (e.g. Copilot CLI uses ``~/.copilot/`` instead of - ``~/.github/``). - """ - - unsupported_user_primitives: tuple[str, ...] = () - """Primitives that are **not** available at user scope even when the - target itself is partially supported.""" - - user_primitive_overrides: dict[str, PrimitiveMapping] | None = None - """Primitive mapping overrides applied at user scope only. - - When set, these entries replace the corresponding entries in - ``primitives`` after ``unsupported_user_primitives`` filtering in - ``for_scope(user_scope=True)``. - - Use this when a primitive must be deployed to a *different* location - or via a *different* transform at user scope. The canonical example - is the Copilot target: at project scope each ``*.instructions.md`` - file deploys individually to ``.github/instructions/``; at user scope - they are all concatenated into the single file that Copilot CLI reads - (``~/.copilot/copilot-instructions.md``). - """ - - user_root_resolver: Callable[[], Path | None] | None = None # noqa: F821 - """Optional callable that resolves the deploy root at runtime. - - When set, ``for_scope(user_scope=True)`` calls this resolver instead of - using a static ``user_root_dir``. If the resolver returns ``None`` - the target is unavailable in the current environment (same semantics - as ``user_supported=False``). - - The callable must be hashable by reference (plain function or - staticmethod) so ``frozen=True`` is preserved. - """ - - resolved_deploy_root: Path | None = None # noqa: F821 - """Absolute deploy root populated by ``for_scope()`` when - ``user_root_resolver`` returns a concrete ``Path``. - - Downstream code uses ``deploy_path()`` to route filesystem I/O - through this root instead of ``project_root / root_dir``. - """ - - requires_flag: str | None = None - """When set, the target is only returned by ``active_targets`` / - ``active_targets_user_scope`` / ``resolve_targets`` when the named - experimental flag is enabled. The target entry is always visible - in ``KNOWN_TARGETS`` for tooling introspection. - """ - - scope_invariant_resolver: bool = False - """When True, ``user_root_resolver`` runs in BOTH project and user - scope (the resolved deploy root does not depend on install intent). - - Set this for targets whose deploy root is a user-machine resource - that exists regardless of who triggered the install -- e.g. - ``copilot-app`` (the GitHub Copilot desktop App's SQLite DB at - ``~/.copilot/data.db`` is the same path whether a team-shared - workflow comes in via project ``apm.yml`` or user-scope ``--global``). - - Contrast with cowork, where the OneDrive deploy root only makes - sense at user scope; project-scope cowork is intentionally rejected. - """ - - generated_files: tuple[str, ...] = () - """Additional generated files associated with this target. - - These are compile-time outputs that live at the target root but are not - deployed via primitive integrators, e.g. Copilot's root - ``copilot-instructions.md`` file. - """ - - # -- subsystem-specific metadata (single source of truth) ----------------- - # - # The four fields below centralize per-target knowledge that previously - # lived in scattered module-local dicts and ``if/elif`` chains - # (see ``bundle/lockfile_enrichment.py``, ``core/conflict_detector.py``, - # ``commands/compile/cli.py``, ``install/services.py``). Adding a new - # target now requires only a single ``KNOWN_TARGETS`` entry. - - pack_prefixes: tuple[str, ...] = () - """Path prefixes that identify this target's deployed files when packing. - - When empty, ``bundle.lockfile_enrichment`` derives ``(f"{root_dir}/",)`` - from :attr:`root_dir`. Override only when the target deploys to multiple - top-level directories (e.g. Codex deploys both ``.codex/`` and - ``.agents/``). - """ - - compile_family: str | None = None - """Compiler family this target belongs to for ``apm compile`` routing. - - Recognised values: - - * ``"vscode"`` -- emits ``.github/copilot-instructions.md`` *and* AGENTS.md. - * ``"claude"`` -- emits ``CLAUDE.md`` and ``.claude/rules/`` files. - * ``"gemini"`` -- emits ``GEMINI.md``. - * ``"agents"`` -- emits AGENTS.md only (cursor, opencode, codex, windsurf). - * ``None`` -- target has no compile output (agent-skills, copilot-cowork). - - Used by :func:`apm_cli.commands.compile.cli._resolve_compile_target` to - derive multi-target routing from the registry instead of hard-coded sets. - """ - - hooks_config_display: str | None = None - """Human-readable path shown in the install log for hooks integration. - - e.g. ``".claude/settings.json"`` for Claude (hooks merge into a settings - file rather than landing in their own subdir). When ``None``, the - install log falls back to the generic ``"{root}/{subdir}/"`` formula. - """ - - @property - def prefix(self) -> str: - """Return the path prefix for this target (e.g. ``".github/"``). - - Used by ``validate_deploy_path`` and ``partition_managed_files``. - """ - return f"{self.root_dir}/" - - @property - def effective_pack_prefixes(self) -> tuple[str, ...]: - """Return the path prefixes used by pack-time file filtering. - - Falls back to ``(self.prefix,)`` when :attr:`pack_prefixes` is empty, - so most targets need not override the field explicitly. - """ - return self.pack_prefixes if self.pack_prefixes else (self.prefix,) - - def supports(self, primitive: str) -> bool: - """Return ``True`` if this target accepts *primitive*.""" - return primitive in self.primitives - - def effective_root(self, user_scope: bool = False) -> str: - """Return the root directory for the given scope. - - At user scope, returns *user_root_dir* when set, otherwise - falls back to the standard *root_dir*. - """ - if user_scope and self.user_root_dir: - return self.user_root_dir - return self.root_dir - - def supports_at_user_scope(self, primitive: str) -> bool: - """Return ``True`` if *primitive* can be deployed at user scope.""" - if not self.user_supported: - return False - if primitive in self.unsupported_user_primitives: - return False - return primitive in self.primitives - - def deploy_path(self, project_root: Path, *parts: str) -> Path: # noqa: F821 - """Return the filesystem path for deployment. - - When ``resolved_deploy_root`` is set (dynamic-root targets like - cowork), the path is rooted there. Otherwise falls back to the - standard ``project_root / root_dir`` pattern. - - Args: - project_root: Workspace or home directory root. - *parts: Additional path segments (e.g. ``"skills"``, ``"my-skill"``). - """ - if self.resolved_deploy_root is not None: - return ( - self.resolved_deploy_root.joinpath(*parts) if parts else self.resolved_deploy_root - ) - base = project_root / self.root_dir - return base.joinpath(*parts) if parts else base - - def for_scope(self, user_scope: bool = False) -> TargetProfile | None: - """Return a scope-resolved copy of this profile. - - When *user_scope* is ``False``, returns ``self`` unchanged. - - When *user_scope* is ``True``: - - If ``user_root_resolver`` is set, calls it. Returns ``None`` - when the resolver returns ``None`` (target unavailable). - Otherwise returns a copy with ``resolved_deploy_root`` set and - primitives filtered for user scope. - - Returns ``None`` if this target does not support user scope. - - Otherwise returns a frozen copy with ``root_dir`` set to - ``user_root_dir`` (or left unchanged when ``user_root_dir`` - is ``None``) and ``primitives`` filtered to exclude entries - listed in ``unsupported_user_primitives``. - - This is the **single place** where scope resolution happens. - All downstream code reads ``target.root_dir`` directly. - """ - if not user_scope: - # Most targets have no project-scope resolver work to do. - # The scope_invariant_resolver opt-in lets a target whose - # deploy root is a user-machine resource (e.g. copilot-app's - # ~/.copilot/data.db) populate resolved_deploy_root even when - # the install intent is project-scope. Downstream lockfile - # enrichment then routes via the dynamic-root URI path. - if self.scope_invariant_resolver and self.user_root_resolver is not None: - resolved_root = self.user_root_resolver() - if resolved_root is None: - return None - from dataclasses import replace - - return replace(self, resolved_deploy_root=resolved_root) - return self - - from dataclasses import replace - - # --- dynamic-root resolver path (cowork) --- - if self.user_root_resolver is not None: - resolved_root = self.user_root_resolver() - if resolved_root is None: - return None - if self.unsupported_user_primitives: - filtered = { - k: v - for k, v in self.primitives.items() - if k not in self.unsupported_user_primitives - } - else: - filtered = self.primitives - if self.user_primitive_overrides: - merged = dict(filtered) - merged.update(self.user_primitive_overrides) - filtered = merged - return replace( - self, - primitives=filtered, - resolved_deploy_root=resolved_root, - ) - - if not self.user_supported: - return None - - new_root = self.user_root_dir or self.root_dir - - # Claude Code honors CLAUDE_CONFIG_DIR (default ~/.claude); mirror - # that at user scope so `apm install -g` lands where Claude reads. - if self.name == "claude": - import os - from pathlib import Path - - env = os.environ.get("CLAUDE_CONFIG_DIR", "").strip() - if env: - # ``resolve`` collapses ``..`` so traversal segments cannot - # leak into ``root_dir`` and escape ``project_root / root_dir``. - abs_path = Path(env).expanduser().resolve(strict=False) - home = Path.home().resolve(strict=False) - try: - # Keep ``root_dir`` home-relative so cleanup prefix matching holds. - new_root = abs_path.relative_to(home).as_posix() - except ValueError: - # Fallback: when CLAUDE_CONFIG_DIR points outside $HOME we - # store an absolute path. ``pathlib.Path / `` is - # ```` so deploy + cleanup write to the right - # place. Caveat: the lockfile path translator - # (``install/services._deployed_path_entry``) calls - # ``relative_to(project_root)`` and raises ``RuntimeError`` - # for out-of-tree paths that are not dynamic-root targets. - # Today this is unreachable because user-scope CLAUDE - # installs do not flow through that translator, but any - # future refactor that lockfiles user-scope deploys must - # treat absolute ``root_dir`` as a dynamic-root case. - new_root = str(abs_path) - - if self.unsupported_user_primitives: - filtered = { - k: v - for k, v in self.primitives.items() - if k not in self.unsupported_user_primitives - } - else: - filtered = self.primitives - - if self.user_primitive_overrides: - merged = dict(filtered) - merged.update(self.user_primitive_overrides) - filtered = merged - - return replace(self, root_dir=new_root, primitives=filtered) - +from .target_profile import RULE_FORMATS as RULE_FORMATS +from .target_profile import PrimitiveMapping, TargetProfile # ------------------------------------------------------------------ # Runtime -> canonical target alias map From bccf3eea5b8f46a3c02129953831f58f2873c239 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 03:51:29 +0200 Subject: [PATCH 12/21] refactor(integration): tighten mcp_integrator_install.py complexity Decompose _resolve_target_runtimes (C901 50, PLR0912 54, PLR0915 132) into in-module patch-safe helpers: _detect_installed_runtimes (+ fallback and a data-driven _runtime_opted_in predicate that collapses the parallel cursor/opencode/gemini/windsurf opt-in branches via _DIR_GATED_RUNTIMES), _intersect_script_runtimes, and _apply_user_scope_filter. find_runtime_binary stays a module global so the 38 monkeypatch sites remain valid. Fix _install_registry_group PLR0913 (13 args) with a cohesive _RegistryDepGroup parameter object (deps/names/dep_map), keeping the callable behaviour stable. All six Stage-2 thresholds now clean for this file; 785 lines (<800). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../integration/mcp_integrator_install.py | 393 +++++++++--------- 1 file changed, 205 insertions(+), 188 deletions(-) diff --git a/src/apm_cli/integration/mcp_integrator_install.py b/src/apm_cli/integration/mcp_integrator_install.py index a7787044b..a205b9c0b 100644 --- a/src/apm_cli/integration/mcp_integrator_install.py +++ b/src/apm_cli/integration/mcp_integrator_install.py @@ -7,6 +7,7 @@ from __future__ import annotations import builtins +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any @@ -17,12 +18,28 @@ if TYPE_CHECKING: from apm_cli.core.scope import InstallScope +# Opt-in runtimes gated solely on the presence of a project marker directory. +# Data-driven so the parallel detection branches collapse into one loop. +_DIR_GATED_RUNTIMES: dict[str, str] = { + "cursor": ".cursor", + "opencode": ".opencode", + "gemini": ".gemini", + "windsurf": ".windsurf", +} + + +@dataclass(frozen=True) +class _RegistryDepGroup: + """One group of registry deps sharing a single target registry endpoint.""" + + deps: list + names: list + dep_map: dict + def _install_registry_group( operations: Any, - group_dep_names: list, - group_dep_map: dict, - group_deps: list, + group: _RegistryDepGroup, target_runtimes: list, stored_mcp_configs: dict, servers_to_update: builtins.set, @@ -35,7 +52,7 @@ def _install_registry_group( ) -> int: """Process one group of registry deps through a single ``MCPServerOperations`` instance. - All deps in ``group_deps`` share the same target registry (either the + All deps in ``group.deps`` share the same target registry (either the default or a per-dep override URL). ``servers_to_update`` and ``successful_updates`` are mutated in-place; the function returns the number of servers newly configured or updated in this group. @@ -43,6 +60,10 @@ def _install_registry_group( # Lazy import: only available after MCPIntegrator finishes loading. from apm_cli.integration.mcp_integrator import MCPIntegrator + group_dep_names = group.names + group_dep_map = group.dep_map + group_deps = group.deps + configured_count = 0 # Early validation: check all servers exist in registry (fail-fast). @@ -184,6 +205,176 @@ def _install_registry_group( return configured_count +def _runtime_opted_in( + runtime_name: str, + project_root_path: Path, + is_vscode_available, + manager, +) -> bool: + """Decide whether a single runtime should be targeted for this project. + + Opt-in runtimes are gated on a project marker directory (or, for Claude, + a binary on PATH) so a host-wide install does not silently opt every + project into MCP writes. Plain runtimes fall back to availability probing. + """ + if runtime_name == "vscode": + return bool(is_vscode_available(project_root=project_root_path)) + if runtime_name in _DIR_GATED_RUNTIMES: + return (project_root_path / _DIR_GATED_RUNTIMES[runtime_name]).is_dir() + if runtime_name == "claude": + # Project marker OR `claude` on PATH (user-scope writes). + return (project_root_path / ".claude").is_dir() or ( + find_runtime_binary("claude") is not None + ) + if runtime_name == "intellij": + # JetBrains Copilot: the user-scope config dir is created on first run. + from apm_cli.adapters.client.intellij import _intellij_config_dir + + return _intellij_config_dir().is_dir() + return bool(manager.is_runtime_available(runtime_name)) + + +def _detect_installed_runtimes_fallback(project_root_path: Path, is_vscode_available) -> list[str]: + """Binary/marker-probe runtime detection when the ClientFactory stack is absent.""" + installed_runtimes = [rt for rt in ["copilot", "codex"] if find_runtime_binary(rt) is not None] + if is_vscode_available(project_root=project_root_path): + installed_runtimes.append("vscode") + for name, marker in _DIR_GATED_RUNTIMES.items(): + if (project_root_path / marker).is_dir(): + installed_runtimes.append(name) + if (project_root_path / ".claude").is_dir() or (find_runtime_binary("claude") is not None): + installed_runtimes.append("claude") + try: + from apm_cli.adapters.client.intellij import _intellij_config_dir + + if _intellij_config_dir().is_dir(): + installed_runtimes.append("intellij") + except (ImportError, ValueError): + # ValueError (PathTraversalError) when LOCALAPPDATA/XDG_DATA_HOME is + # misconfigured -- treat as "not installed" rather than crash. + pass + return installed_runtimes + + +def _detect_installed_runtimes(project_root_path: Path) -> list[str]: + """Discover all MCP-capable runtimes installed for ``project_root_path``.""" + from apm_cli.integration.mcp_integrator import _is_vscode_available + + try: + from apm_cli.factory import ClientFactory + from apm_cli.runtime.manager import RuntimeManager + + manager = RuntimeManager() + except ImportError: + return _detect_installed_runtimes_fallback(project_root_path, _is_vscode_available) + + installed_runtimes: list[str] = [] + for runtime_name in [ + "copilot", + "codex", + "vscode", + "cursor", + "opencode", + "gemini", + "windsurf", + "claude", + "intellij", + ]: + try: + if _runtime_opted_in(runtime_name, project_root_path, _is_vscode_available, manager): + ClientFactory.create_client(runtime_name) + installed_runtimes.append(runtime_name) + except (ValueError, ImportError): + continue + return installed_runtimes + + +def _intersect_script_runtimes( + installed_runtimes: list[str], + apm_config: dict | None, + verbose: bool, + logger, + console, +) -> list[str]: + """Narrow installed runtimes to those referenced in apm.yml scripts. + + With no script references, all installed runtimes are targeted. + """ + from apm_cli.integration.mcp_integrator import MCPIntegrator + + script_runtimes = MCPIntegrator._detect_runtimes( + apm_config.get("scripts", {}) if apm_config else {} + ) + + if not script_runtimes: + target_runtimes = installed_runtimes + if target_runtimes: + if verbose: + logger.verbose_detail( + f"No scripts detected, using all installed runtimes: " + f"{', '.join(target_runtimes)}" + ) + else: + logger.warning("No MCP-compatible runtimes installed") + logger.progress("Install a runtime with: apm runtime setup copilot") + return target_runtimes + + target_runtimes = [rt for rt in installed_runtimes if rt in script_runtimes] + if verbose: + if console: + console.print(f"| [cyan]{STATUS_SYMBOLS['info']} Runtime Detection[/cyan]") + console.print(f"| +- Installed: {', '.join(installed_runtimes)}") + console.print(f"| +- Used in scripts: {', '.join(script_runtimes)}") + if target_runtimes: + console.print( + f"| +- Target: {', '.join(target_runtimes)} (available + used in scripts)" + ) + console.print("|") + else: + logger.verbose_detail(f"Installed runtimes: {', '.join(installed_runtimes)}") + logger.verbose_detail(f"Script runtimes: {', '.join(script_runtimes)}") + if target_runtimes: + logger.verbose_detail(f"Target runtimes: {', '.join(target_runtimes)}") + if not target_runtimes: + logger.warning("Scripts reference runtimes that are not installed") + logger.progress("Install missing runtimes with: apm runtime setup ") + return target_runtimes + + +def _apply_user_scope_filter(target_runtimes: list[str], scope, logger) -> list[str] | None: + """At USER scope, keep only runtimes that support global installation.""" + from apm_cli.core.scope import InstallScope as _IS + + if scope is not _IS.USER: + return target_runtimes + + from apm_cli.factory import ClientFactory as _CF + + pre_filter = list(target_runtimes) + filtered_runtimes = [] + for rt in target_runtimes: + try: + client = _CF.create_client(rt) + except ValueError: + continue + if client.supports_user_scope: + filtered_runtimes.append(rt) + skipped = set(pre_filter) - set(filtered_runtimes) + if skipped: + msg = ( + f"Skipped workspace-only runtimes at user scope: " + f"{', '.join(sorted(skipped))}" + f" -- omit --global to install these" + ) + logger.warning(msg) + if not filtered_runtimes: + logger.warning( + "No runtimes support user-scope MCP installation (supported: copilot, codex, gemini)" + ) + return None + return filtered_runtimes + + def _resolve_target_runtimes( runtime: str | None, exclude: str | None, @@ -202,10 +393,7 @@ def _resolve_target_runtimes( when the caller should immediately return 0 (e.g. all runtimes excluded, no user-scope-capable runtimes available). """ - from apm_cli.integration.mcp_integrator import ( - MCPIntegrator, - _is_vscode_available, - ) + from apm_cli.integration.mcp_integrator import MCPIntegrator if runtime: # Single runtime mode - skip auto-discovery entirely. @@ -224,155 +412,11 @@ def _resolve_target_runtimes( except Exception: apm_config = None - # Step 1: Get all installed runtimes on the system - try: - from apm_cli.factory import ClientFactory - from apm_cli.runtime.manager import RuntimeManager - - manager = RuntimeManager() - installed_runtimes: list[str] = [] - - for runtime_name in [ - "copilot", - "codex", - "vscode", - "cursor", - "opencode", - "gemini", - "windsurf", - "claude", - "intellij", - ]: - try: - if runtime_name == "vscode": - if _is_vscode_available(project_root=project_root_path): - ClientFactory.create_client(runtime_name) - installed_runtimes.append(runtime_name) - elif runtime_name == "cursor": - # Cursor is opt-in: only target when .cursor/ exists - if (project_root_path / ".cursor").is_dir(): - ClientFactory.create_client(runtime_name) - installed_runtimes.append(runtime_name) - elif runtime_name == "opencode": - # OpenCode is opt-in: only target when .opencode/ exists - if (project_root_path / ".opencode").is_dir(): - ClientFactory.create_client(runtime_name) - installed_runtimes.append(runtime_name) - elif runtime_name == "gemini": - # Gemini CLI is opt-in: only target when .gemini/ exists - if (project_root_path / ".gemini").is_dir(): - ClientFactory.create_client(runtime_name) - installed_runtimes.append(runtime_name) - elif runtime_name == "windsurf": - # Windsurf is opt-in: only target when .windsurf/ exists - if (project_root_path / ".windsurf").is_dir(): - ClientFactory.create_client(runtime_name) - installed_runtimes.append(runtime_name) - elif runtime_name == "claude": - # Claude Code is opt-in: target when .claude/ exists - # in the project (project-scope writes) OR when the - # `claude` binary is on PATH (user-scope writes). - # The PATH check is the gate that prevents the - # adapter from writing to ~/.claude.json on hosts - # where Claude Code was never installed. - if (project_root_path / ".claude").is_dir() or ( - find_runtime_binary("claude") is not None - ): - ClientFactory.create_client(runtime_name) - installed_runtimes.append(runtime_name) - elif runtime_name == "intellij": - # JetBrains Copilot is opt-in: target when the - # user-scope config directory already exists. This - # directory is created by the JetBrains Copilot - # plugin on first run, so its presence reliably - # signals that the plugin is installed. - from apm_cli.adapters.client.intellij import _intellij_config_dir - - if _intellij_config_dir().is_dir(): - ClientFactory.create_client(runtime_name) - installed_runtimes.append(runtime_name) - else: # noqa: PLR5501 - if manager.is_runtime_available(runtime_name): - ClientFactory.create_client(runtime_name) - installed_runtimes.append(runtime_name) - except (ValueError, ImportError): - continue - except ImportError: - installed_runtimes = [ - rt for rt in ["copilot", "codex"] if find_runtime_binary(rt) is not None - ] - # VS Code: check binary on PATH or .vscode/ directory presence - if _is_vscode_available(project_root=project_root_path): - installed_runtimes.append("vscode") - # Cursor is directory-presence based, not binary-based - if (project_root_path / ".cursor").is_dir(): - installed_runtimes.append("cursor") - # OpenCode is directory-presence based - if (project_root_path / ".opencode").is_dir(): - installed_runtimes.append("opencode") - # Gemini CLI is directory-presence based - if (project_root_path / ".gemini").is_dir(): - installed_runtimes.append("gemini") - # Windsurf is directory-presence based - if (project_root_path / ".windsurf").is_dir(): - installed_runtimes.append("windsurf") - # Claude Code: directory-presence OR binary-on-PATH - if (project_root_path / ".claude").is_dir() or ( - find_runtime_binary("claude") is not None - ): - installed_runtimes.append("claude") - # JetBrains Copilot: user-scope config directory presence - try: - from apm_cli.adapters.client.intellij import _intellij_config_dir - - if _intellij_config_dir().is_dir(): - installed_runtimes.append("intellij") - except (ImportError, ValueError): - # ValueError (PathTraversalError) when LOCALAPPDATA/XDG_DATA_HOME - # is misconfigured -- treat as "not installed" rather than crash. - pass - - # Step 2: Get runtimes referenced in apm.yml scripts - script_runtimes = MCPIntegrator._detect_runtimes( - apm_config.get("scripts", {}) if apm_config else {} + installed_runtimes = _detect_installed_runtimes(project_root_path) + target_runtimes = _intersect_script_runtimes( + installed_runtimes, apm_config, verbose, logger, console ) - # Step 3: Target runtimes BOTH installed AND referenced in scripts - if script_runtimes: - target_runtimes = [rt for rt in installed_runtimes if rt in script_runtimes] - - if verbose: - if console: - console.print(f"| [cyan]{STATUS_SYMBOLS['info']} Runtime Detection[/cyan]") - console.print(f"| +- Installed: {', '.join(installed_runtimes)}") - console.print(f"| +- Used in scripts: {', '.join(script_runtimes)}") - if target_runtimes: - console.print( - f"| +- Target: {', '.join(target_runtimes)} " - f"(available + used in scripts)" - ) - console.print("|") - else: - logger.verbose_detail(f"Installed runtimes: {', '.join(installed_runtimes)}") - logger.verbose_detail(f"Script runtimes: {', '.join(script_runtimes)}") - if target_runtimes: - logger.verbose_detail(f"Target runtimes: {', '.join(target_runtimes)}") - - if not target_runtimes: - logger.warning("Scripts reference runtimes that are not installed") - logger.progress("Install missing runtimes with: apm runtime setup ") - else: - target_runtimes = installed_runtimes - if target_runtimes: - if verbose: - logger.verbose_detail( - f"No scripts detected, using all installed runtimes: " - f"{', '.join(target_runtimes)}" - ) - else: - logger.warning("No MCP-compatible runtimes installed") - logger.progress("Install a runtime with: apm runtime setup copilot") - # Surface auto-detected runtimes in non-verbose plain-logger mode so # users get a signal about what `apm install --mcp` is targeting -- # notably the machine-scoped JetBrains (intellij) runtime, which is @@ -413,36 +457,7 @@ def _resolve_target_runtimes( # Scope filtering: at USER scope, keep only global-capable runtimes. # Applied after both explicit --runtime and auto-discovery paths. - from apm_cli.core.scope import InstallScope as _IS - - if scope is _IS.USER: - from apm_cli.factory import ClientFactory as _CF - - pre_filter = list(target_runtimes) - filtered_runtimes = [] - for rt in target_runtimes: - try: - client = _CF.create_client(rt) - except ValueError: - continue - if client.supports_user_scope: - filtered_runtimes.append(rt) - target_runtimes = filtered_runtimes - skipped = set(pre_filter) - set(target_runtimes) - if skipped: - msg = ( - f"Skipped workspace-only runtimes at user scope: " - f"{', '.join(sorted(skipped))}" - f" -- omit --global to install these" - ) - logger.warning(msg) - if not target_runtimes: - logger.warning( - "No runtimes support user-scope MCP installation (supported: copilot, codex, gemini)" - ) - return None - - return target_runtimes + return _apply_user_scope_filter(target_runtimes, scope, logger) def _install_self_defined_deps( @@ -728,9 +743,11 @@ def run_mcp_install( operations = MCPServerOperations(registry_url=group_registry_url) configured_count += _install_registry_group( operations=operations, - group_dep_names=group_dep_names, - group_dep_map=group_dep_map, - group_deps=group_deps_list, + group=_RegistryDepGroup( + deps=group_deps_list, + names=group_dep_names, + dep_map=group_dep_map, + ), target_runtimes=target_runtimes, stored_mcp_configs=stored_mcp_configs, servers_to_update=servers_to_update, From d8f0bc5c17b8b77495279be85ee17a7baf8e76a2 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 04:08:44 +0200 Subject: [PATCH 13/21] refactor(integration): split mcp_integrator.py under 800-line budget Extract the JSON/TOML/Claude config-cleanup trio into a new mcp_config_clean.py, and collect_transitive / _install_for_runtime / _gate_project_scoped_runtimes bodies into a new mcp_runtime_ops.py. MCPIntegrator keeps thin delegating staticmethods so the heavily monkeypatched MCPIntegrator. call/patch surface is unchanged; patched module globals (_rich_success, Path) are routed back through mcp_integrator from the siblings so the patch targets stay honored. mcp_integrator.py 1150 -> 765 lines (file-length offender cleared). File-length backlog 30 -> 29. Issue: #1078 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/integration/mcp_config_clean.py | 160 +++++++ src/apm_cli/integration/mcp_integrator.py | 465 ++------------------ src/apm_cli/integration/mcp_runtime_ops.py | 335 ++++++++++++++ 3 files changed, 535 insertions(+), 425 deletions(-) create mode 100644 src/apm_cli/integration/mcp_config_clean.py create mode 100644 src/apm_cli/integration/mcp_runtime_ops.py diff --git a/src/apm_cli/integration/mcp_config_clean.py b/src/apm_cli/integration/mcp_config_clean.py new file mode 100644 index 000000000..331076c4b --- /dev/null +++ b/src/apm_cli/integration/mcp_config_clean.py @@ -0,0 +1,160 @@ +"""JSON / TOML / Claude MCP config cleanup helpers. + +Extracted from :mod:`apm_cli.integration.mcp_integrator` to keep that module +under the file-length budget. Removal notices route ``_rich_success`` back +through ``mcp_integrator`` so the module-level patch target +``apm_cli.integration.mcp_integrator._rich_success`` stays honored. +""" + +import builtins +import json +import logging +from pathlib import Path + +_log = logging.getLogger(__name__) + + +def _emit_rich_success(msg: str) -> None: + """Emit a rich success notice via the (patchable) mcp_integrator helper.""" + from apm_cli.integration import mcp_integrator as _mi + + _mi._rich_success(msg, symbol="check") + + +def _clean_json_mcp_config( + config_path: Path, + stale_names: builtins.set, + logger, + label: str, + servers_key: str = "mcpServers", + trailing_newline: bool = False, + use_rich: bool = False, +) -> int: + """Remove stale entries from a JSON-based MCP config file. + + Args: + config_path: Path to the JSON config file. + stale_names: Set of server names to remove (expanded form). + logger: Command logger for progress messages. + label: Human-readable config label used in log messages. + servers_key: Key under which MCP servers are stored (default: ``"mcpServers"``). + trailing_newline: When True, append a trailing newline after JSON serialisation. + use_rich: When True, emit removal notices via ``_rich_success``; otherwise use + ``logger.progress``. + + Returns: + Number of entries removed. + """ + if not config_path.exists(): + return 0 + try: + config = json.loads(config_path.read_text(encoding="utf-8")) + servers = config.get(servers_key, {}) + removed = [n for n in stale_names if n in servers] + for name in removed: + del servers[name] + if removed: + text = json.dumps(config, indent=2) + if trailing_newline: + text += "\n" + config_path.write_text(text, encoding="utf-8") + for name in removed: + msg = f"Removed stale MCP server '{name}' from {label}" + if use_rich: + _emit_rich_success(msg) + else: + logger.progress(msg) + return len(removed) + except Exception: + _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) + return 0 + + +def _clean_toml_mcp_config( + config_path: Path, + stale_names: builtins.set, + label: str, + logger=None, + use_rich: bool = True, +) -> int: + """Remove stale entries from a TOML-based MCP config file. + + Args: + config_path: Path to the TOML config file. + stale_names: Set of server names to remove (expanded form). + label: Human-readable config label used in log messages. + logger: Optional command logger for progress messages. When provided + and *use_rich* is False, removal notices use ``logger.progress``. + use_rich: When True (default), emit removal notices via ``_rich_success``; + otherwise use ``logger.progress``. + + Returns: + Number of entries removed. + """ + if not config_path.exists(): + return 0 + try: + import toml as _toml + + config = _toml.loads(config_path.read_text(encoding="utf-8")) + servers = config.get("mcp_servers", {}) + removed = [n for n in stale_names if n in servers] + for name in removed: + del servers[name] + if removed: + config_path.write_text(_toml.dumps(config), encoding="utf-8") + for name in removed: + msg = f"Removed stale MCP server '{name}' from {label}" + if use_rich: + _emit_rich_success(msg) + elif logger is not None: + logger.progress(msg) + return len(removed) + except Exception: + _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) + return 0 + + +def _clean_claude_config( + config_path: Path, + stale_names: builtins.set, + logger, + is_user_scope: bool = False, +) -> int: + """Remove stale entries from a Claude Code JSON config file. + + Handles both the project-level ``.mcp.json`` and the user-level + ``~/.claude.json``, which share the same JSON structure but differ in + scope-validation requirements and log labels. + + Args: + config_path: Path to the Claude JSON config file. + stale_names: Set of server names to remove (expanded form). + logger: Command logger for progress messages. + is_user_scope: When True, validates that the top-level config is a dict + (``~/.claude.json`` guard) and uses the user-scope log label. + + Returns: + Number of entries removed. + """ + label = "~/.claude.json" if is_user_scope else ".mcp.json" + if not config_path.exists(): + return 0 + try: + config = json.loads(config_path.read_text(encoding="utf-8")) + if is_user_scope and not isinstance(config, dict): + return 0 + servers = config.get("mcpServers", {}) + if not isinstance(servers, dict): + servers = {} + removed = [n for n in stale_names if n in servers] + for name in removed: + del servers[name] + if removed: + config_path.write_text(json.dumps(config, indent=2) + "\n", encoding="utf-8") + for name in removed: + logger.progress(f"Removed stale MCP server '{name}' from {label}") + return len(removed) + except Exception: + _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) + return 0 diff --git a/src/apm_cli/integration/mcp_integrator.py b/src/apm_cli/integration/mcp_integrator.py index 162fb9e07..d37ebe1eb 100644 --- a/src/apm_cli/integration/mcp_integrator.py +++ b/src/apm_cli/integration/mcp_integrator.py @@ -11,7 +11,6 @@ import builtins import copy -import json import logging import re import shutil @@ -21,12 +20,19 @@ from apm_cli.core.null_logger import NullCommandLogger from apm_cli.deps.lockfile import LockFile, get_lockfile_path -from apm_cli.integration._shared import deduplicate_deps, resolve_locked_apm_yml_paths +from apm_cli.integration._shared import deduplicate_deps +from apm_cli.integration.mcp_config_clean import ( + _clean_claude_config as _clean_claude_config, +) +from apm_cli.integration.mcp_config_clean import ( + _clean_json_mcp_config as _clean_json_mcp_config, +) +from apm_cli.integration.mcp_config_clean import ( + _clean_toml_mcp_config as _clean_toml_mcp_config, +) from apm_cli.runtime.utils import find_runtime_binary from apm_cli.utils.console import ( _get_console, # noqa: F401 -- re-exported; mcp_integrator_install imports this via lazy import - _rich_error, - _rich_info, _rich_success, ) @@ -50,145 +56,6 @@ def _is_vscode_available(project_root: Path | str | None = None) -> bool: return shutil.which("code") is not None or (root / ".vscode").is_dir() -def _clean_json_mcp_config( - config_path: Path, - stale_names: builtins.set, - logger, - label: str, - servers_key: str = "mcpServers", - trailing_newline: bool = False, - use_rich: bool = False, -) -> int: - """Remove stale entries from a JSON-based MCP config file. - - Args: - config_path: Path to the JSON config file. - stale_names: Set of server names to remove (expanded form). - logger: Command logger for progress messages. - label: Human-readable config label used in log messages. - servers_key: Key under which MCP servers are stored (default: ``"mcpServers"``). - trailing_newline: When True, append a trailing newline after JSON serialisation. - use_rich: When True, emit removal notices via ``_rich_success``; otherwise use - ``logger.progress``. - - Returns: - Number of entries removed. - """ - if not config_path.exists(): - return 0 - try: - config = json.loads(config_path.read_text(encoding="utf-8")) - servers = config.get(servers_key, {}) - removed = [n for n in stale_names if n in servers] - for name in removed: - del servers[name] - if removed: - text = json.dumps(config, indent=2) - if trailing_newline: - text += "\n" - config_path.write_text(text, encoding="utf-8") - for name in removed: - msg = f"Removed stale MCP server '{name}' from {label}" - if use_rich: - _rich_success(msg, symbol="check") - else: - logger.progress(msg) - return len(removed) - except Exception: - _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) - return 0 - - -def _clean_toml_mcp_config( - config_path: Path, - stale_names: builtins.set, - label: str, - logger=None, - use_rich: bool = True, -) -> int: - """Remove stale entries from a TOML-based MCP config file. - - Args: - config_path: Path to the TOML config file. - stale_names: Set of server names to remove (expanded form). - label: Human-readable config label used in log messages. - logger: Optional command logger for progress messages. When provided - and *use_rich* is False, removal notices use ``logger.progress``. - use_rich: When True (default), emit removal notices via ``_rich_success``; - otherwise use ``logger.progress``. - - Returns: - Number of entries removed. - """ - if not config_path.exists(): - return 0 - try: - import toml as _toml - - config = _toml.loads(config_path.read_text(encoding="utf-8")) - servers = config.get("mcp_servers", {}) - removed = [n for n in stale_names if n in servers] - for name in removed: - del servers[name] - if removed: - config_path.write_text(_toml.dumps(config), encoding="utf-8") - for name in removed: - msg = f"Removed stale MCP server '{name}' from {label}" - if use_rich: - _rich_success(msg, symbol="check") - elif logger is not None: - logger.progress(msg) - return len(removed) - except Exception: - _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) - return 0 - - -def _clean_claude_config( - config_path: Path, - stale_names: builtins.set, - logger, - is_user_scope: bool = False, -) -> int: - """Remove stale entries from a Claude Code JSON config file. - - Handles both the project-level ``.mcp.json`` and the user-level - ``~/.claude.json``, which share the same JSON structure but differ in - scope-validation requirements and log labels. - - Args: - config_path: Path to the Claude JSON config file. - stale_names: Set of server names to remove (expanded form). - logger: Command logger for progress messages. - is_user_scope: When True, validates that the top-level config is a dict - (``~/.claude.json`` guard) and uses the user-scope log label. - - Returns: - Number of entries removed. - """ - label = "~/.claude.json" if is_user_scope else ".mcp.json" - if not config_path.exists(): - return 0 - try: - config = json.loads(config_path.read_text(encoding="utf-8")) - if is_user_scope and not isinstance(config, dict): - return 0 - servers = config.get("mcpServers", {}) - if not isinstance(servers, dict): - servers = {} - removed = [n for n in stale_names if n in servers] - for name in removed: - del servers[name] - if removed: - config_path.write_text(json.dumps(config, indent=2) + "\n", encoding="utf-8") - for name in removed: - logger.progress(f"Removed stale MCP server '{name}' from {label}") - return len(removed) - except Exception: - _log.debug("Failed to clean stale MCP servers from %s", label, exc_info=True) - return 0 - - class MCPIntegrator: """MCP lifecycle orchestrator -- dependency resolution, installation, and cleanup. @@ -209,66 +76,16 @@ def collect_transitive( logger=None, diagnostics=None, ) -> list: - """Collect MCP dependencies from resolved APM packages listed in apm.lock. - - Only scans apm.yml files for packages present in apm.lock to avoid - picking up stale/orphaned packages from previous installs. - Falls back to scanning all apm.yml files if no lock file is available. - - Self-defined servers (registry: false) from direct dependencies - (depth == 1) are auto-trusted. Self-defined servers from transitive - dependencies (depth > 1) are skipped with a warning unless - *trust_private* is True. - """ - if logger is None: - logger = NullCommandLogger() - if not apm_modules_dir.exists(): - return [] + """Collect MCP deps from resolved packages (see mcp_runtime_ops.collect_transitive).""" + from apm_cli.integration import mcp_runtime_ops - from apm_cli.models.apm_package import APMPackage - - # Build set of expected apm.yml paths from apm.lock - resolved, direct_paths = resolve_locked_apm_yml_paths(apm_modules_dir, lock_path) - apm_yml_paths = resolved if resolved is not None else apm_modules_dir.rglob("apm.yml") - - collected = [] - for apm_yml_path in apm_yml_paths: - try: - pkg = APMPackage.from_apm_yml(apm_yml_path) - mcp = pkg.get_mcp_dependencies() - if mcp: - is_direct = apm_yml_path.resolve() in direct_paths - for dep in mcp: - if hasattr(dep, "is_self_defined") and dep.is_self_defined: - if is_direct: - logger.progress( - f"Trusting direct dependency MCP '{dep.name}' from '{pkg.name}'" - ) - elif trust_private: - logger.progress( - f"Trusting self-defined MCP server '{dep.name}' " - f"from transitive package '{pkg.name}' (--trust-transitive-mcp)" - ) - else: - _trust_msg = ( - f"Transitive package '{pkg.name}' declares self-defined " - f"MCP server '{dep.name}' (registry: false). " - f"Re-declare it in your apm.yml or use --trust-transitive-mcp." - ) - if diagnostics: - diagnostics.warn(_trust_msg) - else: - logger.warning(_trust_msg) - continue - collected.append(dep) - except Exception: - _log.debug( - "Skipping package at %s: failed to parse apm.yml", - apm_yml_path, - exc_info=True, - ) - continue - return collected + return mcp_runtime_ops.collect_transitive( + apm_modules_dir, + lock_path=lock_path, + trust_private=trust_private, + logger=logger, + diagnostics=diagnostics, + ) # ------------------------------------------------------------------ # Deduplication @@ -854,78 +671,19 @@ def _install_for_runtime( user_scope: bool = False, logger=None, ) -> bool: - """Install MCP dependencies for a specific runtime. + """Install MCP deps for a runtime (see mcp_runtime_ops.install_for_runtime).""" + from apm_cli.integration import mcp_runtime_ops - Returns True if all deps were configured successfully, False otherwise. - """ - if logger is None: - logger = NullCommandLogger() - try: - from apm_cli.core.operations import install_package - - all_ok = True - for dep in mcp_deps: - logger.verbose_detail(f" Installing {dep}...") - try: - result = install_package( - runtime, - dep, - shared_env_vars=shared_env_vars, - server_info_cache=server_info_cache, - shared_runtime_vars=shared_runtime_vars, - project_root=project_root, - user_scope=user_scope, - ) - if result["failed"]: - logger.error(f" Failed to install {dep}") - all_ok = False - elif logger and runtime == "codex": - from apm_cli.factory import ClientFactory - - config_path = ClientFactory.create_client( - runtime, - project_root=project_root, - user_scope=user_scope, - ).get_config_path() - _log.debug("Codex config written to %s", config_path) - logger.verbose_detail(f" Codex config: {config_path}") - except Exception as install_error: - _log.debug( - "Failed to install MCP dep %s for runtime %s", - dep, - runtime, - exc_info=True, - ) - logger.error(f" Failed to install {dep}: {install_error}") - all_ok = False - - # Emit aggregated post-install diagnostics for runtimes that - # support runtime env-var substitution (currently Copilot CLI). - # Safe no-op for runtimes whose adapter doesn't aggregate state. - try: - if runtime == "copilot": - from apm_cli.adapters.client.copilot import CopilotClientAdapter - - CopilotClientAdapter.emit_install_run_summary() - except Exception: - _log.debug("Failed to emit install-run summary", exc_info=True) - - return all_ok - - except ImportError as e: - logger.warning(f"Core operations not available for runtime {runtime}: {e}") - logger.progress(f"Dependencies for {runtime}: {', '.join(mcp_deps)}") - return False - except ValueError as e: - logger.warning(f"Runtime {runtime} not supported: {e}") - logger.progress( - "Supported runtimes: vscode, copilot, codex, cursor, opencode, gemini, claude, windsurf, intellij, llm" - ) - return False - except Exception as e: - _log.debug("Unexpected error installing for runtime %s", runtime, exc_info=True) - logger.error(f"Error installing for runtime {runtime}: {e}") - return False + return mcp_runtime_ops.install_for_runtime( + runtime, + mcp_deps, + shared_env_vars=shared_env_vars, + server_info_cache=server_info_cache, + shared_runtime_vars=shared_runtime_vars, + project_root=project_root, + user_scope=user_scope, + logger=logger, + ) # ------------------------------------------------------------------ # Main orchestrator @@ -940,159 +698,16 @@ def _gate_project_scoped_runtimes( apm_config: dict | None, explicit_target: str | list[str] | None, ) -> list[str]: - """Filter *target_runtimes* against the project's active targets. - - UX parity with ``apm install`` for apm dependencies: the active - target set (explicit ``--target`` > ``targets:`` field > - directory-signal detection) is the whitelist for MCP writes. Any - runtime outside that set is skipped with an info line naming both - what was dropped and the active set, so users can audit the - decision input without re-reading apm.yml (#1335). - - Strict resolution model -- mirrors :func:`resolve_targets`, - the same call ``apm install`` uses - (``install/phases/targets.py:233``): - - - flag > yaml-targets > directory signals (no permissive - "fallback to copilot" greenfield default); - - no flag, no ``targets:``, and no harness-signal directory -> - :class:`NoHarnessError` (red ``[x]``, write nothing); - - multiple ambiguous signals with no disambiguation -> - :class:`AmbiguousHarnessError` (same fail-closed shape). - - ``explicit_target`` accepts ``str``, ``list[str]``, or a CSV - string (``"claude,copilot"``) -- the latter is produced by - legacy callers; it is normalized to a list before the resolver - is invoked so the canonical-name validator does not reject it as - one unknown token. - - A malformed ``targets:`` field (conflicting ``target:`` + - ``targets:``, ``targets: []``, or unknown canonical name) likewise - fails closed: nothing is written. - - Exit semantics differ deliberately from ``install/phases/targets.py``: - the canonical install phase calls ``raise SystemExit(2)`` when - resolution fails; this gate may be invoked mid-bundle (see - ``install/local_bundle_handler``) where a hard exit would corrupt - partial state, so we render the same red ``[x]`` voice and return - an empty list (fail-closed-continue). - - ``user_scope=True`` is a deliberate carve-out: user-scope writes - target ``~/.config`` paths the user owns globally, so the - project-level whitelist is irrelevant. Documented in the - consumer install-mcp-servers guide. - """ - if user_scope: - return target_runtimes - - from apm_cli.core.apm_yml import ( - ConflictingTargetsError, - EmptyTargetsListError, - UnknownTargetError, - parse_targets_field, - ) - from apm_cli.core.errors import ( - AmbiguousHarnessError, - NoHarnessError, - ) - from apm_cli.core.target_detection import resolve_targets - from apm_cli.integration.targets import RUNTIME_TO_CANONICAL_TARGET + """Filter runtimes against active project targets (see mcp_runtime_ops).""" + from apm_cli.integration import mcp_runtime_ops - # --- step 1: parse declared targets (fail-closed on any invalid form) - yaml_targets: list[str] | None = None - if apm_config: - try: - parsed = parse_targets_field(apm_config) - yaml_targets = parsed if parsed else None - except ( - ConflictingTargetsError, - EmptyTargetsListError, - UnknownTargetError, - ) as exc: - # Voice mirrors the canonical `apm install` skills phase - # (install/phases/targets.py:213): red [x] lead-with-outcome, - # then the structured error body. symbol="" suppresses the - # auto-prefix on the body because the exception text already - # begins with "[x] ..." (see core/errors.py). - _rich_error( - "Skipping all MCP config writes -- apm.yml 'targets' field is invalid.", - symbol="error", - ) - _rich_error(str(exc), symbol="") - _log.debug( - "parse_targets_field failed; failing closed (no MCP writes)", - exc_info=True, - ) - return [] - - # --- step 2: normalize CSV explicit_target sugar to a list ----- - # `_wire_bundle_mcp_servers` historically passes a CSV string; the - # canonical-name validator inside _resolve_targets_v2 would reject - # the whole CSV as one unknown token. Normalize first. - flag: str | list[str] | None - if isinstance(explicit_target, str) and "," in explicit_target: - flag = [t.strip() for t in explicit_target.split(",") if t.strip()] - else: - flag = explicit_target - - # Apply the runtime->canonical-target alias BEFORE passing the flag - # to resolve_targets. The canonical-name validator inside the - # resolver only knows about CANONICAL_TARGETS (claude/copilot/...); - # it rejects runtime aliases (vscode/agents) as unknown tokens. - # The MCP gate, however, must accept those aliases because users - # naturally type `--target vscode` for the VS Code Copilot runtime. - if flag is not None: - tokens = [flag] if isinstance(flag, str) else list(flag) - flag = [RUNTIME_TO_CANONICAL_TARGET.get(t, t) for t in tokens] - - # --- step 3: delegate to the canonical v2 resolver ------------- - # This is the same call the `apm install` skills phase makes at - # install/phases/targets.py:233. It enforces the strict - # flag > yaml > signals chain and raises NoHarnessError / - # AmbiguousHarnessError on greenfield / under-disambiguated - # projects -- the ASYMMETRY closed by this PR is that the gate - # used to silently fall back to [copilot] in those cases. - root = project_root or Path.cwd() - try: - resolved = resolve_targets(root, flag=flag, yaml_targets=yaml_targets) - except (NoHarnessError, AmbiguousHarnessError) as exc: - _rich_error( - "Skipping all MCP config writes -- could not resolve active targets.", - symbol="error", - ) - _rich_error(str(exc), symbol="") - _log.debug( - "resolve_targets failed; failing closed (no MCP writes)", - exc_info=True, - ) - return [] - - active = set(resolved.targets) - - # Runtime name "vscode" maps to canonical target "copilot" (same - # alias active_targets honors); shared table prevents drift with - # the alias resolution in integration/targets.py. - out = [rt for rt in target_runtimes if RUNTIME_TO_CANONICAL_TARGET.get(rt, rt) in active] - dropped = sorted(set(target_runtimes) - set(out)) - if dropped: - # Mirror the canonical `Targets: X (source: Y)` provenance shape - # (install/phases/targets.py:265, core/target_detection.py:777): - # double-space before the parenthetical. The "or ''" guard is - # defensive -- an empty active set is unreachable when - # _resolve_targets_v2 succeeded, but if a future contract change - # widens that contract we surface "" rather than render - # "(active targets: )" which reads as a renderer bug. - active_csv = ", ".join(sorted(active)) or "" - _rich_info( - f"Skipped MCP config for {', '.join(dropped)} (active targets: {active_csv})", - symbol="info", - ) - _log.debug( - "Active-targets gate dropped: %s (active=%s)", - dropped, - sorted(active), - ) - return out + return mcp_runtime_ops.gate_project_scoped_runtimes( + target_runtimes, + user_scope=user_scope, + project_root=project_root, + apm_config=apm_config, + explicit_target=explicit_target, + ) @staticmethod def install( diff --git a/src/apm_cli/integration/mcp_runtime_ops.py b/src/apm_cli/integration/mcp_runtime_ops.py new file mode 100644 index 000000000..6aac3eac8 --- /dev/null +++ b/src/apm_cli/integration/mcp_runtime_ops.py @@ -0,0 +1,335 @@ +"""Runtime-facing MCP operations extracted from :class:`MCPIntegrator`. + +``collect_transitive`` and ``install_for_runtime`` reference no module global +that is monkeypatched on ``mcp_integrator``; ``gate_project_scoped_runtimes`` +routes its single ``Path.cwd()`` through ``mcp_integrator`` so the patched +``Path`` is honored. ``MCPIntegrator`` keeps thin delegating staticmethods so +the heavily-patched ``MCPIntegrator.`` call/patch surface is unchanged. +""" + +import logging +from pathlib import Path + +from apm_cli.core.null_logger import NullCommandLogger +from apm_cli.integration._shared import resolve_locked_apm_yml_paths +from apm_cli.utils.console import _rich_error, _rich_info + +_log = logging.getLogger(__name__) + + +def collect_transitive( + apm_modules_dir: Path, + lock_path: Path | None = None, + trust_private: bool = False, + logger=None, + diagnostics=None, +) -> list: + """Collect MCP dependencies from resolved APM packages listed in apm.lock. + + Only scans apm.yml files for packages present in apm.lock to avoid + picking up stale/orphaned packages from previous installs. + Falls back to scanning all apm.yml files if no lock file is available. + + Self-defined servers (registry: false) from direct dependencies + (depth == 1) are auto-trusted. Self-defined servers from transitive + dependencies (depth > 1) are skipped with a warning unless + *trust_private* is True. + """ + if logger is None: + logger = NullCommandLogger() + if not apm_modules_dir.exists(): + return [] + + from apm_cli.models.apm_package import APMPackage + + # Build set of expected apm.yml paths from apm.lock + resolved, direct_paths = resolve_locked_apm_yml_paths(apm_modules_dir, lock_path) + apm_yml_paths = resolved if resolved is not None else apm_modules_dir.rglob("apm.yml") + + collected = [] + for apm_yml_path in apm_yml_paths: + try: + pkg = APMPackage.from_apm_yml(apm_yml_path) + mcp = pkg.get_mcp_dependencies() + if mcp: + is_direct = apm_yml_path.resolve() in direct_paths + for dep in mcp: + if hasattr(dep, "is_self_defined") and dep.is_self_defined: + if is_direct: + logger.progress( + f"Trusting direct dependency MCP '{dep.name}' from '{pkg.name}'" + ) + elif trust_private: + logger.progress( + f"Trusting self-defined MCP server '{dep.name}' " + f"from transitive package '{pkg.name}' (--trust-transitive-mcp)" + ) + else: + _trust_msg = ( + f"Transitive package '{pkg.name}' declares self-defined " + f"MCP server '{dep.name}' (registry: false). " + f"Re-declare it in your apm.yml or use --trust-transitive-mcp." + ) + if diagnostics: + diagnostics.warn(_trust_msg) + else: + logger.warning(_trust_msg) + continue + collected.append(dep) + except Exception: + _log.debug( + "Skipping package at %s: failed to parse apm.yml", + apm_yml_path, + exc_info=True, + ) + continue + return collected + + +def install_for_runtime( + runtime: str, + mcp_deps: list[str], + shared_env_vars: dict = None, # noqa: RUF013 + server_info_cache: dict = None, # noqa: RUF013 + shared_runtime_vars: dict = None, # noqa: RUF013 + project_root=None, + user_scope: bool = False, + logger=None, +) -> bool: + """Install MCP dependencies for a specific runtime. + + Returns True if all deps were configured successfully, False otherwise. + """ + if logger is None: + logger = NullCommandLogger() + try: + from apm_cli.core.operations import install_package + + all_ok = True + for dep in mcp_deps: + logger.verbose_detail(f" Installing {dep}...") + try: + result = install_package( + runtime, + dep, + shared_env_vars=shared_env_vars, + server_info_cache=server_info_cache, + shared_runtime_vars=shared_runtime_vars, + project_root=project_root, + user_scope=user_scope, + ) + if result["failed"]: + logger.error(f" Failed to install {dep}") + all_ok = False + elif logger and runtime == "codex": + from apm_cli.factory import ClientFactory + + config_path = ClientFactory.create_client( + runtime, + project_root=project_root, + user_scope=user_scope, + ).get_config_path() + _log.debug("Codex config written to %s", config_path) + logger.verbose_detail(f" Codex config: {config_path}") + except Exception as install_error: + _log.debug( + "Failed to install MCP dep %s for runtime %s", + dep, + runtime, + exc_info=True, + ) + logger.error(f" Failed to install {dep}: {install_error}") + all_ok = False + + # Emit aggregated post-install diagnostics for runtimes that + # support runtime env-var substitution (currently Copilot CLI). + # Safe no-op for runtimes whose adapter doesn't aggregate state. + try: + if runtime == "copilot": + from apm_cli.adapters.client.copilot import CopilotClientAdapter + + CopilotClientAdapter.emit_install_run_summary() + except Exception: + _log.debug("Failed to emit install-run summary", exc_info=True) + + return all_ok + + except ImportError as e: + logger.warning(f"Core operations not available for runtime {runtime}: {e}") + logger.progress(f"Dependencies for {runtime}: {', '.join(mcp_deps)}") + return False + except ValueError as e: + logger.warning(f"Runtime {runtime} not supported: {e}") + logger.progress( + "Supported runtimes: vscode, copilot, codex, cursor, opencode, gemini, claude, windsurf, intellij, llm" + ) + return False + except Exception as e: + _log.debug("Unexpected error installing for runtime %s", runtime, exc_info=True) + logger.error(f"Error installing for runtime {runtime}: {e}") + return False + + +def gate_project_scoped_runtimes( + target_runtimes: list[str], + *, + user_scope: bool, + project_root, + apm_config: dict | None, + explicit_target: str | list[str] | None, +) -> list[str]: + """Filter *target_runtimes* against the project's active targets. + + UX parity with ``apm install`` for apm dependencies: the active + target set (explicit ``--target`` > ``targets:`` field > + directory-signal detection) is the whitelist for MCP writes. Any + runtime outside that set is skipped with an info line naming both + what was dropped and the active set, so users can audit the + decision input without re-reading apm.yml (#1335). + + Strict resolution model -- mirrors :func:`resolve_targets`, + the same call ``apm install`` uses + (``install/phases/targets.py:233``): + + - flag > yaml-targets > directory signals (no permissive + "fallback to copilot" greenfield default); + - no flag, no ``targets:``, and no harness-signal directory -> + :class:`NoHarnessError` (red ``[x]``, write nothing); + - multiple ambiguous signals with no disambiguation -> + :class:`AmbiguousHarnessError` (same fail-closed shape). + + ``explicit_target`` accepts ``str``, ``list[str]``, or a CSV + string (``"claude,copilot"``) -- the latter is produced by + legacy callers; it is normalized to a list before the resolver + is invoked so the canonical-name validator does not reject it as + one unknown token. + + A malformed ``targets:`` field (conflicting ``target:`` + + ``targets:``, ``targets: []``, or unknown canonical name) likewise + fails closed: nothing is written. + + Exit semantics differ deliberately from ``install/phases/targets.py``: + the canonical install phase calls ``raise SystemExit(2)`` when + resolution fails; this gate may be invoked mid-bundle (see + ``install/local_bundle_handler``) where a hard exit would corrupt + partial state, so we render the same red ``[x]`` voice and return + an empty list (fail-closed-continue). + + ``user_scope=True`` is a deliberate carve-out: user-scope writes + target ``~/.config`` paths the user owns globally, so the + project-level whitelist is irrelevant. Documented in the + consumer install-mcp-servers guide. + """ + if user_scope: + return target_runtimes + + from apm_cli.core.apm_yml import ( + ConflictingTargetsError, + EmptyTargetsListError, + UnknownTargetError, + parse_targets_field, + ) + from apm_cli.core.errors import ( + AmbiguousHarnessError, + NoHarnessError, + ) + from apm_cli.core.target_detection import resolve_targets + from apm_cli.integration.targets import RUNTIME_TO_CANONICAL_TARGET + + # --- step 1: parse declared targets (fail-closed on any invalid form) + yaml_targets: list[str] | None = None + if apm_config: + try: + parsed = parse_targets_field(apm_config) + yaml_targets = parsed if parsed else None + except ( + ConflictingTargetsError, + EmptyTargetsListError, + UnknownTargetError, + ) as exc: + # Voice mirrors the canonical `apm install` skills phase + # (install/phases/targets.py:213): red [x] lead-with-outcome, + # then the structured error body. symbol="" suppresses the + # auto-prefix on the body because the exception text already + # begins with "[x] ..." (see core/errors.py). + _rich_error( + "Skipping all MCP config writes -- apm.yml 'targets' field is invalid.", + symbol="error", + ) + _rich_error(str(exc), symbol="") + _log.debug( + "parse_targets_field failed; failing closed (no MCP writes)", + exc_info=True, + ) + return [] + + # --- step 2: normalize CSV explicit_target sugar to a list ----- + # `_wire_bundle_mcp_servers` historically passes a CSV string; the + # canonical-name validator inside _resolve_targets_v2 would reject + # the whole CSV as one unknown token. Normalize first. + flag: str | list[str] | None + if isinstance(explicit_target, str) and "," in explicit_target: + flag = [t.strip() for t in explicit_target.split(",") if t.strip()] + else: + flag = explicit_target + + # Apply the runtime->canonical-target alias BEFORE passing the flag + # to resolve_targets. The canonical-name validator inside the + # resolver only knows about CANONICAL_TARGETS (claude/copilot/...); + # it rejects runtime aliases (vscode/agents) as unknown tokens. + # The MCP gate, however, must accept those aliases because users + # naturally type `--target vscode` for the VS Code Copilot runtime. + if flag is not None: + tokens = [flag] if isinstance(flag, str) else list(flag) + flag = [RUNTIME_TO_CANONICAL_TARGET.get(t, t) for t in tokens] + + # --- step 3: delegate to the canonical v2 resolver ------------- + # This is the same call the `apm install` skills phase makes at + # install/phases/targets.py:233. It enforces the strict + # flag > yaml > signals chain and raises NoHarnessError / + # AmbiguousHarnessError on greenfield / under-disambiguated + # projects -- the ASYMMETRY closed by this PR is that the gate + # used to silently fall back to [copilot] in those cases. + from apm_cli.integration import mcp_integrator as _mi + + root = project_root or _mi.Path.cwd() + try: + resolved = resolve_targets(root, flag=flag, yaml_targets=yaml_targets) + except (NoHarnessError, AmbiguousHarnessError) as exc: + _rich_error( + "Skipping all MCP config writes -- could not resolve active targets.", + symbol="error", + ) + _rich_error(str(exc), symbol="") + _log.debug( + "resolve_targets failed; failing closed (no MCP writes)", + exc_info=True, + ) + return [] + + active = set(resolved.targets) + + # Runtime name "vscode" maps to canonical target "copilot" (same + # alias active_targets honors); shared table prevents drift with + # the alias resolution in integration/targets.py. + out = [rt for rt in target_runtimes if RUNTIME_TO_CANONICAL_TARGET.get(rt, rt) in active] + dropped = sorted(set(target_runtimes) - set(out)) + if dropped: + # Mirror the canonical `Targets: X (source: Y)` provenance shape + # (install/phases/targets.py:265, core/target_detection.py:777): + # double-space before the parenthetical. The "or ''" guard is + # defensive -- an empty active set is unreachable when + # _resolve_targets_v2 succeeded, but if a future contract change + # widens that contract we surface "" rather than render + # "(active targets: )" which reads as a renderer bug. + active_csv = ", ".join(sorted(active)) or "" + _rich_info( + f"Skipped MCP config for {', '.join(dropped)} (active targets: {active_csv})", + symbol="info", + ) + _log.debug( + "Active-targets gate dropped: %s (active=%s)", + dropped, + sorted(active), + ) + return out From 9bd613c2fc824341d94d230a205287e3f5d936f5 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 04:54:03 +0200 Subject: [PATCH 14/21] refactor(deps): split apm_resolver/plugin_parser/bare_cache under 800-line budget Strangler Stage 2 (#1078), Commit 3a: tighten the deps subsystem's lower-risk files toward the 800-line guardrail and the final complexity thresholds, preserving the test monkeypatch surface. File-length offenders cleared (siblings extracted, delegating wrappers keep patch targets intact): - apm_resolver.py 1028 -> 777 (new apm_resolver_helpers.py) - plugin_parser.py 916 -> 632 (new plugin_server_helpers.py) - bare_cache.py 805 -> 724 (new bare_cache_msg.py) Complexity-only (PLR0911) fixes, behaviour-preserving: - git_reference_resolver.py: extract _call_commits_api; merge guard - lockfile.py: combine two equality guards into one return - registry/outdated.py: merge early-exit guards, preserve source strings bare_cache failure-message builder uses a genuine parameter object: build_clone_failure_message now takes a required CloneFailureContext (bundling the six failure-classifier flags) instead of flat kwargs; clone_engine.py (sole caller) constructs it at the call site. No **kwargs threshold-gaming; 13 call args -> 8 params. backlog 29 -> 26; R0801 10.00/10; 16645 unit+acceptance pass; 2776 targeted integration pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/deps/apm_resolver.py | 355 ++--------- src/apm_cli/deps/apm_resolver_helpers.py | 374 +++++++++++ src/apm_cli/deps/bare_cache.py | 91 +-- src/apm_cli/deps/bare_cache_msg.py | 125 ++++ src/apm_cli/deps/clone_engine.py | 16 +- src/apm_cli/deps/git_reference_resolver.py | 19 +- src/apm_cli/deps/lockfile.py | 10 +- src/apm_cli/deps/plugin_parser.py | 700 ++++++--------------- src/apm_cli/deps/plugin_server_helpers.py | 358 +++++++++++ src/apm_cli/deps/registry/outdated.py | 33 +- 10 files changed, 1161 insertions(+), 920 deletions(-) create mode 100644 src/apm_cli/deps/apm_resolver_helpers.py create mode 100644 src/apm_cli/deps/bare_cache_msg.py create mode 100644 src/apm_cli/deps/plugin_server_helpers.py diff --git a/src/apm_cli/deps/apm_resolver.py b/src/apm_cli/deps/apm_resolver.py index 81e551077..ad4004c66 100644 --- a/src/apm_cli/deps/apm_resolver.py +++ b/src/apm_cli/deps/apm_resolver.py @@ -1,16 +1,42 @@ """APM dependency resolution engine with recursive resolution and conflict detection.""" -import inspect import logging -import os import threading from collections import deque from concurrent.futures import ThreadPoolExecutor -from dataclasses import replace from pathlib import Path from typing import Optional, Protocol from ..models.apm_package import APMPackage, DependencyReference +from .apm_resolver_helpers import _DEFAULT_RESOLVE_PARALLEL as _DEFAULT_RESOLVE_PARALLEL +from .apm_resolver_helpers import ( + _compute_dep_source_path as _compute_dep_source_path, +) +from .apm_resolver_helpers import ( + _create_resolution_summary as _create_resolution_summary, +) +from .apm_resolver_helpers import ( + _detect_circular_deps, + _expand_parent_repo_decl, + _flatten_dependencies, + _remote_parent_eligible, + _validate_dependency_reference, +) +from .apm_resolver_helpers import ( + _download_dedup_key as _download_dedup_key, +) +from .apm_resolver_helpers import ( + _effective_base_dir as _effective_base_dir, +) +from .apm_resolver_helpers import ( + _is_remote_parent as _is_remote_parent, +) +from .apm_resolver_helpers import ( + _resolve_max_parallel as _resolve_max_parallel, +) +from .apm_resolver_helpers import ( + _signature_accepts_parent_pkg as _signature_accepts_parent_pkg, +) from .dependency_graph import ( CircularRef, DependencyGraph, @@ -22,15 +48,6 @@ _logger = logging.getLogger(__name__) -# Default worker pool size for the level-batched BFS download phase. -# Parallel resolution is the CENTRAL execution model (uv-inspired); -# the ``APM_RESOLVE_PARALLEL`` env var exists solely as a diagnostic / -# parity-testing knob (e.g. ``APM_RESOLVE_PARALLEL=1 apm install`` to -# reproduce legacy sequential ordering for diff-debugging). It is NOT -# a user-facing feature toggle. -_DEFAULT_RESOLVE_PARALLEL = 4 - - # Type alias for the download callback. # Takes (dep_ref, apm_modules_dir, parent_chain, parent_pkg) and returns the # install path if successful. ``parent_chain`` is a human-readable breadcrumb @@ -118,51 +135,13 @@ def __init__( @staticmethod def _resolve_max_parallel(explicit: int | None) -> int: - """Compute effective worker count for level-batched parallel BFS. - - Parallel is the default and central execution model. The - override exists for parity testing (``APM_RESOLVE_PARALLEL=1``) - and CI diagnostics, not as a user-facing knob. - - Order of precedence: - 1. Explicit ``max_parallel`` ctor arg. - 2. ``APM_RESOLVE_PARALLEL`` env var (diagnostic/parity knob). - 3. ``_DEFAULT_RESOLVE_PARALLEL``. - - Always coerced to ``>= 1`` so the executor never gets a zero - or negative ``max_workers``. - """ - if explicit is not None: - return max(1, int(explicit)) - env = os.environ.get("APM_RESOLVE_PARALLEL", "").strip() - if env: - try: - return max(1, int(env)) - except ValueError: - _logger.debug("Ignoring invalid APM_RESOLVE_PARALLEL=%r", env) - return _DEFAULT_RESOLVE_PARALLEL + """Compute effective worker count; see :func:`_resolve_max_parallel`.""" + return _resolve_max_parallel(explicit) @staticmethod def _signature_accepts_parent_pkg(callback) -> bool: - """Return True if ``callback`` declares a ``parent_pkg`` parameter - (or accepts ``**kwargs``). - - Falls back to False if the signature can't be introspected (e.g. C - extensions, builtins). The conservative fallback is correct: if we - don't know the callback's shape, assume the legacy 3-arg form so - the resolver won't pass an extra positional/keyword that triggers - TypeError and silently drops the dependency (#940 SR1). - """ - try: - sig = inspect.signature(callback) - except (TypeError, ValueError): - return False - for param in sig.parameters.values(): - if param.kind is inspect.Parameter.VAR_KEYWORD: - return True - if param.name == "parent_pkg": - return True - return False + """Return True if callback accepts ``parent_pkg``; see helper.""" + return _signature_accepts_parent_pkg(callback) def resolve_dependencies(self, project_root: Path) -> DependencyGraph: """ @@ -231,54 +210,15 @@ def resolve_dependencies(self, project_root: Path) -> DependencyGraph: def _remote_parent_eligible(self, parent_dep: DependencyReference) -> bool: """Return True if *parent_dep* can serve as the Git repo for ``git: parent`` expansion.""" - if parent_dep.is_azure_devops(): - return bool(parent_dep.ado_repo and parent_dep.repo_url.count("/") >= 2) - return "/" in parent_dep.repo_url + return _remote_parent_eligible(parent_dep) def expand_parent_repo_decl( self, parent_dep: DependencyReference, child_dep: DependencyReference, ) -> DependencyReference: - """Expand ``{ git: parent, path: ... }`` using the declaring package's coordinates. - - The child keeps its ``virtual_path`` (monorepo subdirectory), ``alias``, and - optional ``ref`` override; repository identity (host, ``repo_url``, ADO - fields, etc.) is inherited from *parent_dep*. - """ - if not child_dep.is_parent_repo_inheritance: - raise ValueError( - "expand_parent_repo_decl requires child_dep.is_parent_repo_inheritance" - ) - if parent_dep.is_local: - raise ValueError("git: parent cannot inherit from a local path dependency") - if parent_dep.repo_url.startswith("_local/"): - raise ValueError("git: parent cannot inherit from a local path dependency") - if not self._remote_parent_eligible(parent_dep): - raise ValueError("git: parent requires a remote Git parent package dependency") - - merged_ref = ( - child_dep.reference if child_dep.reference is not None else parent_dep.reference - ) - - return replace( - child_dep, - repo_url=parent_dep.repo_url, - host=parent_dep.host, - port=parent_dep.port, - explicit_scheme=parent_dep.explicit_scheme, - ado_organization=parent_dep.ado_organization, - ado_project=parent_dep.ado_project, - ado_repo=parent_dep.ado_repo, - artifactory_prefix=parent_dep.artifactory_prefix, - is_insecure=parent_dep.is_insecure, - allow_insecure=parent_dep.allow_insecure, - reference=merged_ref, - is_virtual=True, - is_parent_repo_inheritance=False, - is_local=False, - local_path=None, - ) + """Expand ``{ git: parent, path: ... }`` using the declaring package's coordinates.""" + return _expand_parent_repo_decl(parent_dep, child_dep) def _resolve_marketplace_dep(self, dep_ref: DependencyReference) -> DependencyReference: """Resolve a marketplace dependency to a concrete DependencyReference. @@ -587,126 +527,16 @@ def build_dependency_tree(self, root_apm_yml: Path) -> DependencyTree: return tree def detect_circular_dependencies(self, tree: DependencyTree) -> list[CircularRef]: - """ - Detect and report circular dependency chains. - - Uses depth-first search to detect cycles in the dependency graph. - A cycle is detected when we encounter the same repository URL - in our current traversal path. - - Args: - tree: The dependency tree to analyze - - Returns: - List[CircularRef]: List of detected circular dependencies - """ - circular_deps = [] - visited: set[str] = set() - current_path: list[str] = [] - current_path_set: set[str] = set() # O(1) membership test (#171) - - def dfs_detect_cycles(node: DependencyNode) -> None: - """Recursive DFS function to detect cycles.""" - node_id = node.get_id() - # Use unique key (includes subdirectory path) to distinguish monorepo packages - # e.g., vineethsoma/agent-packages/agents/X vs vineethsoma/agent-packages/skills/Y - unique_key = node.dependency_ref.get_unique_key() - - # Check if this unique key is already in our current path (cycle detected) - if unique_key in current_path_set: - # Found a cycle - create the cycle path - cycle_start_index = current_path.index(unique_key) - cycle_path = current_path[cycle_start_index:] + [unique_key] # noqa: RUF005 - - circular_ref = CircularRef(cycle_path=cycle_path, detected_at_depth=node.depth) - circular_deps.append(circular_ref) - return - - # Mark current node as visited and add unique key to path - visited.add(node_id) - current_path.append(unique_key) - current_path_set.add(unique_key) - - # Check all children - for child in node.children: - child_id = child.get_id() - - # Only recurse if we haven't processed this subtree completely - if ( - child_id not in visited - or child.dependency_ref.get_unique_key() in current_path_set - ): - dfs_detect_cycles(child) - - # Remove from path when backtracking (but keep in visited) - current_path_set.discard(current_path.pop()) - - # Start DFS from all root level dependencies (depth 1) - root_deps = tree.get_nodes_at_depth(1) - for root_dep in root_deps: - if root_dep.get_id() not in visited: - current_path = [] # Reset path for each root - current_path_set = set() - dfs_detect_cycles(root_dep) - - return circular_deps + """Detect cycles; see :func:`apm_resolver_helpers._detect_circular_deps`.""" + return _detect_circular_deps(tree) def flatten_dependencies(self, tree: DependencyTree) -> FlatDependencyMap: - """ - Flatten tree to avoid duplicate installations (NPM hoisting). - - Implements "first wins" conflict resolution strategy where the first - declared dependency takes precedence over later conflicting dependencies. - - Args: - tree: The dependency tree to flatten - - Returns: - FlatDependencyMap: Flattened dependencies ready for installation - """ - flat_map = FlatDependencyMap() - seen_keys: set[str] = set() - - # Process dependencies level by level (breadth-first) - # This ensures that dependencies declared earlier in the tree get priority - for depth in range(1, tree.max_depth + 1): - nodes_at_depth = tree.get_nodes_at_depth(depth) - - # Sort nodes by their position in the tree to ensure deterministic ordering - # In a real implementation, this would be based on declaration order - nodes_at_depth.sort(key=lambda node: node.get_id()) - - for node in nodes_at_depth: - unique_key = node.dependency_ref.get_unique_key() - - if unique_key not in seen_keys: - # First occurrence - add without conflict - flat_map.add_dependency(node.dependency_ref, is_conflict=False) - seen_keys.add(unique_key) - else: - # Conflict - record it but keep the first one - flat_map.add_dependency(node.dependency_ref, is_conflict=True) - - return flat_map + """Flatten tree (NPM hoisting); see :func:`apm_resolver_helpers._flatten_dependencies`.""" + return _flatten_dependencies(tree) def _validate_dependency_reference(self, dep_ref: DependencyReference) -> bool: - """ - Validate that a dependency reference is well-formed. - - Args: - dep_ref: The dependency reference to validate - - Returns: - bool: True if valid, False otherwise - """ - if not dep_ref.repo_url: - return False - - # Basic validation - in real implementation would be more thorough - if "/" not in dep_ref.repo_url: # noqa: SIM103 - return False - - return True + """Validate that *dep_ref* is well-formed; see helper.""" + return _validate_dependency_reference(dep_ref) def _load_work_item(self, item): """Worker payload for the level-batched parallel BFS. @@ -920,34 +750,8 @@ def _try_load_dependency_package( @staticmethod def _is_remote_parent(parent_pkg: APMPackage | None) -> bool: - """Return True if *parent_pkg* is a REMOTE package (i.e. fetched via - git URL or pinned by ref/path). - - Used to gate ``local_path`` deps: only the root project and other - local packages may legitimately declare them. Remote packages - declaring a local_path is a path-confusion vector. - - SECURITY NOTE: this is a heuristic on the ``source`` field. A - sufficiently adversarial remote could spoof a local-looking source. - The downstream containment check via ``ensure_path_within`` is the - actual security boundary; this gate just produces the user-facing - error early. - """ - if parent_pkg is None or not parent_pkg.source: - return False - src = str(parent_pkg.source) - # Local deps get ``source = "_local/"`` (see DependencyReference - # construction for is_local=True). Treat that prefix as definitively - # local even though it contains a slash. - if src.startswith("_local/"): - return False - # Remote sources look like URLs or owner/repo refs. Local sources - # are filesystem paths the user typed in their apm.yml. - return ( - src.startswith(("http://", "https://", "git@", "ssh://", "git+")) - or "://" in src - or (src.count("/") >= 1 and not src.startswith((".", "/", "~"))) - ) + """Return True if *parent_pkg* is a REMOTE package; see helper.""" + return _is_remote_parent(parent_pkg) @staticmethod def _compute_dep_source_path( @@ -955,74 +759,19 @@ def _compute_dep_source_path( parent_pkg: APMPackage | None, install_path: Path, ) -> Path: - """Return the source-path anchor for a dependency. - - For LOCAL deps we return the *original* user source directory so that - transitive ``../sibling`` references inside its apm.yml resolve as a - developer reading the file expects (#857). For REMOTE deps we return - the clone location under apm_modules. - """ - if dep_ref.is_local and dep_ref.local_path: - local = Path(dep_ref.local_path).expanduser() - if not local.is_absolute() and parent_pkg is not None and parent_pkg.source_path: - return (parent_pkg.source_path / local).resolve() - return local.resolve() - return install_path.resolve() + """Return the source-path anchor for a dependency; see helper.""" + return _compute_dep_source_path(dep_ref, parent_pkg, install_path) @staticmethod def _download_dedup_key(dep_ref: DependencyReference, parent_pkg: APMPackage | None) -> str: - """Dedup key for the download cache. - - Includes the parent's source_path so two parents anchoring the same - local dep at different absolute locations don't collide on the first - one's resolved path. For non-local deps, the parent anchor doesn't - affect resolution, so the bare unique key suffices. - """ - base = dep_ref.get_unique_key() - if dep_ref.is_local and parent_pkg is not None and parent_pkg.source_path: - return f"{base}@{parent_pkg.source_path}" - return base + """Dedup key for the download cache; see helper.""" + return _download_dedup_key(dep_ref, parent_pkg) @staticmethod def _effective_base_dir(parent_pkg: APMPackage | None, project_root: Path) -> Path: - """Return the directory used to anchor relative ``local_path`` deps. - - For direct (root-declared) deps, this is the project root. For - transitive deps, it is the declaring package's source_path so a - ``../sibling`` written inside the original package directory means - what the author meant (#857). - """ - if parent_pkg is not None and parent_pkg.source_path is not None: - return parent_pkg.source_path - return project_root + """Return the directory used to anchor relative ``local_path`` deps; see helper.""" + return _effective_base_dir(parent_pkg, project_root) def _create_resolution_summary(self, graph: DependencyGraph) -> str: - """ - Create a human-readable summary of the resolution results. - - Args: - graph: The resolved dependency graph - - Returns: - str: Summary string - """ - summary = graph.get_summary() - lines = [ - "Dependency Resolution Summary:", - f" Root package: {summary['root_package']}", - f" Total dependencies: {summary['total_dependencies']}", - f" Maximum depth: {summary['max_depth']}", - ] - - if summary["has_conflicts"]: - lines.append(f" Conflicts detected: {summary['conflict_count']}") - - if summary["has_circular_dependencies"]: - lines.append(f" Circular dependencies: {summary['circular_count']}") - - if summary["has_errors"]: - lines.append(f" Resolution errors: {summary['error_count']}") - - lines.append(f" Status: {'[+] Valid' if summary['is_valid'] else '[x] Invalid'}") - - return "\n".join(lines) + """Create a human-readable summary of the resolution results; see helper.""" + return _create_resolution_summary(graph) diff --git a/src/apm_cli/deps/apm_resolver_helpers.py b/src/apm_cli/deps/apm_resolver_helpers.py new file mode 100644 index 000000000..92232e699 --- /dev/null +++ b/src/apm_cli/deps/apm_resolver_helpers.py @@ -0,0 +1,374 @@ +"""Static helper functions extracted from :class:`APMDependencyResolver`. + +Moved to this sibling module to keep :mod:`apm_resolver` under the +file-length guardrail. All functions are pure (no I/O, no class state) +and are re-exported from :mod:`apm_resolver` via thin +``@staticmethod`` / instance-method stubs so existing callers -- +including ``APMDependencyResolver._resolve_max_parallel(7)``-style +test assertions -- are unchanged. + +Rule A: every public name here that was previously accessible as +``apm_cli.deps.apm_resolver.`` is re-exported (redundant-alias +form) from :mod:`apm_resolver` to preserve patch targets. +""" + +from __future__ import annotations + +import inspect +import os +from pathlib import Path + +from ..models.apm_package import APMPackage, DependencyReference +from .dependency_graph import ( + CircularRef, + DependencyGraph, + DependencyNode, + DependencyTree, + FlatDependencyMap, +) + +# Must match the constant defined in apm_resolver (same value, separated to +# avoid a circular import). +_DEFAULT_RESOLVE_PARALLEL = 4 + + +# --------------------------------------------------------------------------- +# Parallel-worker helpers +# --------------------------------------------------------------------------- + + +def _resolve_max_parallel(explicit: int | None) -> int: + """Compute effective worker count for level-batched parallel BFS. + + Parallel is the default and central execution model. The override + exists for parity testing (``APM_RESOLVE_PARALLEL=1``) and CI + diagnostics, not as a user-facing knob. + + Order of precedence: + 1. Explicit ``max_parallel`` ctor arg. + 2. ``APM_RESOLVE_PARALLEL`` env var (diagnostic/parity knob). + 3. ``_DEFAULT_RESOLVE_PARALLEL``. + + Always coerced to ``>= 1`` so the executor never gets a zero or + negative ``max_workers``. + """ + import logging + + if explicit is not None: + return max(1, int(explicit)) + env = os.environ.get("APM_RESOLVE_PARALLEL", "").strip() + if env: + try: + return max(1, int(env)) + except ValueError: + logging.getLogger(__name__).debug("Ignoring invalid APM_RESOLVE_PARALLEL=%r", env) + return _DEFAULT_RESOLVE_PARALLEL + + +def _signature_accepts_parent_pkg(callback) -> bool: + """Return True if ``callback`` declares a ``parent_pkg`` parameter + (or accepts ``**kwargs``). + + Falls back to False if the signature can't be introspected (e.g. C + extensions, builtins). The conservative fallback is correct: if we + don't know the callback's shape, assume the legacy 3-arg form so + the resolver won't pass an extra positional/keyword that triggers + TypeError and silently drops the dependency (#940 SR1). + """ + try: + sig = inspect.signature(callback) + except (TypeError, ValueError): + return False + for param in sig.parameters.values(): + if param.kind is inspect.Parameter.VAR_KEYWORD: + return True + if param.name == "parent_pkg": + return True + return False + + +# --------------------------------------------------------------------------- +# Dependency-reference guards (no I/O) +# --------------------------------------------------------------------------- + + +def _remote_parent_eligible(parent_dep: DependencyReference) -> bool: + """Return True if *parent_dep* can serve as the Git repo for ``git: parent`` expansion.""" + if parent_dep.is_azure_devops(): + return bool(parent_dep.ado_repo and parent_dep.repo_url.count("/") >= 2) + return "/" in parent_dep.repo_url + + +def _expand_parent_repo_decl( + parent_dep: DependencyReference, + child_dep: DependencyReference, +) -> DependencyReference: + """Expand ``{ git: parent, path: ... }`` using the declaring package's coordinates. + + The child keeps its ``virtual_path`` (monorepo subdirectory), ``alias``, and + optional ``ref`` override; repository identity (host, ``repo_url``, ADO + fields, etc.) is inherited from *parent_dep*. + """ + from dataclasses import replace + + if not child_dep.is_parent_repo_inheritance: + raise ValueError("expand_parent_repo_decl requires child_dep.is_parent_repo_inheritance") + if parent_dep.is_local: + raise ValueError("git: parent cannot inherit from a local path dependency") + if parent_dep.repo_url.startswith("_local/"): + raise ValueError("git: parent cannot inherit from a local path dependency") + if not _remote_parent_eligible(parent_dep): + raise ValueError("git: parent requires a remote Git parent package dependency") + + merged_ref = child_dep.reference if child_dep.reference is not None else parent_dep.reference + + return replace( + child_dep, + repo_url=parent_dep.repo_url, + host=parent_dep.host, + port=parent_dep.port, + explicit_scheme=parent_dep.explicit_scheme, + ado_organization=parent_dep.ado_organization, + ado_project=parent_dep.ado_project, + ado_repo=parent_dep.ado_repo, + artifactory_prefix=parent_dep.artifactory_prefix, + is_insecure=parent_dep.is_insecure, + allow_insecure=parent_dep.allow_insecure, + reference=merged_ref, + is_virtual=True, + is_parent_repo_inheritance=False, + is_local=False, + local_path=None, + ) + + +# --------------------------------------------------------------------------- +# Tree algorithms (pure graph operations -- no package loading) +# --------------------------------------------------------------------------- + + +def _detect_circular_deps(tree: DependencyTree) -> list[CircularRef]: + """Detect and report circular dependency chains. + + Uses depth-first search to detect cycles in the dependency graph. + A cycle is detected when we encounter the same repository URL + in our current traversal path. + + Args: + tree: The dependency tree to analyse. + + Returns: + List[CircularRef]: List of detected circular dependencies. + """ + circular_deps: list[CircularRef] = [] + visited: set[str] = set() + current_path: list[str] = [] + current_path_set: set[str] = set() # O(1) membership test (#171) + + def dfs_detect_cycles(node: DependencyNode) -> None: + """Recursive DFS function to detect cycles.""" + node_id = node.get_id() + # Use unique key (includes subdirectory path) to distinguish monorepo packages + # e.g., vineethsoma/agent-packages/agents/X vs vineethsoma/agent-packages/skills/Y + unique_key = node.dependency_ref.get_unique_key() + + # Check if this unique key is already in our current path (cycle detected) + if unique_key in current_path_set: + # Found a cycle - create the cycle path + cycle_start_index = current_path.index(unique_key) + cycle_path = current_path[cycle_start_index:] + [unique_key] # noqa: RUF005 + + circular_ref = CircularRef(cycle_path=cycle_path, detected_at_depth=node.depth) + circular_deps.append(circular_ref) + return + + # Mark current node as visited and add unique key to path + visited.add(node_id) + current_path.append(unique_key) + current_path_set.add(unique_key) + + # Check all children + for child in node.children: + child_id = child.get_id() + + # Only recurse if we haven't processed this subtree completely + if child_id not in visited or child.dependency_ref.get_unique_key() in current_path_set: + dfs_detect_cycles(child) + + # Remove from path when backtracking (but keep in visited) + current_path_set.discard(current_path.pop()) + + # Start DFS from all root level dependencies (depth 1) + root_deps = tree.get_nodes_at_depth(1) + for root_dep in root_deps: + if root_dep.get_id() not in visited: + current_path.clear() + current_path_set.clear() + dfs_detect_cycles(root_dep) + + return circular_deps + + +def _flatten_dependencies(tree: DependencyTree) -> FlatDependencyMap: + """Flatten tree to avoid duplicate installations (NPM hoisting). + + Implements "first wins" conflict resolution strategy where the first + declared dependency takes precedence over later conflicting dependencies. + + Args: + tree: The dependency tree to flatten. + + Returns: + FlatDependencyMap: Flattened dependencies ready for installation. + """ + flat_map = FlatDependencyMap() + seen_keys: set[str] = set() + + # Process dependencies level by level (breadth-first) + # This ensures that dependencies declared earlier in the tree get priority + for depth in range(1, tree.max_depth + 1): + nodes_at_depth = tree.get_nodes_at_depth(depth) + + # Sort nodes by their position in the tree to ensure deterministic ordering + nodes_at_depth.sort(key=lambda node: node.get_id()) + + for node in nodes_at_depth: + unique_key = node.dependency_ref.get_unique_key() + + if unique_key not in seen_keys: + # First occurrence - add without conflict + flat_map.add_dependency(node.dependency_ref, is_conflict=False) + seen_keys.add(unique_key) + else: + # Conflict - record it but keep the first one + flat_map.add_dependency(node.dependency_ref, is_conflict=True) + + return flat_map + + +# --------------------------------------------------------------------------- +# Package-loading utilities +# --------------------------------------------------------------------------- + + +def _is_remote_parent(parent_pkg: APMPackage | None) -> bool: + """Return True if *parent_pkg* is a REMOTE package (i.e. fetched via + git URL or pinned by ref/path). + + Used to gate ``local_path`` deps: only the root project and other + local packages may legitimately declare them. Remote packages + declaring a local_path is a path-confusion vector. + + SECURITY NOTE: this is a heuristic on the ``source`` field. A + sufficiently adversarial remote could spoof a local-looking source. + The downstream containment check via ``ensure_path_within`` is the + actual security boundary; this gate just produces the user-facing + error early. + """ + if parent_pkg is None or not parent_pkg.source: + return False + src = str(parent_pkg.source) + # Local deps get ``source = "_local/"`` (see DependencyReference + # construction for is_local=True). Treat that prefix as definitively + # local even though it contains a slash. + if src.startswith("_local/"): + return False + # Remote sources look like URLs or owner/repo refs. Local sources + # are filesystem paths the user typed in their apm.yml. + return ( + src.startswith(("http://", "https://", "git@", "ssh://", "git+")) + or "://" in src + or (src.count("/") >= 1 and not src.startswith((".", "/", "~"))) + ) + + +def _compute_dep_source_path( + dep_ref: DependencyReference, + parent_pkg: APMPackage | None, + install_path: Path, +) -> Path: + """Return the source-path anchor for a dependency. + + For LOCAL deps we return the *original* user source directory so that + transitive ``../sibling`` references inside its apm.yml resolve as a + developer reading the file expects (#857). For REMOTE deps we return + the clone location under apm_modules. + """ + if dep_ref.is_local and dep_ref.local_path: + local = Path(dep_ref.local_path).expanduser() + if not local.is_absolute() and parent_pkg is not None and parent_pkg.source_path: + return (parent_pkg.source_path / local).resolve() + return local.resolve() + return install_path.resolve() + + +def _download_dedup_key(dep_ref: DependencyReference, parent_pkg: APMPackage | None) -> str: + """Dedup key for the download cache. + + Includes the parent's source_path so two parents anchoring the same + local dep at different absolute locations don't collide on the first + one's resolved path. For non-local deps, the parent anchor doesn't + affect resolution, so the bare unique key suffices. + """ + base = dep_ref.get_unique_key() + if dep_ref.is_local and parent_pkg is not None and parent_pkg.source_path: + return f"{base}@{parent_pkg.source_path}" + return base + + +def _effective_base_dir(parent_pkg: APMPackage | None, project_root: Path) -> Path: + """Return the directory used to anchor relative ``local_path`` deps. + + For direct (root-declared) deps, this is the project root. For + transitive deps, it is the declaring package's source_path so a + ``../sibling`` written inside the original package directory means + what the author meant (#857). + """ + if parent_pkg is not None and parent_pkg.source_path is not None: + return parent_pkg.source_path + return project_root + + +# --------------------------------------------------------------------------- +# Summary formatting +# --------------------------------------------------------------------------- + + +def _create_resolution_summary(graph: DependencyGraph) -> str: + """Create a human-readable summary of the resolution results. + + Args: + graph: The resolved dependency graph. + + Returns: + str: Summary string. + """ + summary = graph.get_summary() + lines = [ + "Dependency Resolution Summary:", + f" Root package: {summary['root_package']}", + f" Total dependencies: {summary['total_dependencies']}", + f" Maximum depth: {summary['max_depth']}", + ] + + if summary["has_conflicts"]: + lines.append(f" Conflicts detected: {summary['conflict_count']}") + + if summary["has_circular_dependencies"]: + lines.append(f" Circular dependencies: {summary['circular_count']}") + + if summary["has_errors"]: + lines.append(f" Resolution errors: {summary['error_count']}") + + lines.append(f" Status: {'[+] Valid' if summary['is_valid'] else '[x] Invalid'}") + + return "\n".join(lines) + + +def _validate_dependency_reference(dep_ref: DependencyReference) -> bool: + """Validate that *dep_ref* is well-formed (non-empty repo_url with a slash).""" + if not dep_ref.repo_url: + return False + if "/" not in dep_ref.repo_url: # noqa: SIM103 + return False + return True diff --git a/src/apm_cli/deps/bare_cache.py b/src/apm_cli/deps/bare_cache.py index f13347ec0..b4af14662 100644 --- a/src/apm_cli/deps/bare_cache.py +++ b/src/apm_cli/deps/bare_cache.py @@ -35,6 +35,11 @@ from ..utils.git_sparse import apply_sparse_cone +# Rule A re-export: implementation lives in bare_cache_msg to keep this +# module under the file-length guardrail; names remain resolvable here. +from .bare_cache_msg import CloneFailureContext as CloneFailureContext +from .bare_cache_msg import build_clone_failure_message as build_clone_failure_message + if TYPE_CHECKING: from ..models.apm_package import DependencyReference @@ -717,89 +722,3 @@ def _wt_action(url: str, env: dict[str, str], target: Path) -> None: verbose_callback=verbose_callback, ) return repo_holder[0] - - -def build_clone_failure_message( - *, - repo_url_base: str, - plan: Any, - dep_ref: DependencyReference | None, - dep_host: str | None, - is_ado: bool, - is_generic: bool, - has_ado_token: bool, - has_token: bool, - auth_resolver: Any, - configured_github_host: str, - default_host_fn: Callable[[], str], - last_error: Exception | None, - sanitize_git_error: Callable[[str], str], -) -> str: - """Build the aggregate ``RuntimeError`` message for a failed transport plan. - - Extracted from :meth:`GitHubPackageDownloader._execute_transport_plan` - to keep that module under the file-length guardrail. Pure formatting: - no I/O, no clone attempts. - """ - if plan.strict and len(plan.attempts) >= 1: - tried = plan.attempts[0].label - error_msg = f"Failed to clone repository {repo_url_base} via {tried}. " - if plan.fallback_hint: - error_msg += plan.fallback_hint + " " - else: - error_msg = f"Failed to clone repository {repo_url_base} using all available methods. " - if is_ado and not has_ado_token: - host = dep_host or "dev.azure.com" - error_msg += auth_resolver.build_error_context( - host, - "clone", - org=dep_ref.ado_organization if dep_ref else None, - port=dep_ref.port if dep_ref else None, - dep_url=dep_ref.repo_url if dep_ref else None, - ) - elif is_generic: - if dep_host: - host_info = auth_resolver.classify_host( - dep_host, - port=dep_ref.port if dep_ref else None, - ) - host_name = host_info.display_name - else: - host_name = "the target host" - error_msg += ( - f"For private repositories on {host_name}, configure SSH keys or a git credential helper. " - f"APM delegates authentication to git for non-GitHub/ADO hosts." - ) - elif ( - configured_github_host - and dep_host - and dep_host == configured_github_host - and configured_github_host != "github.com" - ): - suggested = f"github.com/{repo_url_base}" - if dep_ref and dep_ref.virtual_path: - suggested += f"/{dep_ref.virtual_path}" - error_msg += ( - f"GITHUB_HOST is set to '{configured_github_host}', so shorthand dependencies " - f"(without a hostname) resolve against that host. " - f"If this package lives on a different server (e.g., github.com), " - f"use the full hostname in apm.yml: {suggested}" - ) - elif not has_token: - host = dep_host or default_host_fn() - org = dep_ref.repo_url.split("/")[0] if dep_ref and dep_ref.repo_url else None - error_msg += auth_resolver.build_error_context( - host, - "clone", - org=org, - port=dep_ref.port if dep_ref else None, - dep_url=dep_ref.repo_url if dep_ref else None, - ) - else: - error_msg += "Please check repository access permissions and authentication setup." - - if last_error: - sanitized_error = sanitize_git_error(str(last_error)) - error_msg += f" Last error: {sanitized_error}" - - return error_msg diff --git a/src/apm_cli/deps/bare_cache_msg.py b/src/apm_cli/deps/bare_cache_msg.py new file mode 100644 index 000000000..efd02dfc4 --- /dev/null +++ b/src/apm_cli/deps/bare_cache_msg.py @@ -0,0 +1,125 @@ +"""Clone-failure message builder for the WS2 dedup pipeline. + +Extracted from :mod:`bare_cache` to keep that module under the +file-length guardrail. Re-exported from ``bare_cache`` so callers +see no change. + +Public names: +* :class:`CloneFailureContext` -- frozen dataclass bundling the six + classifier flags / host-info params. +* :func:`build_clone_failure_message` -- aggregate error-message + formatter. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from ..models.apm_package import DependencyReference + + +@dataclass(frozen=True) +class CloneFailureContext: + """Classifier flags and host info for :func:`build_clone_failure_message`. + + Bundles the six parameters that describe the *kind* of failure so + the caller constructs the context once and passes a single cohesive + argument instead of six separate keyword arguments. + """ + + is_ado: bool + is_generic: bool + has_ado_token: bool + has_token: bool + dep_host: str | None + configured_github_host: str + + +def build_clone_failure_message( + *, + repo_url_base: str, + plan: Any, + dep_ref: DependencyReference | None, + auth_resolver: Any, + default_host_fn: Callable[[], str], + last_error: Exception | None, + sanitize_git_error: Callable[[str], str], + clone_ctx: CloneFailureContext, +) -> str: + """Build the aggregate ``RuntimeError`` message for a failed transport plan. + + Extracted from :meth:`GitHubPackageDownloader._execute_transport_plan` + to keep that module under the file-length guardrail. Pure formatting: + no I/O, no clone attempts. + + ``clone_ctx`` bundles the six failure-classifier flags (``is_ado``, + ``is_generic``, ``has_ado_token``, ``has_token``, ``dep_host``, + ``configured_github_host``); callers construct it before the call. + """ + if plan.strict and len(plan.attempts) >= 1: + tried = plan.attempts[0].label + error_msg = f"Failed to clone repository {repo_url_base} via {tried}. " + if plan.fallback_hint: + error_msg += plan.fallback_hint + " " + else: + error_msg = f"Failed to clone repository {repo_url_base} using all available methods. " + + if clone_ctx.is_ado and not clone_ctx.has_ado_token: + host = clone_ctx.dep_host or "dev.azure.com" + error_msg += auth_resolver.build_error_context( + host, + "clone", + org=dep_ref.ado_organization if dep_ref else None, + port=dep_ref.port if dep_ref else None, + dep_url=dep_ref.repo_url if dep_ref else None, + ) + elif clone_ctx.is_generic: + if clone_ctx.dep_host: + host_info = auth_resolver.classify_host( + clone_ctx.dep_host, + port=dep_ref.port if dep_ref else None, + ) + host_name = host_info.display_name + else: + host_name = "the target host" + error_msg += ( + f"For private repositories on {host_name}, configure SSH keys " + f"or a git credential helper. " + f"APM delegates authentication to git for non-GitHub/ADO hosts." + ) + elif ( + clone_ctx.configured_github_host + and clone_ctx.dep_host + and clone_ctx.dep_host == clone_ctx.configured_github_host + and clone_ctx.configured_github_host != "github.com" + ): + suggested = f"github.com/{repo_url_base}" + if dep_ref and dep_ref.virtual_path: + suggested += f"/{dep_ref.virtual_path}" + error_msg += ( + f"GITHUB_HOST is set to '{clone_ctx.configured_github_host}', " + f"so shorthand dependencies (without a hostname) resolve against that host. " + f"If this package lives on a different server (e.g., github.com), " + f"use the full hostname in apm.yml: {suggested}" + ) + elif not clone_ctx.has_token: + host = clone_ctx.dep_host or default_host_fn() + org = dep_ref.repo_url.split("/")[0] if dep_ref and dep_ref.repo_url else None + error_msg += auth_resolver.build_error_context( + host, + "clone", + org=org, + port=dep_ref.port if dep_ref else None, + dep_url=dep_ref.repo_url if dep_ref else None, + ) + else: + error_msg += "Please check repository access permissions and authentication setup." + + if last_error: + sanitized_error = sanitize_git_error(str(last_error)) + error_msg += f" Last error: {sanitized_error}" + + return error_msg diff --git a/src/apm_cli/deps/clone_engine.py b/src/apm_cli/deps/clone_engine.py index b355a34d4..678574d44 100644 --- a/src/apm_cli/deps/clone_engine.py +++ b/src/apm_cli/deps/clone_engine.py @@ -35,7 +35,7 @@ is_ado_auth_failure_signal, is_github_hostname, ) -from .bare_cache import build_clone_failure_message +from .bare_cache import CloneFailureContext, build_clone_failure_message from .transport_selection import TransportAttempt, TransportPlan if TYPE_CHECKING: @@ -327,16 +327,18 @@ def _env_for(attempt: TransportAttempt) -> dict[str, str]: repo_url_base=repo_url_base, plan=plan, dep_ref=dep_ref, - dep_host=dep_host, - is_ado=bool(is_ado), - is_generic=is_generic, - has_ado_token=host.has_ado_token, - has_token=has_token, auth_resolver=host.auth_resolver, - configured_github_host=os.environ.get("GITHUB_HOST", ""), default_host_fn=default_host, last_error=last_error, sanitize_git_error=host._sanitize_git_error, + clone_ctx=CloneFailureContext( + is_ado=bool(is_ado), + is_generic=is_generic, + has_ado_token=host.has_ado_token, + has_token=has_token, + dep_host=dep_host, + configured_github_host=os.environ.get("GITHUB_HOST", ""), + ), ) raise RuntimeError(error_msg) diff --git a/src/apm_cli/deps/git_reference_resolver.py b/src/apm_cli/deps/git_reference_resolver.py index 75acaef3d..bc86830db 100644 --- a/src/apm_cli/deps/git_reference_resolver.py +++ b/src/apm_cli/deps/git_reference_resolver.py @@ -236,9 +236,10 @@ def resolve_commit_sha_for_ref(self, dep_ref: DependencyReference, ref: str) -> host = self._host try: - if dep_ref.is_artifactory() or dep_ref.is_azure_devops(): - return None + is_unsupported = dep_ref.is_artifactory() or dep_ref.is_azure_devops() except Exception: + is_unsupported = True + if is_unsupported: return None target_host = dep_ref.host or default_host() @@ -272,14 +273,22 @@ def resolve_commit_sha_for_ref(self, dep_ref: DependencyReference, ref: str) -> if token: headers["Authorization"] = f"token {token}" + return self._call_commits_api(api_url, headers, host) + + def _call_commits_api( + self, api_url: str, headers: dict[str, str], host: _DownloaderContext + ) -> str | None: + """Execute the commits API call and return the validated 40-char SHA or None. + + Extracted from :meth:`resolve_commit_sha_for_ref` to keep that method's + return-statement count within the PLR0911 limit. + """ try: response = host._resilient_get(api_url, headers=headers, timeout=10) if response.status_code != 200: return None body = (response.text or "").strip() - if re.match(r"^[a-f0-9]{40}$", body.lower()): - return body.lower() - return None + return body.lower() if re.match(r"^[a-f0-9]{40}$", body.lower()) else None except Exception: return None diff --git a/src/apm_cli/deps/lockfile.py b/src/apm_cli/deps/lockfile.py index 5de15840c..27b384d6f 100644 --- a/src/apm_cli/deps/lockfile.py +++ b/src/apm_cli/deps/lockfile.py @@ -682,13 +682,13 @@ def is_semantically_equivalent(self, other: LockFile) -> bool: return False if self.lsp_configs != other.lsp_configs: return False - if sorted(self.local_deployed_files) != sorted(other.local_deployed_files): - return False # Issue #887: include hash dict in equivalence so post-install # hash updates persist even when the file list is unchanged. - if dict(self.local_deployed_file_hashes) != dict(other.local_deployed_file_hashes): # noqa: SIM103 - return False - return True + # Combine the last two comparisons into a single return to keep + # PLR0911 (too many return statements) satisfied. + return sorted(self.local_deployed_files) == sorted(other.local_deployed_files) and dict( + self.local_deployed_file_hashes + ) == dict(other.local_deployed_file_hashes) @classmethod def installed_paths_for_project(cls, project_root: Path) -> list[str]: diff --git a/src/apm_cli/deps/plugin_parser.py b/src/apm_cli/deps/plugin_parser.py index 8e34d8b8f..6ba55ff72 100644 --- a/src/apm_cli/deps/plugin_parser.py +++ b/src/apm_cli/deps/plugin_parser.py @@ -14,7 +14,6 @@ import json import logging import os -import shutil from pathlib import Path from typing import Any @@ -23,6 +22,35 @@ from ..utils.console import _rich_warning from ..utils.path_security import PathTraversalError, ensure_path_within +# Rule A re-export: implementations in plugin_server_helpers; names stay resolvable here. +from .plugin_server_helpers import ( + _extract_lsp_servers as _extract_lsp_servers, +) +from .plugin_server_helpers import ( + _extract_mcp_servers as _extract_mcp_servers, +) +from .plugin_server_helpers import ( + _lsp_servers_to_apm_deps as _lsp_servers_to_apm_deps, +) +from .plugin_server_helpers import ( + _mcp_servers_to_apm_deps as _mcp_servers_to_apm_deps, +) +from .plugin_server_helpers import ( + _read_lsp_file as _read_lsp_file, +) +from .plugin_server_helpers import ( + _read_lsp_json as _read_lsp_json, +) +from .plugin_server_helpers import ( + _read_mcp_file as _read_mcp_file, +) +from .plugin_server_helpers import ( + _read_mcp_json as _read_mcp_json, +) +from .plugin_server_helpers import ( + _substitute_plugin_root as _substitute_plugin_root, +) + _logger = logging.getLogger(__name__) @@ -214,519 +242,161 @@ def synthesize_apm_yml_from_plugin(plugin_path: Path, manifest: dict[str, Any]) return apm_yml_path -def _extract_mcp_servers(plugin_path: Path, manifest: dict[str, Any]) -> dict[str, Any]: - """Extract MCP server definitions from a plugin manifest. - - Resolves ``mcpServers`` by type (per Claude Code spec): - - ``str`` -> read that file path relative to plugin root, parse JSON, - extract ``mcpServers`` key. - - ``list`` -> read each file path, merge (last-wins on name conflict). - - ``dict`` -> use directly as inline server definitions. - - When ``mcpServers`` is absent and ``.mcp.json`` (or ``.github/.mcp.json``) - exists at plugin root, read it as the default (matches Claude Code - auto-discovery). +# --------------------------------------------------------------------------- +# _map_plugin_artifacts sub-helpers (module-level so they can be tested +# independently and to keep _map_plugin_artifacts below C901=35). +# --------------------------------------------------------------------------- - Security: symlinks are skipped, JSON parse errors are logged as warnings. - ``${CLAUDE_PLUGIN_ROOT}`` in string values is replaced with the absolute - plugin path. - - Args: - plugin_path: Root of the plugin directory. - manifest: Parsed plugin.json dict. +def _resolve_plugin_sources( + plugin_path: Path, manifest: dict[str, Any], component: str, default_dir: str +) -> list[Path]: + """Return list of existing source paths (dirs or files) for *component*. - Returns: - dict mapping server name -> server config. Empty on failure. + Uses ``manifest[component]`` when present (list or str), else falls + back to the ``default_dir`` directory inside *plugin_path*. Every + path is verified to exist, not be a symlink, and resolve inside + *plugin_path* (path-traversal guard). """ - logger = logging.getLogger("apm") - mcp_value = manifest.get("mcpServers") - - if mcp_value is not None: - # Manifest explicitly defines mcpServers - if isinstance(mcp_value, dict): - servers = dict(mcp_value) - elif isinstance(mcp_value, str): - servers = _read_mcp_file(plugin_path, mcp_value, logger) - elif isinstance(mcp_value, list): - servers = {} - for entry in mcp_value: - if isinstance(entry, str): - servers.update(_read_mcp_file(plugin_path, entry, logger)) - else: - logger.warning("Ignoring non-string entry in mcpServers array: %s", entry) - else: - logger.warning("Unsupported mcpServers type %s; ignoring", type(mcp_value).__name__) - return {} - else: - # Fall back to auto-discovery: .mcp.json then .github/.mcp.json - servers = {} - for fallback in (".mcp.json", ".github/.mcp.json"): - candidate = plugin_path / fallback - if candidate.exists() and candidate.is_file() and not candidate.is_symlink(): - servers = _read_mcp_json(candidate, logger) - if servers: - break - - # Substitute ${CLAUDE_PLUGIN_ROOT} in all string values - if servers: - abs_root = str(plugin_path.resolve()) - servers = _substitute_plugin_root(servers, abs_root, logger) - - return servers - - -def _read_mcp_file(plugin_path: Path, rel_path: str, logger: logging.Logger) -> dict[str, Any]: - """Read a JSON file relative to *plugin_path* and return its ``mcpServers`` dict.""" - target = (plugin_path / rel_path).resolve() - # Security: must stay inside plugin_path and not be a symlink - try: - target.relative_to(plugin_path.resolve()) - except ValueError: - logger.warning("MCP file path escapes plugin root: %s", rel_path) - return {} - candidate = plugin_path / rel_path - if not candidate.exists() or not candidate.is_file(): - logger.warning("MCP file not found: %s", candidate) - return {} - if candidate.is_symlink(): - logger.warning("Skipping symlinked MCP file: %s", candidate) - return {} - return _read_mcp_json(candidate, logger) - - -def _read_mcp_json(path: Path, logger: logging.Logger) -> dict[str, Any]: - """Parse a JSON file and return the ``mcpServers`` mapping.""" - try: - data = json.loads(path.read_text(encoding="utf-8")) - except (json.JSONDecodeError, OSError) as exc: - logger.warning("Failed to read MCP config %s: %s", path, exc) - return {} - if not isinstance(data, dict): - return {} - servers = data.get("mcpServers", {}) - return dict(servers) if isinstance(servers, dict) else {} - - -def _substitute_plugin_root( - servers: dict[str, Any], abs_root: str, logger: logging.Logger -) -> dict[str, Any]: - """Replace ``${CLAUDE_PLUGIN_ROOT}`` in server config string values.""" - placeholder = "${CLAUDE_PLUGIN_ROOT}" - substituted = False - - def _walk(obj: Any) -> Any: - nonlocal substituted - if isinstance(obj, str) and placeholder in obj: - substituted = True - return obj.replace(placeholder, abs_root) - if isinstance(obj, dict): - return {k: _walk(v) for k, v in obj.items()} - if isinstance(obj, list): - return [_walk(item) for item in obj] - return obj - - result = {name: _walk(cfg) for name, cfg in servers.items()} - if substituted: - logger.info("Substituted ${CLAUDE_PLUGIN_ROOT} with %s", abs_root) - return result - - -def _mcp_servers_to_apm_deps(servers: dict[str, Any], plugin_path: Path) -> list[dict[str, Any]]: - """Convert raw MCP server configs to ``dependencies.mcp`` dicts. - - Transport inference: - - ``command`` present -> stdio - - ``url`` present -> http (or ``type`` if it's a valid transport) - - Neither -> skipped with warning - - Every entry gets ``registry: false`` (self-defined, not registry lookups). - - All resulting entries are routed through ``MCPDependency.from_dict()`` - so plugin-synthesized servers must clear the same security validation - chokepoint as CLI-authored or manually edited entries (name shape, URL - scheme allowlist, header CRLF, command path-traversal). Entries that - fail validation are skipped with a warning rather than crashing the - plugin install -- a single malformed server should not block the - whole plugin. - - Args: - servers: Mapping of server name -> server config dict. - plugin_path: Plugin root (used for log context only). - - Returns: - List of dicts consumable by ``MCPDependency.from_dict()``. - """ - from ..models.dependency.mcp import MCPDependency - - logger = logging.getLogger("apm") - deps: list[dict[str, Any]] = [] - - for name, cfg in servers.items(): - if not isinstance(cfg, dict): - logger.warning("Skipping non-dict MCP server config '%s'", name) - continue - - dep: dict[str, Any] = {"name": name, "registry": False} - - if "command" in cfg: - dep["transport"] = "stdio" - dep["command"] = cfg["command"] - if "args" in cfg: - dep["args"] = cfg["args"] - elif "url" in cfg: - raw_type = cfg.get("type", "http") - valid_transports = {"http", "sse", "streamable-http"} - dep["transport"] = raw_type if raw_type in valid_transports else "http" - dep["url"] = cfg["url"] - if "headers" in cfg: - dep["headers"] = cfg["headers"] - else: - _surface_warning( - f"Skipping MCP server '{name}' from plugin " - f"'{plugin_path.name}': no 'command' or 'url'", - logger, - ) - continue - - if "env" in cfg: - dep["env"] = cfg["env"] - if "tools" in cfg: - dep["tools"] = cfg["tools"] - - # Route through the validation chokepoint. Plugins are an ingress - # path: a malicious plugin could otherwise smuggle path traversal, - # CRLF, or unsafe URL schemes that bypass MCPDependency.validate(). - # PR #809 follow-up: surface validation errors to the user via the - # rich console (stdlib logger has no handlers configured). - try: - MCPDependency.from_dict(dep) - except (ValueError, Exception) as exc: - _surface_warning( - f"Skipping invalid MCP server '{name}' from plugin '{plugin_path.name}': {exc}", - logger, - ) - continue - - deps.append(dep) - - return deps - - -def _extract_lsp_servers(plugin_path: Path, manifest: dict[str, Any]) -> dict[str, Any]: - """Extract LSP server definitions from a plugin manifest. - - Resolves ``lspServers`` by type (per Claude Code spec): - - ``str`` -> read that file path relative to plugin root, parse JSON. - - ``dict`` -> use directly as inline server definitions. - - When ``lspServers`` is absent and ``.lsp.json`` exists at plugin root, - read it as the default (matches Claude Code auto-discovery). - - Security: symlinks are skipped, JSON parse errors are logged as warnings. - - ``${CLAUDE_PLUGIN_ROOT}`` in string values is replaced with the absolute - plugin path. - - Args: - plugin_path: Root of the plugin directory. - manifest: Parsed plugin.json dict. - - Returns: - dict mapping server name -> server config. Empty on failure. - """ - logger = logging.getLogger("apm") - lsp_value = manifest.get("lspServers") - - if lsp_value is not None: - if isinstance(lsp_value, dict): - servers = dict(lsp_value) - elif isinstance(lsp_value, str): - servers = _read_lsp_file(plugin_path, lsp_value, logger) - else: - logger.warning("Unsupported lspServers type %s; ignoring", type(lsp_value).__name__) - return {} - else: - # Fall back to auto-discovery: .lsp.json - servers = {} - candidate = plugin_path / ".lsp.json" - if candidate.exists() and candidate.is_file() and not candidate.is_symlink(): - servers = _read_lsp_json(candidate, logger) - - # Substitute ${CLAUDE_PLUGIN_ROOT} in all string values - if servers: - abs_root = str(plugin_path.resolve()) - servers = _substitute_plugin_root(servers, abs_root, logger) + custom = manifest.get(component) + if isinstance(custom, list): + paths = [] + for p in custom: + src = plugin_path / str(p) + if ( + src.exists() + and not src.is_symlink() + and _is_within_plugin(src, plugin_path, component=component) + ): + paths.append(src) + return paths + if isinstance(custom, str): + src = plugin_path / custom + if ( + src.exists() + and not src.is_symlink() + and _is_within_plugin(src, plugin_path, component=component) + ): + return [src] + return [] + default = plugin_path / default_dir + if ( + default.exists() + and not default.is_symlink() + and default.is_dir() + and _is_within_plugin(default, plugin_path, component=component) + ): + return [default] + return [] - return servers +def _is_same_path(src: Path, dst: Path) -> bool: + """Return True when *src* and *dst* resolve to the same filesystem path. -def _read_lsp_file(plugin_path: Path, rel_path: str, logger: logging.Logger) -> dict[str, Any]: - """Read a JSON file relative to *plugin_path* and return its LSP server dict.""" - target = (plugin_path / rel_path).resolve() - try: - target.relative_to(plugin_path.resolve()) - except ValueError: - logger.warning("LSP file path escapes plugin root: %s", rel_path) - return {} - candidate = plugin_path / rel_path - if not candidate.exists() or not candidate.is_file(): - logger.warning("LSP file not found: %s", candidate) - return {} - if candidate.is_symlink(): - logger.warning("Skipping symlinked LSP file: %s", candidate) - return {} - return _read_lsp_json(candidate, logger) - - -def _read_lsp_json(path: Path, logger: logging.Logger) -> dict[str, Any]: - """Parse a JSON file and return the LSP servers mapping. - - Unlike .mcp.json which has a wrapper key, .lsp.json uses server names - as top-level keys directly. + Copying onto self raises ``shutil.SameFileError``; callers must skip. """ try: - data = json.loads(path.read_text(encoding="utf-8")) - except (json.JSONDecodeError, OSError) as exc: - logger.warning("Failed to read LSP config %s: %s", path, exc) - return {} - if not isinstance(data, dict): - return {} - return dict(data) - + return src.resolve() == dst.resolve() + except OSError: + return False -def _lsp_servers_to_apm_deps(servers: dict[str, Any], plugin_path: Path) -> list[dict[str, Any]]: - """Convert raw LSP server configs to ``dependencies.lsp`` dicts. - Required fields per Claude Code spec: - - ``command``: binary to run - - ``extensionToLanguage``: mapping of file extensions to language IDs +def _copy_plugin_command_file( + source_file: Path, dest_dir: Path, rel_to: Path | None = None +) -> None: + """Copy a command file into *dest_dir*, normalising ``.md`` -> ``.prompt.md``.""" + if rel_to is not None: + relative_path = source_file.relative_to(rel_to) + target_path = dest_dir / relative_path + else: + target_path = dest_dir / source_file.name + if not source_file.name.endswith(".prompt.md") and source_file.suffix == ".md": + target_path = target_path.with_name(f"{source_file.stem}.prompt.md") + target_path.parent.mkdir(parents=True, exist_ok=True) + if _is_same_path(source_file, target_path): + return + import shutil - All resulting entries are routed through ``LSPDependency.from_dict()`` - for validation. Entries that fail validation are skipped with a warning. + shutil.copy2(source_file, target_path) - Args: - servers: Mapping of server name -> server config dict. - plugin_path: Plugin root (used for log context only). - Returns: - List of dicts consumable by ``LSPDependency.from_dict()``. - """ - from ..models.dependency.lsp import LSPDependency +def _map_plugin_agents(agent_sources: list[Path], apm_dir: Path) -> None: + """Copy agent sources into ``.apm/agents/``.""" + import shutil - logger = logging.getLogger("apm") - deps: list[dict[str, Any]] = [] + from apm_cli.security.gate import ignore_non_content - for name, cfg in servers.items(): - if not isinstance(cfg, dict): - logger.warning("Skipping non-dict LSP server config '%s'", name) + target_agents = apm_dir / "agents" + _assert_no_symlink_descendants(target_agents) + agent_dirs = [s for s in agent_sources if s.is_dir()] + agent_files = [s for s in agent_sources if s.is_file()] + for d in agent_dirs: + if _is_same_path(d, target_agents): continue + shutil.copytree(d, target_agents, dirs_exist_ok=True, ignore=ignore_non_content) + if agent_files: + target_agents.mkdir(parents=True, exist_ok=True) + for f in agent_files: + dst = target_agents / f.name + if not _is_same_path(f, dst): + shutil.copy2(f, dst) - dep: dict[str, Any] = {"name": name} - - # Copy all recognized fields - for key in ( - "command", - "args", - "extensionToLanguage", - "transport", - "env", - "initializationOptions", - "settings", - "workspaceFolder", - "startupTimeout", - "shutdownTimeout", - "restartOnCrash", - "maxRestarts", - ): - if key in cfg: - dep[key] = cfg[key] - # Route through the validation chokepoint - try: - LSPDependency.from_dict(dep) - except Exception as exc: - _surface_warning( - f"Skipping invalid LSP server '{name}' from plugin '{plugin_path.name}': {exc}", - logger, - ) - continue +def _map_plugin_skills(skill_sources: list[Path], apm_dir: Path, manifest: dict[str, Any]) -> None: + """Copy skill sources into ``.apm/skills/``.""" + import shutil - deps.append(dep) + from apm_cli.security.gate import ignore_non_content - return deps + target_skills = apm_dir / "skills" + _assert_no_symlink_descendants(target_skills) + skill_dirs = [s for s in skill_sources if s.is_dir()] + skill_files = [s for s in skill_sources if s.is_file()] + is_custom_list = isinstance(manifest.get("skills"), list) + if is_custom_list and skill_dirs: + target_skills.mkdir(parents=True, exist_ok=True) + for d in skill_dirs: + nested = target_skills / d.name + if not _is_same_path(d, nested): + shutil.copytree(d, nested, ignore=ignore_non_content, dirs_exist_ok=True) + elif skill_dirs: + for d in skill_dirs: + if not _is_same_path(d, target_skills): + shutil.copytree(d, target_skills, dirs_exist_ok=True, ignore=ignore_non_content) + if skill_files: + target_skills.mkdir(parents=True, exist_ok=True) + for f in skill_files: + dst = target_skills / f.name + if not _is_same_path(f, dst): + shutil.copy2(f, dst) -def _map_plugin_artifacts( - plugin_path: Path, apm_dir: Path, manifest: dict[str, Any] | None = None -) -> None: - """Map plugin artifacts to .apm/ subdirectories and copy pass-through files. - - Copies: - - agents/ -> .apm/agents/ - - skills/ -> .apm/skills/ - - commands/ -> .apm/prompts/ (*.md normalized to *.prompt.md) - - hooks/ -> .apm/hooks/ (directory, config file, or inline object) - - .mcp.json -> .apm/.mcp.json (MCP-based plugins need this to function) - - .lsp.json -> .apm/.lsp.json - - settings.json -> .apm/settings.json +def _map_plugin_commands(command_sources: list[Path], apm_dir: Path) -> None: + """Copy command sources into ``.apm/prompts/``, normalising ``.md`` -> ``.prompt.md``.""" + target_prompts = apm_dir / "prompts" + _assert_no_symlink_descendants(target_prompts) + target_prompts.mkdir(parents=True, exist_ok=True) + for source in command_sources: + if source.is_file() and not source.is_symlink(): + _copy_plugin_command_file(source, target_prompts) + elif source.is_dir(): + for source_file in source.rglob("*"): + if not source_file.is_file() or source_file.is_symlink(): + continue + _copy_plugin_command_file(source_file, target_prompts, rel_to=source) - When the manifest specifies custom component paths (e.g. ``"agents": ["custom/"]``), - those paths are used instead of the defaults. - Symlinks are skipped entirely to prevent content exfiltration attacks. +def _map_plugin_hooks(manifest: dict[str, Any], plugin_path: Path, apm_dir: Path) -> None: + """Map hooks into ``.apm/hooks/``. - Args: - plugin_path: Root of the plugin directory. - apm_dir: Path to the .apm/ directory. - manifest: Optional plugin.json metadata; used for custom component paths. + The spec allows a directory path, a config file path, or an inline + object. All three forms are handled. """ - if manifest is None: - manifest = {} + import json + import shutil from apm_cli.security.gate import ignore_non_content - # Resolve source paths -- use manifest arrays if present, else defaults. - # Custom paths may be directories OR individual files. - # - # Security: every manifest-controlled path is verified to resolve - # inside *plugin_path* before it is copied. Without this guard, a - # malicious plugin could set ``"commands": "/etc/passwd"`` or - # ``"agents": ["../../host"]`` and trick ``apm install`` into copying - # arbitrary host files into the project's ``.apm/`` tree (and from - # there into ``.github/prompts/`` via auto-integration). - def _resolve_sources(component: str, default_dir: str): - """Return list of existing source paths (dirs or files) for a component.""" - custom = manifest.get(component) - if isinstance(custom, list): - paths = [] - for p in custom: - raw = str(p) - src = plugin_path / raw - if ( - src.exists() - and not src.is_symlink() - and _is_within_plugin(src, plugin_path, component=component) - ): - paths.append(src) - return paths - elif isinstance(custom, str): - src = plugin_path / custom - if ( - src.exists() - and not src.is_symlink() - and _is_within_plugin(src, plugin_path, component=component) - ): - return [src] - return [] - default = plugin_path / default_dir - if ( - default.exists() - and not default.is_symlink() - and default.is_dir() - and _is_within_plugin(default, plugin_path, component=component) - ): - return [default] - return [] - - # Helper: True when *src* and *dst* resolve to the same filesystem path - # (e.g. a manifest entry pointing at a file already inside the target). - # Copying onto self raises ``shutil.SameFileError`` and ``shutil.copytree`` - # over identical directories triggers it per-file, so callers must skip. - def _is_same_path(src: Path, dst: Path) -> bool: - try: - return src.resolve() == dst.resolve() - except OSError: - return False - - # Map agents/ - # Unlike skills (which are named directories containing SKILL.md), agents - # are flat files -- each .md is one agent. So we always merge directory - # contents directly into .apm/agents/ (no nesting by dir name). - agent_sources = _resolve_sources("agents", "agents") - if agent_sources: - target_agents = apm_dir / "agents" - _assert_no_symlink_descendants(target_agents) - agent_dirs = [s for s in agent_sources if s.is_dir()] - agent_files = [s for s in agent_sources if s.is_file()] - for d in agent_dirs: - if _is_same_path(d, target_agents): - continue - shutil.copytree(d, target_agents, dirs_exist_ok=True, ignore=ignore_non_content) - if agent_files: - target_agents.mkdir(parents=True, exist_ok=True) - for f in agent_files: - dst = target_agents / f.name - if _is_same_path(f, dst): - continue - shutil.copy2(f, dst) - - # Map skills/ - skill_sources = _resolve_sources("skills", "skills") - if skill_sources: - target_skills = apm_dir / "skills" - _assert_no_symlink_descendants(target_skills) - skill_dirs = [s for s in skill_sources if s.is_dir()] - skill_files = [s for s in skill_sources if s.is_file()] - - is_custom_list = isinstance(manifest.get("skills"), list) - if is_custom_list and skill_dirs: - target_skills.mkdir(parents=True, exist_ok=True) - for d in skill_dirs: - nested = target_skills / d.name - if _is_same_path(d, nested): - continue - shutil.copytree( - d, - nested, - ignore=ignore_non_content, - dirs_exist_ok=True, - ) - elif skill_dirs: - for d in skill_dirs: - if _is_same_path(d, target_skills): - continue - shutil.copytree(d, target_skills, dirs_exist_ok=True, ignore=ignore_non_content) - if skill_files: - target_skills.mkdir(parents=True, exist_ok=True) - for f in skill_files: - dst = target_skills / f.name - if _is_same_path(f, dst): - continue - shutil.copy2(f, dst) - - # Map commands/ -> .apm/prompts/ (normalize .md -> .prompt.md) - command_sources = _resolve_sources("commands", "commands") - if command_sources: - target_prompts = apm_dir / "prompts" - _assert_no_symlink_descendants(target_prompts) - target_prompts.mkdir(parents=True, exist_ok=True) - - def _copy_command_file(source_file: Path, dest_dir: Path, rel_to: Path = None): # noqa: RUF013 - """Copy a command file, normalizing .md -> .prompt.md.""" - if rel_to: - relative_path = source_file.relative_to(rel_to) - target_path = dest_dir / relative_path - else: - target_path = dest_dir / source_file.name - if not source_file.name.endswith(".prompt.md") and source_file.suffix == ".md": - target_path = target_path.with_name(f"{source_file.stem}.prompt.md") - target_path.parent.mkdir(parents=True, exist_ok=True) - if _is_same_path(source_file, target_path): - return - shutil.copy2(source_file, target_path) - - for source in command_sources: - if source.is_file() and not source.is_symlink(): - _copy_command_file(source, target_prompts) - elif source.is_dir(): - for source_file in source.rglob("*"): - if not source_file.is_file() or source_file.is_symlink(): - continue - _copy_command_file(source_file, target_prompts, rel_to=source) - - # Map hooks/ -- the spec allows a directory path, a config file path, - # or an inline object. Handle all three forms. hooks_value = manifest.get("hooks") if isinstance(hooks_value, dict): # Inline hooks object -> write as .apm/hooks/hooks.json @@ -737,9 +407,9 @@ def _copy_command_file(source_file: Path, dest_dir: Path, rel_to: Path = None): elif isinstance(hooks_value, str) and (plugin_path / hooks_value).is_file(): # Config file path (e.g. "hooks": "hooks.json") src_file = plugin_path / hooks_value - if src_file.is_symlink() or not _is_within_plugin(src_file, plugin_path, component="hooks"): - pass - else: + if not src_file.is_symlink() and _is_within_plugin( + src_file, plugin_path, component="hooks" + ): target_hooks = apm_dir / "hooks" _assert_no_symlink_descendants(target_hooks) target_hooks.mkdir(parents=True, exist_ok=True) @@ -748,16 +418,19 @@ def _copy_command_file(source_file: Path, dest_dir: Path, rel_to: Path = None): shutil.copy2(src_file, dst) else: # Directory path(s) -- standard flow - hook_sources = _resolve_sources("hooks", "hooks") + hook_sources = _resolve_plugin_sources(plugin_path, manifest, "hooks", "hooks") if hook_sources: target_hooks = apm_dir / "hooks" _assert_no_symlink_descendants(target_hooks) for d in hook_sources: - if _is_same_path(d, target_hooks): - continue - shutil.copytree(d, target_hooks, dirs_exist_ok=True, ignore=ignore_non_content) + if not _is_same_path(d, target_hooks): + shutil.copytree(d, target_hooks, dirs_exist_ok=True, ignore=ignore_non_content) + + +def _copy_plugin_passthrough_files(plugin_path: Path, apm_dir: Path) -> None: + """Copy ``.mcp.json``, ``.lsp.json``, and ``settings.json`` into *apm_dir*.""" + import shutil - # Pass-through files required for MCP/LSP plugins to function for passthrough in (".mcp.json", ".lsp.json", "settings.json"): source_file = plugin_path / passthrough if source_file.exists() and not source_file.is_symlink(): @@ -770,6 +443,49 @@ def _copy_command_file(source_file: Path, dest_dir: Path, rel_to: Path = None): shutil.copy2(source_file, dst) +def _map_plugin_artifacts( + plugin_path: Path, apm_dir: Path, manifest: dict[str, Any] | None = None +) -> None: + """Map plugin artifacts to .apm/ subdirectories and copy pass-through files. + + Copies: + - agents/ -> .apm/agents/ + - skills/ -> .apm/skills/ + - commands/ -> .apm/prompts/ (*.md normalised to *.prompt.md) + - hooks/ -> .apm/hooks/ (directory, config file, or inline object) + - .mcp.json -> .apm/.mcp.json + - .lsp.json -> .apm/.lsp.json + - settings.json -> .apm/settings.json + + Symlinks are skipped entirely to prevent content exfiltration attacks. + Custom component paths from the manifest are security-validated to + resolve inside *plugin_path* before any copy is attempted. + + Args: + plugin_path: Root of the plugin directory. + apm_dir: Path to the .apm/ directory. + manifest: Optional plugin.json metadata; used for custom component paths. + """ + if manifest is None: + manifest = {} + + agent_sources = _resolve_plugin_sources(plugin_path, manifest, "agents", "agents") + if agent_sources: + _map_plugin_agents(agent_sources, apm_dir) + + skill_sources = _resolve_plugin_sources(plugin_path, manifest, "skills", "skills") + if skill_sources: + _map_plugin_skills(skill_sources, apm_dir, manifest) + + command_sources = _resolve_plugin_sources(plugin_path, manifest, "commands", "commands") + if command_sources: + _map_plugin_commands(command_sources, apm_dir) + + _map_plugin_hooks(manifest, plugin_path, apm_dir) + + _copy_plugin_passthrough_files(plugin_path, apm_dir) + + def _generate_apm_yml(manifest: dict[str, Any]) -> str: """Generate apm.yml content from plugin metadata. diff --git a/src/apm_cli/deps/plugin_server_helpers.py b/src/apm_cli/deps/plugin_server_helpers.py new file mode 100644 index 000000000..c2740cfb2 --- /dev/null +++ b/src/apm_cli/deps/plugin_server_helpers.py @@ -0,0 +1,358 @@ +"""MCP and LSP server-extraction helpers for plugin_parser. + +Extracted from :mod:`plugin_parser` to keep that module under the +file-length guardrail. All public names are re-exported from +``plugin_parser`` so import paths are unchanged. +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + + +def _read_mcp_file(plugin_path: Path, rel_path: str, logger: logging.Logger) -> dict[str, Any]: + """Read a JSON file relative to *plugin_path* and return its ``mcpServers`` dict.""" + target = (plugin_path / rel_path).resolve() + # Security: must stay inside plugin_path and not be a symlink + try: + target.relative_to(plugin_path.resolve()) + except ValueError: + logger.warning("MCP file path escapes plugin root: %s", rel_path) + return {} + candidate = plugin_path / rel_path + if not candidate.exists() or not candidate.is_file(): + logger.warning("MCP file not found: %s", candidate) + return {} + if candidate.is_symlink(): + logger.warning("Skipping symlinked MCP file: %s", candidate) + return {} + return _read_mcp_json(candidate, logger) + + +def _read_mcp_json(path: Path, logger: logging.Logger) -> dict[str, Any]: + """Parse a JSON file and return the ``mcpServers`` mapping.""" + try: + data = json.loads(path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Failed to read MCP config %s: %s", path, exc) + return {} + if not isinstance(data, dict): + return {} + servers = data.get("mcpServers", {}) + return dict(servers) if isinstance(servers, dict) else {} + + +def _substitute_plugin_root( + servers: dict[str, Any], abs_root: str, logger: logging.Logger +) -> dict[str, Any]: + """Replace ``${CLAUDE_PLUGIN_ROOT}`` in server config string values.""" + placeholder = "${CLAUDE_PLUGIN_ROOT}" + substituted = False + + def _walk(obj: Any) -> Any: + nonlocal substituted + if isinstance(obj, str) and placeholder in obj: + substituted = True + return obj.replace(placeholder, abs_root) + if isinstance(obj, dict): + return {k: _walk(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_walk(item) for item in obj] + return obj + + result = {name: _walk(cfg) for name, cfg in servers.items()} + if substituted: + logger.info("Substituted ${CLAUDE_PLUGIN_ROOT} with %s", abs_root) + return result + + +def _extract_mcp_servers(plugin_path: Path, manifest: dict[str, Any]) -> dict[str, Any]: + """Extract MCP server definitions from a plugin manifest. + + Resolves ``mcpServers`` by type (per Claude Code spec): + - ``str`` -> read that file path relative to plugin root, parse JSON, + extract ``mcpServers`` key. + - ``list`` -> read each file path, merge (last-wins on name conflict). + - ``dict`` -> use directly as inline server definitions. + + When ``mcpServers`` is absent and ``.mcp.json`` (or ``.github/.mcp.json``) + exists at plugin root, read it as the default (matches Claude Code + auto-discovery). + + Security: symlinks are skipped, JSON parse errors are logged as warnings. + + ``${CLAUDE_PLUGIN_ROOT}`` in string values is replaced with the absolute + plugin path. + + Args: + plugin_path: Root of the plugin directory. + manifest: Parsed plugin.json dict. + + Returns: + dict mapping server name -> server config. Empty on failure. + """ + logger = logging.getLogger("apm") + mcp_value = manifest.get("mcpServers") + + if mcp_value is not None: + # Manifest explicitly defines mcpServers + if isinstance(mcp_value, dict): + servers = dict(mcp_value) + elif isinstance(mcp_value, str): + servers = _read_mcp_file(plugin_path, mcp_value, logger) + elif isinstance(mcp_value, list): + servers = {} + for entry in mcp_value: + if isinstance(entry, str): + servers.update(_read_mcp_file(plugin_path, entry, logger)) + else: + logger.warning("Ignoring non-string entry in mcpServers array: %s", entry) + else: + logger.warning("Unsupported mcpServers type %s; ignoring", type(mcp_value).__name__) + return {} + else: + # Fall back to auto-discovery: .mcp.json then .github/.mcp.json + servers = {} + for fallback in (".mcp.json", ".github/.mcp.json"): + candidate = plugin_path / fallback + if candidate.exists() and candidate.is_file() and not candidate.is_symlink(): + servers = _read_mcp_json(candidate, logger) + if servers: + break + + # Substitute ${CLAUDE_PLUGIN_ROOT} in all string values + if servers: + abs_root = str(plugin_path.resolve()) + servers = _substitute_plugin_root(servers, abs_root, logger) + + return servers + + +def _mcp_servers_to_apm_deps(servers: dict[str, Any], plugin_path: Path) -> list[dict[str, Any]]: + """Convert raw MCP server configs to ``dependencies.mcp`` dicts. + + Transport inference: + - ``command`` present -> stdio + - ``url`` present -> http (or ``type`` if it's a valid transport) + - Neither -> skipped with warning + + Every entry gets ``registry: false`` (self-defined, not registry lookups). + + All resulting entries are routed through ``MCPDependency.from_dict()`` + so plugin-synthesised servers must clear the same security validation + chokepoint as CLI-authored or manually edited entries (name shape, URL + scheme allowlist, header CRLF, command path-traversal). Entries that + fail validation are skipped with a warning rather than crashing the + plugin install -- a single malformed server should not block the + whole plugin. + + Args: + servers: Mapping of server name -> server config dict. + plugin_path: Plugin root (used for log context only). + + Returns: + List of dicts consumable by ``MCPDependency.from_dict()``. + """ + from ..models.dependency.mcp import MCPDependency + from .plugin_parser import _surface_warning + + logger = logging.getLogger("apm") + deps: list[dict[str, Any]] = [] + + for name, cfg in servers.items(): + if not isinstance(cfg, dict): + logger.warning("Skipping non-dict MCP server config '%s'", name) + continue + + dep: dict[str, Any] = {"name": name, "registry": False} + + if "command" in cfg: + dep["transport"] = "stdio" + dep["command"] = cfg["command"] + if "args" in cfg: + dep["args"] = cfg["args"] + elif "url" in cfg: + raw_type = cfg.get("type", "http") + valid_transports = {"http", "sse", "streamable-http"} + dep["transport"] = raw_type if raw_type in valid_transports else "http" + dep["url"] = cfg["url"] + if "headers" in cfg: + dep["headers"] = cfg["headers"] + else: + _surface_warning( + f"Skipping MCP server '{name}' from plugin " + f"'{plugin_path.name}': no 'command' or 'url'", + logger, + ) + continue + + if "env" in cfg: + dep["env"] = cfg["env"] + if "tools" in cfg: + dep["tools"] = cfg["tools"] + + # Route through the validation chokepoint. Plugins are an ingress + # path: a malicious plugin could otherwise smuggle path traversal, + # CRLF, or unsafe URL schemes that bypass MCPDependency.validate(). + # PR #809 follow-up: surface validation errors to the user via the + # rich console (stdlib logger has no handlers configured). + try: + MCPDependency.from_dict(dep) + except (ValueError, Exception) as exc: + _surface_warning( + f"Skipping invalid MCP server '{name}' from plugin '{plugin_path.name}': {exc}", + logger, + ) + continue + + deps.append(dep) + + return deps + + +def _read_lsp_file(plugin_path: Path, rel_path: str, logger: logging.Logger) -> dict[str, Any]: + """Read a JSON file relative to *plugin_path* and return its LSP server dict.""" + target = (plugin_path / rel_path).resolve() + try: + target.relative_to(plugin_path.resolve()) + except ValueError: + logger.warning("LSP file path escapes plugin root: %s", rel_path) + return {} + candidate = plugin_path / rel_path + if not candidate.exists() or not candidate.is_file(): + logger.warning("LSP file not found: %s", candidate) + return {} + if candidate.is_symlink(): + logger.warning("Skipping symlinked LSP file: %s", candidate) + return {} + return _read_lsp_json(candidate, logger) + + +def _read_lsp_json(path: Path, logger: logging.Logger) -> dict[str, Any]: + """Parse a JSON file and return the LSP servers mapping. + + Unlike .mcp.json which has a wrapper key, .lsp.json uses server names + as top-level keys directly. + """ + try: + data = json.loads(path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Failed to read LSP config %s: %s", path, exc) + return {} + if not isinstance(data, dict): + return {} + return dict(data) + + +def _extract_lsp_servers(plugin_path: Path, manifest: dict[str, Any]) -> dict[str, Any]: + """Extract LSP server definitions from a plugin manifest. + + Resolves ``lspServers`` by type (per Claude Code spec): + - ``str`` -> read that file path relative to plugin root, parse JSON. + - ``dict`` -> use directly as inline server definitions. + + When ``lspServers`` is absent and ``.lsp.json`` exists at plugin root, + read it as the default (matches Claude Code auto-discovery). + + Security: symlinks are skipped, JSON parse errors are logged as warnings. + + ``${CLAUDE_PLUGIN_ROOT}`` in string values is replaced with the absolute + plugin path. + + Args: + plugin_path: Root of the plugin directory. + manifest: Parsed plugin.json dict. + + Returns: + dict mapping server name -> server config. Empty on failure. + """ + logger = logging.getLogger("apm") + lsp_value = manifest.get("lspServers") + + if lsp_value is not None: + if isinstance(lsp_value, dict): + servers = dict(lsp_value) + elif isinstance(lsp_value, str): + servers = _read_lsp_file(plugin_path, lsp_value, logger) + else: + logger.warning("Unsupported lspServers type %s; ignoring", type(lsp_value).__name__) + return {} + else: + # Fall back to auto-discovery: .lsp.json + servers = {} + candidate = plugin_path / ".lsp.json" + if candidate.exists() and candidate.is_file() and not candidate.is_symlink(): + servers = _read_lsp_json(candidate, logger) + + # Substitute ${CLAUDE_PLUGIN_ROOT} in all string values + if servers: + abs_root = str(plugin_path.resolve()) + servers = _substitute_plugin_root(servers, abs_root, logger) + + return servers + + +def _lsp_servers_to_apm_deps(servers: dict[str, Any], plugin_path: Path) -> list[dict[str, Any]]: + """Convert raw LSP server configs to ``dependencies.lsp`` dicts. + + Required fields per Claude Code spec: + - ``command``: binary to run + - ``extensionToLanguage``: mapping of file extensions to language IDs + + All resulting entries are routed through ``LSPDependency.from_dict()`` + for validation. Entries that fail validation are skipped with a warning. + + Args: + servers: Mapping of server name -> server config dict. + plugin_path: Plugin root (used for log context only). + + Returns: + List of dicts consumable by ``LSPDependency.from_dict()``. + """ + from ..models.dependency.lsp import LSPDependency + from .plugin_parser import _surface_warning + + logger = logging.getLogger("apm") + deps: list[dict[str, Any]] = [] + + for name, cfg in servers.items(): + if not isinstance(cfg, dict): + logger.warning("Skipping non-dict LSP server config '%s'", name) + continue + + dep: dict[str, Any] = {"name": name} + + # Copy all recognised fields + for key in ( + "command", + "args", + "extensionToLanguage", + "transport", + "env", + "initializationOptions", + "settings", + "workspaceFolder", + "startupTimeout", + "shutdownTimeout", + "restartOnCrash", + "maxRestarts", + ): + if key in cfg: + dep[key] = cfg[key] + + # Route through the validation chokepoint + try: + LSPDependency.from_dict(dep) + except Exception as exc: + _surface_warning( + f"Skipping invalid LSP server '{name}' from plugin '{plugin_path.name}': {exc}", + logger, + ) + continue + + deps.append(dep) + + return deps diff --git a/src/apm_cli/deps/registry/outdated.py b/src/apm_cli/deps/registry/outdated.py index 59e6c3777..e3fb0064d 100644 --- a/src/apm_cli/deps/registry/outdated.py +++ b/src/apm_cli/deps/registry/outdated.py @@ -163,22 +163,16 @@ def check_registry_locked_dep( package_name = locked.get_unique_key() current = locked.version or "" - if ctx is None: + # Combine two early-exit guards (no context / feature disabled) into one + # return statement to keep PLR0911 satisfied. + if ctx is None or not is_package_registry_enabled(): + source = "registry" if ctx is None else "registry (feature disabled)" return OutdatedRow( package=package_name, current=current or "(none)", latest="-", status="unknown", - source="registry", - ) - - if not is_package_registry_enabled(): - return OutdatedRow( - package=package_name, - current=current or "(none)", - latest="-", - status="unknown", - source="registry (feature disabled)", + source=source, ) manifest_dep = ctx.manifest_index.get(package_name) @@ -205,23 +199,18 @@ def check_registry_locked_dep( ) registry_name = (manifest_dep.registry_name if manifest_dep else None) or ctx.default_registry - if not registry_name: - return OutdatedRow( - package=package_name, - current=current, - latest="-", - status="unknown", - source="registry (no default registry)", + base_url = ctx.registries.get(registry_name) if registry_name else None + # Combine registry-name-missing and base-url-missing into one return. + if not registry_name or not base_url: + source_detail = ( + "no default registry" if not registry_name else f"{registry_name!r} not configured" ) - - base_url = ctx.registries.get(registry_name) - if not base_url: return OutdatedRow( package=package_name, current=current, latest="-", status="unknown", - source=f"registry ({registry_name!r} not configured)", + source=f"registry ({source_detail})", ) source_label = f"registry: {registry_name}" From 2ff44ee513b084b9a49026d53a96d1d3473370cb Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 05:32:35 +0200 Subject: [PATCH 15/21] refactor(deps): split github_downloader/download_strategies under 800-line budget Strangler Stage 2 (#1078), Commit 3b: bring the two largest deps files under the 800-line guardrail and the final complexity thresholds while preserving the heavy test monkeypatch surface. File-length offenders cleared via delegating wrappers + ops siblings (every new module is <=525 lines, well under budget): - github_downloader.py 1806 -> 737; bodies moved to github_downloader_setup_ops.py (247), github_downloader_package_ops.py (525), github_downloader_subdir_ops.py (448) - download_strategies.py 1123 -> 587; bodies moved to download_strategies_ops.py (379), download_strategies_backends_ops.py (416); DownloadDelegate class stays in download_strategies.py Complexity offenders decomposed in their new homes (genuine cohesive helpers, distinct paths kept separate -- NOT collapsed): - download_subdirectory_package C901/PLR0912/PLR0915 -> per-cache-tier helpers (persistent / shared-bare / legacy-clone) + extract/build/log - download_github_file C901/PLR0912 -> per-step + per-backend helpers Patch-safety: thin class wrappers keep patched method names; moved bodies route patched module globals via function-level _gh./_ds. aliases (Rule B); inter-ops calls route through the class wrappers so no ops module imports another at module scope (no cycles). No **kwargs gaming, no noqa, no threshold edits. backlog 26 -> 24; R0801 10.00/10; 16645 unit+acceptance pass; 3030 targeted integration pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/deps/download_strategies.py | 582 +-------- .../deps/download_strategies_backends_ops.py | 416 ++++++ src/apm_cli/deps/download_strategies_ops.py | 379 ++++++ src/apm_cli/deps/github_downloader.py | 1163 +---------------- .../deps/github_downloader_package_ops.py | 525 ++++++++ .../deps/github_downloader_setup_ops.py | 247 ++++ .../deps/github_downloader_subdir_ops.py | 448 +++++++ 7 files changed, 2085 insertions(+), 1675 deletions(-) create mode 100644 src/apm_cli/deps/download_strategies_backends_ops.py create mode 100644 src/apm_cli/deps/download_strategies_ops.py create mode 100644 src/apm_cli/deps/github_downloader_package_ops.py create mode 100644 src/apm_cli/deps/github_downloader_setup_ops.py create mode 100644 src/apm_cli/deps/github_downloader_subdir_ops.py diff --git a/src/apm_cli/deps/download_strategies.py b/src/apm_cli/deps/download_strategies.py index 239f17561..1e61ecfb2 100644 --- a/src/apm_cli/deps/download_strategies.py +++ b/src/apm_cli/deps/download_strategies.py @@ -14,15 +14,20 @@ import sys import time from pathlib import Path -from urllib.parse import quote +from urllib.parse import quote as quote import requests -from ..core.auth import AuthResolver, HostInfo +from ..core.auth import AuthResolver as AuthResolver +from ..core.auth import HostInfo from ..models.apm_package import DependencyReference from ..utils.github_host import ( - build_ado_api_url, - build_artifactory_archive_url, + build_ado_api_url as build_ado_api_url, +) +from ..utils.github_host import ( + build_artifactory_archive_url as build_artifactory_archive_url, +) +from ..utils.github_host import ( build_https_clone_url, build_raw_content_url, build_ssh_url, @@ -301,83 +306,10 @@ def download_artifactory_archive( target_path: Path, scheme: str = "https", ) -> None: - """Download and extract a zip archive from Artifactory VCS proxy. - - Tries multiple URL patterns (GitHub-style and GitLab-style). - GitHub archives contain a single root directory named {repo}-{ref}/; - this method strips that prefix on extraction so files land directly - in *target_path*. + """Download and extract a zip archive from Artifactory VCS proxy.""" + from .download_strategies_backends_ops import download_artifactory_archive as _impl - Raises RuntimeError on failure. - """ - import io - import zipfile - - archive_urls = build_artifactory_archive_url(host, prefix, owner, repo, ref, scheme=scheme) - headers = self.get_artifactory_headers() - - # Guard: reject unreasonably large archives (default 500 MB) - max_archive_bytes = int(os.environ.get("ARTIFACTORY_MAX_ARCHIVE_MB", "500")) * 1024 * 1024 - - last_error = None - for url in archive_urls: - _debug(f"Trying Artifactory archive: {url}") - try: - resp = self._host._resilient_get(url, headers=headers, timeout=60) - if resp.status_code == 200: - if len(resp.content) > max_archive_bytes: - last_error = f"Archive too large ({len(resp.content)} bytes) from {url}" - _debug(last_error) - continue - # Extract zip, stripping the top-level directory - target_path.mkdir(parents=True, exist_ok=True) - with zipfile.ZipFile(io.BytesIO(resp.content)) as zf: - # Identify the root prefix (e.g., "repo-main/") - names = zf.namelist() - if not names: - raise RuntimeError(f"Empty archive from {url}") - root_prefix = names[0] - if not root_prefix.endswith("/"): - # Single file archive; extract as-is - zf.extractall(target_path) - return - for member in zf.infolist(): - # Strip root prefix - if member.filename == root_prefix: - continue - rel = member.filename[len(root_prefix) :] - if not rel: - continue - # Guard: prevent zip path traversal (CWE-22) - dest = target_path / rel - if not dest.resolve().is_relative_to(target_path.resolve()): - _debug(f"Skipping zip entry escaping target: {member.filename}") - continue - unix_mode = (member.external_attr >> 16) & 0xFFFF - if member.is_dir(): - dest.mkdir(parents=True, exist_ok=True) - else: - dest.parent.mkdir(parents=True, exist_ok=True) - with zf.open(member) as src, open(dest, "wb") as dst: - dst.write(src.read()) - if unix_mode: - os.chmod(dest, unix_mode & 0o755) - _debug(f"Extracted Artifactory archive to {target_path}") - return - else: - last_error = f"HTTP {resp.status_code} from {url}" - _debug(last_error) - except zipfile.BadZipFile: - last_error = f"Invalid zip archive from {url}" - _debug(last_error) - except requests.RequestException as e: - last_error = str(e) - _debug(f"Request failed: {last_error}") - - raise RuntimeError( - f"Failed to download package {owner}/{repo}#{ref} from Artifactory " - f"({host}/{prefix}). Last error: {last_error}" - ) + return _impl(self, host, prefix, owner, repo, ref, target_path, scheme) def download_file_from_artifactory( self, @@ -389,70 +321,10 @@ def download_file_from_artifactory( ref: str, scheme: str = "https", ) -> bytes: - """Download a single file from Artifactory. - - Tries the Archive Entry Download API first (fetches one file - without downloading the full archive). Falls back to the full - archive approach when the entry API is unavailable or returns an - error. - """ - # Fast path: use the RegistryClient interface for entry download - cfg = self._host.registry_config - if cfg is not None and cfg.host == host: - client = cfg.get_client() - content = client.fetch_file( - owner, - repo, - file_path, - ref, - resilient_get=self._host._resilient_get, - ) - else: - # No RegistryConfig or host mismatch (explicit FQDN mode) -- - # fall back to the standalone helper. - from .artifactory_entry import fetch_entry_from_archive - - content = fetch_entry_from_archive( - host, - prefix, - owner, - repo, - file_path, - ref, - scheme=scheme, - headers=self.get_artifactory_headers(), - resilient_get=self._host._resilient_get, - ) - if content is not None: - return content - - # Fallback: download full archive and extract the file - import io - import zipfile + """Download a single file from Artifactory (entry API, then archive).""" + from .download_strategies_backends_ops import download_file_from_artifactory as _impl - archive_urls = build_artifactory_archive_url(host, prefix, owner, repo, ref, scheme=scheme) - headers = self.get_artifactory_headers() - - for url in archive_urls: - try: - resp = self._host._resilient_get(url, headers=headers, timeout=60) - if resp.status_code != 200: - continue - with zipfile.ZipFile(io.BytesIO(resp.content)) as zf: - names = zf.namelist() - root_prefix = names[0] if names else "" - target_name = root_prefix + file_path - if target_name in names: - return zf.read(target_name) - if file_path in names: - return zf.read(file_path) - except (zipfile.BadZipFile, requests.RequestException): - continue - - raise RuntimeError( - f"Failed to download file '{file_path}' from Artifactory " - f"({host}/{prefix}/{owner}/{repo}#{ref})" - ) + return _impl(self, host, prefix, owner, repo, file_path, ref, scheme) # ------------------------------------------------------------------ # Raw / CDN download helper @@ -485,94 +357,10 @@ def download_ado_file( file_path: str, ref: str = "main", ) -> bytes: - """Download a file from Azure DevOps repository. - - Args: - dep_ref: Parsed dependency reference with ADO-specific fields - file_path: Path to file within the repository - ref: Git reference (branch, tag, or commit SHA) - - Returns: - bytes: File content - """ - import base64 - - # Validate required ADO fields before proceeding - if not all([dep_ref.ado_organization, dep_ref.ado_project, dep_ref.ado_repo]): - raise ValueError( - "Invalid Azure DevOps dependency reference: missing " - "organization, project, or repo. " - f"Got: org={dep_ref.ado_organization}, " - f"project={dep_ref.ado_project}, repo={dep_ref.ado_repo}" - ) - - host = dep_ref.host or "dev.azure.com" - api_url = build_ado_api_url( - dep_ref.ado_organization, - dep_ref.ado_project, - dep_ref.ado_repo, - file_path, - ref, - host, - ) + """Download a file from an Azure DevOps repository.""" + from .download_strategies_backends_ops import download_ado_file as _impl - # Set up authentication headers - ADO uses Basic auth with PAT - headers: dict[str, str] = {} - if self._host.ado_token: - # ADO uses Basic auth: username can be empty, password is the PAT - auth = base64.b64encode(f":{self._host.ado_token}".encode()).decode() - headers["Authorization"] = f"Basic {auth}" - - try: - response = self._host._resilient_get(api_url, headers=headers, timeout=30) - response.raise_for_status() - return response.content - except requests.exceptions.HTTPError as e: - if e.response.status_code == 404: - # Try fallback branches - if ref not in ["main", "master"]: - raise RuntimeError( - f"File not found: {file_path} at ref '{ref}' in {dep_ref.repo_url}" - ) from e - - fallback_ref = "master" if ref == "main" else "main" - fallback_url = build_ado_api_url( - dep_ref.ado_organization, - dep_ref.ado_project, - dep_ref.ado_repo, - file_path, - fallback_ref, - host, - ) - - try: - response = self._host._resilient_get(fallback_url, headers=headers, timeout=30) - response.raise_for_status() - return response.content - except requests.exceptions.HTTPError as fallback_err: - raise RuntimeError( - f"File not found: {file_path} in {dep_ref.repo_url} " - f"(tried refs: {ref}, {fallback_ref})" - ) from fallback_err - elif e.response.status_code in (401, 403): - error_msg = f"Authentication failed for Azure DevOps {dep_ref.repo_url}. " - if not self._host.ado_token: - error_msg += self._host.auth_resolver.build_error_context( - host, - "download", - org=dep_ref.ado_organization if dep_ref else None, - port=dep_ref.port if dep_ref else None, - dep_url=dep_ref.repo_url if dep_ref else None, - ) - else: - error_msg += "Please check your Azure DevOps PAT permissions." - raise RuntimeError(error_msg) from e - else: - raise RuntimeError( - f"Failed to download {file_path}: HTTP {e.response.status_code}" - ) from e - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Network error downloading {file_path}: {e}") from e + return _impl(self, dep_ref, file_path, ref) # ------------------------------------------------------------------ # GitLab file download @@ -586,75 +374,9 @@ def download_gitlab_file( verbose_callback=None, ) -> bytes: """Download a file via GitLab REST v4 ``repository/files/.../raw``.""" - host = dep_ref.host or default_host() - host_info = self._host.auth_resolver.classify_host(host) - project_path = dep_ref.repo_url - if not project_path: - raise RuntimeError("Missing repository path for GitLab file download") - - org = project_path.split("/")[0] - file_ctx = self._host.auth_resolver.resolve(host, org, port=dep_ref.port) - token = file_ctx.token - headers = AuthResolver.gitlab_rest_headers(token) - - api_base = host_info.api_base.rstrip("/") - enc_proj = quote(project_path, safe="") - enc_file = quote(file_path, safe="") - - def _raw_url(r: str) -> str: - return ( - f"{api_base}/projects/{enc_proj}/repository/files/{enc_file}/raw" - f"?ref={quote(r, safe='')}" - ) + from .download_strategies_backends_ops import download_gitlab_file as _impl - api_url = _raw_url(ref) - - try: - response = self._host._resilient_get(api_url, headers=headers, timeout=30) - response.raise_for_status() - if verbose_callback: - verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") - return response.content - except requests.exceptions.HTTPError as e: - if e.response is not None and e.response.status_code == 404: - if ref not in ("main", "master"): - raise RuntimeError( - f"File not found: {file_path} at ref '{ref}' in {dep_ref.repo_url}" - ) from e - fallback_ref = "master" if ref == "main" else "main" - fallback_url = _raw_url(fallback_ref) - try: - response = self._host._resilient_get(fallback_url, headers=headers, timeout=30) - response.raise_for_status() - if verbose_callback: - verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") - return response.content - except requests.exceptions.HTTPError as fallback_err: - raise RuntimeError( - f"File not found: {file_path} in {dep_ref.repo_url} " - f"(tried refs: {ref}, {fallback_ref})" - ) from fallback_err - if e.response is not None and e.response.status_code in (401, 403): - error_msg = ( - f"Authentication failed for GitLab {dep_ref.repo_url} " - f"(file: {file_path}, ref: {ref}). " - ) - if not token: - error_msg += self._host.auth_resolver.build_error_context( - host, "download", org=org, port=dep_ref.port - ) - else: - error_msg += ( - "Please verify your token can read this project (required API scope)." - ) - raise RuntimeError(error_msg) from e - if e.response is not None: - raise RuntimeError( - f"Failed to download {file_path}: HTTP {e.response.status_code}" - ) from e - raise - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Network error downloading {file_path}: {e}") from e + return _impl(self, dep_ref, file_path, ref, verbose_callback) # ------------------------------------------------------------------ # GitHub file download @@ -667,268 +389,10 @@ def download_github_file( ref: str = "main", verbose_callback=None, ) -> bytes: - """Download a file from GitHub repository. - - For github.com without a token, tries raw.githubusercontent.com first - (CDN, no rate limit) before falling back to the Contents API. - Authenticated requests and non-github.com hosts always use the - Contents API directly. - - Args: - dep_ref: Parsed dependency reference - file_path: Path to file within the repository - ref: Git reference (branch, tag, or commit SHA) - verbose_callback: Optional callable for verbose logging - - Returns: - bytes: File content - """ - host = dep_ref.host or default_host() - - # Parse owner/repo from repo_url - owner, repo = dep_ref.repo_url.split("/", 1) - - # Resolve token via AuthResolver for CDN fast-path decision - org = None - if dep_ref and dep_ref.repo_url: - parts = dep_ref.repo_url.split("/") - if parts: - org = parts[0] - file_ctx = self._host.auth_resolver.resolve(host, org, port=dep_ref.port) - token = file_ctx.token - - # --- CDN fast-path for github.com without a token --- - # raw.githubusercontent.com is served from GitHub's CDN and is not - # subject to the REST API rate limit (60 req/h unauthenticated). - # Only available for github.com -- GHES/GHE-DR have no equivalent. - if host.lower() == "github.com" and not token: - content = self.try_raw_download(owner, repo, ref, file_path) - if content is not None: - if verbose_callback: - verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") - return content - # raw download returned 404 -- could be wrong default branch. - # Try the other default branch before falling through to the API. - if ref in ("main", "master"): - fallback_ref = "master" if ref == "main" else "main" - content = self.try_raw_download(owner, repo, fallback_ref, file_path) - if content is not None: - if verbose_callback: - verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") - return content - # All raw attempts failed -- fall through to API path which - # handles private repos, rate-limit messaging, and SAML errors. - - # --- Generic host: raw URL first, then API version negotiation --- - # For non-GitHub non-GHE hosts (Gitea, Gogs, self-hosted git), try the - # raw URL path first, then negotiate API versions v1 -> v3. - is_github_host = is_github_hostname(host) or self._is_configured_ghes(host) - if not is_github_host: - raw_url = f"https://{host}/{owner}/{repo}/raw/{ref}/{file_path}" - raw_headers = self._build_generic_host_auth_headers(host, file_ctx, accept=None) - if verbose_callback: - verbose_callback(f"Trying raw URL on generic host {host}: {raw_url}") - try: - response = self._host._resilient_get(raw_url, headers=raw_headers, timeout=30) - if response.status_code == 200: - if verbose_callback: - verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") - return response.content - except (requests.RequestException, OSError) as raw_err: - if verbose_callback: - verbose_callback( - f"Raw URL on {host} failed for {file_path}@{ref}: " - f"{type(raw_err).__name__}; falling back to Contents API." - ) - - # --- Contents API path (authenticated, enterprise, or raw fallback) --- - # Build API URL candidates - format differs by host type - api_url_candidates = self._build_contents_api_urls( - host, owner, repo, file_path, ref, is_github_host=is_github_host - ) - api_url = api_url_candidates[0] - - # Set up authentication headers - # GitHub family: use GitHub raw-media accept header. Generic hosts - # ignore it and may return JSON envelopes -- handle that on read. - accept = "application/vnd.github.v3.raw" if is_github_host else "application/json" - if is_github_host: - headers: dict[str, str] = {"Accept": accept} - if token: - headers["Authorization"] = f"token {token}" - else: - headers = self._build_generic_host_auth_headers(host, file_ctx, accept=accept) + """Download a file from a GitHub repository (CDN fast-path then API).""" + from .download_strategies_ops import download_github_file as _impl - # Try to download with the specified ref - try: - if verbose_callback and not is_github_host: - verbose_callback(f"Trying Contents API on {host}: {api_url}") - response = self._host._resilient_get(api_url, headers=headers, timeout=30) - response.raise_for_status() - if verbose_callback: - verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") - return self._extract_contents_api_payload(response, is_github_host) - except requests.exceptions.HTTPError as e: - if e.response.status_code == 404: - # For generic hosts, try remaining API version candidates before ref fallback - for candidate_url in api_url_candidates[1:]: - try: - if verbose_callback: - verbose_callback( - f"Contents API 404; trying next candidate: {candidate_url}" - ) - candidate_resp = self._host._resilient_get( - candidate_url, headers=headers, timeout=30 - ) - candidate_resp.raise_for_status() - if verbose_callback: - verbose_callback( - f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}" - ) - return self._extract_contents_api_payload(candidate_resp, is_github_host) - except requests.exceptions.HTTPError as ce: - if ce.response.status_code != 404: - raise RuntimeError( # noqa: B904 - f"Failed to download {file_path}: HTTP {ce.response.status_code}" - ) - - # Try fallback branches if the specified ref fails - if ref not in ["main", "master"]: - raise RuntimeError( # noqa: B904 - self._build_unsupported_or_missing_error( - host, - dep_ref.repo_url, - file_path, - ref, - api_url_candidates, - is_github_host=is_github_host, - ) - ) - - # Try the other default branch - fallback_ref = "master" if ref == "main" else "main" - fallback_url_candidates = self._build_contents_api_urls( - host, owner, repo, file_path, fallback_ref - ) - - for fallback_url in fallback_url_candidates: - try: - response = self._host._resilient_get( - fallback_url, headers=headers, timeout=30 - ) - response.raise_for_status() - if verbose_callback: - verbose_callback( - f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}" - ) - return self._extract_contents_api_payload(response, is_github_host) - except requests.exceptions.HTTPError as fe: - if fe.response.status_code != 404: - raise RuntimeError( # noqa: B904 - f"Failed to download {file_path}: HTTP {fe.response.status_code}" - ) - - raise RuntimeError( # noqa: B904 - self._build_unsupported_or_missing_error( - host, - dep_ref.repo_url, - file_path, - ref, - api_url_candidates, - is_github_host=is_github_host, - fallback_ref=fallback_ref, - ) - ) - elif e.response.status_code in (401, 403): - # Distinguish rate limiting from auth failure. - # X-RateLimit-* headers are GitHub-specific; treat as - # rate-limit only when the host is in the GitHub family. - is_rate_limit = False - if is_github_host: - try: - rl_remaining = e.response.headers.get("X-RateLimit-Remaining") - if rl_remaining is not None and int(rl_remaining) == 0: - is_rate_limit = True - except (TypeError, ValueError): - pass - - if is_rate_limit: - error_msg = f"GitHub API rate limit exceeded for {dep_ref.repo_url}. " - if not token: - error_msg += ( - "Unauthenticated requests are limited to " - "60/hour (shared per IP). " - + self._host.auth_resolver.build_error_context( - host, - "API request (rate limited)", - org=owner, - port=(dep_ref.port if dep_ref else None), - dep_url=(dep_ref.repo_url if dep_ref else None), - ) - ) - else: - error_msg += ( - "Authenticated rate limit exhausted. " - "Wait a few minutes or check your token's " - "rate-limit quota." - ) - raise RuntimeError(error_msg) from e - - # Retry without auth -- the repo might be public. - # GHES/GHE-DR don't support unauthenticated org-scoped retries. - if token and is_github_host and not host.lower().endswith(".ghe.com"): - try: - unauth_headers: dict[str, str] = {"Accept": "application/vnd.github.v3.raw"} - response = self._host._resilient_get( - api_url, headers=unauth_headers, timeout=30 - ) - response.raise_for_status() - if verbose_callback: - verbose_callback( - f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}" - ) - return self._extract_contents_api_payload(response, is_github_host) - except requests.exceptions.HTTPError: - pass # Fall through to the original error - - error_msg = ( - f"Authentication failed for {dep_ref.repo_url} " - f"(file: {file_path}, ref: {ref}). " - ) - if not token: - error_msg += self._host.auth_resolver.build_error_context( - host, - "download", - org=owner, - port=dep_ref.port if dep_ref else None, - dep_url=dep_ref.repo_url if dep_ref else None, - ) - elif is_github_host and not host.lower().endswith(".ghe.com"): - error_msg += ( - "Both authenticated and unauthenticated access " - "were attempted. The repository may be private, " - "or your token may lack SSO/SAML authorization " - "for this organization." - ) - elif is_github_host: - error_msg += "Please check your GitHub token permissions." - else: - # Generic host: don't claim SSO/SAML or "GitHub token". - error_msg += ( - f"Host {host} rejected the request. " - "Verify the repository exists and that the token has " - "access. Tokens are sourced from your git credential " - "helper, a per-org GITHUB_APM_PAT_ env var, or " - f"GITHUB_HOST={host} when this host is your GitHub " - "Enterprise Server." - ) - raise RuntimeError(error_msg) # noqa: B904 - else: - raise RuntimeError( - f"Failed to download {file_path}: HTTP {e.response.status_code}" - ) from e - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Network error downloading {file_path}: {e}") # noqa: B904 + return _impl(self, dep_ref, file_path, ref, verbose_callback) # ------------------------------------------------------------------ # Helpers for download_github_file diff --git a/src/apm_cli/deps/download_strategies_backends_ops.py b/src/apm_cli/deps/download_strategies_backends_ops.py new file mode 100644 index 000000000..ce48deecf --- /dev/null +++ b/src/apm_cli/deps/download_strategies_backends_ops.py @@ -0,0 +1,416 @@ +"""Backend (Artifactory / Azure DevOps / GitLab) file-download ops for +:class:`~apm_cli.deps.download_strategies.DownloadDelegate`. + +Moved bodies (kept thin wrappers on the class). Each function takes the +owning ``DownloadDelegate`` as ``delegate``. Names that tests patch on +``apm_cli.deps.download_strategies`` are referenced through a function-level +``_ds`` alias so the patch still applies; this module never imports the +original module at module scope (avoids an import cycle). +""" + +import base64 +import io +import os +import zipfile +from pathlib import Path + +import requests + +from ..models.apm_package import DependencyReference + + +def download_artifactory_archive( + delegate, + host: str, + prefix: str, + owner: str, + repo: str, + ref: str, + target_path: Path, + scheme: str = "https", +) -> None: + """Download and extract a zip archive from Artifactory VCS proxy. + + Tries multiple URL patterns (GitHub-style and GitLab-style). + GitHub archives contain a single root directory named {repo}-{ref}/; + this method strips that prefix on extraction so files land directly + in *target_path*. + + Raises RuntimeError on failure. + """ + from apm_cli.deps import download_strategies as _ds + + archive_urls = _ds.build_artifactory_archive_url(host, prefix, owner, repo, ref, scheme=scheme) + headers = delegate.get_artifactory_headers() + + # Guard: reject unreasonably large archives (default 500 MB) + max_archive_bytes = int(os.environ.get("ARTIFACTORY_MAX_ARCHIVE_MB", "500")) * 1024 * 1024 + + last_error = None + for url in archive_urls: + _ds._debug(f"Trying Artifactory archive: {url}") + try: + resp = delegate._host._resilient_get(url, headers=headers, timeout=60) + if resp.status_code == 200: + if len(resp.content) > max_archive_bytes: + last_error = f"Archive too large ({len(resp.content)} bytes) from {url}" + _ds._debug(last_error) + continue + _extract_stripped_archive(resp.content, target_path, url) + _ds._debug(f"Extracted Artifactory archive to {target_path}") + return + last_error = f"HTTP {resp.status_code} from {url}" + _ds._debug(last_error) + except zipfile.BadZipFile: + last_error = f"Invalid zip archive from {url}" + _ds._debug(last_error) + except requests.RequestException as e: + last_error = str(e) + _ds._debug(f"Request failed: {last_error}") + + raise RuntimeError( + f"Failed to download package {owner}/{repo}#{ref} from Artifactory " + f"({host}/{prefix}). Last error: {last_error}" + ) + + +def _extract_stripped_archive(content: bytes, target_path: Path, url: str) -> None: + """Extract a zip archive into *target_path*, stripping the root prefix.""" + from apm_cli.deps import download_strategies as _ds + + target_path.mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(io.BytesIO(content)) as zf: + names = zf.namelist() + if not names: + raise RuntimeError(f"Empty archive from {url}") + root_prefix = names[0] + if not root_prefix.endswith("/"): + # Single file archive; extract as-is + zf.extractall(target_path) + return + for member in zf.infolist(): + if member.filename == root_prefix: + continue + rel = member.filename[len(root_prefix) :] + if not rel: + continue + # Guard: prevent zip path traversal (CWE-22) + dest = target_path / rel + if not dest.resolve().is_relative_to(target_path.resolve()): + _ds._debug(f"Skipping zip entry escaping target: {member.filename}") + continue + unix_mode = (member.external_attr >> 16) & 0xFFFF + if member.is_dir(): + dest.mkdir(parents=True, exist_ok=True) + else: + dest.parent.mkdir(parents=True, exist_ok=True) + with zf.open(member) as src, open(dest, "wb") as dst: + dst.write(src.read()) + if unix_mode: + os.chmod(dest, unix_mode & 0o755) + + +def download_file_from_artifactory( + delegate, + host: str, + prefix: str, + owner: str, + repo: str, + file_path: str, + ref: str, + scheme: str = "https", +) -> bytes: + """Download a single file from Artifactory. + + Tries the Archive Entry Download API first (fetches one file + without downloading the full archive). Falls back to the full + archive approach when the entry API is unavailable or returns an + error. + """ + from apm_cli.deps import download_strategies as _ds + + # Fast path: use the RegistryClient interface for entry download + cfg = delegate._host.registry_config + if cfg is not None and cfg.host == host: + client = cfg.get_client() + content = client.fetch_file( + owner, + repo, + file_path, + ref, + resilient_get=delegate._host._resilient_get, + ) + else: + # No RegistryConfig or host mismatch (explicit FQDN mode) -- + # fall back to the standalone helper. + from .artifactory_entry import fetch_entry_from_archive + + content = fetch_entry_from_archive( + host, + prefix, + owner, + repo, + file_path, + ref, + scheme=scheme, + headers=delegate.get_artifactory_headers(), + resilient_get=delegate._host._resilient_get, + ) + if content is not None: + return content + + # Fallback: download full archive and extract the file + archive_urls = _ds.build_artifactory_archive_url(host, prefix, owner, repo, ref, scheme=scheme) + headers = delegate.get_artifactory_headers() + + for url in archive_urls: + try: + resp = delegate._host._resilient_get(url, headers=headers, timeout=60) + if resp.status_code != 200: + continue + with zipfile.ZipFile(io.BytesIO(resp.content)) as zf: + names = zf.namelist() + root_prefix = names[0] if names else "" + target_name = root_prefix + file_path + if target_name in names: + return zf.read(target_name) + if file_path in names: + return zf.read(file_path) + except (zipfile.BadZipFile, requests.RequestException): + continue + + raise RuntimeError( + f"Failed to download file '{file_path}' from Artifactory " + f"({host}/{prefix}/{owner}/{repo}#{ref})" + ) + + +def download_ado_file( + delegate, + dep_ref: DependencyReference, + file_path: str, + ref: str = "main", +) -> bytes: + """Download a file from Azure DevOps repository. + + Args: + dep_ref: Parsed dependency reference with ADO-specific fields + file_path: Path to file within the repository + ref: Git reference (branch, tag, or commit SHA) + + Returns: + bytes: File content + """ + from apm_cli.deps import download_strategies as _ds + + # Validate required ADO fields before proceeding + if not all([dep_ref.ado_organization, dep_ref.ado_project, dep_ref.ado_repo]): + raise ValueError( + "Invalid Azure DevOps dependency reference: missing " + "organization, project, or repo. " + f"Got: org={dep_ref.ado_organization}, " + f"project={dep_ref.ado_project}, repo={dep_ref.ado_repo}" + ) + + host = dep_ref.host or "dev.azure.com" + api_url = _ds.build_ado_api_url( + dep_ref.ado_organization, + dep_ref.ado_project, + dep_ref.ado_repo, + file_path, + ref, + host, + ) + + # Set up authentication headers - ADO uses Basic auth with PAT + headers: dict[str, str] = {} + if delegate._host.ado_token: + # ADO uses Basic auth: username can be empty, password is the PAT + auth = base64.b64encode(f":{delegate._host.ado_token}".encode()).decode() + headers["Authorization"] = f"Basic {auth}" + + try: + response = delegate._host._resilient_get(api_url, headers=headers, timeout=30) + response.raise_for_status() + return response.content + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + return _ado_handle_404(delegate, e, dep_ref, file_path, ref, host, headers) + if e.response.status_code in (401, 403): + raise RuntimeError(_ado_auth_error_msg(delegate, dep_ref, host)) from e + raise RuntimeError(f"Failed to download {file_path}: HTTP {e.response.status_code}") from e + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Network error downloading {file_path}: {e}") from e + + +def _ado_handle_404( + delegate, + e: requests.exceptions.HTTPError, + dep_ref: DependencyReference, + file_path: str, + ref: str, + host: str, + headers: dict[str, str], +) -> bytes: + """Retry the other default branch when an ADO file 404s.""" + from apm_cli.deps import download_strategies as _ds + + if ref not in ["main", "master"]: + raise RuntimeError( + f"File not found: {file_path} at ref '{ref}' in {dep_ref.repo_url}" + ) from e + + fallback_ref = "master" if ref == "main" else "main" + fallback_url = _ds.build_ado_api_url( + dep_ref.ado_organization, + dep_ref.ado_project, + dep_ref.ado_repo, + file_path, + fallback_ref, + host, + ) + + try: + response = delegate._host._resilient_get(fallback_url, headers=headers, timeout=30) + response.raise_for_status() + return response.content + except requests.exceptions.HTTPError as fallback_err: + raise RuntimeError( + f"File not found: {file_path} in {dep_ref.repo_url} (tried refs: {ref}, {fallback_ref})" + ) from fallback_err + + +def _ado_auth_error_msg(delegate, dep_ref: DependencyReference, host: str) -> str: + """Build the auth-failure message for an ADO 401/403.""" + error_msg = f"Authentication failed for Azure DevOps {dep_ref.repo_url}. " + if not delegate._host.ado_token: + error_msg += delegate._host.auth_resolver.build_error_context( + host, + "download", + org=dep_ref.ado_organization if dep_ref else None, + port=dep_ref.port if dep_ref else None, + dep_url=dep_ref.repo_url if dep_ref else None, + ) + else: + error_msg += "Please check your Azure DevOps PAT permissions." + return error_msg + + +def download_gitlab_file( + delegate, + dep_ref: DependencyReference, + file_path: str, + ref: str = "main", + verbose_callback=None, +) -> bytes: + """Download a file via GitLab REST v4 ``repository/files/.../raw``.""" + from apm_cli.deps import download_strategies as _ds + + host = dep_ref.host or _ds.default_host() + host_info = delegate._host.auth_resolver.classify_host(host) + project_path = dep_ref.repo_url + if not project_path: + raise RuntimeError("Missing repository path for GitLab file download") + + org = project_path.split("/")[0] + file_ctx = delegate._host.auth_resolver.resolve(host, org, port=dep_ref.port) + token = file_ctx.token + headers = _ds.AuthResolver.gitlab_rest_headers(token) + + api_base = host_info.api_base.rstrip("/") + enc_proj = _ds.quote(project_path, safe="") + enc_file = _ds.quote(file_path, safe="") + + def _raw_url(r: str) -> str: + return ( + f"{api_base}/projects/{enc_proj}/repository/files/{enc_file}/raw" + f"?ref={_ds.quote(r, safe='')}" + ) + + api_url = _raw_url(ref) + + try: + response = delegate._host._resilient_get(api_url, headers=headers, timeout=30) + response.raise_for_status() + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return response.content + except requests.exceptions.HTTPError as e: + return _gitlab_handle_http_error( + delegate, + e, + dep_ref, + file_path, + ref, + host, + org, + token, + headers, + _raw_url, + verbose_callback, + ) + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Network error downloading {file_path}: {e}") from e + + +def _gitlab_handle_http_error( + delegate, + e: requests.exceptions.HTTPError, + dep_ref: DependencyReference, + file_path: str, + ref: str, + host: str, + org: str, + token: str | None, + headers: dict[str, str], + raw_url_builder, + verbose_callback, +) -> bytes: + """Handle 404/auth/other errors for a GitLab raw file fetch.""" + if e.response is not None and e.response.status_code == 404: + if ref not in ("main", "master"): + raise RuntimeError( + f"File not found: {file_path} at ref '{ref}' in {dep_ref.repo_url}" + ) from e + fallback_ref = "master" if ref == "main" else "main" + fallback_url = raw_url_builder(fallback_ref) + try: + response = delegate._host._resilient_get(fallback_url, headers=headers, timeout=30) + response.raise_for_status() + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return response.content + except requests.exceptions.HTTPError as fallback_err: + raise RuntimeError( + f"File not found: {file_path} in {dep_ref.repo_url} " + f"(tried refs: {ref}, {fallback_ref})" + ) from fallback_err + if e.response is not None and e.response.status_code in (401, 403): + raise RuntimeError( + _gitlab_auth_error_msg(delegate, dep_ref, file_path, ref, host, org, token) + ) from e + if e.response is not None: + raise RuntimeError(f"Failed to download {file_path}: HTTP {e.response.status_code}") from e + raise e + + +def _gitlab_auth_error_msg( + delegate, + dep_ref: DependencyReference, + file_path: str, + ref: str, + host: str, + org: str, + token: str | None, +) -> str: + """Build the auth-failure message for a GitLab 401/403.""" + error_msg = ( + f"Authentication failed for GitLab {dep_ref.repo_url} (file: {file_path}, ref: {ref}). " + ) + if not token: + error_msg += delegate._host.auth_resolver.build_error_context( + host, "download", org=org, port=dep_ref.port + ) + else: + error_msg += "Please verify your token can read this project (required API scope)." + return error_msg diff --git a/src/apm_cli/deps/download_strategies_ops.py b/src/apm_cli/deps/download_strategies_ops.py new file mode 100644 index 000000000..8afdc115e --- /dev/null +++ b/src/apm_cli/deps/download_strategies_ops.py @@ -0,0 +1,379 @@ +"""GitHub file-download ops for +:class:`~apm_cli.deps.download_strategies.DownloadDelegate`. + +Moved body of ``download_github_file`` plus its cohesive ``_gh_*`` helpers +(CDN fast-path, generic-host raw attempt, Contents-API request, 404 / auth +handling, message builders). Names that tests patch on +``apm_cli.deps.download_strategies`` are referenced through a function-level +``_ds`` alias so the patch still applies. +""" + +import requests + +from ..models.apm_package import DependencyReference + + +def download_github_file( + delegate, + dep_ref: DependencyReference, + file_path: str, + ref: str = "main", + verbose_callback=None, +) -> bytes: + """Download a file from GitHub repository. + + For github.com without a token, tries raw.githubusercontent.com first + (CDN, no rate limit) before falling back to the Contents API. + Authenticated requests and non-github.com hosts always use the + Contents API directly. + + Args: + dep_ref: Parsed dependency reference + file_path: Path to file within the repository + ref: Git reference (branch, tag, or commit SHA) + verbose_callback: Optional callable for verbose logging + + Returns: + bytes: File content + """ + from apm_cli.deps import download_strategies as _ds + + host = dep_ref.host or _ds.default_host() + + # Parse owner/repo from repo_url + owner, repo = dep_ref.repo_url.split("/", 1) + + # Resolve token via AuthResolver for CDN fast-path decision + org = None + if dep_ref and dep_ref.repo_url: + parts = dep_ref.repo_url.split("/") + if parts: + org = parts[0] + file_ctx = delegate._host.auth_resolver.resolve(host, org, port=dep_ref.port) + token = file_ctx.token + + # --- CDN fast-path for github.com without a token --- + if host.lower() == "github.com" and not token: + content = _gh_cdn_fastpath( + delegate, host, owner, repo, ref, file_path, dep_ref, verbose_callback + ) + if content is not None: + return content + # All raw attempts failed -- fall through to API path which handles + # private repos, rate-limit messaging, and SAML errors. + + # --- Generic host: raw URL first, then API version negotiation --- + is_github_host = _ds.is_github_hostname(host) or delegate._is_configured_ghes(host) + if not is_github_host: + content = _gh_generic_raw_attempt( + delegate, host, owner, repo, ref, file_path, dep_ref, file_ctx, verbose_callback + ) + if content is not None: + return content + + # --- Contents API path (authenticated, enterprise, or raw fallback) --- + return _gh_contents_api( + delegate, + host, + owner, + repo, + file_path, + ref, + dep_ref, + token, + file_ctx, + is_github_host, + verbose_callback, + ) + + +def _gh_cdn_fastpath( + delegate, host, owner, repo, ref, file_path, dep_ref, verbose_callback +) -> bytes | None: + """Try raw.githubusercontent.com (CDN) for github.com, both default branches.""" + content = delegate.try_raw_download(owner, repo, ref, file_path) + if content is not None: + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return content + # raw download returned 404 -- could be wrong default branch. + if ref in ("main", "master"): + fallback_ref = "master" if ref == "main" else "main" + content = delegate.try_raw_download(owner, repo, fallback_ref, file_path) + if content is not None: + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return content + return None + + +def _gh_generic_raw_attempt( + delegate, host, owner, repo, ref, file_path, dep_ref, file_ctx, verbose_callback +) -> bytes | None: + """Try the raw URL path on a generic (non-GitHub) host before the Contents API.""" + raw_url = f"https://{host}/{owner}/{repo}/raw/{ref}/{file_path}" + raw_headers = delegate._build_generic_host_auth_headers(host, file_ctx, accept=None) + if verbose_callback: + verbose_callback(f"Trying raw URL on generic host {host}: {raw_url}") + try: + response = delegate._host._resilient_get(raw_url, headers=raw_headers, timeout=30) + if response.status_code == 200: + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return response.content + except (requests.RequestException, OSError) as raw_err: + if verbose_callback: + verbose_callback( + f"Raw URL on {host} failed for {file_path}@{ref}: " + f"{type(raw_err).__name__}; falling back to Contents API." + ) + return None + + +def _gh_contents_api( + delegate, + host, + owner, + repo, + file_path, + ref, + dep_ref, + token, + file_ctx, + is_github_host, + verbose_callback, +) -> bytes: + """Fetch a file via the GitHub/GHES Contents API, with 404/auth handling.""" + api_url_candidates = delegate._build_contents_api_urls( + host, owner, repo, file_path, ref, is_github_host=is_github_host + ) + api_url = api_url_candidates[0] + + # GitHub family: use GitHub raw-media accept header. Generic hosts + # ignore it and may return JSON envelopes -- handle that on read. + accept = "application/vnd.github.v3.raw" if is_github_host else "application/json" + if is_github_host: + headers: dict[str, str] = {"Accept": accept} + if token: + headers["Authorization"] = f"token {token}" + else: + headers = delegate._build_generic_host_auth_headers(host, file_ctx, accept=accept) + + try: + if verbose_callback and not is_github_host: + verbose_callback(f"Trying Contents API on {host}: {api_url}") + response = delegate._host._resilient_get(api_url, headers=headers, timeout=30) + response.raise_for_status() + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return delegate._extract_contents_api_payload(response, is_github_host) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + return _gh_handle_404( + delegate, + host, + owner, + repo, + file_path, + ref, + dep_ref, + headers, + api_url_candidates, + is_github_host, + verbose_callback, + ) + if e.response.status_code in (401, 403): + return _gh_handle_auth_error( + delegate, + e, + host, + owner, + file_path, + ref, + dep_ref, + api_url, + token, + is_github_host, + verbose_callback, + ) + raise RuntimeError(f"Failed to download {file_path}: HTTP {e.response.status_code}") from e + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Network error downloading {file_path}: {e}") # noqa: B904 + + +def _gh_handle_404( + delegate, + host, + owner, + repo, + file_path, + ref, + dep_ref, + headers, + api_url_candidates, + is_github_host, + verbose_callback, +) -> bytes: + """Handle a Contents-API 404: try remaining candidates, then ref fallback.""" + # For generic hosts, try remaining API version candidates before ref fallback + for candidate_url in api_url_candidates[1:]: + try: + if verbose_callback: + verbose_callback(f"Contents API 404; trying next candidate: {candidate_url}") + candidate_resp = delegate._host._resilient_get( + candidate_url, headers=headers, timeout=30 + ) + candidate_resp.raise_for_status() + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return delegate._extract_contents_api_payload(candidate_resp, is_github_host) + except requests.exceptions.HTTPError as ce: + if ce.response.status_code != 404: + raise RuntimeError( # noqa: B904 + f"Failed to download {file_path}: HTTP {ce.response.status_code}" + ) + + # Try fallback branches if the specified ref fails + if ref not in ["main", "master"]: + raise RuntimeError( + delegate._build_unsupported_or_missing_error( + host, + dep_ref.repo_url, + file_path, + ref, + api_url_candidates, + is_github_host=is_github_host, + ) + ) + + # Try the other default branch + fallback_ref = "master" if ref == "main" else "main" + fallback_url_candidates = delegate._build_contents_api_urls( + host, owner, repo, file_path, fallback_ref + ) + + for fallback_url in fallback_url_candidates: + try: + response = delegate._host._resilient_get(fallback_url, headers=headers, timeout=30) + response.raise_for_status() + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return delegate._extract_contents_api_payload(response, is_github_host) + except requests.exceptions.HTTPError as fe: + if fe.response.status_code != 404: + raise RuntimeError( # noqa: B904 + f"Failed to download {file_path}: HTTP {fe.response.status_code}" + ) + + raise RuntimeError( + delegate._build_unsupported_or_missing_error( + host, + dep_ref.repo_url, + file_path, + ref, + api_url_candidates, + is_github_host=is_github_host, + fallback_ref=fallback_ref, + ) + ) + + +def _gh_handle_auth_error( + delegate, + e, + host, + owner, + file_path, + ref, + dep_ref, + api_url, + token, + is_github_host, + verbose_callback, +) -> bytes: + """Handle a Contents-API 401/403: rate-limit vs auth, with unauth retry.""" + # Distinguish rate limiting from auth failure. X-RateLimit-* headers are + # GitHub-specific; treat as rate-limit only when host is GitHub family. + is_rate_limit = False + if is_github_host: + try: + rl_remaining = e.response.headers.get("X-RateLimit-Remaining") + if rl_remaining is not None and int(rl_remaining) == 0: + is_rate_limit = True + except (TypeError, ValueError): + pass + + if is_rate_limit: + raise RuntimeError(_gh_rate_limit_msg(delegate, host, owner, dep_ref, token)) from e + + # Retry without auth -- the repo might be public. GHES/GHE-DR don't + # support unauthenticated org-scoped retries. + if token and is_github_host and not host.lower().endswith(".ghe.com"): + try: + unauth_headers: dict[str, str] = {"Accept": "application/vnd.github.v3.raw"} + response = delegate._host._resilient_get(api_url, headers=unauth_headers, timeout=30) + response.raise_for_status() + if verbose_callback: + verbose_callback(f"Downloaded file: {host}/{dep_ref.repo_url}/{file_path}") + return delegate._extract_contents_api_payload(response, is_github_host) + except requests.exceptions.HTTPError: + pass # Fall through to the original error + + raise RuntimeError( + _gh_auth_failed_msg(delegate, host, owner, file_path, ref, dep_ref, token, is_github_host) + ) + + +def _gh_rate_limit_msg(delegate, host, owner, dep_ref, token) -> str: + """Build the rate-limit error message for a GitHub 403.""" + error_msg = f"GitHub API rate limit exceeded for {dep_ref.repo_url}. " + if not token: + error_msg += ( + "Unauthenticated requests are limited to 60/hour (shared per IP). " + + delegate._host.auth_resolver.build_error_context( + host, + "API request (rate limited)", + org=owner, + port=(dep_ref.port if dep_ref else None), + dep_url=(dep_ref.repo_url if dep_ref else None), + ) + ) + else: + error_msg += ( + "Authenticated rate limit exhausted. " + "Wait a few minutes or check your token's rate-limit quota." + ) + return error_msg + + +def _gh_auth_failed_msg( + delegate, host, owner, file_path, ref, dep_ref, token, is_github_host +) -> str: + """Build the auth-failure error message for a GitHub 401/403.""" + error_msg = f"Authentication failed for {dep_ref.repo_url} (file: {file_path}, ref: {ref}). " + if not token: + error_msg += delegate._host.auth_resolver.build_error_context( + host, + "download", + org=owner, + port=dep_ref.port if dep_ref else None, + dep_url=dep_ref.repo_url if dep_ref else None, + ) + elif is_github_host and not host.lower().endswith(".ghe.com"): + error_msg += ( + "Both authenticated and unauthenticated access were attempted. " + "The repository may be private, or your token may lack SSO/SAML " + "authorization for this organization." + ) + elif is_github_host: + error_msg += "Please check your GitHub token permissions." + else: + # Generic host: don't claim SSO/SAML or "GitHub token". + error_msg += ( + f"Host {host} rejected the request. " + "Verify the repository exists and that the token has access. " + "Tokens are sourced from your git credential helper, a per-org " + f"GITHUB_APM_PAT_ env var, or GITHUB_HOST={host} when this " + "host is your GitHub Enterprise Server." + ) + return error_msg diff --git a/src/apm_cli/deps/github_downloader.py b/src/apm_cli/deps/github_downloader.py index e58ee3bab..15ab8fd4a 100644 --- a/src/apm_cli/deps/github_downloader.py +++ b/src/apm_cli/deps/github_downloader.py @@ -2,49 +2,49 @@ import contextlib import os -import re -import subprocess + +# subprocess / tempfile are re-exported (tests patch them on this module) even +# though their only direct users now live in github_downloader_ops, which +# routes back through ``_gh.`` so the patches still apply. +import subprocess as subprocess import sys -import tempfile -import threading -import time +import tempfile as tempfile +import time as time from collections.abc import Callable -from datetime import datetime from pathlib import Path from typing import Any, Union import git # noqa: F401 -- re-exported; tests patch apm_cli.deps.github_downloader.git import requests from git import RemoteProgress, Repo -from git.exc import GitCommandError -from ..core.auth import AuthContext, AuthResolver +from ..core.auth import AuthContext +from ..core.auth import AuthResolver as AuthResolver +from ..models.apm_package import APMPackage as APMPackage from ..models.apm_package import ( - APMPackage, DependencyReference, - GitReferenceType, PackageInfo, - PackageType, RemoteRef, ResolvedReference, - validate_apm_package, +) +from ..models.apm_package import ( + validate_apm_package as validate_apm_package, ) from ..utils.console import ( - _rich_warning, # noqa: F401 -- re-exported; tests patch github_downloader._rich_warning + _rich_warning as _rich_warning, ) from ..utils.github_host import ( default_host, is_github_hostname, - sanitize_token_url_in_message, ) -from ..utils.yaml_io import yaml_to_str +from ..utils.yaml_io import yaml_to_str as yaml_to_str from .bare_cache import ( bare_clone_with_fallback, clone_with_fallback, fetch_sha_into_bare, materialize_from_bare, ) -from .download_strategies import DownloadDelegate +from .download_strategies import DownloadDelegate as DownloadDelegate from .git_remote_ops import ( parse_ls_remote_output, semver_sort_key, @@ -188,90 +188,10 @@ def __init__( protocol_pref: ProtocolPreference | None = None, allow_fallback: bool | None = None, ): - """Initialize the GitHub package downloader. - - Args: - auth_resolver: Auth resolver instance. Defaults to a new AuthResolver. - transport_selector: TransportSelector for protocol decisions. - Defaults to a new selector with GitConfigInsteadOfResolver. - protocol_pref: User-stated transport preference for shorthand - deps. When None, resolved from ``APM_GIT_PROTOCOL`` env var, - then ``prefer-ssh`` in ``~/.apm/config.json``, then ``None`` - (let git insteadOf rules decide). - allow_fallback: When True, permits cross-protocol fallback - (legacy behavior). When None, resolved from - ``APM_ALLOW_PROTOCOL_FALLBACK`` env var, then - ``allow-protocol-fallback`` in ``~/.apm/config.json``, - then ``False``. - """ - self.auth_resolver = auth_resolver or AuthResolver() - self.token_manager = self.auth_resolver._token_manager # Backward compat - self.git_env = self._setup_git_environment() - self._transport_selector = transport_selector or TransportSelector() - if protocol_pref is not None: - self._protocol_pref = protocol_pref - else: - # Use the config-aware helper (env > apm config > None) so that - # ``apm config set ssh true`` is honoured even when the downloader - # is constructed without explicit args (e.g. in validation.py). - from ..config import get_apm_protocol_pref as _get_pref - from .transport_selection import ProtocolPreference - - _pref_str = _get_pref() - self._protocol_pref = ProtocolPreference.from_str(_pref_str) - if allow_fallback is not None: - self._allow_fallback = allow_fallback - else: - # Config-aware helper (env > apm config > False). - from ..config import get_apm_allow_protocol_fallback as _get_fallback - - self._allow_fallback = _get_fallback() - # Dedup set for the issue #786 cross-protocol port warning: one install - # run calls _clone_with_fallback multiple times per dep (ref-resolution - # clone, then the actual dep clone). We want the warning exactly once - # per (host, repo, port) identity across all those calls. - self._fallback_port_warned: set = set() - self._fallback_port_warned_lock = threading.Lock() - - # Delegate backend-specific download logic to the download delegate. - self._strategies = DownloadDelegate(host=self) - - # Artifactory orchestration is encapsulated in a dedicated facade - # (download_package / download_subdirectory) backed by the - # DownloadDelegate's HTTP archive downloader. - from .artifactory_orchestrator import ArtifactoryOrchestrator - from .clone_engine import CloneEngine - from .git_reference_resolver import GitReferenceResolver - - self._artifactory = ArtifactoryOrchestrator(archive_downloader=self._strategies) - self._refs = GitReferenceResolver(host=self) - self._clone_engine = CloneEngine(host=self) - - # WS2a (#1116): per-run shared clone cache for subdirectory dep - # deduplication. Set by the install pipeline before resolution - # starts; None means no dedup (each subdir dep clones independently). - self.shared_clone_cache = None - - # WS3 (#1116): persistent cross-run git cache. When set, the - # download flow checks the on-disk cache before any network clone. - # Set by the install pipeline; None disables persistent caching. - self.persistent_git_cache = None - - # #1369: tiered ref resolver. Attached by resolve.py / outdated.py - # after construction via ``build_tiered_ref_resolver``. When set, - # :meth:`resolve_git_reference` delegates to it before falling - # through to ``self._refs.resolve``. Declared here so the - # attribute is part of the documented downloader surface rather - # than a monkey-patched field. - self._tiered_resolver = None - - # Perf #1433: optional InstallLogger attached by the install - # pipeline. When set, the subdir download path emits structured - # verbose-only [perf] lines (subdir_download_start / - # bare_clone_strategy / materialize_result). None means the - # downloader is being driven outside the install pipeline (e.g. - # tests, marketplace) -- the [perf] channel stays silent. - self.install_logger = None + """Initialize the GitHub package downloader (wiring delegated).""" + from .github_downloader_setup_ops import init_downloader as _impl + + _impl(self, auth_resolver, transport_selector, protocol_pref, allow_fallback) def _git_env_dict(self) -> dict[str, str]: """Return a sanitized git env dict for cache-layer subprocess calls. @@ -283,49 +203,10 @@ def _git_env_dict(self) -> dict[str, str]: return GitAuthEnvBuilder.subprocess_env_dict(self.git_env) def _setup_git_environment(self) -> dict[str, Any]: - """Set up Git environment with authentication using centralized token manager. + """Set up Git environment with authentication (delegated).""" + from .github_downloader_setup_ops import setup_git_environment as _impl - Builds the auth-bearing env via :class:`GitAuthEnvBuilder`, then - records token-state attributes on the downloader (these are read - by many other methods on the class). - """ - from .git_auth_env import GitAuthEnvBuilder - - builder = GitAuthEnvBuilder(self.token_manager) - env = builder.setup_environment() - - # IMPORTANT: Do not resolve credentials via helpers at construction time. - # AuthResolver.resolve(...) can trigger OS credential helper UI. If we do - # this eagerly (host-only key) and later resolve per-dependency (host+org), - # users can see duplicate auth prompts. Keep constructor token state env-only - # and resolve lazily per dependency during clone/validate flows. - self.github_token = self.token_manager.get_token_for_purpose("modules", env) - self.has_github_token = self.github_token is not None - self._github_token_from_credential_fill = False - - # GitLab (env-only at init; lazy auth resolution happens per dep) - self.gitlab_token = self.token_manager.get_token_for_purpose("gitlab_modules", env) - self.has_gitlab_token = self.gitlab_token is not None - - # Azure DevOps (env-only at init; lazy auth resolution happens per dep) - self.ado_token = self.token_manager.get_token_for_purpose("ado_modules", env) - self.has_ado_token = self.ado_token is not None - - # JFrog Artifactory (not host-based, uses dedicated env var) - self.artifactory_token = self.token_manager.get_token_for_purpose( - "artifactory_modules", env - ) - self.has_artifactory_token = self.artifactory_token is not None - - _debug( - f"Token setup: has_github_token={self.has_github_token}, " - f"has_gitlab_token={self.has_gitlab_token}, " - f"has_ado_token={self.has_ado_token}, " - f"has_artifactory_token={self.has_artifactory_token}" - f"{', source=credential_helper' if self._github_token_from_credential_fill else ''}" - ) - - return env + return _impl(self) # --- Registry proxy support --- @@ -418,50 +299,18 @@ def _parse_artifactory_base_url(self) -> tuple | None: return ArtifactoryRouter.parse_proxy_config() def _resolve_dep_token(self, dep_ref: DependencyReference | None = None) -> str | None: - """Resolve the per-dependency auth token via AuthResolver. - - GitHub, GitLab, and ADO hosts use the token resolved by AuthResolver. - Other generic hosts return None so git credential helpers can provide - credentials instead. + """Resolve the per-dependency auth token via AuthResolver (delegated).""" + from .github_downloader_setup_ops import resolve_dep_token as _impl - Args: - dep_ref: Optional dependency reference for host/org lookup. - - Returns: - Token string or None. - """ - if dep_ref is None: - return self.github_token - - if self._is_generic_dependency_host(dep_ref): - return None - - dep_ctx = self.auth_resolver.resolve_for_dep(dep_ref) - return dep_ctx.token + return _impl(self, dep_ref) def _resolve_dep_auth_ctx( self, dep_ref: DependencyReference | None = None ) -> AuthContext | None: - """Resolve the full AuthContext for a dependency. - - Returns the AuthContext from AuthResolver, or None for generic hosts - or when no dep_ref is provided. - """ - if dep_ref is None: - return None + """Resolve the full AuthContext for a dependency (delegated).""" + from .github_downloader_setup_ops import resolve_dep_auth_ctx as _impl - dep_host = dep_ref.host - if self._is_generic_dependency_host(dep_ref): - return None - - ctx = self.auth_resolver.resolve_for_dep(dep_ref) - # Verbose source surfacing (#852): one-time per-host log line so users - # can see which credential source was actually used. Routed through - # AuthResolver.notify_auth_source() (#856 follow-up F2) so the line - # obeys the same verbose-channel logic as every other diagnostic. - if os.environ.get("APM_VERBOSE") == "1": - self.auth_resolver.notify_auth_source(dep_host or "", ctx) - return ctx + return _impl(self, dep_ref) def _build_noninteractive_git_env( self, @@ -490,40 +339,10 @@ def _resilient_get( ) def _sanitize_git_error(self, error_message: str) -> str: - """Sanitize Git error messages to remove potentially sensitive authentication information. - - Args: - error_message: Raw error message from Git operations - - Returns: - str: Sanitized error message with sensitive data removed - """ - import re - - # Remove any tokens that might appear in URLs for github hosts (format: https://token@host) - # Sanitize for default host and common enterprise hosts via helper - sanitized = sanitize_token_url_in_message(error_message, host=default_host()) - - # Sanitize Azure DevOps URLs - both cloud (dev.azure.com) and any on-prem server - # Use a generic pattern to catch https://token@anyhost format for all hosts - # This catches: dev.azure.com, ado.company.com, tfs.internal.corp, etc. - sanitized = re.sub(r"https://[^@\s]+@([^\s/]+)", r"https://***@\1", sanitized) - - # Remove any tokens that might appear as standalone values - sanitized = re.sub( - r"(ghp_|gho_|ghu_|ghs_|ghr_|glpat[_-])[a-zA-Z0-9_\-]+", - "***", - sanitized, - ) - - # Remove environment variable values that might contain tokens - sanitized = re.sub( - r"(GITHUB_TOKEN|GITHUB_APM_PAT|ADO_APM_PAT|GH_TOKEN|GITHUB_COPILOT_PAT|GITLAB_APM_PAT|GITLAB_TOKEN)=[^\s]+", - r"\1=***", - sanitized, - ) + """Sanitize Git error messages to remove sensitive auth info (delegated).""" + from .github_downloader_setup_ops import sanitize_git_error as _impl - return sanitized + return _impl(self, error_message) def _build_repo_url( self, @@ -710,56 +529,10 @@ def _resolve_commit_sha_for_ref(self, dep_ref: DependencyReference, ref: str) -> def download_raw_file( self, dep_ref: DependencyReference, file_path: str, ref: str = "main", verbose_callback=None ) -> bytes: - """Download a single file from repository (GitHub or Azure DevOps). + """Download a single file from a repository (delegated).""" + from .github_downloader_setup_ops import download_raw_file as _impl - Args: - dep_ref: Parsed dependency reference - file_path: Path to file within the repository (e.g., "prompts/code-review.prompt.md") - ref: Git reference (branch, tag, or commit SHA). Defaults to "main" - verbose_callback: Optional callable for verbose logging (receives str messages) - - Returns: - bytes: File content - - Raises: - RuntimeError: If download fails or file not found - """ - _ = dep_ref.host or default_host() - - # Check if this is Artifactory (Mode 1: explicit FQDN) - if dep_ref.is_artifactory(): - repo_parts = dep_ref.repo_url.split("/") - return self._download_file_from_artifactory( - dep_ref.host, - dep_ref.artifactory_prefix, - repo_parts[0], - repo_parts[1] if len(repo_parts) > 1 else repo_parts[0], - file_path, - ref, - ) - - # Check if this should go through Artifactory proxy (Mode 2) - art_proxy = self._parse_artifactory_base_url() - if art_proxy and self._should_use_artifactory_proxy(dep_ref): - repo_parts = dep_ref.repo_url.split("/") - return self._download_file_from_artifactory( - art_proxy[0], - art_proxy[1], - repo_parts[0], - repo_parts[1] if len(repo_parts) > 1 else repo_parts[0], - file_path, - ref, - scheme=art_proxy[2], - ) - - # Check if this is Azure DevOps - if dep_ref.is_azure_devops(): - return self._download_ado_file(dep_ref, file_path, ref) - - # GitHub API - return self._download_github_file( - dep_ref, file_path, ref, verbose_callback=verbose_callback - ) + return _impl(self, dep_ref, file_path, ref, verbose_callback) def _download_ado_file( self, dep_ref: DependencyReference, file_path: str, ref: str = "main" @@ -866,154 +639,10 @@ def download_virtual_file_package( progress_task_id=None, progress_obj=None, ) -> PackageInfo: - """Download a single file as a virtual APM package. - - Creates a minimal APM package structure with the file placed in the appropriate - .apm/ subdirectory based on its extension. - - Args: - dep_ref: Dependency reference with virtual_path set - target_path: Local path where virtual package should be created - progress_task_id: Rich Progress task ID for progress updates - progress_obj: Rich Progress object for progress updates + """Download a single file as a virtual APM package (delegated).""" + from .github_downloader_package_ops import download_virtual_file_package as _impl - Returns: - PackageInfo: Information about the created virtual package - - Raises: - ValueError: If the dependency is not a valid virtual file package - RuntimeError: If download fails - """ - if not dep_ref.is_virtual or not dep_ref.virtual_path: - raise ValueError("Dependency must be a virtual file package") - - if not dep_ref.is_virtual_file(): - raise ValueError( - f"Path '{dep_ref.virtual_path}' is not a valid individual file. " - f"Must end with one of: {', '.join(DependencyReference.VIRTUAL_FILE_EXTENSIONS)}" - ) - - # Determine the ref to use - ref = dep_ref.reference or "main" - - # Resolve the commit SHA cheaply BEFORE the file download. This is one - # short HTTP call (Accept: application/vnd.github.sha returns just the - # 40-char SHA in the body) and the result is propagated into PackageInfo - # so the lockfile and per-dep header can render the SHA suffix instead - # of just the ref name. On non-GitHub hosts or any failure this returns - # None and we fall back to ref-name only -- the install never fails on - # SHA resolution. - resolved_commit = self._resolve_commit_sha_for_ref(dep_ref, ref) - - # Update progress - downloading - if progress_obj and progress_task_id is not None: - progress_obj.update(progress_task_id, completed=50, total=100) - - # Download the file content - try: - file_content = self.download_raw_file(dep_ref, dep_ref.virtual_path, ref) - except RuntimeError as e: - raise RuntimeError(f"Failed to download virtual package: {e}") from e - - # Update progress - processing - if progress_obj and progress_task_id is not None: - progress_obj.update(progress_task_id, completed=90, total=100) - - # Create target directory structure - target_path.mkdir(parents=True, exist_ok=True) - - # Determine the subdirectory based on file extension - subdirs = { - ".prompt.md": "prompts", - ".instructions.md": "instructions", - ".chatmode.md": "chatmodes", - ".agent.md": "agents", - } - - subdir = None - filename = dep_ref.virtual_path.split("/")[-1] - for ext, dir_name in subdirs.items(): - if dep_ref.virtual_path.endswith(ext): - subdir = dir_name - break - - if not subdir: - raise ValueError(f"Unknown file extension for {dep_ref.virtual_path}") - - # Create .apm structure - apm_dir = target_path / ".apm" / subdir - apm_dir.mkdir(parents=True, exist_ok=True) - - # Write the file - file_path = apm_dir / filename - file_path.write_bytes(file_content) - - # Generate minimal apm.yml - package_name = dep_ref.get_virtual_package_name() - - # Try to extract description from file frontmatter - description = f"Virtual package containing {filename}" - try: - content_str = file_content.decode("utf-8") - # Simple frontmatter parsing (YAML between --- markers) - if content_str.startswith("---\n"): - end_idx = content_str.find("\n---\n", 4) - if end_idx > 0: - frontmatter = content_str[4:end_idx] - # Look for description field - for line in frontmatter.split("\n"): - if line.startswith("description:"): - description = line.split(":", 1)[1].strip().strip("\"'") - break - except Exception: - # If frontmatter parsing fails, use default description - pass - - apm_yml_data = { - "name": package_name, - "version": "1.0.0", - "description": description, - "author": dep_ref.repo_url.split("/")[0], - } - apm_yml_content = yaml_to_str(apm_yml_data) - - apm_yml_path = target_path / "apm.yml" - apm_yml_path.write_text(apm_yml_content, encoding="utf-8") - - # Create APMPackage object - package = APMPackage( - name=package_name, - version="1.0.0", - description=description, - author=dep_ref.repo_url.split("/")[0], - source=dep_ref.to_github_url(), - package_path=target_path, - ) - - # Build the resolved reference. On non-GitHub hosts or SHA-resolve - # failure the resolved_commit stays None and the suffix renders as - # "#ref" only -- matching the existing subdirectory behavior in - # _try_sparse_checkout / _download_subdirectory. - ref_type = ( - GitReferenceType.COMMIT - if re.match(r"^[a-f0-9]{40}$", ref.lower()) - else GitReferenceType.BRANCH - ) - resolved_ref = ResolvedReference( - original_ref=str(dep_ref.reference) if dep_ref.reference else ref, - ref_name=ref, - ref_type=ref_type, - resolved_commit=resolved_commit, - ) - - # Return PackageInfo - return PackageInfo( - package=package, - install_path=target_path, - installed_at=datetime.now().isoformat(), - dependency_ref=dep_ref, # Store for canonical dependency string - resolved_reference=resolved_ref, - ) + return _impl(self, dep_ref, target_path, progress_task_id, progress_obj) def _try_sparse_checkout( self, @@ -1022,64 +651,10 @@ def _try_sparse_checkout( subdir_path: str, ref: str | None = None, ) -> bool: - """Attempt sparse-checkout to download only a subdirectory (git 2.25+). - - Returns True on success. Falls back silently on failure. - """ + """Attempt sparse-checkout to download only a subdirectory (delegated).""" + from .github_downloader_package_ops import try_sparse_checkout as _impl - try: - temp_clone_path.mkdir(parents=True, exist_ok=True) - - # Resolve per-dependency token via AuthResolver. - dep_token = self._resolve_dep_token(dep_ref) - dep_auth_ctx = self._resolve_dep_auth_ctx(dep_ref) - dep_auth_scheme = dep_auth_ctx.auth_scheme if dep_auth_ctx else "basic" - - # For ADO bearer, use the AuthContext git_env with header injection - if dep_auth_scheme == "bearer" and dep_auth_ctx is not None: - env = {**os.environ, **(dep_auth_ctx.git_env or {})} - else: - env = {**os.environ, **(self.git_env or {})} - auth_url = self._build_repo_url( - dep_ref.repo_url, - use_ssh=False, - dep_ref=dep_ref, - token=dep_token, - auth_scheme=dep_auth_scheme, - ) - - cmds = [ - ["git", "init"], - ["git", "remote", "add", "origin", auth_url], - ["git", "sparse-checkout", "init", "--cone"], - ["git", "sparse-checkout", "set", subdir_path], - ] - fetch_cmd = ["git", "fetch", "origin"] - fetch_cmd.append(ref or "HEAD") - fetch_cmd.append("--depth=1") - cmds.append(fetch_cmd) - cmds.append(["git", "checkout", "FETCH_HEAD"]) - - for cmd in cmds: - result = subprocess.run( - cmd, - cwd=str(temp_clone_path), - env=env, - capture_output=True, - text=True, - encoding="utf-8", - timeout=120, - ) - if result.returncode != 0: - _debug( - f"Sparse-checkout step failed ({' '.join(cmd)}): {result.stderr.strip()}" - ) - return False - - return True - except Exception as e: - _debug(f"Sparse-checkout failed: {e}") - return False + return _impl(self, dep_ref, temp_clone_path, subdir_path, ref) def download_subdirectory_package( self, @@ -1088,401 +663,10 @@ def download_subdirectory_package( progress_task_id=None, progress_obj=None, ) -> PackageInfo: - """Download a subdirectory from a repo as an APM package. + """Download a subdirectory from a repo as an APM package (delegated).""" + from .github_downloader_subdir_ops import download_subdirectory_package as _impl - Used for Claude Skills or APM packages nested in monorepos. - Clones the repo, extracts the subdirectory, and cleans up. - - Args: - dep_ref: Dependency reference with virtual_path set to subdirectory - target_path: Local path where package should be created - progress_task_id: Rich Progress task ID for progress updates - progress_obj: Rich Progress object for progress updates - - Returns: - PackageInfo: Information about the downloaded package - - Raises: - ValueError: If the dependency is not a valid subdirectory package - RuntimeError: If download or validation fails - """ - if not dep_ref.is_virtual or not dep_ref.virtual_path: - raise ValueError("Dependency must be a virtual subdirectory package") - - if not dep_ref.is_virtual_subdirectory(): - raise ValueError(f"Path '{dep_ref.virtual_path}' is not a valid subdirectory package") - - # Use user-specified ref, or None to use repo's default branch - ref = dep_ref.reference # None if not specified - subdir_path = dep_ref.virtual_path - _perf_logger = getattr(self, "install_logger", None) - _dep_display = str(dep_ref) - - # Update progress - starting - if progress_obj and progress_task_id is not None: - progress_obj.update(progress_task_id, completed=10, total=100) - - # WS2a (#1116): attempt shared clone dedup when a per-run cache - # is available. Two subdir deps from the same (host, owner, repo, ref) - # share one clone; different refs always get independent clones. - shared_cache = self.shared_clone_cache - use_shared = shared_cache is not None - # Determine cache key components from the dep_ref. - cache_host = dep_ref.host or default_host() - cache_owner = dep_ref.repo_url.split("/")[0] if "/" in dep_ref.repo_url else "" - cache_repo = dep_ref.repo_url.split("/")[1] if "/" in dep_ref.repo_url else dep_ref.repo_url - - # WS3 (#1116): try persistent cross-run cache first. - # Build a canonical URL for cache key derivation. - _persistent_cache = self.persistent_git_cache - _persistent_checkout: Path | None = None - _resolved_sha_for_cache: str | None = None - if _persistent_cache is not None: - _canonical_url = f"https://{cache_host}/{cache_owner}/{cache_repo}" - try: - # Tiered ref resolution (perf #1433 follow-up): resolve - # the ref through the attached TieredRefResolver BEFORE - # calling get_checkout so the cache skips its internal - # ls-remote. Same pattern as the non-subdir path at - # line ~1604 which passes locked_sha=resolved. - try: - _resolved = self.resolve_git_reference(dep_ref) - _resolved_sha_for_cache = _resolved.resolved_commit - except Exception: - _resolved_sha_for_cache = None - # Sparse-cone (#1433): keying the persistent shard by - # (sha, subdir) ensures the cached working tree is the - # subdir only (<2 MB) instead of the full repo - # (~78 MB for dotnet/skills). Different subdirs of the - # same SHA land in separate variant shards; bare cache - # is unchanged so they still share object data. - _persistent_checkout = _persistent_cache.get_checkout( - _canonical_url, - _resolved_sha_for_cache or ref, - locked_sha=_resolved_sha_for_cache, - env=self._git_env_dict(), - sparse_paths=[subdir_path], - ) - except Exception: - # Cache miss or failure -- fall through to normal clone path. - _persistent_checkout = None - - # Use mkdtemp + explicit cleanup so we control when rmtree runs. - # tempfile.TemporaryDirectory().__exit__ calls shutil.rmtree without our - # retry logic, which raises WinError 32 when git processes still hold - # handles at the end of the with-block. - from ..config import get_apm_temp_dir - - temp_dir = None - shared_bare_path: Path | None = None - # WS2 path resolves the SHA from the BARE so we don't pay - # rev-parse twice (or open the working-tree Repo unnecessarily). - # See design.md sec 5.5: _ws2_resolved_commit threads the SHA past - # the generic Repo(temp_clone_path).head.commit.hexsha block below. - _ws2_resolved_commit: str | None = None - try: - if _persistent_checkout is not None: - # WS3: persistent cache hit -- use the cached checkout directly. - temp_clone_path = _persistent_checkout - if _perf_logger is not None: - _sha_short = ( - (ref or "")[:12] if ref and re.match(r"^[a-f0-9]{7,40}$", ref) else "" - ) - _perf_logger.subdir_download_start( - _dep_display, - cache_state="persistent-hit", - sha_short=_sha_short, - sparse_paths=[subdir_path], - ) - _perf_logger.materialize_result( - sparse_applied=True, - consumer_size_bytes=_dir_size_bytes(_persistent_checkout), - ) - elif use_shared: - # WS2 (#1126): shared cache holds BARE clones keyed by - # (host, owner, repo, ref). Each consumer materializes its - # own working tree from the bare; this is subdir-agnostic - # so two parallel consumers requesting different - # subdirectories of the same repo+ref can share one bare - # without racing on sparse-checkout. See design.md sec 5.5. - is_commit_sha = ref and re.match(r"^[a-f0-9]{7,40}$", ref) is not None - _perf_t0_bare = time.monotonic() - - def _shared_bare_clone_fn(bare_target: Path) -> None: - self._bare_clone_with_fallback( - dep_ref.repo_url, - bare_target, - dep_ref=dep_ref, - ref=ref, - is_commit_sha=bool(is_commit_sha), - ) - - def _shared_bare_fetch_fn(existing_bare: Path, ref_or_sha: str) -> bool: - # get_or_clone passes `ref` here; for SHA pins it is the SHA. - return self._fetch_sha_into_bare( - existing_bare, - ref_or_sha, - dep_ref=dep_ref, - ) - - try: - shared_bare_path = shared_cache.get_or_clone( - cache_host, - cache_owner, - cache_repo, - ref, - _shared_bare_clone_fn, - fetch_fn=_shared_bare_fetch_fn if is_commit_sha else None, - ) - except Exception as e: - raise RuntimeError(f"Failed to clone repository: {e}") from e - _perf_bare_elapsed_ms = int((time.monotonic() - _perf_t0_bare) * 1000) - if _perf_logger is not None: - _strategy = ( - f"init+fetch --depth=1 origin {ref[:12]}" - if is_commit_sha - else f"--depth=1 --branch {ref or ''}" - ) - _perf_logger.subdir_download_start( - _dep_display, - cache_state="shared-bare", - sha_short=ref[:12] if is_commit_sha and ref else "", - sparse_paths=[subdir_path], - ) - _perf_logger.bare_clone_strategy(_strategy, _perf_bare_elapsed_ms) - - # Per-consumer materialization. mkdtemp gives a unique - # path so concurrent consumers do not collide. The bare - # is read-only after this point; only the consumer dir - # is written to. - temp_dir = tempfile.mkdtemp(dir=get_apm_temp_dir()) - temp_clone_path = Path(temp_dir) / "consumer" - try: - _ws2_resolved_commit = self._materialize_from_bare( - shared_bare_path, - temp_clone_path, - ref=ref, - env=self._git_env_dict(), - # Only short-circuit SHA resolution when the user - # pinned a full 40-char SHA. Abbreviated SHAs - # (7-39 chars) must be resolved to the full - # SHA against the bare so resolved_commit - # matches `head.commit.hexsha` (always 40-char) - # in lockfile comparisons. The bare's HEAD has - # already been update-ref'd to the full SHA in - # _bare_action, so rev-parse HEAD returns 40 chars. - # Copilot review finding (#1135). - known_sha=ref if (is_commit_sha and len(ref) == 40) else None, - # Sparse-cone (#1433): materialize ONLY the - # subdirectory we need. Cuts the consumer - # working tree from full-repo to subdir-size - # on a typical monorepo (78 MB -> <2 MB for - # dotnet/skills). Bare cache is unchanged - # (subdir-agnostic) so multiple consumers - # requesting different subdirs of the same - # repo+SHA still share the object DB. - sparse_paths=[subdir_path], - ) - except Exception as e: - raise RuntimeError( - f"Failed to prepare dependency from cached clone: {e}" - ) from e - if _perf_logger is not None: - _perf_logger.materialize_result( - sparse_applied=True, - consumer_size_bytes=_dir_size_bytes(temp_clone_path), - ) - else: - # Legacy per-dep clone path (no shared cache). - temp_dir = tempfile.mkdtemp(dir=get_apm_temp_dir()) - # Sparse checkout always targets "repo/". If it fails we clone into - # "repo_clone/" so we never have to rmtree a directory that may still - # have live git handles from the failed subprocess. - sparse_clone_path = Path(temp_dir) / "repo" - temp_clone_path = sparse_clone_path - - # Update progress - cloning - if progress_obj and progress_task_id is not None: - progress_obj.update(progress_task_id, completed=20, total=100) - - # Phase 4 (#171): Try sparse-checkout first (git 2.25+), fall back to full clone - sparse_ok = self._try_sparse_checkout(dep_ref, sparse_clone_path, subdir_path, ref) - - if not sparse_ok: - # Full clone into a fresh subdirectory so we don't have to touch - # the (possibly locked) sparse-checkout directory at all. - temp_clone_path = Path(temp_dir) / "repo_clone" - - package_display_name = subdir_path.split("/")[-1] - progress_reporter = ( - GitProgressReporter(progress_task_id, progress_obj, package_display_name) - if progress_task_id and progress_obj - else None - ) - - # Detect if ref is a commit SHA (can't be used with --branch in shallow clones) - is_commit_sha = ref and re.match(r"^[a-f0-9]{7,40}$", ref) is not None - - clone_kwargs = { - "dep_ref": dep_ref, - } - if is_commit_sha: - # For commit SHAs, clone without checkout then checkout the specific commit. - # Shallow clone doesn't support fetching by arbitrary SHA. - clone_kwargs["no_checkout"] = True - else: - clone_kwargs["depth"] = 1 - if ref: - clone_kwargs["branch"] = ref - - try: - self._clone_with_fallback( - dep_ref.repo_url, - temp_clone_path, - progress_reporter=progress_reporter, - **clone_kwargs, - ) - except Exception as e: - raise RuntimeError(f"Failed to clone repository: {e}") from e - - if is_commit_sha: - repo_obj = None - try: - repo_obj = Repo(temp_clone_path) - repo_obj.git.checkout(ref) - except Exception as e: - raise RuntimeError(f"Failed to checkout commit {ref}: {e}") from e - finally: - _close_repo(repo_obj) - - # Disable progress reporter after clone - if progress_reporter: - progress_reporter.disabled = True - - # Update progress - extracting subdirectory - if progress_obj and progress_task_id is not None: - progress_obj.update(progress_task_id, completed=70, total=100) - - # Check if subdirectory exists - source_subdir = temp_clone_path / subdir_path - # Security: ensure subdirectory resolves within the cloned repo - from ..utils.path_security import ensure_path_within - - ensure_path_within(source_subdir, temp_clone_path) - if not source_subdir.exists(): - raise RuntimeError(f"Subdirectory '{subdir_path}' not found in repository") - - if not source_subdir.is_dir(): - raise RuntimeError(f"Path '{subdir_path}' is not a directory") - - # Create target directory - target_path.mkdir(parents=True, exist_ok=True) - - # If target exists and has content, remove it - if target_path.exists() and any(target_path.iterdir()): - _rmtree(target_path) - target_path.mkdir(parents=True, exist_ok=True) - - # Copy subdirectory contents to target (retry on transient - # file-lock errors caused by antivirus scanning on Windows). - from ..utils.file_ops import robust_copy2, robust_copytree - - for item in source_subdir.iterdir(): - src = source_subdir / item.name - dst = target_path / item.name - if src.is_dir(): - robust_copytree(src, dst) - else: - robust_copy2(src, dst) - - # Capture commit SHA; close the Repo object immediately so its file - # handles are released before _rmtree() runs in the finally block. - # WS2 path skips this because _materialize_from_bare already - # resolved the SHA from the bare (avoids opening Repo on the - # consumer dir, which leaks a Windows file handle that would - # block the rmtree below; see design.md sec 5.5). - if _ws2_resolved_commit is not None: - resolved_commit = _ws2_resolved_commit - else: - repo = None - try: - repo = Repo(temp_clone_path) - resolved_commit = repo.head.commit.hexsha - except Exception: - resolved_commit = "unknown" - finally: - _close_repo(repo) - - # Update progress - validating - if progress_obj and progress_task_id is not None: - progress_obj.update(progress_task_id, completed=90, total=100) - - except PermissionError as exc: - exc_path = getattr(exc, "filename", None) - # If temp_dir wasn't created (mkdtemp failed) or the error is within - # the temp tree, this is likely a restricted temp directory issue. - if temp_dir is None or (exc_path and str(exc_path).startswith(str(temp_dir))): - raise RuntimeError( - "Access denied in temporary directory" - + (f" '{temp_dir}'" if temp_dir else "") - + ". Corporate security may restrict this path. " - "Fix: apm config set temp-dir " - ) from None - raise - except OSError as exc: - if getattr(exc, "errno", None) == 13 or getattr(exc, "winerror", None) == 5: - exc_path = getattr(exc, "filename", None) - if temp_dir is None or (exc_path and str(exc_path).startswith(str(temp_dir))): - raise RuntimeError( - "Access denied in temporary directory" - + (f" '{temp_dir}'" if temp_dir else "") - + ". Corporate security may restrict this path. " - "Fix: apm config set temp-dir " - ) from None - raise - finally: - if temp_dir: - _rmtree(temp_dir) - - # Validate the extracted package (after temp dir is cleaned up) - validation_result = validate_apm_package(target_path) - if not validation_result.is_valid: - error_msgs = "; ".join(validation_result.errors) - raise RuntimeError( - f"Subdirectory is not a valid APM package or Claude Skill: {error_msgs}" - ) - - # Get the resolved reference for metadata - resolved_ref = ResolvedReference( - original_ref=ref or "default", - ref_name=ref or "default", - ref_type=GitReferenceType.BRANCH, - resolved_commit=resolved_commit, - ) - - # For plugins without an explicit version, stamp with the short commit SHA. - package = validation_result.package - from .package_validator import stamp_plugin_version - - stamp_plugin_version( - package, - validation_result.package_type, - resolved_commit, - target_path, - ) - - # Update progress - complete - if progress_obj and progress_task_id is not None: - progress_obj.update(progress_task_id, completed=100, total=100) - - return PackageInfo( - package=package, - install_path=target_path, - resolved_reference=resolved_ref, - installed_at=datetime.now().isoformat(), - dependency_ref=dep_ref, - package_type=validation_result.package_type, - ) + return _impl(self, dep_ref, target_path, progress_task_id, progress_obj) def _download_subdirectory_from_artifactory( self, @@ -1526,263 +710,10 @@ def download_package( progress_obj=None, verbose_callback=None, ) -> PackageInfo: - """Download a GitHub repository and validate it as an APM package. - - For virtual packages (individual files or collections), creates a minimal - package structure instead of cloning the full repository. - - Args: - repo_ref: Repository reference — either a DependencyReference object - or a string (e.g., "user/repo#branch"). Passing the object - directly avoids a lossy parse round-trip for generic git hosts. - target_path: Local path where package should be downloaded - progress_task_id: Rich Progress task ID for progress updates - progress_obj: Rich Progress object for progress updates - verbose_callback: Optional callable for verbose logging (receives str messages) - - Returns: - PackageInfo: Information about the downloaded package + """Download a GitHub repository and validate it as an APM package (delegated).""" + from .github_downloader_package_ops import download_package as _impl - Raises: - ValueError: If the repository reference is invalid - RuntimeError: If download or validation fails - """ - # Accept both string and DependencyReference to avoid lossy round-trips - if isinstance(repo_ref, DependencyReference): - dep_ref = repo_ref - else: - try: - dep_ref = DependencyReference.parse(repo_ref) - except ValueError as e: - raise ValueError(f"Invalid repository reference '{repo_ref}': {e}") from e - - # Handle virtual packages differently - if dep_ref.is_virtual: - art_proxy = self._parse_artifactory_base_url() - if self._is_artifactory_only() and not dep_ref.is_artifactory() and not art_proxy: - raise RuntimeError( - f"PROXY_REGISTRY_ONLY is set but no Artifactory proxy is configured for '{repo_ref}'. " - "Set PROXY_REGISTRY_URL or use explicit Artifactory FQDN syntax." - ) - if dep_ref.is_virtual_file(): - return self.download_virtual_file_package( - dep_ref, target_path, progress_task_id, progress_obj - ) - # SUBDIRECTORY (the only other virtual type after #1094 dropped - # the `.collection.yml` form): includes Artifactory modes. - if dep_ref.is_artifactory(): - proxy_info = (dep_ref.host, dep_ref.artifactory_prefix, "https") - return self._download_subdirectory_from_artifactory( - dep_ref, target_path, proxy_info, progress_task_id, progress_obj - ) - if self._is_artifactory_only() and art_proxy: - return self._download_subdirectory_from_artifactory( - dep_ref, target_path, art_proxy, progress_task_id, progress_obj - ) - return self.download_subdirectory_package( - dep_ref, target_path, progress_task_id, progress_obj - ) - - # Artifactory download path (Mode 1: explicit FQDN, Mode 2: transparent proxy) - use_artifactory = dep_ref.is_artifactory() - art_proxy = None - if not use_artifactory: - art_proxy = self._parse_artifactory_base_url() - if art_proxy and self._should_use_artifactory_proxy(dep_ref): - use_artifactory = True - - if use_artifactory: - return self._download_package_from_artifactory( - dep_ref, target_path, art_proxy, progress_task_id, progress_obj - ) - - # When PROXY_REGISTRY_ONLY is set but no Artifactory proxy matched, block direct git - if self._is_artifactory_only(): - raise RuntimeError( - f"PROXY_REGISTRY_ONLY is set but no Artifactory proxy is configured for '{dep_ref}'. " - "Set PROXY_REGISTRY_URL or use explicit Artifactory FQDN syntax." - ) - - # Regular package download (existing logic) - resolved_ref = self.resolve_git_reference(dep_ref) - - # Create target directory if it doesn't exist - target_path.mkdir(parents=True, exist_ok=True) - - # If directory already exists and has content, remove it - if target_path.exists() and any(target_path.iterdir()): - _rmtree(target_path) - target_path.mkdir(parents=True, exist_ok=True) - - # WS3 (#1116): persistent cross-run cache fast path for whole-repo - # deps. When a cached checkout exists for the resolved SHA, copy - # files directly into target_path and skip the network clone. - _persistent_cache = self.persistent_git_cache - if _persistent_cache is not None: - try: - cache_host = dep_ref.host or default_host() - cache_owner = dep_ref.repo_url.split("/")[0] if "/" in dep_ref.repo_url else "" - cache_repo = ( - dep_ref.repo_url.split("/")[1] if "/" in dep_ref.repo_url else dep_ref.repo_url - ) - _canonical_url = f"https://{cache_host}/{cache_owner}/{cache_repo}" - _cached = _persistent_cache.get_checkout( - _canonical_url, - resolved_ref.resolved_commit or resolved_ref.ref_name, - locked_sha=resolved_ref.resolved_commit, - env=self._git_env_dict(), - ) - from ..utils.file_ops import robust_copy2, robust_copytree - - for item in _cached.iterdir(): - if item.name == ".git": - continue - src = _cached / item.name - dst = target_path / item.name - if src.is_dir(): - robust_copytree(src, dst) - else: - robust_copy2(src, dst) - - # Validate, then return without cloning. - validation_result = validate_apm_package(target_path) - if validation_result.is_valid and validation_result.package: - package = validation_result.package - package.source = dep_ref.to_github_url() - package.resolved_commit = resolved_ref.resolved_commit - if ( - validation_result.package_type == PackageType.MARKETPLACE_PLUGIN - and package.version == "0.0.0" - and resolved_ref.resolved_commit - ): - short_sha = resolved_ref.resolved_commit[:7] - package.version = short_sha - apm_yml_path = target_path / "apm.yml" - if apm_yml_path.exists(): - from ..utils.yaml_io import dump_yaml, load_yaml - - _data = load_yaml(apm_yml_path) or {} - _data["version"] = short_sha - dump_yaml(_data, apm_yml_path) - return PackageInfo( - package=package, - install_path=target_path, - resolved_reference=resolved_ref, - installed_at=datetime.now().isoformat(), - dependency_ref=dep_ref, - package_type=validation_result.package_type, - ) - # Validation failed against cached copy: fall through to a - # fresh clone (cache may be stale or repo structure changed). - if target_path.exists() and any(target_path.iterdir()): - _rmtree(target_path) - target_path.mkdir(parents=True, exist_ok=True) - except Exception: - # Any cache failure -> fall back to network clone. - if target_path.exists() and any(target_path.iterdir()): - _rmtree(target_path) - target_path.mkdir(parents=True, exist_ok=True) - - # Store progress reporter so we can disable it after clone - progress_reporter = None - package_display_name = ( - dep_ref.repo_url.split("/")[-1] if "/" in dep_ref.repo_url else dep_ref.repo_url - ) - - try: - # Clone the repository using fallback authentication methods - # Use shallow clone for performance if we have a specific commit - if resolved_ref.ref_type == GitReferenceType.COMMIT: - # For commits, we need to clone and checkout the specific commit - progress_reporter = ( - GitProgressReporter(progress_task_id, progress_obj, package_display_name) - if progress_task_id and progress_obj - else None - ) - repo = self._clone_with_fallback( - dep_ref.repo_url, - target_path, - progress_reporter=progress_reporter, - dep_ref=dep_ref, - verbose_callback=verbose_callback, - ) - repo.git.checkout(resolved_ref.resolved_commit) - else: - # For branches and tags, we can use shallow clone - progress_reporter = ( - GitProgressReporter(progress_task_id, progress_obj, package_display_name) - if progress_task_id and progress_obj - else None - ) - repo = self._clone_with_fallback( - dep_ref.repo_url, - target_path, - progress_reporter=progress_reporter, - dep_ref=dep_ref, - verbose_callback=verbose_callback, - depth=1, - branch=resolved_ref.ref_name, - ) - - # Disable progress reporter to prevent late git updates - if progress_reporter: - progress_reporter.disabled = True - - # Remove .git directory to save space and prevent treating as a Git repository - git_dir = target_path / ".git" - if git_dir.exists(): - _rmtree(git_dir) - - except GitCommandError as e: - # Check if this might be a private repository access issue - if "Authentication failed" in str(e) or "remote: Repository not found" in str(e): - error_msg = f"Failed to clone repository {dep_ref.repo_url}. " - host = dep_ref.host or default_host() - org = dep_ref.repo_url.split("/")[0] if dep_ref.repo_url else None - error_msg += self.auth_resolver.build_error_context( - host, - "clone", - org=org, - port=dep_ref.port, - dep_url=dep_ref.repo_url, - ) - raise RuntimeError(error_msg) from e - else: - sanitized_error = self._sanitize_git_error(str(e)) - raise RuntimeError( - f"Failed to clone repository {dep_ref.repo_url}: {sanitized_error}" - ) from e - except RuntimeError: - # Re-raise RuntimeError from _clone_with_fallback - raise - - # Validate the downloaded package - from ._shared import _validate_and_load_package - - validation_result = validate_apm_package(target_path) - package = _validate_and_load_package(validation_result, target_path, dep_ref) - package.resolved_commit = resolved_ref.resolved_commit - - # For plugins without an explicit version, use the short commit SHA so the - # lock file and conflict detection have a meaningful, stable version string. - from .package_validator import stamp_plugin_version - - stamp_plugin_version( - package, - validation_result.package_type, - resolved_ref.resolved_commit, - target_path, - ) - - # Create and return PackageInfo - return PackageInfo( - package=package, - install_path=target_path, - resolved_reference=resolved_ref, - installed_at=datetime.now().isoformat(), - dependency_ref=dep_ref, # Store for canonical dependency string - package_type=validation_result.package_type, # Track if APM, Claude Skill, or Hybrid - ) + return _impl(self, repo_ref, target_path, progress_task_id, progress_obj, verbose_callback) def _get_clone_progress_callback(self): """Get a progress callback for Git clone operations. diff --git a/src/apm_cli/deps/github_downloader_package_ops.py b/src/apm_cli/deps/github_downloader_package_ops.py new file mode 100644 index 000000000..073641008 --- /dev/null +++ b/src/apm_cli/deps/github_downloader_package_ops.py @@ -0,0 +1,525 @@ +"""Whole-repo / virtual-file / sparse-checkout ops for :class:`GitHubPackageDownloader`. + +Moved bodies (kept thin wrappers on the class): ``download_package`` and its +``_package_*`` helpers, ``download_virtual_file_package`` and its +``_virtual_*`` helpers, and ``try_sparse_checkout``. Cross-cluster calls +(e.g. routing a virtual dep to the subdirectory handler) go through the +class wrappers via ``downloader.`` so they stay monkeypatch-safe and +form no import cycle. Patched globals are routed through ``_gh.``. +""" + +import os +import re +from datetime import datetime +from pathlib import Path +from typing import Union + +from git.exc import GitCommandError + +from ..models.apm_package import ( + DependencyReference, + GitReferenceType, + PackageInfo, + PackageType, + ResolvedReference, +) + + +def download_virtual_file_package( + downloader, + dep_ref: DependencyReference, + target_path: Path, + progress_task_id=None, + progress_obj=None, +) -> PackageInfo: + """Download a single file as a virtual APM package. + + Creates a minimal APM package structure with the file placed in the + appropriate .apm/ subdirectory based on its extension. + """ + from apm_cli.deps import github_downloader as _gh + + if not dep_ref.is_virtual or not dep_ref.virtual_path: + raise ValueError("Dependency must be a virtual file package") + + if not dep_ref.is_virtual_file(): + raise ValueError( + f"Path '{dep_ref.virtual_path}' is not a valid individual file. " + f"Must end with one of: {', '.join(DependencyReference.VIRTUAL_FILE_EXTENSIONS)}" + ) + + # Determine the ref to use + ref = dep_ref.reference or "main" + + # Resolve the commit SHA cheaply BEFORE the file download (one short HTTP + # call). On non-GitHub hosts or any failure this returns None and we fall + # back to ref-name only -- the install never fails on SHA resolution. + resolved_commit = downloader._resolve_commit_sha_for_ref(dep_ref, ref) + + if progress_obj and progress_task_id is not None: + progress_obj.update(progress_task_id, completed=50, total=100) + + try: + file_content = downloader.download_raw_file(dep_ref, dep_ref.virtual_path, ref) + except RuntimeError as e: + raise RuntimeError(f"Failed to download virtual package: {e}") from e + + if progress_obj and progress_task_id is not None: + progress_obj.update(progress_task_id, completed=90, total=100) + + target_path.mkdir(parents=True, exist_ok=True) + + subdir = _virtual_subdir_for(dep_ref.virtual_path) + if not subdir: + raise ValueError(f"Unknown file extension for {dep_ref.virtual_path}") + + filename = dep_ref.virtual_path.split("/")[-1] + apm_dir = target_path / ".apm" / subdir + apm_dir.mkdir(parents=True, exist_ok=True) + + file_path = apm_dir / filename + file_path.write_bytes(file_content) + + package_name = dep_ref.get_virtual_package_name() + description = _virtual_description(file_content, filename) + + apm_yml_data = { + "name": package_name, + "version": "1.0.0", + "description": description, + "author": dep_ref.repo_url.split("/")[0], + } + apm_yml_content = _gh.yaml_to_str(apm_yml_data) + + apm_yml_path = target_path / "apm.yml" + apm_yml_path.write_text(apm_yml_content, encoding="utf-8") + + package = _gh.APMPackage( + name=package_name, + version="1.0.0", + description=description, + author=dep_ref.repo_url.split("/")[0], + source=dep_ref.to_github_url(), + package_path=target_path, + ) + + # Build the resolved reference. On non-GitHub hosts or SHA-resolve failure + # the resolved_commit stays None and the suffix renders as "#ref" only. + ref_type = ( + GitReferenceType.COMMIT + if re.match(r"^[a-f0-9]{40}$", ref.lower()) + else GitReferenceType.BRANCH + ) + resolved_ref = ResolvedReference( + original_ref=str(dep_ref.reference) if dep_ref.reference else ref, + ref_name=ref, + ref_type=ref_type, + resolved_commit=resolved_commit, + ) + + return PackageInfo( + package=package, + install_path=target_path, + installed_at=datetime.now().isoformat(), + dependency_ref=dep_ref, + resolved_reference=resolved_ref, + ) + + +def _virtual_subdir_for(virtual_path: str) -> str | None: + """Map a virtual file path's extension to its .apm subdirectory.""" + subdirs = { + ".prompt.md": "prompts", + ".instructions.md": "instructions", + ".chatmode.md": "chatmodes", + ".agent.md": "agents", + } + for ext, dir_name in subdirs.items(): + if virtual_path.endswith(ext): + return dir_name + return None + + +def _virtual_description(file_content: bytes, filename: str) -> str: + """Extract a description from YAML frontmatter, or build a default one.""" + description = f"Virtual package containing {filename}" + try: + content_str = file_content.decode("utf-8") + if content_str.startswith("---\n"): + end_idx = content_str.find("\n---\n", 4) + if end_idx > 0: + frontmatter = content_str[4:end_idx] + for line in frontmatter.split("\n"): + if line.startswith("description:"): + return line.split(":", 1)[1].strip().strip("\"'") + except Exception: + # If frontmatter parsing fails, use the default description. + pass + return description + + +def try_sparse_checkout( + downloader, + dep_ref: DependencyReference, + temp_clone_path: Path, + subdir_path: str, + ref: str | None = None, +) -> bool: + """Attempt sparse-checkout to download only a subdirectory (git 2.25+). + + Returns True on success. Falls back silently on failure. + """ + from apm_cli.deps import github_downloader as _gh + + try: + temp_clone_path.mkdir(parents=True, exist_ok=True) + + # Resolve per-dependency token via AuthResolver. + dep_token = downloader._resolve_dep_token(dep_ref) + dep_auth_ctx = downloader._resolve_dep_auth_ctx(dep_ref) + dep_auth_scheme = dep_auth_ctx.auth_scheme if dep_auth_ctx else "basic" + + # For ADO bearer, use the AuthContext git_env with header injection. + if dep_auth_scheme == "bearer" and dep_auth_ctx is not None: + env = {**os.environ, **(dep_auth_ctx.git_env or {})} + else: + env = {**os.environ, **(downloader.git_env or {})} + auth_url = downloader._build_repo_url( + dep_ref.repo_url, + use_ssh=False, + dep_ref=dep_ref, + token=dep_token, + auth_scheme=dep_auth_scheme, + ) + + cmds = [ + ["git", "init"], + ["git", "remote", "add", "origin", auth_url], + ["git", "sparse-checkout", "init", "--cone"], + ["git", "sparse-checkout", "set", subdir_path], + ] + fetch_cmd = ["git", "fetch", "origin"] + fetch_cmd.append(ref or "HEAD") + fetch_cmd.append("--depth=1") + cmds.append(fetch_cmd) + cmds.append(["git", "checkout", "FETCH_HEAD"]) + + for cmd in cmds: + result = _gh.subprocess.run( + cmd, + cwd=str(temp_clone_path), + env=env, + capture_output=True, + text=True, + encoding="utf-8", + timeout=120, + ) + if result.returncode != 0: + _gh._debug( + f"Sparse-checkout step failed ({' '.join(cmd)}): {result.stderr.strip()}" + ) + return False + + return True + except Exception as e: + _gh._debug(f"Sparse-checkout failed: {e}") + return False + + +def download_package( + downloader, + repo_ref: Union[str, "DependencyReference"], + target_path: Path, + progress_task_id=None, + progress_obj=None, + verbose_callback=None, +) -> PackageInfo: + """Download a GitHub repository and validate it as an APM package. + + For virtual packages (individual files or subdirectories), creates a + minimal package structure / extracts the subdir instead of cloning the + full repository. Artifactory FQDN/proxy modes route to the Artifactory + orchestrator. A persistent cross-run cache fast-path may skip the clone. + """ + from apm_cli.deps import github_downloader as _gh + + # Accept both string and DependencyReference to avoid lossy round-trips. + if isinstance(repo_ref, DependencyReference): + dep_ref = repo_ref + else: + try: + dep_ref = DependencyReference.parse(repo_ref) + except ValueError as e: + raise ValueError(f"Invalid repository reference '{repo_ref}': {e}") from e + + # Handle virtual packages differently. + if dep_ref.is_virtual: + return _package_download_virtual( + downloader, dep_ref, repo_ref, target_path, progress_task_id, progress_obj + ) + + # Artifactory download path (Mode 1: explicit FQDN, Mode 2: transparent proxy). + use_artifactory = dep_ref.is_artifactory() + art_proxy = None + if not use_artifactory: + art_proxy = downloader._parse_artifactory_base_url() + if art_proxy and downloader._should_use_artifactory_proxy(dep_ref): + use_artifactory = True + + if use_artifactory: + return downloader._download_package_from_artifactory( + dep_ref, target_path, art_proxy, progress_task_id, progress_obj + ) + + # PROXY_REGISTRY_ONLY set but no Artifactory proxy matched -> block direct git. + if downloader._is_artifactory_only(): + raise RuntimeError( + f"PROXY_REGISTRY_ONLY is set but no Artifactory proxy is configured for '{dep_ref}'. " + "Set PROXY_REGISTRY_URL or use explicit Artifactory FQDN syntax." + ) + + # Regular package download. + resolved_ref = downloader.resolve_git_reference(dep_ref) + + target_path.mkdir(parents=True, exist_ok=True) + if target_path.exists() and any(target_path.iterdir()): + _gh._rmtree(target_path) + target_path.mkdir(parents=True, exist_ok=True) + + # WS3: persistent cross-run cache fast path for whole-repo deps. + if downloader.persistent_git_cache is not None: + cached = _package_try_persistent_cache(downloader, dep_ref, resolved_ref, target_path) + if cached is not None: + return cached + + _package_clone_repo( + downloader, + dep_ref, + resolved_ref, + target_path, + progress_task_id, + progress_obj, + verbose_callback, + ) + return _package_finalize(downloader, dep_ref, resolved_ref, target_path) + + +def _package_download_virtual( + downloader, dep_ref, repo_ref, target_path, progress_task_id, progress_obj +) -> PackageInfo: + """Route a virtual dep to file / subdirectory / Artifactory handlers.""" + art_proxy = downloader._parse_artifactory_base_url() + if downloader._is_artifactory_only() and not dep_ref.is_artifactory() and not art_proxy: + raise RuntimeError( + f"PROXY_REGISTRY_ONLY is set but no Artifactory proxy is configured for '{repo_ref}'. " + "Set PROXY_REGISTRY_URL or use explicit Artifactory FQDN syntax." + ) + if dep_ref.is_virtual_file(): + return downloader.download_virtual_file_package( + dep_ref, target_path, progress_task_id, progress_obj + ) + # SUBDIRECTORY (the only other virtual type): includes Artifactory modes. + if dep_ref.is_artifactory(): + proxy_info = (dep_ref.host, dep_ref.artifactory_prefix, "https") + return downloader._download_subdirectory_from_artifactory( + dep_ref, target_path, proxy_info, progress_task_id, progress_obj + ) + if downloader._is_artifactory_only() and art_proxy: + return downloader._download_subdirectory_from_artifactory( + dep_ref, target_path, art_proxy, progress_task_id, progress_obj + ) + return downloader.download_subdirectory_package( + dep_ref, target_path, progress_task_id, progress_obj + ) + + +def _package_try_persistent_cache(downloader, dep_ref, resolved_ref, target_path): + """Copy a cached checkout into target_path, validate, and build PackageInfo. + + Returns the PackageInfo on a usable cache hit, or None to fall through to a + fresh network clone (cache miss, stale copy, or validation failure). + """ + from apm_cli.deps import github_downloader as _gh + + persistent_cache = downloader.persistent_git_cache + try: + cache_host = dep_ref.host or _gh.default_host() + cache_owner = dep_ref.repo_url.split("/")[0] if "/" in dep_ref.repo_url else "" + cache_repo = dep_ref.repo_url.split("/")[1] if "/" in dep_ref.repo_url else dep_ref.repo_url + canonical_url = f"https://{cache_host}/{cache_owner}/{cache_repo}" + cached = persistent_cache.get_checkout( + canonical_url, + resolved_ref.resolved_commit or resolved_ref.ref_name, + locked_sha=resolved_ref.resolved_commit, + env=downloader._git_env_dict(), + ) + from ..utils.file_ops import robust_copy2, robust_copytree + + for item in cached.iterdir(): + if item.name == ".git": + continue + src = cached / item.name + dst = target_path / item.name + if src.is_dir(): + robust_copytree(src, dst) + else: + robust_copy2(src, dst) + + validation_result = _gh.validate_apm_package(target_path) + if validation_result.is_valid and validation_result.package: + return _package_info_from_cache(validation_result, dep_ref, resolved_ref, target_path) + # Validation failed against cached copy: fall through to a fresh clone. + _package_clean_target(target_path) + except Exception: + # Any cache failure -> fall back to network clone. + _package_clean_target(target_path) + return None + + +def _package_info_from_cache(validation_result, dep_ref, resolved_ref, target_path) -> PackageInfo: + """Build PackageInfo from a validated cached checkout (stamping plugin version).""" + package = validation_result.package + package.source = dep_ref.to_github_url() + package.resolved_commit = resolved_ref.resolved_commit + if ( + validation_result.package_type == PackageType.MARKETPLACE_PLUGIN + and package.version == "0.0.0" + and resolved_ref.resolved_commit + ): + short_sha = resolved_ref.resolved_commit[:7] + package.version = short_sha + apm_yml_path = target_path / "apm.yml" + if apm_yml_path.exists(): + from ..utils.yaml_io import dump_yaml, load_yaml + + data = load_yaml(apm_yml_path) or {} + data["version"] = short_sha + dump_yaml(data, apm_yml_path) + return PackageInfo( + package=package, + install_path=target_path, + resolved_reference=resolved_ref, + installed_at=datetime.now().isoformat(), + dependency_ref=dep_ref, + package_type=validation_result.package_type, + ) + + +def _package_clean_target(target_path) -> None: + """Remove target_path contents so a fresh clone starts clean.""" + from apm_cli.deps import github_downloader as _gh + + if target_path.exists() and any(target_path.iterdir()): + _gh._rmtree(target_path) + target_path.mkdir(parents=True, exist_ok=True) + + +def _package_clone_repo( + downloader, + dep_ref, + resolved_ref, + target_path, + progress_task_id, + progress_obj, + verbose_callback, +) -> None: + """Clone the repo (shallow for branches/tags, checkout for commits) and drop .git.""" + from apm_cli.deps import github_downloader as _gh + + progress_reporter = None + package_display_name = ( + dep_ref.repo_url.split("/")[-1] if "/" in dep_ref.repo_url else dep_ref.repo_url + ) + + try: + if resolved_ref.ref_type == GitReferenceType.COMMIT: + progress_reporter = ( + _gh.GitProgressReporter(progress_task_id, progress_obj, package_display_name) + if progress_task_id and progress_obj + else None + ) + repo = downloader._clone_with_fallback( + dep_ref.repo_url, + target_path, + progress_reporter=progress_reporter, + dep_ref=dep_ref, + verbose_callback=verbose_callback, + ) + repo.git.checkout(resolved_ref.resolved_commit) + else: + progress_reporter = ( + _gh.GitProgressReporter(progress_task_id, progress_obj, package_display_name) + if progress_task_id and progress_obj + else None + ) + repo = downloader._clone_with_fallback( + dep_ref.repo_url, + target_path, + progress_reporter=progress_reporter, + dep_ref=dep_ref, + verbose_callback=verbose_callback, + depth=1, + branch=resolved_ref.ref_name, + ) + + if progress_reporter: + progress_reporter.disabled = True + + # Remove .git to save space and prevent treating target as a Git repo. + git_dir = target_path / ".git" + if git_dir.exists(): + _gh._rmtree(git_dir) + except GitCommandError as e: + _package_raise_clone_error(downloader, dep_ref, e) + except RuntimeError: + # Re-raise RuntimeError from _clone_with_fallback. + raise + + +def _package_raise_clone_error(downloader, dep_ref, e: GitCommandError) -> None: + """Translate a GitCommandError into an actionable RuntimeError.""" + from apm_cli.deps import github_downloader as _gh + + if "Authentication failed" in str(e) or "remote: Repository not found" in str(e): + error_msg = f"Failed to clone repository {dep_ref.repo_url}. " + host = dep_ref.host or _gh.default_host() + org = dep_ref.repo_url.split("/")[0] if dep_ref.repo_url else None + error_msg += downloader.auth_resolver.build_error_context( + host, + "clone", + org=org, + port=dep_ref.port, + dep_url=dep_ref.repo_url, + ) + raise RuntimeError(error_msg) from e + sanitized_error = downloader._sanitize_git_error(str(e)) + raise RuntimeError(f"Failed to clone repository {dep_ref.repo_url}: {sanitized_error}") from e + + +def _package_finalize(downloader, dep_ref, resolved_ref, target_path) -> PackageInfo: + """Validate the cloned package, stamp version, and build PackageInfo.""" + from apm_cli.deps import github_downloader as _gh + + from ._shared import _validate_and_load_package + from .package_validator import stamp_plugin_version + + validation_result = _gh.validate_apm_package(target_path) + package = _validate_and_load_package(validation_result, target_path, dep_ref) + package.resolved_commit = resolved_ref.resolved_commit + + # For plugins without an explicit version, use the short commit SHA. + stamp_plugin_version( + package, + validation_result.package_type, + resolved_ref.resolved_commit, + target_path, + ) + + return PackageInfo( + package=package, + install_path=target_path, + resolved_reference=resolved_ref, + installed_at=datetime.now().isoformat(), + dependency_ref=dep_ref, + package_type=validation_result.package_type, + ) diff --git a/src/apm_cli/deps/github_downloader_setup_ops.py b/src/apm_cli/deps/github_downloader_setup_ops.py new file mode 100644 index 000000000..c0dee900c --- /dev/null +++ b/src/apm_cli/deps/github_downloader_setup_ops.py @@ -0,0 +1,247 @@ +"""Setup / auth / raw-file ops for :class:`GitHubPackageDownloader`. + +Moved bodies (kept thin wrappers on the class): constructor wiring, git-env +setup, error sanitisation, per-dependency auth resolution, and raw-file +download routing. Patched globals are routed through a function-level +``from apm_cli.deps import github_downloader as _gh`` alias so monkeypatches +on the original module still apply; no module-scope import of the original +module (avoids an import cycle). +""" + +import os +import re + +from ..models.apm_package import DependencyReference +from ..utils.github_host import sanitize_token_url_in_message + + +def setup_git_environment(downloader) -> dict: + """Set up Git environment with authentication using centralized token manager. + + Builds the auth-bearing env via :class:`GitAuthEnvBuilder`, then records + token-state attributes on the downloader (read by many other methods). + """ + from apm_cli.deps import github_downloader as _gh + + from .git_auth_env import GitAuthEnvBuilder + + builder = GitAuthEnvBuilder(downloader.token_manager) + env = builder.setup_environment() + + # IMPORTANT: Do not resolve credentials via helpers at construction time. + # AuthResolver.resolve(...) can trigger OS credential helper UI. If we do + # this eagerly (host-only key) and later resolve per-dependency (host+org), + # users can see duplicate auth prompts. Keep constructor token state env-only + # and resolve lazily per dependency during clone/validate flows. + downloader.github_token = downloader.token_manager.get_token_for_purpose("modules", env) + downloader.has_github_token = downloader.github_token is not None + downloader._github_token_from_credential_fill = False + + # GitLab (env-only at init; lazy auth resolution happens per dep) + downloader.gitlab_token = downloader.token_manager.get_token_for_purpose("gitlab_modules", env) + downloader.has_gitlab_token = downloader.gitlab_token is not None + + # Azure DevOps (env-only at init; lazy auth resolution happens per dep) + downloader.ado_token = downloader.token_manager.get_token_for_purpose("ado_modules", env) + downloader.has_ado_token = downloader.ado_token is not None + + # JFrog Artifactory (not host-based, uses dedicated env var) + downloader.artifactory_token = downloader.token_manager.get_token_for_purpose( + "artifactory_modules", env + ) + downloader.has_artifactory_token = downloader.artifactory_token is not None + + _gh._debug( + f"Token setup: has_github_token={downloader.has_github_token}, " + f"has_gitlab_token={downloader.has_gitlab_token}, " + f"has_ado_token={downloader.has_ado_token}, " + f"has_artifactory_token={downloader.has_artifactory_token}" + f"{', source=credential_helper' if downloader._github_token_from_credential_fill else ''}" + ) + + return env + + +def sanitize_git_error(downloader, error_message: str) -> str: + """Sanitize Git error messages to remove potentially sensitive auth information.""" + from apm_cli.deps import github_downloader as _gh + + # Remove any tokens that might appear in URLs for github hosts (https://token@host). + sanitized = sanitize_token_url_in_message(error_message, host=_gh.default_host()) + + # Sanitize Azure DevOps URLs - both cloud (dev.azure.com) and any on-prem server. + # Generic pattern catches https://token@anyhost for all hosts. + sanitized = re.sub(r"https://[^@\s]+@([^\s/]+)", r"https://***@\1", sanitized) + + # Remove any tokens that might appear as standalone values. + sanitized = re.sub( + r"(ghp_|gho_|ghu_|ghs_|ghr_|glpat[_-])[a-zA-Z0-9_\-]+", + "***", + sanitized, + ) + + # Remove environment variable values that might contain tokens. + sanitized = re.sub( + r"(GITHUB_TOKEN|GITHUB_APM_PAT|ADO_APM_PAT|GH_TOKEN|GITHUB_COPILOT_PAT|GITLAB_APM_PAT|GITLAB_TOKEN)=[^\s]+", + r"\1=***", + sanitized, + ) + + return sanitized + + +def resolve_dep_token(downloader, dep_ref: DependencyReference | None = None) -> str | None: + """Resolve the per-dependency auth token via AuthResolver. + + GitHub, GitLab, and ADO hosts use the token resolved by AuthResolver. + Other generic hosts return None so git credential helpers can provide + credentials instead. + """ + if dep_ref is None: + return downloader.github_token + + if downloader._is_generic_dependency_host(dep_ref): + return None + + dep_ctx = downloader.auth_resolver.resolve_for_dep(dep_ref) + return dep_ctx.token + + +def resolve_dep_auth_ctx(downloader, dep_ref: DependencyReference | None = None): + """Resolve the full AuthContext for a dependency. + + Returns the AuthContext from AuthResolver, or None for generic hosts or + when no dep_ref is provided. + """ + if dep_ref is None: + return None + + dep_host = dep_ref.host + if downloader._is_generic_dependency_host(dep_ref): + return None + + ctx = downloader.auth_resolver.resolve_for_dep(dep_ref) + # Verbose source surfacing (#852): one-time per-host log line so users can + # see which credential source was actually used. Routed through + # AuthResolver.notify_auth_source() (#856 follow-up F2). + if os.environ.get("APM_VERBOSE") == "1": + downloader.auth_resolver.notify_auth_source(dep_host or "", ctx) + return ctx + + +def download_raw_file( + downloader, + dep_ref: DependencyReference, + file_path: str, + ref: str = "main", + verbose_callback=None, +) -> bytes: + """Download a single file from a repository (GitHub, GitLab, ADO, Artifactory).""" + from apm_cli.deps import github_downloader as _gh + + _ = dep_ref.host or _gh.default_host() + + # Check if this is Artifactory (Mode 1: explicit FQDN) + if dep_ref.is_artifactory(): + repo_parts = dep_ref.repo_url.split("/") + return downloader._download_file_from_artifactory( + dep_ref.host, + dep_ref.artifactory_prefix, + repo_parts[0], + repo_parts[1] if len(repo_parts) > 1 else repo_parts[0], + file_path, + ref, + ) + + # Check if this should go through Artifactory proxy (Mode 2) + art_proxy = downloader._parse_artifactory_base_url() + if art_proxy and downloader._should_use_artifactory_proxy(dep_ref): + repo_parts = dep_ref.repo_url.split("/") + return downloader._download_file_from_artifactory( + art_proxy[0], + art_proxy[1], + repo_parts[0], + repo_parts[1] if len(repo_parts) > 1 else repo_parts[0], + file_path, + ref, + scheme=art_proxy[2], + ) + + # Check if this is Azure DevOps + if dep_ref.is_azure_devops(): + return downloader._download_ado_file(dep_ref, file_path, ref) + + # GitHub API + return downloader._download_github_file( + dep_ref, file_path, ref, verbose_callback=verbose_callback + ) + + +def init_downloader( + downloader, auth_resolver, transport_selector, protocol_pref, allow_fallback +) -> None: + """Wire up a freshly-constructed :class:`GitHubPackageDownloader`. + + Resolves auth/transport defaults, builds the delegate + orchestrator + collaborators, and declares the install-pipeline-attached fields + (shared/persistent caches, tiered resolver, install logger) so they are + part of the documented surface rather than monkey-patched fields. + """ + import threading + + from apm_cli.deps import github_downloader as _gh + + downloader.auth_resolver = auth_resolver or _gh.AuthResolver() + downloader.token_manager = downloader.auth_resolver._token_manager # Backward compat + downloader.git_env = downloader._setup_git_environment() + downloader._transport_selector = transport_selector or _gh.TransportSelector() + if protocol_pref is not None: + downloader._protocol_pref = protocol_pref + else: + # Config-aware helper (env > apm config > None) so ``apm config set ssh + # true`` is honoured even when constructed without explicit args. + from ..config import get_apm_protocol_pref as _get_pref + from .transport_selection import ProtocolPreference + + downloader._protocol_pref = ProtocolPreference.from_str(_get_pref()) + if allow_fallback is not None: + downloader._allow_fallback = allow_fallback + else: + # Config-aware helper (env > apm config > False). + from ..config import get_apm_allow_protocol_fallback as _get_fallback + + downloader._allow_fallback = _get_fallback() + # Dedup set for the issue #786 cross-protocol port warning: one install run + # calls _clone_with_fallback multiple times per dep. We want the warning + # exactly once per (host, repo, port) identity across all those calls. + downloader._fallback_port_warned: set = set() + downloader._fallback_port_warned_lock = threading.Lock() + + # Delegate backend-specific download logic to the download delegate. + downloader._strategies = _gh.DownloadDelegate(host=downloader) + + # Artifactory orchestration is encapsulated in a dedicated facade backed by + # the DownloadDelegate's HTTP archive downloader. + from .artifactory_orchestrator import ArtifactoryOrchestrator + from .clone_engine import CloneEngine + from .git_reference_resolver import GitReferenceResolver + + downloader._artifactory = ArtifactoryOrchestrator(archive_downloader=downloader._strategies) + downloader._refs = GitReferenceResolver(host=downloader) + downloader._clone_engine = CloneEngine(host=downloader) + + # WS2a (#1116): per-run shared clone cache for subdirectory dep dedup. Set + # by the install pipeline before resolution; None means no dedup. + downloader.shared_clone_cache = None + + # WS3 (#1116): persistent cross-run git cache. When set, the download flow + # checks the on-disk cache before any network clone. None disables it. + downloader.persistent_git_cache = None + + # #1369: tiered ref resolver. Attached by resolve.py / outdated.py after + # construction. When set, resolve_git_reference delegates to it. + downloader._tiered_resolver = None + + # Perf #1433: optional InstallLogger attached by the install pipeline. When + # set, the subdir download path emits structured verbose-only [perf] lines. + downloader.install_logger = None diff --git a/src/apm_cli/deps/github_downloader_subdir_ops.py b/src/apm_cli/deps/github_downloader_subdir_ops.py new file mode 100644 index 000000000..e402dfd12 --- /dev/null +++ b/src/apm_cli/deps/github_downloader_subdir_ops.py @@ -0,0 +1,448 @@ +"""Subdirectory-package ops for :class:`GitHubPackageDownloader`. + +Moved body (kept a thin wrapper on the class): ``download_subdirectory_package`` +decomposed per cache tier (persistent WS3, shared-bare WS2, legacy +sparse/plain-clone fallback). The tiers are intentionally distinct and must +not be merged. Patched globals are routed through a function-level +``from apm_cli.deps import github_downloader as _gh`` alias. +""" + +import re +from datetime import datetime +from pathlib import Path + +from ..models.apm_package import ( + DependencyReference, + GitReferenceType, + PackageInfo, + ResolvedReference, +) + + +class _SubdirCloneState: + """Mutable lifecycle holder so a tier helper can register its temp dir. + + ``temp_dir`` must be visible to the orchestrator's ``finally`` block the + moment a helper creates it (so a mid-clone failure still cleans up). + ``ws2_resolved_commit`` carries the SHA the bare-cache path already + resolved, letting the extract step skip re-opening the working tree. + """ + + __slots__ = ("temp_dir", "ws2_resolved_commit") + + def __init__(self): + self.temp_dir = None + self.ws2_resolved_commit = None + + +def download_subdirectory_package( + downloader, + dep_ref: DependencyReference, + target_path: Path, + progress_task_id=None, + progress_obj=None, +) -> PackageInfo: + """Download a subdirectory from a repo as an APM package. + + Used for Claude Skills or APM packages nested in monorepos. Clones the + repo (through whichever cache tier applies), extracts the subdirectory, + validates it, and cleans up. + + The cache tiers are intentionally distinct and must not be merged: + persistent cross-run cache (WS3), per-run shared bare clone (WS2), and the + legacy per-dep sparse/full clone fallback. Each is its own helper. + """ + from apm_cli.deps import github_downloader as _gh + + if not dep_ref.is_virtual or not dep_ref.virtual_path: + raise ValueError("Dependency must be a virtual subdirectory package") + + if not dep_ref.is_virtual_subdirectory(): + raise ValueError(f"Path '{dep_ref.virtual_path}' is not a valid subdirectory package") + + # Use user-specified ref, or None to use repo's default branch + ref = dep_ref.reference + subdir_path = dep_ref.virtual_path + perf_logger = getattr(downloader, "install_logger", None) + dep_display = str(dep_ref) + + if progress_obj and progress_task_id is not None: + progress_obj.update(progress_task_id, completed=10, total=100) + + shared_cache = downloader.shared_clone_cache + use_shared = shared_cache is not None + cache_host = dep_ref.host or _gh.default_host() + cache_owner = dep_ref.repo_url.split("/")[0] if "/" in dep_ref.repo_url else "" + cache_repo = dep_ref.repo_url.split("/")[1] if "/" in dep_ref.repo_url else dep_ref.repo_url + + # WS3: try persistent cross-run cache first. + persistent_checkout: Path | None = None + if downloader.persistent_git_cache is not None: + persistent_checkout = _subdir_persistent_checkout( + downloader, dep_ref, ref, subdir_path, cache_host, cache_owner, cache_repo + ) + + state = _SubdirCloneState() + try: + if persistent_checkout is not None: + # WS3: persistent cache hit -- use the cached checkout directly. + temp_clone_path = persistent_checkout + _subdir_log_persistent_hit( + perf_logger, dep_display, ref, subdir_path, persistent_checkout + ) + elif use_shared: + temp_clone_path = _subdir_shared_bare_materialize( + downloader, + dep_ref, + ref, + subdir_path, + cache_host, + cache_owner, + cache_repo, + shared_cache, + perf_logger, + dep_display, + state, + ) + else: + temp_clone_path = _subdir_legacy_clone( + downloader, dep_ref, ref, subdir_path, progress_task_id, progress_obj, state + ) + + if progress_obj and progress_task_id is not None: + progress_obj.update(progress_task_id, completed=70, total=100) + + resolved_commit = _subdir_extract_to_target( + downloader, temp_clone_path, subdir_path, target_path, state.ws2_resolved_commit + ) + + if progress_obj and progress_task_id is not None: + progress_obj.update(progress_task_id, completed=90, total=100) + except PermissionError as exc: + _subdir_reraise_access_error(exc, state.temp_dir) + raise + except OSError as exc: + if getattr(exc, "errno", None) == 13 or getattr(exc, "winerror", None) == 5: + _subdir_reraise_access_error(exc, state.temp_dir) + raise + finally: + if state.temp_dir: + _gh._rmtree(state.temp_dir) + + return _subdir_build_package_info( + target_path, ref, resolved_commit, dep_ref, progress_task_id, progress_obj + ) + + +def _subdir_persistent_checkout( + downloader, dep_ref, ref, subdir_path, cache_host, cache_owner, cache_repo +) -> Path | None: + """WS3: resolve a sparse-keyed checkout from the persistent cross-run cache.""" + persistent_cache = downloader.persistent_git_cache + canonical_url = f"https://{cache_host}/{cache_owner}/{cache_repo}" + try: + # Tiered ref resolution (#1433): resolve the ref BEFORE get_checkout so + # the cache skips its internal ls-remote (same pattern as the non-subdir + # path which passes locked_sha=resolved). + try: + resolved_sha = downloader.resolve_git_reference(dep_ref).resolved_commit + except Exception: + resolved_sha = None + # Sparse-cone (#1433): keying the persistent shard by (sha, subdir) + # ensures the cached working tree is the subdir only (<2 MB) instead of + # the full repo. Bare cache is unchanged so variants share object data. + return persistent_cache.get_checkout( + canonical_url, + resolved_sha or ref, + locked_sha=resolved_sha, + env=downloader._git_env_dict(), + sparse_paths=[subdir_path], + ) + except Exception: + # Cache miss or failure -- fall through to normal clone path. + return None + + +def _subdir_log_persistent_hit(perf_logger, dep_display, ref, subdir_path, checkout) -> None: + """Emit the verbose [perf] lines for a persistent-cache hit.""" + from apm_cli.deps import github_downloader as _gh + + if perf_logger is None: + return + sha_short = (ref or "")[:12] if ref and re.match(r"^[a-f0-9]{7,40}$", ref) else "" + perf_logger.subdir_download_start( + dep_display, + cache_state="persistent-hit", + sha_short=sha_short, + sparse_paths=[subdir_path], + ) + perf_logger.materialize_result( + sparse_applied=True, + consumer_size_bytes=_gh._dir_size_bytes(checkout), + ) + + +def _subdir_shared_bare_materialize( + downloader, + dep_ref, + ref, + subdir_path, + cache_host, + cache_owner, + cache_repo, + shared_cache, + perf_logger, + dep_display, + state, +) -> Path: + """WS2: share a BARE clone keyed by (host, owner, repo, ref); materialize per consumer. + + The bare is subdir-agnostic, so concurrent consumers requesting different + subdirectories of the same repo+ref share one bare without racing on + sparse-checkout. Each consumer materializes its own working tree. + """ + from apm_cli.deps import github_downloader as _gh + + from ..config import get_apm_temp_dir + + is_commit_sha = ref and re.match(r"^[a-f0-9]{7,40}$", ref) is not None + perf_t0_bare = _gh.time.monotonic() + + def _shared_bare_clone_fn(bare_target: Path) -> None: + downloader._bare_clone_with_fallback( + dep_ref.repo_url, + bare_target, + dep_ref=dep_ref, + ref=ref, + is_commit_sha=bool(is_commit_sha), + ) + + def _shared_bare_fetch_fn(existing_bare: Path, ref_or_sha: str) -> bool: + # get_or_clone passes `ref` here; for SHA pins it is the SHA. + return downloader._fetch_sha_into_bare(existing_bare, ref_or_sha, dep_ref=dep_ref) + + try: + shared_bare_path = shared_cache.get_or_clone( + cache_host, + cache_owner, + cache_repo, + ref, + _shared_bare_clone_fn, + fetch_fn=_shared_bare_fetch_fn if is_commit_sha else None, + ) + except Exception as e: + raise RuntimeError(f"Failed to clone repository: {e}") from e + perf_bare_elapsed_ms = int((_gh.time.monotonic() - perf_t0_bare) * 1000) + if perf_logger is not None: + strategy = ( + f"init+fetch --depth=1 origin {ref[:12]}" + if is_commit_sha + else f"--depth=1 --branch {ref or ''}" + ) + perf_logger.subdir_download_start( + dep_display, + cache_state="shared-bare", + sha_short=ref[:12] if is_commit_sha and ref else "", + sparse_paths=[subdir_path], + ) + perf_logger.bare_clone_strategy(strategy, perf_bare_elapsed_ms) + + # Per-consumer materialization. mkdtemp gives a unique path so concurrent + # consumers do not collide. The bare is read-only after this point. + state.temp_dir = _gh.tempfile.mkdtemp(dir=get_apm_temp_dir()) + temp_clone_path = Path(state.temp_dir) / "consumer" + try: + state.ws2_resolved_commit = downloader._materialize_from_bare( + shared_bare_path, + temp_clone_path, + ref=ref, + env=downloader._git_env_dict(), + # Only short-circuit SHA resolution for a full 40-char SHA; + # abbreviated SHAs must be resolved against the bare so + # resolved_commit matches head.commit.hexsha (#1135). + known_sha=ref if (is_commit_sha and len(ref) == 40) else None, + # Sparse-cone (#1433): materialize ONLY the subdirectory we need. + sparse_paths=[subdir_path], + ) + except Exception as e: + raise RuntimeError(f"Failed to prepare dependency from cached clone: {e}") from e + if perf_logger is not None: + perf_logger.materialize_result( + sparse_applied=True, + consumer_size_bytes=_gh._dir_size_bytes(temp_clone_path), + ) + return temp_clone_path + + +def _subdir_legacy_clone( + downloader, dep_ref, ref, subdir_path, progress_task_id, progress_obj, state +) -> Path: + """Legacy per-dep clone path (no shared cache): sparse-checkout then full clone.""" + from apm_cli.deps import github_downloader as _gh + + from ..config import get_apm_temp_dir + + state.temp_dir = _gh.tempfile.mkdtemp(dir=get_apm_temp_dir()) + # Sparse checkout always targets "repo/". If it fails we clone into + # "repo_clone/" so we never have to rmtree a directory that may still have + # live git handles from the failed subprocess. + sparse_clone_path = Path(state.temp_dir) / "repo" + temp_clone_path = sparse_clone_path + + if progress_obj and progress_task_id is not None: + progress_obj.update(progress_task_id, completed=20, total=100) + + # Phase 4 (#171): Try sparse-checkout first (git 2.25+), fall back to full clone. + sparse_ok = downloader._try_sparse_checkout(dep_ref, sparse_clone_path, subdir_path, ref) + if sparse_ok: + return temp_clone_path + + # Full clone into a fresh subdirectory so we don't have to touch the + # (possibly locked) sparse-checkout directory at all. + temp_clone_path = Path(state.temp_dir) / "repo_clone" + + package_display_name = subdir_path.split("/")[-1] + progress_reporter = ( + _gh.GitProgressReporter(progress_task_id, progress_obj, package_display_name) + if progress_task_id and progress_obj + else None + ) + + # Detect if ref is a commit SHA (can't be used with --branch in shallow clones). + is_commit_sha = ref and re.match(r"^[a-f0-9]{7,40}$", ref) is not None + + clone_kwargs = {"dep_ref": dep_ref} + if is_commit_sha: + # For commit SHAs, clone without checkout then checkout the specific commit. + clone_kwargs["no_checkout"] = True + else: + clone_kwargs["depth"] = 1 + if ref: + clone_kwargs["branch"] = ref + + try: + downloader._clone_with_fallback( + dep_ref.repo_url, + temp_clone_path, + progress_reporter=progress_reporter, + **clone_kwargs, + ) + except Exception as e: + raise RuntimeError(f"Failed to clone repository: {e}") from e + + if is_commit_sha: + repo_obj = None + try: + repo_obj = _gh.Repo(temp_clone_path) + repo_obj.git.checkout(ref) + except Exception as e: + raise RuntimeError(f"Failed to checkout commit {ref}: {e}") from e + finally: + _gh._close_repo(repo_obj) + + if progress_reporter: + progress_reporter.disabled = True + return temp_clone_path + + +def _subdir_extract_to_target( + downloader, temp_clone_path, subdir_path, target_path, ws2_resolved_commit +) -> str: + """Copy the subdirectory into target_path and resolve the commit SHA.""" + from apm_cli.deps import github_downloader as _gh + + from ..utils.file_ops import robust_copy2, robust_copytree + from ..utils.path_security import ensure_path_within + + source_subdir = temp_clone_path / subdir_path + # Security: ensure subdirectory resolves within the cloned repo. + ensure_path_within(source_subdir, temp_clone_path) + if not source_subdir.exists(): + raise RuntimeError(f"Subdirectory '{subdir_path}' not found in repository") + if not source_subdir.is_dir(): + raise RuntimeError(f"Path '{subdir_path}' is not a directory") + + target_path.mkdir(parents=True, exist_ok=True) + + # If target exists and has content, remove it. + if target_path.exists() and any(target_path.iterdir()): + _gh._rmtree(target_path) + target_path.mkdir(parents=True, exist_ok=True) + + for item in source_subdir.iterdir(): + src = source_subdir / item.name + dst = target_path / item.name + if src.is_dir(): + robust_copytree(src, dst) + else: + robust_copy2(src, dst) + + # Capture commit SHA; close the Repo immediately so its handles are released + # before _rmtree runs. The WS2 path already resolved the SHA from the bare + # (avoids opening Repo on the consumer dir, which leaks a Windows handle). + if ws2_resolved_commit is not None: + return ws2_resolved_commit + repo = None + try: + repo = _gh.Repo(temp_clone_path) + return repo.head.commit.hexsha + except Exception: + return "unknown" + finally: + _gh._close_repo(repo) + + +def _subdir_reraise_access_error(exc, temp_dir) -> None: + """Translate a temp-dir permission error into an actionable RuntimeError.""" + exc_path = getattr(exc, "filename", None) + # If temp_dir wasn't created (mkdtemp failed) or the error is within the + # temp tree, this is likely a restricted temp directory issue. + if temp_dir is None or (exc_path and str(exc_path).startswith(str(temp_dir))): + raise RuntimeError( + "Access denied in temporary directory" + + (f" '{temp_dir}'" if temp_dir else "") + + ". Corporate security may restrict this path. " + "Fix: apm config set temp-dir " + ) from None + + +def _subdir_build_package_info( + target_path, ref, resolved_commit, dep_ref, progress_task_id, progress_obj +) -> PackageInfo: + """Validate the extracted package, stamp version, and build PackageInfo.""" + from apm_cli.deps import github_downloader as _gh + + from .package_validator import stamp_plugin_version + + validation_result = _gh.validate_apm_package(target_path) + if not validation_result.is_valid: + error_msgs = "; ".join(validation_result.errors) + raise RuntimeError(f"Subdirectory is not a valid APM package or Claude Skill: {error_msgs}") + + resolved_ref = ResolvedReference( + original_ref=ref or "default", + ref_name=ref or "default", + ref_type=GitReferenceType.BRANCH, + resolved_commit=resolved_commit, + ) + + # For plugins without an explicit version, stamp with the short commit SHA. + package = validation_result.package + stamp_plugin_version( + package, + validation_result.package_type, + resolved_commit, + target_path, + ) + + if progress_obj and progress_task_id is not None: + progress_obj.update(progress_task_id, completed=100, total=100) + + return PackageInfo( + package=package, + install_path=target_path, + resolved_reference=resolved_ref, + installed_at=datetime.now().isoformat(), + dependency_ref=dep_ref, + package_type=validation_result.package_type, + ) From 12ea829bcfd05cd25c260da5ad281139acfdf6c2 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 06:24:33 +0200 Subject: [PATCH 16/21] refactor(commands): split commands+bundle under 800-line guardrail (#1078) Strangler Stage 2, Commit 4 of 8. Drive the commands and bundle subsystems under the 800-line file guardrail and clear their second-tier complexity offenders, preserving behaviour and the monkeypatch/import surface via re-export shims and delegating wrappers. Length splits (all resulting modules < 800): - marketplace/__init__.py 1601 -> 535 (+_table_ops 350, _registry_cmds 543, _publish_ops 261, _search_cmd 119) - commands/audit.py 1161 -> 758 (+_audit_ops 372) - commands/deps/cli.py 925 -> 673 (+_cli_ops 273) - commands/pack.py 808 -> 694 (+_pack_ops 168) - compile/cli.py 1002 -> 671 (+_run_ops 417) Complexity-only fixes (genuine simplification, no suppression): - compile _run_compilation: 13 args -> 5 via frozen CompilationRunConfig parameter object constructed at the call site - marketplace doctor run_doctor: extracted 8 cohesive check helpers - uninstall, outdated, find, plugin_exporter, local_bundle: merged early-return guards and extracted cohesive helpers to clear PLR0911/PLR0912/PLR0915/C901 File-length backlog 24 -> 19. Complexity gate clean at final Stage-2 thresholds; thresholds themselves flip in the final enforcement commit. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/bundle/local_bundle.py | 66 +- src/apm_cli/bundle/plugin_exporter.py | 55 +- src/apm_cli/commands/_audit_ops.py | 372 ++++++ src/apm_cli/commands/_pack_ops.py | 168 +++ src/apm_cli/commands/audit.py | 421 +----- src/apm_cli/commands/compile/_run_ops.py | 417 ++++++ src/apm_cli/commands/compile/cli.py | 447 +------ src/apm_cli/commands/deps/_cli_ops.py | 273 ++++ src/apm_cli/commands/deps/cli.py | 256 +--- src/apm_cli/commands/find.py | 20 +- src/apm_cli/commands/marketplace/__init__.py | 1180 +---------------- .../commands/marketplace/_publish_ops.py | 261 ++++ .../commands/marketplace/_registry_cmds.py | 543 ++++++++ .../commands/marketplace/_search_cmd.py | 119 ++ .../commands/marketplace/_table_ops.py | 350 +++++ src/apm_cli/commands/marketplace/doctor.py | 260 ++-- src/apm_cli/commands/outdated.py | 185 ++- src/apm_cli/commands/pack.py | 140 +- src/apm_cli/commands/uninstall/cli.py | 99 +- 19 files changed, 2995 insertions(+), 2637 deletions(-) create mode 100644 src/apm_cli/commands/_audit_ops.py create mode 100644 src/apm_cli/commands/_pack_ops.py create mode 100644 src/apm_cli/commands/compile/_run_ops.py create mode 100644 src/apm_cli/commands/deps/_cli_ops.py create mode 100644 src/apm_cli/commands/marketplace/_publish_ops.py create mode 100644 src/apm_cli/commands/marketplace/_registry_cmds.py create mode 100644 src/apm_cli/commands/marketplace/_search_cmd.py create mode 100644 src/apm_cli/commands/marketplace/_table_ops.py diff --git a/src/apm_cli/bundle/local_bundle.py b/src/apm_cli/bundle/local_bundle.py index 04acb0498..c157fb48d 100644 --- a/src/apm_cli/bundle/local_bundle.py +++ b/src/apm_cli/bundle/local_bundle.py @@ -197,6 +197,37 @@ def _find_extracted_root(extract_dir: Path) -> Path | None: return None +def _extract_archive_safely(path: Path, temp_dir: Path) -> bool: + """Extract a tarball to *temp_dir* after validating all members. + + Returns ``True`` on success, ``False`` when any member is unsafe or + when extraction itself fails. + """ + try: + with tarfile.open(path, "r:gz") as tar: + for member in tar.getmembers(): + if member.issym() or member.islnk(): + return False + name = member.name + if ( + name.startswith("/") + or PureWindowsPath(name).drive + or PureWindowsPath(name).is_absolute() + ): + return False + try: + validate_path_segments(name, context="tar member") + except PathTraversalError: + return False + if sys.version_info >= (3, 12): + tar.extractall(temp_dir, filter="data") + else: + tar.extractall(temp_dir) # noqa: S202 -- validated above + except (tarfile.TarError, OSError): + return False + return True + + def detect_local_bundle(path: Path) -> LocalBundleInfo | None: """Probe *path*; return :class:`LocalBundleInfo` or ``None``. @@ -220,40 +251,7 @@ def detect_local_bundle(path: Path) -> LocalBundleInfo | None: if path.is_file() and _looks_like_archive(path): temp_dir = Path(tempfile.mkdtemp(prefix="apm-local-bundle-")) - try: - with tarfile.open(path, "r:gz") as tar: - # Reject member symlinks/hardlinks and absolute / parent paths - # for safety (analogous to the pack-side filter). Using - # ``validate_path_segments`` normalises backslashes and - # percent-decoding, and ``PureWindowsPath`` catches drive-letter - # absolute forms (e.g. ``C:/foo``) that ``startswith('/')`` misses. - for member in tar.getmembers(): - if member.issym() or member.islnk(): - shutil.rmtree(temp_dir, ignore_errors=True) - return None - name = member.name - if ( - name.startswith("/") - or PureWindowsPath(name).drive - or PureWindowsPath(name).is_absolute() - ): - shutil.rmtree(temp_dir, ignore_errors=True) - return None - try: - validate_path_segments(name, context="tar member") - except PathTraversalError: - shutil.rmtree(temp_dir, ignore_errors=True) - return None - # tarfile.extractall(filter="data") requires Python 3.12+. - # The repo declares requires-python = ">=3.10", so on 3.10/3.11 - # we extract without the filter. The pre-extraction validation - # above is the primary gate (rejects symlinks, absolute paths, - # and any '..' segment), not filter="data". - if sys.version_info >= (3, 12): - tar.extractall(temp_dir, filter="data") - else: - tar.extractall(temp_dir) # noqa: S202 -- validated above - except (tarfile.TarError, OSError): + if not _extract_archive_safely(path, temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) return None bundle_root = _find_extracted_root(temp_dir) diff --git a/src/apm_cli/bundle/plugin_exporter.py b/src/apm_cli/bundle/plugin_exporter.py index 4e0585502..7899e4ef8 100644 --- a/src/apm_cli/bundle/plugin_exporter.py +++ b/src/apm_cli/bundle/plugin_exporter.py @@ -397,6 +397,35 @@ def _dep_install_path(dep: LockedDependency, apm_modules_dir: Path) -> Path: return dep_ref.get_install_path(apm_modules_dir) +def _scan_bundle_source_files(file_map: dict, logger) -> None: + """Warn (never block) when source files contain hidden characters.""" + from ..security.gate import WARN_POLICY, SecurityGate + + total = 0 + for _rel, (src, _owner) in file_map.items(): + if src.is_symlink(): + continue + if src.is_dir(): + verdict = SecurityGate.scan_files(src, policy=WARN_POLICY) + total += len(verdict.all_findings) + elif src.is_file(): + try: + text = src.read_text(encoding="utf-8", errors="replace") + except OSError: + continue + verdict = SecurityGate.scan_text(text, str(src), policy=WARN_POLICY) + total += len(verdict.all_findings) + if total: + msg = ( + f"Bundle contains {total} hidden character(s) across " + f"source files — run 'apm audit' to inspect before publishing" + ) + if logger: + logger.warning(msg) + else: + _rich_warning(msg) + + # --------------------------------------------------------------------------- # Main exporter # --------------------------------------------------------------------------- @@ -548,31 +577,7 @@ def export_plugin_bundle( return PackResult(bundle_path=bundle_dir, files=output_files) # 10. Security scan (warn-only, never blocks) - from ..security.gate import WARN_POLICY, SecurityGate - - scan_findings_total = 0 - for _rel, (src, _owner) in file_map.items(): - if src.is_symlink(): - continue - if src.is_dir(): - verdict = SecurityGate.scan_files(src, policy=WARN_POLICY) - scan_findings_total += len(verdict.all_findings) - elif src.is_file(): - try: - text = src.read_text(encoding="utf-8", errors="replace") - except OSError: - continue - verdict = SecurityGate.scan_text(text, str(src), policy=WARN_POLICY) - scan_findings_total += len(verdict.all_findings) - if scan_findings_total: - _warn_msg = ( - f"Bundle contains {scan_findings_total} hidden character(s) across " - f"source files — run 'apm audit' to inspect before publishing" - ) - if logger: - logger.warning(_warn_msg) - else: - _rich_warning(_warn_msg) + _scan_bundle_source_files(file_map, logger) # 11. Write files to output directory (clean slate to prevent symlink attacks) if bundle_dir.exists(): diff --git a/src/apm_cli/commands/_audit_ops.py b/src/apm_cli/commands/_audit_ops.py new file mode 100644 index 000000000..eb5342465 --- /dev/null +++ b/src/apm_cli/commands/_audit_ops.py @@ -0,0 +1,372 @@ +"""Audit command heavy-lifting extracted to keep audit.py under 800 lines. + +All patched globals (ContentScanner, get_lockfile_path, scan_lockfile_packages, +_has_actionable_findings, _render_summary, _render_findings_table, _preview_strip, +_apply_strip, _resolve_external_options, _run_external_scanners, _scan_single_file) +are accessed through the original ``audit`` module at call-time so that test +monkey-patches on ``apm_cli.commands.audit.*`` take effect normally. + +No module-level import of ``audit`` here to avoid circular imports; each +function does a function-level ``from apm_cli.commands import audit as _a``. +""" + +import sys +from pathlib import Path + +import click + +from ..deps.lockfile import LockFile +from ..utils.console import STATUS_SYMBOLS + +# --------------------------------------------------------------------------- +# _audit_ci_gate +# --------------------------------------------------------------------------- + + +def _audit_ci_gate( + cfg, + policy_source, + no_cache, + no_policy, + no_fail_fast, + no_drift=False, +): + """Handle ``apm audit --ci`` -- lockfile consistency gate.""" + from apm_cli.commands import audit as _a # route patched globals through original module + + logger = cfg.logger + + from ..policy.ci_checks import _check_drift, run_baseline_checks + from ..policy.policy_checks import run_policy_checks + + fail_fast = not no_fail_fast + + ci_result = run_baseline_checks(cfg.project_root, fail_fast=fail_fast, ci_mode=True) + + from ..policy.discovery import discover_policy, discover_policy_with_chain + from ..policy.project_config import read_project_fetch_failure_default + + fetch_result = None + auto_discovered = False + if policy_source and (not fail_fast or ci_result.passed): + fetch_result = discover_policy( + cfg.project_root, + policy_override=policy_source, + no_cache=no_cache, + ) + elif not policy_source and not no_policy and (not fail_fast or ci_result.passed): + fetch_result = discover_policy_with_chain(cfg.project_root) + auto_discovered = True + + if fetch_result is not None: + fetch_failure_outcomes = ( + "malformed", + "cache_miss_fetch_fail", + "garbage_response", + ) + no_policy_outcomes = ("absent", "no_git_remote", "empty") + + if auto_discovered and fetch_result.outcome == "disabled": + click.echo( + "[i] Org-policy auto-discovery disabled by project apm.yml " + "(policy.discovery_enabled=false); no enforcement applied", + err=True, + ) + fetch_result = None + elif ( + fetch_result.outcome in fetch_failure_outcomes + or fetch_result.error + or (auto_discovered and fetch_result.outcome in no_policy_outcomes) + ): + project_default = read_project_fetch_failure_default(cfg.project_root) + source = fetch_result.source + err_text = fetch_result.error or fetch_result.fetch_error or fetch_result.outcome + cause = _a._audit_outcome_cause(fetch_result.outcome, source, err_text) + if project_default == "block": + click.echo( + f"[x] {cause} (policy.fetch_failure_default=block)", + err=True, + ) + sys.exit(1) + else: + click.echo( + f"[!] {cause}; enforcement skipped " + "(set policy.fetch_failure_default=block in apm.yml to fail closed)", + err=True, + ) + fetch_result = None + + if fetch_result is not None and fetch_result.found: + policy_obj = fetch_result.policy + + if policy_obj.enforcement == "off": + pass # Policy checks disabled + else: + from ..policy.models import CheckResult + + policy_result = run_policy_checks(cfg.project_root, policy_obj, fail_fast=fail_fast) + if policy_obj.enforcement == "block": + ci_result.checks.extend(policy_result.checks) + else: + for check in policy_result.checks: + ci_result.checks.append( + CheckResult( + name=check.name, + passed=True, + message=check.message + + (" (enforcement: warn)" if not check.passed else ""), + details=check.details, + ) + ) + + drift_findings: list = [] + if not no_drift and (cfg.project_root / "apm.yml").exists(): + lockfile_path = _a.get_lockfile_path(cfg.project_root) + if lockfile_path.exists(): + lockfile = LockFile.read(lockfile_path) + if lockfile is not None: + drift_check, drift_findings = _check_drift( + cfg.project_root, + lockfile, + cache_only=True, + verbose=cfg.verbose, + ) + ci_result.checks.append(drift_check) + elif no_drift and cfg.output_format == "text": + click.echo( + f"{STATUS_SYMBOLS['warning']} drift detection skipped (--no-drift); " + "coverage reduced -- hand-edits and missing integrations will not be caught", + err=True, + ) + + effective_format = cfg.output_format + if cfg.output_path and effective_format == "text": + from ..security.audit_report import detect_format_from_extension + + effective_format = detect_format_from_extension(Path(cfg.output_path)) + + if effective_format in ("json", "sarif"): + import json as _json + + from ..install.drift import render_drift_json, render_drift_sarif + + if effective_format == "sarif": + payload = ci_result.to_sarif() + if drift_findings: + payload["runs"][0]["results"].extend(render_drift_sarif(drift_findings)) + else: + payload = ci_result.to_json() + if drift_findings or not no_drift: + payload["drift"] = render_drift_json(drift_findings) + + output = _json.dumps(payload, indent=2) + if cfg.output_path: + Path(cfg.output_path).parent.mkdir(parents=True, exist_ok=True) + Path(cfg.output_path).write_text(output, encoding="utf-8") + logger.success(f"CI audit report written to {cfg.output_path}") + else: + click.echo(output) + else: + _a._render_ci_results(ci_result) + if drift_findings: + from ..install.drift import render_drift_text + + click.echo("") + click.echo(render_drift_text(drift_findings, verbose=cfg.verbose)) + + sys.exit(0 if ci_result.passed else 1) + + +# --------------------------------------------------------------------------- +# _audit_content_scan +# --------------------------------------------------------------------------- + + +def _audit_content_scan( + cfg, + package, + file_path, + strip, + dry_run, + no_drift=False, + external=(), + external_sarif=None, + external_llm=None, + external_args=None, +): + """Handle default ``apm audit`` -- content integrity scanning.""" + from apm_cli.commands import audit as _a # route patched globals through original module + + logger = cfg.logger + project_root = cfg.project_root + + effective_format = cfg.output_format + if cfg.output_path and effective_format == "text": + from ..security.audit_report import detect_format_from_extension + + effective_format = detect_format_from_extension(Path(cfg.output_path)) + + if effective_format != "text" and (strip or dry_run): + logger.error(f"--format {effective_format} cannot be combined with --strip or --dry-run") + sys.exit(1) + + if file_path: + findings_by_file, files_scanned = _a._scan_single_file(Path(file_path), logger) + scan_paths = [Path(file_path)] + else: + scan_paths = [project_root] + lockfile_path = _a.get_lockfile_path(project_root) + if not lockfile_path.exists(): + if not external: + logger.progress( + "No apm.lock.yaml found -- nothing to scan. Use --file to scan a specific file." + ) + sys.exit(0) + findings_by_file, files_scanned = {}, 0 + else: + if package: + logger.progress(f"Scanning package: {package}") + else: + logger.start("Scanning all installed packages...") + + findings_by_file, files_scanned = _a.scan_lockfile_packages( + project_root, + package_filter=package, + ) + + if files_scanned == 0 and not external: + if package: + logger.warning( + f"Package '{package}' not found in apm.lock.yaml or has no deployed files" + ) + else: + logger.progress("No deployed files found in apm.lock.yaml") + sys.exit(0) + + if external: + options_by_name = _a._resolve_external_options(external, external_llm, external_args) + external_findings = _a._run_external_scanners( + cfg, external, external_sarif, scan_paths, options_by_name + ) + from ..security.external.runner import merge_findings + + merge_findings(findings_by_file, external_findings) + + if dry_run and not strip: + logger.progress("--dry-run only works with --strip (e.g. apm audit --strip --dry-run)") + + if strip: + if not findings_by_file: + logger.progress("Nothing to clean -- no hidden characters found") + sys.exit(0) + if dry_run: + _a._preview_strip(findings_by_file, logger) + sys.exit(0) + modified = _a._apply_strip(findings_by_file, project_root, logger) + if modified > 0: + logger.success(f"Cleaned {modified} file(s)") + else: + logger.progress("Nothing to clean -- no strippable characters found") + sys.exit(0) + + drift_findings: list = [] + drift_failed = False + if ( + not no_drift + and not strip + and not file_path + and not package + and (project_root / "apm.yml").exists() + ): + from ..policy.ci_checks import DRIFT_SKIP_PREFIX, _check_drift + + lockfile_path = _a.get_lockfile_path(project_root) + if lockfile_path.exists(): + lockfile = LockFile.read(lockfile_path) + if lockfile is not None: + drift_check, drift_findings = _check_drift( + project_root, + lockfile, + cache_only=True, + verbose=cfg.verbose, + ) + drift_failed = not drift_check.passed + if drift_failed and not drift_findings: + click.echo( + f"{STATUS_SYMBOLS['warning']} drift check could not run: " + f"{drift_check.message}", + err=True, + ) + elif ( + drift_check.passed + and not drift_findings + and drift_check.message.startswith(DRIFT_SKIP_PREFIX) + ): + click.echo( + f"{STATUS_SYMBOLS['warning']} {drift_check.message}", + err=True, + ) + elif no_drift and cfg.output_format == "text": + click.echo( + f"{STATUS_SYMBOLS['warning']} drift detection skipped (--no-drift); " + "coverage reduced -- hand-edits and missing integrations will not be caught", + err=True, + ) + + if not findings_by_file or not _a._has_actionable_findings(findings_by_file): + exit_code = 0 + else: + all_findings = [f for ff in findings_by_file.values() for f in ff] + exit_code = 1 if _a.ContentScanner.has_critical(all_findings) else 2 + + _ = drift_failed # retained for symmetry; gate path lives in --ci. + + if effective_format == "text": + if cfg.output_path: + logger.error( + "Text format does not support --output. " + "Use --format json, sarif, or markdown to write to a file." + ) + sys.exit(1) + if findings_by_file: + _a._render_findings_table(findings_by_file, verbose=cfg.verbose) + _a._render_summary(findings_by_file, files_scanned, logger) + if drift_findings: + from ..install.drift import render_drift_text + + click.echo("") + click.echo(render_drift_text(drift_findings, verbose=cfg.verbose)) + elif effective_format == "markdown": + from ..security.audit_report import findings_to_markdown + + md_report = findings_to_markdown(findings_by_file, files_scanned=files_scanned) + if cfg.output_path: + Path(cfg.output_path).parent.mkdir(parents=True, exist_ok=True) + Path(cfg.output_path).write_text(md_report, encoding="utf-8") + logger.success(f"Audit report written to {cfg.output_path}") + else: + click.echo(md_report) + else: + from ..security.audit_report import ( + findings_to_json, + findings_to_sarif, + serialize_report, + write_report, + ) + + if effective_format == "sarif": + report = findings_to_sarif(findings_by_file, files_scanned=files_scanned) + else: + report = findings_to_json( + findings_by_file, + files_scanned=files_scanned, + exit_code=exit_code, + ) + + if cfg.output_path: + write_report(report, Path(cfg.output_path)) + logger.success(f"Audit report written to {cfg.output_path}") + else: + click.echo(serialize_report(report)) + + sys.exit(exit_code) diff --git a/src/apm_cli/commands/_pack_ops.py b/src/apm_cli/commands/_pack_ops.py new file mode 100644 index 000000000..e5a38a92a --- /dev/null +++ b/src/apm_cli/commands/_pack_ops.py @@ -0,0 +1,168 @@ +"""Release-gate logic extracted from pack_cmd to reduce complexity. + +``_run_release_gates`` handles --check-versions and --check-clean. +``_emit_drift_recipe`` is the recovery-recipe printer used by drift reporting. + +These functions do not reference any names patched by tests on +``apm_cli.commands.pack`` (BuildOrchestrator is used only in the caller), +so no late-import routing is needed here. +""" + +from __future__ import annotations + +from pathlib import Path + + +def _emit_drift_recipe(logger, out_path: str) -> None: + """Emit the canonical recovery recipe when marketplace.json drift is detected. + + Teaches producers the amend+force-with-lease pattern so they can fix the + drift without a noisy follow-up commit. + """ + logger.info("") + logger.info(" To recover cleanly (fold into the current commit):") + logger.info("") + logger.info(" apm pack # regenerate locally") + logger.info(f" git add -- {out_path}") + logger.info(" git commit --amend --no-edit # fold into the current commit") + logger.info(" git push --force-with-lease # safe re-push") + logger.info("") + logger.info(" Or as a follow-up commit:") + logger.info("") + logger.info(f" apm pack && git add -- {out_path}") + logger.info(" git commit -m 'chore(marketplace): regen'") + logger.info("") + logger.info(" Why this exists: marketplace.json is checked in (lockfile pattern)") + logger.info(" so consumers can resolve packages without running 'apm pack'. CI") + logger.info(" enforces that the checked-in copy matches the apm.yml source of truth.") + + +def _run_release_gates( + ctx, + options, + check_versions: bool, + check_clean: bool, + json_output: bool, + logger, + project_root: Path, +) -> tuple[bool, bool, dict | None, dict | None, list]: + """Run --check-versions and --check-clean release gates. + + Returns ``(version_gate_failed, drift_gate_failed, + version_alignment_payload, drift_payload, gate_errors)``. + + When the marketplace config is absent both gates are skipped with an + info message and the function returns all-clean values. + """ + from ..marketplace.builder import BuildOptions as MktBuildOptions + from ..marketplace.builder import MarketplaceBuilder + from ..marketplace.drift_check import check_marketplace_drift, render_diff_lines + from ..marketplace.migration import ConfigSource, detect_config_source + from ..marketplace.version_check import check_version_alignment + from ..marketplace.yml_schema import MarketplaceYmlError + + # Inline helper to keep this function self-contained + from .pack import _emit_json_error_or_raise + + version_alignment_payload: dict | None = None + drift_payload: dict | None = None + gate_errors: list[dict] = [] + version_gate_failed = False + drift_gate_failed = False + + gate_config = None + try: + source = detect_config_source(project_root) + if source != ConfigSource.NONE: + from ..marketplace.migration import load_marketplace_config + + gate_config = load_marketplace_config(project_root) + except MarketplaceYmlError as exc: + _emit_json_error_or_raise(ctx, json_output, "build_error", str(exc)) + return (False, False, None, None, []) + + if gate_config is None: + if check_versions: + logger.info("Version alignment check skipped: no marketplace block; nothing to check.") + if check_clean: + logger.info("Marketplace drift check skipped: no marketplace block; nothing to check.") + return (False, False, None, None, []) + + if check_versions: + v_report = check_version_alignment(gate_config, project_root) + version_alignment_payload = v_report.to_json_dict() + if v_report.ok: + if not json_output: + if v_report.expected is not None: + logger.success( + f"Version alignment OK [strategy={v_report.strategy}, " + f"expected={v_report.expected}]" + ) + else: + logger.success(f"Version alignment OK [strategy={v_report.strategy}]") + for row in v_report.packages: + tag_str = f" -> tag {row.rendered_tag}" if row.rendered_tag else "" + logger.info(f" {row.path} {row.version}{tag_str} [{row.reason}]") + else: + version_gate_failed = True + if not json_output: + if v_report.expected is not None: + logger.error( + f"Version alignment failed [strategy={v_report.strategy}, " + f"expected={v_report.expected}]" + ) + else: + logger.error(f"Version alignment failed [strategy={v_report.strategy}]") + for row in v_report.packages: + tag_str = f" -> tag {row.rendered_tag}" if row.rendered_tag else "" + version_str = row.version if row.version is not None else "" + logger.info(f" {row.path} {version_str}{tag_str} [{row.reason}]") + for msg in v_report.error_messages(): + gate_errors.append({"code": "version_misaligned", "message": msg}) + + if check_clean: + mkt_opts = MktBuildOptions( + dry_run=True, + offline=options.marketplace_offline, + include_prerelease=options.marketplace_include_prerelease, + ) + drift_builder = MarketplaceBuilder.from_config( + gate_config, project_root=project_root, options=mkt_opts + ) + d_report = check_marketplace_drift(drift_builder, gate_config, project_root) + drift_payload = d_report.to_json_dict() + if d_report.ok: + if not json_output: + formats = ", ".join(o.format for o in d_report.outputs) + logger.success(f"Marketplace working tree clean [outputs={formats}]") + for out in d_report.outputs: + logger.info(f" {out.path} [unchanged]") + else: + drift_gate_failed = True + if not json_output: + dirty_formats = ", ".join( + o.format for o in d_report.outputs if o.status != "unchanged" + ) + logger.error(f"Marketplace working tree dirty [outputs={dirty_formats}]") + for out in d_report.outputs: + if out.status == "unchanged": + logger.info(f" {out.path} [unchanged]") + elif out.status == "missing": + logger.info(f" {out.path} [missing on disk; would be created]") + _emit_drift_recipe(logger, out.path) + else: + count = len(out.differences) + logger.info(f" {out.path} [drift: {count} differences]") + for line in render_diff_lines(out): + logger.info(line) + _emit_drift_recipe(logger, out.path) + for msg in d_report.error_messages(): + gate_errors.append({"code": "marketplace_drift", "message": msg}) + + return ( + version_gate_failed, + drift_gate_failed, + version_alignment_payload, + drift_payload, + gate_errors, + ) diff --git a/src/apm_cli/commands/audit.py b/src/apm_cli/commands/audit.py index 02f35701f..90022af5e 100644 --- a/src/apm_cli/commands/audit.py +++ b/src/apm_cli/commands/audit.py @@ -19,10 +19,12 @@ import click from ..core.command_logger import CommandLogger -from ..deps.lockfile import LockFile, get_lockfile_path +from ..deps.lockfile import get_lockfile_path # noqa: F401 -- re-exported for test patching from ..policy._help_text import POLICY_SOURCE_FORMS_HELP from ..security.content_scanner import ContentScanner, ScanFinding -from ..security.file_scanner import scan_lockfile_packages +from ..security.file_scanner import ( + scan_lockfile_packages, # noqa: F401 -- re-exported for test patching +) from ..utils.console import ( STATUS_SYMBOLS, _get_console, @@ -30,6 +32,8 @@ _rich_error, _rich_success, ) +from ._audit_ops import _audit_ci_gate as _audit_ci_gate +from ._audit_ops import _audit_content_scan as _audit_content_scan # -- Shared config -------------------------------------------------- @@ -409,197 +413,9 @@ def _render_ci_results(ci_result: "CIAuditResult") -> None: ) -# -- Mode handlers -------------------------------------------------- - - -def _audit_ci_gate( - cfg: _AuditConfig, - policy_source: str | None, - no_cache: bool, - no_policy: bool, - no_fail_fast: bool, - no_drift: bool = False, -) -> None: - """Handle ``apm audit --ci`` -- lockfile consistency gate. - - Runs baseline lockfile checks, drift detection (unless ``--no-drift``), - and (optionally) org-policy checks, then emits a structured report - and exits with 0 (clean) or 1 (violations). - """ - logger = cfg.logger - - from ..policy.ci_checks import _check_drift, run_baseline_checks - from ..policy.policy_checks import run_policy_checks - - fail_fast = not no_fail_fast - - # Always run baseline checks - ci_result = run_baseline_checks(cfg.project_root, fail_fast=fail_fast, ci_mode=True) - - # Resolve policy source: explicit --policy wins; otherwise mirror - # install's auto-discovery (closes #827) so CI catches sideloaded - # files via unmanaged-files checks. --no-policy skips discovery. - from ..policy.discovery import discover_policy, discover_policy_with_chain - from ..policy.project_config import ( - read_project_fetch_failure_default, - ) - - fetch_result = None - auto_discovered = False - if policy_source and (not fail_fast or ci_result.passed): - fetch_result = discover_policy( - cfg.project_root, - policy_override=policy_source, - no_cache=no_cache, - ) - elif not policy_source and not no_policy and (not fail_fast or ci_result.passed): - # Auto-discovery (mirror install path) - fetch_result = discover_policy_with_chain(cfg.project_root) - auto_discovered = True - - if fetch_result is not None: - # Honour project-side fetch_failure_default for outcomes that - # mean "no enforcement applied". Pre-#1159, auto-discovery - # silently swallowed `absent` / `no_git_remote` / `empty` / - # `disabled` -- a fail-open governance bypass. Now those - # outcomes are surfaced explicitly: - # - # * malformed / cache_miss_fetch_fail / garbage_response - # -> existing fetch-failure handling (warn unless block); - # applies to BOTH explicit --policy and auto-discovery. - # * absent / no_git_remote / empty (auto-discovery only) - # -> were silently dropped pre-#1159; now surfaced as - # explicit warnings, and honour `block` for parity with - # install. Explicit --policy keeps the legacy fall- - # through so an opt-in pointer at a baseline file does - # not regress. - # * disabled (auto-discovery only) - # -> emit a forensic `[i]` breadcrumb in --ci mode so - # audit logs explain WHY no policy ran. - fetch_failure_outcomes = ( - "malformed", - "cache_miss_fetch_fail", - "garbage_response", - ) - no_policy_outcomes = ("absent", "no_git_remote", "empty") - - if auto_discovered and fetch_result.outcome == "disabled": - click.echo( - "[i] Org-policy auto-discovery disabled by project apm.yml " - "(policy.discovery_enabled=false); no enforcement applied", - err=True, - ) - fetch_result = None - elif ( - fetch_result.outcome in fetch_failure_outcomes - or fetch_result.error - or (auto_discovered and fetch_result.outcome in no_policy_outcomes) - ): - project_default = read_project_fetch_failure_default(cfg.project_root) - source = fetch_result.source - err_text = fetch_result.error or fetch_result.fetch_error or fetch_result.outcome - cause = _audit_outcome_cause(fetch_result.outcome, source, err_text) - if project_default == "block": - click.echo( - f"[x] {cause} (policy.fetch_failure_default=block)", - err=True, - ) - sys.exit(1) - else: - click.echo( - f"[!] {cause}; enforcement skipped " - "(set policy.fetch_failure_default=block in apm.yml to fail closed)", - err=True, - ) - fetch_result = None - - if fetch_result is not None and fetch_result.found: - policy_obj = fetch_result.policy - - # Respect enforcement level - if policy_obj.enforcement == "off": - pass # Policy checks disabled - else: - from ..policy.models import CheckResult - - policy_result = run_policy_checks(cfg.project_root, policy_obj, fail_fast=fail_fast) - if policy_obj.enforcement == "block": - ci_result.checks.extend(policy_result.checks) - else: - # enforcement == "warn": include results but don't fail - for check in policy_result.checks: - ci_result.checks.append( - CheckResult( - name=check.name, - passed=True, # downgrade to pass - message=check.message - + (" (enforcement: warn)" if not check.passed else ""), - details=check.details, - ) - ) - - # -- Drift detection (default-on per ADR-02) -------------------- - drift_findings: list = [] - if not no_drift and (cfg.project_root / "apm.yml").exists(): - from ..deps.lockfile import LockFile, get_lockfile_path - - lockfile_path = get_lockfile_path(cfg.project_root) - if lockfile_path.exists(): - lockfile = LockFile.read(lockfile_path) - if lockfile is not None: - drift_check, drift_findings = _check_drift( - cfg.project_root, - lockfile, - cache_only=True, - verbose=cfg.verbose, - ) - ci_result.checks.append(drift_check) - elif no_drift and cfg.output_format == "text": - # In structured output (json/sarif), --no-drift is implicit from - # the absence of the drift check entry; no need to pollute output. - click.echo( - f"{STATUS_SYMBOLS['warning']} drift detection skipped (--no-drift); " - "coverage reduced -- hand-edits and missing integrations will not be caught", - err=True, - ) - - # Resolve effective format - effective_format = cfg.output_format - if cfg.output_path and effective_format == "text": - from ..security.audit_report import detect_format_from_extension - - effective_format = detect_format_from_extension(Path(cfg.output_path)) - - if effective_format in ("json", "sarif"): - import json as _json - - from ..install.drift import render_drift_json, render_drift_sarif - - if effective_format == "sarif": - payload = ci_result.to_sarif() - if drift_findings: - payload["runs"][0]["results"].extend(render_drift_sarif(drift_findings)) - else: - payload = ci_result.to_json() - if drift_findings or not no_drift: - payload["drift"] = render_drift_json(drift_findings) - - output = _json.dumps(payload, indent=2) - if cfg.output_path: - Path(cfg.output_path).parent.mkdir(parents=True, exist_ok=True) - Path(cfg.output_path).write_text(output, encoding="utf-8") - logger.success(f"CI audit report written to {cfg.output_path}") - else: - click.echo(output) - else: - _render_ci_results(ci_result) - if drift_findings: - from ..install.drift import render_drift_text - - click.echo("") - click.echo(render_drift_text(drift_findings, verbose=cfg.verbose)) - - sys.exit(0 if ci_result.passed else 1) +# -- Mode handlers (implementations live in _audit_ops.py) ---------- +# Re-exported at import time so ``apm_cli.commands.audit._audit_ci_gate`` +# and ``apm_cli.commands.audit._audit_content_scan`` remain patchable. def _resolve_external_options( @@ -684,225 +500,6 @@ def _run_external_scanners( sys.exit(2) -def _audit_content_scan( - cfg: _AuditConfig, - package: str | None, - file_path: str | None, - strip: bool, - dry_run: bool, - no_drift: bool = False, - external: tuple[str, ...] = (), - external_sarif: str | None = None, - external_llm: bool | None = None, - external_args: str | None = None, -) -> None: - """Handle default ``apm audit`` -- content integrity scanning. - - Scans deployed prompt files (or a single file via ``--file``) for - hidden Unicode characters, optionally stripping them. - """ - logger = cfg.logger - project_root = cfg.project_root - - # Resolve effective format (auto-detect from extension when needed) - effective_format = cfg.output_format - if cfg.output_path and effective_format == "text": - from ..security.audit_report import detect_format_from_extension - - effective_format = detect_format_from_extension(Path(cfg.output_path)) - - # --format json/sarif/markdown is incompatible with --strip / --dry-run - if effective_format != "text" and (strip or dry_run): - logger.error(f"--format {effective_format} cannot be combined with --strip or --dry-run") - sys.exit(1) - - if file_path: - # -- File mode: scan a single arbitrary file -- - findings_by_file, files_scanned = _scan_single_file(Path(file_path), logger) - scan_paths = [Path(file_path)] - else: - scan_paths = [project_root] - # -- Package mode: scan from lockfile -- - lockfile_path = get_lockfile_path(project_root) - if not lockfile_path.exists(): - if not external: - logger.progress( - "No apm.lock.yaml found -- nothing to scan. Use --file to scan a specific file." - ) - sys.exit(0) - # External scanners are an independent source: proceed with an - # empty native result set so their findings still surface. - findings_by_file, files_scanned = {}, 0 - else: - if package: - logger.progress(f"Scanning package: {package}") - else: - logger.start("Scanning all installed packages...") - - findings_by_file, files_scanned = scan_lockfile_packages( - project_root, - package_filter=package, - ) - - if files_scanned == 0 and not external: - if package: - logger.warning( - f"Package '{package}' not found in apm.lock.yaml or has no deployed files" - ) - else: - logger.progress("No deployed files found in apm.lock.yaml") - sys.exit(0) - - # -- External scanners (opt-in, additive) ----------------------- - if external: - options_by_name = _resolve_external_options(external, external_llm, external_args) - external_findings = _run_external_scanners( - cfg, external, external_sarif, scan_paths, options_by_name - ) - from ..security.external.runner import merge_findings - - merge_findings(findings_by_file, external_findings) - - # -- Warn if --dry-run used without --strip -- - if dry_run and not strip: - logger.progress("--dry-run only works with --strip (e.g. apm audit --strip --dry-run)") - - # -- Strip mode -- - if strip: - if not findings_by_file: - logger.progress("Nothing to clean -- no hidden characters found") - sys.exit(0) - if dry_run: - _preview_strip(findings_by_file, logger) - sys.exit(0) - modified = _apply_strip(findings_by_file, project_root, logger) - if modified > 0: - logger.success(f"Cleaned {modified} file(s)") - else: - logger.progress("Nothing to clean -- no strippable characters found") - sys.exit(0) - - # -- Drift detection (default-on per ADR-02) -------------------- - # Drift only applies to whole-project audit (not --file or --strip - # modes; not single-package scoped). Mutex on no_drift+strip/file - # is enforced earlier via UsageError. - drift_findings: list = [] - drift_failed = False - if ( - not no_drift - and not strip - and not file_path - and not package - and (project_root / "apm.yml").exists() - ): - from ..policy.ci_checks import DRIFT_SKIP_PREFIX, _check_drift - - lockfile_path = get_lockfile_path(project_root) - if lockfile_path.exists(): - lockfile = LockFile.read(lockfile_path) - if lockfile is not None: - drift_check, drift_findings = _check_drift( - project_root, - lockfile, - cache_only=True, - verbose=cfg.verbose, - ) - drift_failed = not drift_check.passed - # Bare `apm audit` is advisory: drift_failed does not gate - # the exit code (that lives in --ci). But silence on a - # cache-pin / cache-miss skip or failure is a UX trap: the - # user cannot tell whether drift was clean or whether it was - # never attempted. Surface the reason on stderr whenever the - # drift check produced no findings. - if drift_failed and not drift_findings: - click.echo( - f"{STATUS_SYMBOLS['warning']} drift check could not run: " - f"{drift_check.message}", - err=True, - ) - elif ( - drift_check.passed - and not drift_findings - and drift_check.message.startswith(DRIFT_SKIP_PREFIX) - ): - click.echo( - f"{STATUS_SYMBOLS['warning']} {drift_check.message}", - err=True, - ) - elif no_drift and cfg.output_format == "text": - # In structured output (json/sarif), --no-drift is implicit from - # the absence of the drift check entry; no need to pollute output. - click.echo( - f"{STATUS_SYMBOLS['warning']} drift detection skipped (--no-drift); " - "coverage reduced -- hand-edits and missing integrations will not be caught", - err=True, - ) - - # -- Display findings -- - # Determine exit code first (shared by all formats) - if not findings_by_file or not _has_actionable_findings(findings_by_file): - exit_code = 0 - else: - all_findings = [f for ff in findings_by_file.values() for f in ff] - exit_code = 1 if ContentScanner.has_critical(all_findings) else 2 - - # Note: bare `apm audit` is advisory for drift; drift findings are - # rendered (text/json/sarif) but DO NOT escalate the exit code. Use - # `apm audit --ci` (handled in _audit_ci_gate) to gate on drift. - _ = drift_failed # retained for symmetry; gate path lives in --ci. - - if effective_format == "text": - if cfg.output_path: - logger.error( - "Text format does not support --output. " - "Use --format json, sarif, or markdown to write to a file." - ) - sys.exit(1) - if findings_by_file: - _render_findings_table(findings_by_file, verbose=cfg.verbose) - _render_summary(findings_by_file, files_scanned, logger) - if drift_findings: - from ..install.drift import render_drift_text - - click.echo("") - click.echo(render_drift_text(drift_findings, verbose=cfg.verbose)) - elif effective_format == "markdown": - from ..security.audit_report import findings_to_markdown - - md_report = findings_to_markdown(findings_by_file, files_scanned=files_scanned) - if cfg.output_path: - Path(cfg.output_path).parent.mkdir(parents=True, exist_ok=True) - Path(cfg.output_path).write_text(md_report, encoding="utf-8") - logger.success(f"Audit report written to {cfg.output_path}") - else: - click.echo(md_report) - else: - from ..security.audit_report import ( - findings_to_json, - findings_to_sarif, - serialize_report, - write_report, - ) - - if effective_format == "sarif": - report = findings_to_sarif(findings_by_file, files_scanned=files_scanned) - else: - report = findings_to_json( - findings_by_file, - files_scanned=files_scanned, - exit_code=exit_code, - ) - - if cfg.output_path: - write_report(report, Path(cfg.output_path)) - logger.success(f"Audit report written to {cfg.output_path}") - else: - click.echo(serialize_report(report)) - - # -- Exit code -- - sys.exit(exit_code) - - # -- Command -------------------------------------------------------- diff --git a/src/apm_cli/commands/compile/_run_ops.py b/src/apm_cli/commands/compile/_run_ops.py new file mode 100644 index 000000000..dc1e1f391 --- /dev/null +++ b/src/apm_cli/commands/compile/_run_ops.py @@ -0,0 +1,417 @@ +"""Main compilation flow helpers extracted from ``compile/cli.py``. + +Extracted to keep that module under 800 lines. Contains: +- ``CompilationRunConfig`` -- frozen dataclass grouping compilation options +- ``_run_compilation`` -- main compilation flow (resolves target, compiles, + reports results) + +Rule B (monkeypatch safety): any name that tests patch on the *original* +``cli`` module (``AgentsCompiler``, ``CompilationConfig``, +``_resolve_effective_target``, ``_rich_info``) is loaded via a +function-level late import so patches on ``apm_cli.commands.compile.cli.*`` +still apply. +""" + +from __future__ import annotations + +import dataclasses +import sys +from pathlib import Path + +from ...constants import AGENTS_MD_FILENAME +from ...utils import perf_stats +from ...utils.console import _rich_panel +from .._helpers import ( + _check_orphaned_packages, + _rich_blank_line, +) + +# --------------------------------------------------------------------------- +# Parameter object +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class CompilationRunConfig: + """Compilation options passed to ``_run_compilation``. + + Groups the nine compilation-specific CLI flags so ``_run_compilation`` + only takes five regular arguments instead of thirteen, satisfying + PLR0913 without hiding any parameters from callers. + """ + + target: object # str | list[str] | None + output: str + no_links: bool + chatmode: str | None + with_constitution: bool + single_agents: bool + local_only: bool + clean: bool + no_dedup: bool + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _coerce_provenance_targets(value): + """Coerce a target value to a list of target-name strings.""" + if value is None: + return [] + if isinstance(value, str): + return [t.strip() for t in value.split(",") if t.strip()] + if isinstance(value, list): + return [str(t) for t in value] + if isinstance(value, frozenset): + return sorted(value) + return [] + + +def _build_compile_provenance(target, config_target, effective_target, detection_reason): + """Return ``(provenance_targets, provenance_source)`` for the info line.""" + if detection_reason == "explicit --target flag": + return _coerce_provenance_targets(target), "--target flag" + if detection_reason == "apm.yml target": + return _coerce_provenance_targets(config_target), "apm.yml" + if isinstance(effective_target, frozenset): + return sorted(effective_target), f"auto-detect ({detection_reason})" + if isinstance(effective_target, str): + return [effective_target], f"auto-detect ({detection_reason})" + return [], f"auto-detect ({detection_reason})" + + +def _show_compile_strategy_progress(logger, run_config, config, effective_target, detection_reason): + """Emit target-aware progress messages before compilation starts.""" + from ...core.target_detection import ( + REASON_NO_TARGET_FOLDER, + get_target_description, + should_compile_agents_md, + should_compile_claude_md, + should_compile_gemini_md, + ) + + if config.strategy == "distributed" and not run_config.single_agents: + if isinstance(effective_target, frozenset): + if isinstance(run_config.target, list): + _target_label = f"--target {','.join(run_config.target)}" + elif isinstance(run_config.target, list) or ( + # config_target is not in scope here; re-derive from run_config + False + ): + _target_label = "multi-target" + else: + _target_label = "multi-target" + + _parts = [] + if should_compile_agents_md(effective_target): + _parts.append("AGENTS.md") + if should_compile_claude_md(effective_target): + _parts.append("CLAUDE.md") + if should_compile_gemini_md(effective_target): + _parts.append("GEMINI.md") + logger.progress(f"Compiling for {' + '.join(_parts)} ({_target_label})") + elif ( + isinstance(effective_target, str) + and effective_target == "vscode" + and detection_reason == REASON_NO_TARGET_FOLDER + ): + logger.progress(f"Compiling for AGENTS.md only ({detection_reason})") + logger.progress( + " Create .github/, .claude/, .codex/, .opencode/ or .cursor/ folder" + " for full integration", + symbol="light_bulb", + ) + else: + description = get_target_description(effective_target) + logger.progress(f"Compiling for {description} - {detection_reason}") + + if run_config.dry_run if hasattr(run_config, "dry_run") else False: + logger.dry_run_notice("showing placement without writing files") + else: + logger.progress("Using single-file compilation (legacy mode)", symbol="page") + + +def _check_and_write_output(logger, output_path, final_content): + """Security-scan and write the final compiled content. + + Returns ``True`` if critical security findings were detected. + """ + from ...security.gate import WARN_POLICY, SecurityGate + + has_critical = False + verdict = SecurityGate.scan_text(final_content, str(output_path), policy=WARN_POLICY) + if verdict.has_findings: + actionable = verdict.critical_count + verdict.warning_count + if verdict.has_critical: + has_critical = True + if actionable: + logger.warning( + f"Compiled output contains {actionable} hidden character(s) " + f"-- run 'apm audit --file {output_path}' to inspect" + ) + try: + from ...compilation.output_writer import CompiledOutputWriter + + CompiledOutputWriter().write(output_path, final_content) + except OSError as e: + logger.error(f"Failed to write final AGENTS.md: {e}") + sys.exit(1) + return has_critical + + +def _handle_single_file_success(logger, compiler, config, dry_run, output_str): + """Handle the single-file compilation success path. + + Returns ``True`` if critical security findings were detected. + """ + from apm_cli.commands.compile import cli as _c + + has_critical = False + + intermediate_config = dataclasses.replace(config, dry_run=True, strategy="single-file") + intermediate_result = compiler.compile(intermediate_config) + + if not intermediate_result.success: + return has_critical + + from ...compilation.injector import ConstitutionInjector + + injector = ConstitutionInjector(base_dir=".") + output_path = Path(config.output_path) + final_content, c_status, c_hash = injector.inject( + intermediate_result.content, + with_constitution=config.with_constitution, + output_path=output_path, + ) + + if not dry_run: + if c_status in ("CREATED", "UPDATED", "MISSING"): + has_critical = _check_and_write_output(logger, output_path, final_content) + else: + logger.progress("No changes detected; preserving existing AGENTS.md for idempotency") + + if dry_run: + logger.success( + "Context compilation completed successfully (dry run)", + symbol="check", + ) + else: + logger.success(f"Context compiled successfully to {output_path}") + + stats = intermediate_result.stats + _rich_blank_line() + _c._display_single_file_summary(stats, c_status, c_hash, output_path, dry_run) + + if dry_run: + preview = final_content[:500] + ("..." if len(final_content) > 500 else "") + _rich_panel(preview, title=" Generated Content Preview", style="cyan") + else: + _c._display_next_steps(output_str) + + return has_critical + + +def _handle_distributed_success(logger, result, dry_run): + """Handle the distributed compilation success path. + + Returns ``True`` if critical security findings were detected. + """ + has_critical = result.has_critical_security + + if dry_run: + return has_critical + + _files_written = sum( + int(v or 0) + for k, v in result.stats.items() + if k.endswith(("_files_written", "_files_generated")) + ) + if _files_written > 0: + logger.success("Compilation completed successfully!", symbol="check") + else: + logger.warning( + "Compilation completed but produced no output " + "files. Check that target directories exist " + "(e.g. .github/, .claude/) or set 'target:' " + "in apm.yml / pass --target explicitly." + ) + return has_critical + + +# --------------------------------------------------------------------------- +# Main compilation flow +# --------------------------------------------------------------------------- + + +def _run_compilation( + logger, + dry_run: bool, + verbose: bool, + source_root: Path | None, + run_config: CompilationRunConfig, +) -> None: + """Main compilation flow: target resolution, config, compile, and output. + + Handles both distributed (default) and single-file (``--single-agents``) + strategies, emits the canonical target-provenance line, runs the + compiler, reports results, and hard-fails on critical security findings. + """ + # Late imports for names patched by tests on the original cli module. + from apm_cli.commands.compile import cli as _c + + AgentsCompiler = _c.AgentsCompiler + CompilationConfig = _c.CompilationConfig + _resolve_effective_target = _c._resolve_effective_target + _rich_info = _c._rich_info + + from ...core.target_detection import ResolvedTargets, format_provenance + from ...primitives.discovery import clear_discovery_cache + + logger.start("Starting context compilation...", symbol="cogs") + + _src = source_root or Path(".") + + effective_target, detection_reason, config_target = _resolve_effective_target( + run_config.target, source_root=_src + ) + + # Emit canonical provenance line. + _provenance_targets, _provenance_source = _build_compile_provenance( + run_config.target, config_target, effective_target, detection_reason + ) + if _provenance_targets: + _rich_info( + format_provenance( + ResolvedTargets( + targets=sorted(set(_provenance_targets)), + source=_provenance_source, + auto_create=True, + ) + ), + symbol="info", + ) + + # Build compilation config. + config = CompilationConfig.from_apm_yml( + output_path=(run_config.output if run_config.output != AGENTS_MD_FILENAME else None), + chatmode=run_config.chatmode, + resolve_links=not run_config.no_links if run_config.no_links else None, + dry_run=dry_run, + single_agents=run_config.single_agents, + trace=verbose, + local_only=run_config.local_only, + debug=verbose, + clean_orphaned=run_config.clean, + target=effective_target, + no_dedup=run_config.no_dedup, + ) + config.with_constitution = run_config.with_constitution + + # Show target-aware progress for the chosen strategy. + if config.strategy == "distributed" and not run_config.single_agents: + if isinstance(effective_target, frozenset): + from ...core.target_detection import ( + should_compile_agents_md, + should_compile_claude_md, + should_compile_gemini_md, + ) + + if isinstance(run_config.target, list): + _target_label = f"--target {','.join(run_config.target)}" + elif isinstance(config_target, list): + _target_label = f"apm.yml target: [{', '.join(config_target)}]" + else: + _target_label = "multi-target" + + _parts = [] + if should_compile_agents_md(effective_target): + _parts.append("AGENTS.md") + if should_compile_claude_md(effective_target): + _parts.append("CLAUDE.md") + if should_compile_gemini_md(effective_target): + _parts.append("GEMINI.md") + logger.progress(f"Compiling for {' + '.join(_parts)} ({_target_label})") + elif isinstance(effective_target, str) and effective_target == "vscode": + from ...core.target_detection import REASON_NO_TARGET_FOLDER + + if detection_reason == REASON_NO_TARGET_FOLDER: + logger.progress(f"Compiling for AGENTS.md only ({detection_reason})") + logger.progress( + " Create .github/, .claude/, .codex/, .opencode/ or .cursor/ folder" + " for full integration", + symbol="light_bulb", + ) + else: + from ...core.target_detection import get_target_description + + description = get_target_description(effective_target) + logger.progress(f"Compiling for {description} - {detection_reason}") + else: + from ...core.target_detection import get_target_description + + description = get_target_description(effective_target) + logger.progress(f"Compiling for {description} - {detection_reason}") + + if dry_run: + logger.dry_run_notice("showing placement without writing files") + if verbose: + logger.verbose_detail("Verbose mode: showing source attribution and optimizer analysis") + else: + logger.progress("Using single-file compilation (legacy mode)", symbol="page") + + # Perform compilation. + clear_discovery_cache() + perf_stats.reset() + compiler = AgentsCompiler(".", source_dir=str(_src)) + result = compiler.compile(config, logger=logger) + compile_has_critical = result.has_critical_security + + if result.success: + if config.strategy == "distributed" and not run_config.single_agents: + compile_has_critical = _handle_distributed_success(logger, result, dry_run) + else: + single_critical = _handle_single_file_success( + logger, compiler, config, dry_run, run_config.output + ) + if single_critical: + compile_has_critical = True + + # Display warnings and errors for all modes. + if result.warnings: + logger.warning(f"Compilation completed with {len(result.warnings)} warning(s):") + for warning in result.warnings: + logger.warning(f" {warning}") + + if result.errors: + logger.error(f"Compilation failed with {len(result.errors)} errors:") + for error in result.errors: + logger.error(f" {error}") + sys.exit(1) + + # Check for orphaned packages after successful compilation. + try: + orphaned_packages = _check_orphaned_packages() + if orphaned_packages: + _rich_blank_line() + logger.warning( + f"Found {len(orphaned_packages)} orphaned package(s) that were " + "included in compilation:" + ) + for pkg in orphaned_packages: + logger.progress(f" * {pkg}") + logger.progress(" Run 'apm prune' to remove orphaned packages") + except Exception: + pass # Continue if orphan check fails + + # Hard-fail on critical security findings. + if compile_has_critical: + logger.error( + "Compiled output contains critical hidden characters" + " -- run 'apm audit' to inspect, 'apm audit --strip' to clean" + ) + perf_stats.render_summary(logger, project_root=str(_src)) + sys.exit(1) + + perf_stats.render_summary(logger, project_root=str(_src)) diff --git a/src/apm_cli/commands/compile/cli.py b/src/apm_cli/commands/compile/cli.py index 264dfaea4..9357c75e7 100644 --- a/src/apm_cli/commands/compile/cli.py +++ b/src/apm_cli/commands/compile/cli.py @@ -2,7 +2,6 @@ from __future__ import annotations -import dataclasses import sys from pathlib import Path from typing import TYPE_CHECKING @@ -12,7 +11,10 @@ if TYPE_CHECKING: from ...core.target_detection import CompileTargetType -from ...compilation import AgentsCompiler, CompilationConfig +from ...compilation import ( + AgentsCompiler, + CompilationConfig, # noqa: F401 -- patched by tests +) from ...constants import AGENTS_MD_FILENAME, APM_DIR, APM_MODULES_DIR, APM_YML_FILENAME from ...core.command_logger import CommandLogger from ...core.target_detection import TargetParamType @@ -21,13 +23,12 @@ from ...utils.console import ( _rich_error, _rich_info, - _rich_panel, ) from .._helpers import ( - _check_orphaned_packages, _get_console, - _rich_blank_line, ) +from ._run_ops import CompilationRunConfig as CompilationRunConfig +from ._run_ops import _run_compilation as _run_compilation from .watcher import _watch_mode @@ -169,6 +170,46 @@ def _get_validation_suggestion(error_msg): return "Check primitive structure and frontmatter" +def _resolve_list_target(target_list, KNOWN_TARGETS): + """Resolve a list of targets to a compiler family string or frozenset.""" + target_set = set(target_list) + skip = {name for name, profile in KNOWN_TARGETS.items() if profile.compile_family is None} + target_set -= skip + if not target_set: + for sentinel in target_list: + if sentinel in skip: + return sentinel + return None + + def _family_of(name: str) -> str | None: + if name == "vscode": + return "vscode" + profile = KNOWN_TARGETS.get(name) + return profile.compile_family if profile else None + + families: set[str] = set() + for name in target_set: + family = _family_of(name) + if family is None: + continue + families.add(family) + if family == "vscode": + families.add("agents") + + if len(families) >= 2: + return "vscode" if families == {"vscode", "agents"} else frozenset(families) + if "claude" in families: + return "claude" + if "gemini" in families: + return "gemini" + if "vscode" in families: + return "vscode" + for name, profile in KNOWN_TARGETS.items(): + if profile.compile_family == "agents" and name in target_set: + return name + return "vscode" # defensive fallback (unreachable) + + def _resolve_compile_target(target): """Map CLI target input to a compiler-understood target. @@ -199,64 +240,7 @@ def _resolve_compile_target(target): if target is None: return None # will trigger detect_target() auto-detection if isinstance(target, list): - target_set = set(target) - # Strip targets with no compile output (compile_family is None); - # they would silently fall through the family resolution otherwise. - # ``vscode`` is a CLI alias for ``copilot`` and shares its profile. - skip = {name for name, profile in KNOWN_TARGETS.items() if profile.compile_family is None} - target_set -= skip - if not target_set: - # Solo agent-skills (or another no-compile target) in a list -- - # pass through as a string so the compiler's no-op path fires. - for sentinel in target: - if sentinel in skip: - return sentinel - return None - - # The "vscode" family handles copilot AND emits AGENTS.md as a - # bonus; the "agents" family emits AGENTS.md only. When both - # appear in a multi-target compile we still need both family - # tokens so the agents compiler routes correctly. - def _family_of(name: str) -> str | None: - if name == "vscode": - return "vscode" - profile = KNOWN_TARGETS.get(name) - return profile.compile_family if profile else None - - families: set[str] = set() - for name in target_set: - family = _family_of(name) - if family is None: - continue - families.add(family) - if family == "vscode": - # copilot also emits AGENTS.md; mirror legacy behavior. - families.add("agents") - - if len(families) >= 2: - # Single-target copilot collapses {"vscode","agents"} to bare - # "vscode" for routing parity with single-string -t copilot. - if families == {"vscode", "agents"}: - return "vscode" - return frozenset(families) - if "claude" in families: - return "claude" - if "gemini" in families: - return "gemini" - if "vscode" in families: - return "vscode" - # Bare agents-family target: preserve the original target name so - # single-element list routing matches single-string semantics - # (-t cursor and -t [cursor] both end up as "cursor"). Iterate - # KNOWN_TARGETS in insertion order so priority ties (e.g. - # ["opencode","codex"]) resolve deterministically to the - # earliest-registered target. Adding a new agents-family - # target (e.g. zed, cline) costs zero edits here -- it inherits - # whatever priority position it occupies in the registry. - for name, profile in KNOWN_TARGETS.items(): - if profile.compile_family == "agents" and name in target_set: - return name - return "vscode" # defensive fallback (unreachable) + return _resolve_list_target(target, KNOWN_TARGETS) return target # single string pass-through @@ -456,318 +440,6 @@ def _run_watch_mode( ) -def _run_compilation( - logger: CommandLogger, - target: str | list[str] | None, - output: str, - dry_run: bool, - no_links: bool, - chatmode: str | None, - with_constitution: bool, - single_agents: bool, - verbose: bool, - local_only: bool, - clean: bool, - no_dedup: bool, - source_root: Path | None = None, -) -> None: - """Main compilation flow: target resolution, config, compile, and output. - - Handles both distributed (default) and single-file (``--single-agents``) - strategies, emits the canonical target-provenance line, runs the - compiler, reports results, and hard-fails on critical security findings. - """ - from ...core.target_detection import ( - REASON_NO_TARGET_FOLDER, - ResolvedTargets, - format_provenance, - get_target_description, - ) - - logger.start("Starting context compilation...", symbol="cogs") - - _src = source_root or Path(".") - - # Resolve effective target using the shared helper (mirrors watch-mode path). - effective_target, detection_reason, config_target = _resolve_effective_target( - target, source_root=_src - ) - - # Emit canonical provenance line BEFORE compilation -- mirrors - # `apm install` so users see the same `[i] Targets: ... - # (source: ...)` line on both surfaces. Use the user-facing - # source values (target / config_target) NOT the compiler-family - # expansion in effective_target -- install shows the schema names - # the user wrote (e.g. "copilot"), so compile must too, otherwise - # parity drifts (compile would print "agents, vscode" for the - # same input). - def _coerce_provenance_targets(value): - if value is None: - return [] - if isinstance(value, str): - return [t.strip() for t in value.split(",") if t.strip()] - if isinstance(value, list): - return [str(t) for t in value] - if isinstance(value, frozenset): - return sorted(value) - return [] - - if detection_reason == "explicit --target flag": - _provenance_targets = _coerce_provenance_targets(target) - _provenance_source = "--target flag" - elif detection_reason == "apm.yml target": - _provenance_targets = _coerce_provenance_targets(config_target) - _provenance_source = "apm.yml" - else: - if isinstance(effective_target, frozenset): - _provenance_targets = sorted(effective_target) - elif isinstance(effective_target, str): - _provenance_targets = [effective_target] - else: - _provenance_targets = [] - _provenance_source = f"auto-detect ({detection_reason})" - - if _provenance_targets: - _rich_info( - format_provenance( - ResolvedTargets( - targets=sorted(set(_provenance_targets)), - source=_provenance_source, - auto_create=True, - ) - ), - symbol="info", - ) - - # Build config with distributed compilation flags (Task 7) - config = CompilationConfig.from_apm_yml( - output_path=output if output != AGENTS_MD_FILENAME else None, - chatmode=chatmode, - resolve_links=not no_links if no_links else None, - dry_run=dry_run, - single_agents=single_agents, - trace=verbose, - local_only=local_only, - debug=verbose, - clean_orphaned=clean, - target=effective_target, - no_dedup=no_dedup, - ) - config.with_constitution = with_constitution - - # Show target-aware progress message for the chosen strategy. - if config.strategy == "distributed" and not single_agents: - if isinstance(effective_target, frozenset): - # Multi-target compile (from CLI `--target a,b` OR apm.yml - # `target: [a, b]`): show what the compiler will produce. - if isinstance(target, list): - _target_label = f"--target {','.join(target)}" - elif isinstance(config_target, list): - _target_label = f"apm.yml target: [{', '.join(config_target)}]" - else: - _target_label = "multi-target" - from ...core.target_detection import ( - should_compile_agents_md, - should_compile_claude_md, - should_compile_gemini_md, - ) - - _parts = [] - if should_compile_agents_md(effective_target): - _parts.append("AGENTS.md") - if should_compile_claude_md(effective_target): - _parts.append("CLAUDE.md") - if should_compile_gemini_md(effective_target): - _parts.append("GEMINI.md") - logger.progress(f"Compiling for {' + '.join(_parts)} ({_target_label})") - elif ( - isinstance(effective_target, str) - and effective_target == "vscode" - and detection_reason == REASON_NO_TARGET_FOLDER - ): - logger.progress(f"Compiling for AGENTS.md only ({detection_reason})") - logger.progress( - " Create .github/, .claude/, .codex/, .opencode/ or .cursor/ folder for full integration", - symbol="light_bulb", - ) - else: - description = get_target_description(effective_target) - logger.progress(f"Compiling for {description} - {detection_reason}") - - if dry_run: - logger.dry_run_notice("showing placement without writing files") - if verbose: - logger.verbose_detail("Verbose mode: showing source attribution and optimizer analysis") - else: - logger.progress("Using single-file compilation (legacy mode)", symbol="page") - - # Perform compilation - clear_discovery_cache() - perf_stats.reset() - compiler = AgentsCompiler(".", source_dir=str(_src)) - result = compiler.compile(config, logger=logger) - compile_has_critical = result.has_critical_security - - if result.success: - # Handle different compilation modes - if config.strategy == "distributed" and not single_agents: - # Distributed compilation results - output already shown by professional formatter - # Just show final success message - if dry_run: - # Success message for dry run already included in formatter output - pass - else: - # Defense-in-depth (#820): don't claim "completed - # successfully" when zero files were emitted. With - # parse_target_field as the upstream gatekeeper this is - # unreachable in normal flow, but silent zero-effect - # success is the worst-case package-manager DX. - # - # Pattern-based stat scan (instead of a hardcoded key - # list) so new compile-time targets pick up the guard - # automatically: any stat ending in ``_files_written`` - # or ``_files_generated`` contributes to the total. - _files_written = sum( - int(v or 0) - for k, v in result.stats.items() - if k.endswith(("_files_written", "_files_generated")) - ) - if _files_written > 0: - logger.success( - "Compilation completed successfully!", - symbol="check", - ) - else: - # Zero-output compile is the silent-success failure - # mode #820 guards against. Don't claim success; - # surface what the user can act on. The cause is - # usually one of: target dirs not present (auto- - # detect found nothing), explicit target rejected - # by policy, or no primitives in the project. - logger.warning( - "Compilation completed but produced no output " - "files. Check that target directories exist " - "(e.g. .github/, .claude/) or set 'target:' " - "in apm.yml / pass --target explicitly." - ) - - else: - # Traditional single-file compilation - keep existing logic - # Perform initial compilation in dry-run to get generated body (without constitution) - intermediate_config = dataclasses.replace( - config, - dry_run=True, - strategy="single-file", - ) - intermediate_result = compiler.compile(intermediate_config) - - if intermediate_result.success: - # Perform constitution injection / preservation - from ...compilation.injector import ConstitutionInjector - - injector = ConstitutionInjector(base_dir=".") - output_path = Path(config.output_path) - final_content, c_status, c_hash = injector.inject( - intermediate_result.content, - with_constitution=config.with_constitution, - output_path=output_path, - ) - - if not dry_run: - # Only rewrite when content materially changes (creation, update, missing constitution case) - if c_status in ("CREATED", "UPDATED", "MISSING"): - # Defense-in-depth: scan compiled output before writing - from ...security.gate import WARN_POLICY, SecurityGate - - verdict = SecurityGate.scan_text( - final_content, str(output_path), policy=WARN_POLICY - ) - if verdict.has_findings: - actionable = verdict.critical_count + verdict.warning_count - if verdict.has_critical: - compile_has_critical = True - if actionable: - logger.warning( - f"Compiled output contains {actionable} hidden character(s) " - f"-- run 'apm audit --file {output_path}' to inspect" - ) - try: - from ...compilation.output_writer import CompiledOutputWriter - - CompiledOutputWriter().write(output_path, final_content) - except OSError as e: - logger.error(f"Failed to write final AGENTS.md: {e}") - sys.exit(1) - else: - logger.progress( - "No changes detected; preserving existing AGENTS.md for idempotency" - ) - - # Report success at the top - if dry_run: - logger.success( - "Context compilation completed successfully (dry run)", - symbol="check", - ) - else: - logger.success( - f"Context compiled successfully to {output_path}", - ) - - stats = ( - intermediate_result.stats - ) # timestamp removed; stats remain version + counts - - # Add spacing before summary table - _rich_blank_line() - - _display_single_file_summary(stats, c_status, c_hash, output_path, dry_run) - - if dry_run: - preview = final_content[:500] + ("..." if len(final_content) > 500 else "") - _rich_panel(preview, title=" Generated Content Preview", style="cyan") - else: - _display_next_steps(output) - - # Display warnings for all compilation modes - if result.warnings: - logger.warning(f"Compilation completed with {len(result.warnings)} warning(s):") - for warning in result.warnings: - logger.warning(f" {warning}") - - if result.errors: - logger.error(f"Compilation failed with {len(result.errors)} errors:") - for error in result.errors: - logger.error(f" {error}") - sys.exit(1) - - # Check for orphaned packages after successful compilation - try: - orphaned_packages = _check_orphaned_packages() - if orphaned_packages: - _rich_blank_line() - logger.warning( - f"Found {len(orphaned_packages)} orphaned package(s) that were included in compilation:" - ) - for pkg in orphaned_packages: - logger.progress(f" * {pkg}") - logger.progress(" Run 'apm prune' to remove orphaned packages") - except Exception: - pass # Continue if orphan check fails - - # Hard-fail when critical security findings were detected in compiled - # output. Consistent with apm install and apm unpack behavior. - if compile_has_critical: - logger.error( - "Compiled output contains critical hidden characters" - " -- run 'apm audit' to inspect, 'apm audit --strip' to clean" - ) - perf_stats.render_summary(logger, project_root=str(_src)) - sys.exit(1) - - perf_stats.render_summary(logger, project_root=str(_src)) - - @click.command(help="Compile APM context into distributed AGENTS.md files") @click.option( "--output", @@ -973,21 +645,18 @@ def compile( # noqa: PLR0913 -- Click handler ) return - _run_compilation( - logger, - target, - output, - dry_run, - no_links, - chatmode, - with_constitution, - single_agents, - verbose, - local_only, - clean, - no_dedup, - source_root=source_root, + run_config = CompilationRunConfig( + target=target, + output=output, + no_links=no_links, + chatmode=chatmode, + with_constitution=with_constitution, + single_agents=single_agents, + local_only=local_only, + clean=clean, + no_dedup=no_dedup, ) + _run_compilation(logger, dry_run, verbose, source_root, run_config) except ImportError as e: logger.error(f"Compilation module not available: {e}") diff --git a/src/apm_cli/commands/deps/_cli_ops.py b/src/apm_cli/commands/deps/_cli_ops.py new file mode 100644 index 000000000..bac5fa6e5 --- /dev/null +++ b/src/apm_cli/commands/deps/_cli_ops.py @@ -0,0 +1,273 @@ +"""Heavy-lifting for ``apm deps`` extracted to keep cli.py under 800 lines. + +Patched globals on ``apm_cli.commands.deps.cli`` (APMPackage, Path, +_resolve_scope_deps) are accessed via a function-level late import so that +test monkey-patches on ``apm_cli.commands.deps.cli.*`` take effect normally. + +No module-level import of ``cli`` here to avoid circular imports. +""" + +import sys + +import click + +from ...constants import APM_YML_FILENAME + +# --------------------------------------------------------------------------- +# _show_scope_deps +# --------------------------------------------------------------------------- + + +def _show_scope_deps(scope_label, apm_dir, logger, console, has_rich, insecure_only=False): + """Display dependencies for a single scope (Project or Global).""" + from apm_cli.commands.deps import cli as _cli # route patched _resolve_scope_deps + + installed_packages, orphaned_packages = _cli._resolve_scope_deps(apm_dir, logger, insecure_only) + + if installed_packages is None: + logger.progress(f"No APM dependencies installed ({scope_label} scope)") + logger.verbose_detail("Run 'apm install' to install dependencies from apm.yml") + return + + if not installed_packages: + if insecure_only: + logger.progress(f"No insecure APM dependencies installed ({scope_label} scope)") + else: + logger.progress( + f"apm_modules/ directory exists but contains no valid packages ({scope_label} scope)" + ) + return + + if has_rich: + from rich.table import Table + + table = Table( + title=( + f" Insecure APM Dependencies ({scope_label})" + if insecure_only + else f" APM Dependencies ({scope_label})" + ), + show_header=True, + header_style="bold cyan", + ) + table.add_column("Package", style="bold white") + table.add_column("Version", style="yellow") + table.add_column("Source", style="blue") + if insecure_only: + table.add_column("Origin", style="bold red") + table.add_column("Prompts", style="magenta", justify="center") + table.add_column("Instructions", style="green", justify="center") + table.add_column("Agents", style="cyan", justify="center") + table.add_column("Skills", style="yellow", justify="center") + table.add_column("Hooks", style="red", justify="center") + + for pkg in installed_packages: + p = pkg["primitives"] + table.add_row( + pkg["name"], + pkg["version"], + pkg["source"], + *([pkg["insecure_via"]] if insecure_only else []), + str(p.get("prompts", 0)) if p.get("prompts", 0) > 0 else "-", + str(p.get("instructions", 0)) if p.get("instructions", 0) > 0 else "-", + str(p.get("agents", 0)) if p.get("agents", 0) > 0 else "-", + str(p.get("skills", 0)) if p.get("skills", 0) > 0 else "-", + str(p.get("hooks", 0)) if p.get("hooks", 0) > 0 else "-", + ) + + console.print(table) + + if orphaned_packages: + logger.warning(f"{len(orphaned_packages)} orphaned package(s) found (not in apm.yml):") + for pkg in orphaned_packages: + logger.warning(f" - {pkg}") + logger.info("Run 'apm prune' to remove orphaned packages") + else: + # Fallback text table + if insecure_only: + click.echo(f" Insecure APM Dependencies ({scope_label}):") + click.echo( + f"{'Package':<30} {'Version':<10} {'Source':<12} {'Origin':<18} " + f"{'Prompts':>7} {'Instr':>7} {'Agents':>7} {'Skills':>7} {'Hooks':>7}" + ) + click.echo("-" * 117) + else: + click.echo(f" APM Dependencies ({scope_label}):") + click.echo( + f"{'Package':<30} {'Version':<10} {'Source':<12} {'Prompts':>7} {'Instr':>7} {'Agents':>7} {'Skills':>7} {'Hooks':>7}" + ) + click.echo("-" * 98) + + for pkg in installed_packages: + p = pkg["primitives"] + name = pkg["name"][:28] + version = pkg["version"][:8] + source = pkg["source"][:10] + insecure_via = pkg["insecure_via"][:16] + prompts = str(p.get("prompts", 0)) if p.get("prompts", 0) > 0 else "-" + instructions = str(p.get("instructions", 0)) if p.get("instructions", 0) > 0 else "-" + agents = str(p.get("agents", 0)) if p.get("agents", 0) > 0 else "-" + skills = str(p.get("skills", 0)) if p.get("skills", 0) > 0 else "-" + hooks = str(p.get("hooks", 0)) if p.get("hooks", 0) > 0 else "-" + if insecure_only: + click.echo( + f"{name:<30} {version:<10} {source:<12} {insecure_via:<18} " + f"{prompts:>7} {instructions:>7} {agents:>7} {skills:>7} {hooks:>7}" + ) + else: + click.echo( + f"{name:<30} {version:<10} {source:<12} {prompts:>7} {instructions:>7} {agents:>7} {skills:>7} {hooks:>7}" + ) + + if orphaned_packages: + logger.warning(f"{len(orphaned_packages)} orphaned package(s) found (not in apm.yml):") + for pkg in orphaned_packages: + logger.warning(f" - {pkg}") + logger.info("Run 'apm prune' to remove orphaned packages") + + +# --------------------------------------------------------------------------- +# _update_impl (body of the ``deps update`` Click command) +# --------------------------------------------------------------------------- + + +def _update_impl(packages, verbose, force, target, parallel_downloads, global_, legacy_skill_paths): + """Implementation of ``apm deps update``. + + Kept in a separate function so the Click-decorated ``update`` wrapper in + cli.py stays thin and patchable. Patched names (APMPackage, Path) are + accessed through the original cli module at call-time. + """ + from apm_cli.commands.deps import cli as _cli # route patched APMPackage / Path + + from ...core.auth import AuthResolver + from ...core.command_logger import InstallLogger + from ...utils.console import _rich_warning + from ..install import ( + _APM_IMPORT_ERROR, + APM_DEPS_AVAILABLE, + _install_apm_dependencies, + ) + + _rich_warning( + "'apm deps update' is deprecated; use 'apm update' instead. " + "'apm update' now supports -g/--global, [PACKAGES]..., --force, and " + "--parallel-downloads, plus an interactive plan, --dry-run, and --yes.", + symbol="warning", + ) + + logger = InstallLogger(verbose=verbose, partial=bool(packages)) + + if not APM_DEPS_AVAILABLE: + logger.error("APM dependency system not available") + if _APM_IMPORT_ERROR: + logger.progress(f"Import error: {_APM_IMPORT_ERROR}") + sys.exit(1) + + from ...core.scope import InstallScope, get_apm_dir + + scope = InstallScope.USER if global_ else InstallScope.PROJECT + project_root = get_apm_dir(scope) + apm_yml_path = project_root / APM_YML_FILENAME + + if not apm_yml_path.exists(): + scope_hint = "~/.apm/" if global_ else "current directory" + logger.error(f"No {APM_YML_FILENAME} found in {scope_hint}") + sys.exit(1) + + try: + apm_package = _cli.APMPackage.from_apm_yml(apm_yml_path) + except Exception as e: + logger.error(f"Failed to parse {APM_YML_FILENAME}: {e}") + sys.exit(1) + + all_deps = apm_package.get_apm_dependencies() + apm_package.get_dev_apm_dependencies() + if not all_deps: + logger.progress("No APM dependencies defined in apm.yml") + return + + from .._helpers import UnknownPackageError, resolve_requested_packages + + try: + only_pkgs = resolve_requested_packages(packages, all_deps) + except UnknownPackageError as e: + logger.error(f"Package '{e.token}' not found in {APM_YML_FILENAME}") + logger.progress(f"Available: {', '.join(e.available)}") + sys.exit(1) + + from ...deps.lockfile import LockFile, get_lockfile_path, migrate_lockfile_if_needed + + lockfile_path = get_lockfile_path(project_root) + migrate_lockfile_if_needed(project_root) + + old_lockfile = LockFile.read(lockfile_path) + had_baseline = old_lockfile is not None + old_shas: dict = {} + if old_lockfile: + for key, dep in old_lockfile.dependencies.items(): + old_shas[key] = dep.resolved_commit + + auth_resolver = AuthResolver() + + noun = f"{len(packages)} package(s)" if packages else f"all {len(all_deps)} dependencies" + if not legacy_skill_paths: + from ...integration.targets import should_use_legacy_skill_paths + + legacy_skill_paths = should_use_legacy_skill_paths() + + logger.start(f"Updating {noun}...") + + try: + install_result = _install_apm_dependencies( + apm_package, + update_refs=True, + verbose=verbose, + only_packages=only_pkgs, + force=force, + parallel_downloads=parallel_downloads, + logger=logger, + auth_resolver=auth_resolver, + target=target, + scope=scope, + legacy_skill_paths=legacy_skill_paths, + ) + except Exception as e: + logger.error(f"Update failed: {e}") + if not verbose: + logger.progress("Run with --verbose for detailed diagnostics") + sys.exit(1) + + if install_result.diagnostics and install_result.diagnostics.has_diagnostics: + install_result.diagnostics.render_summary() + + new_lockfile = LockFile.read(lockfile_path) + changed: list = [] + if new_lockfile: + for key, dep in new_lockfile.dependencies.items(): + old_sha = old_shas.get(key) + new_sha = dep.resolved_commit + if old_sha and new_sha and old_sha != new_sha: + changed.append((key, old_sha[:8], new_sha[:8], dep.resolved_ref or "")) + + error_count = 0 + if install_result.diagnostics: + try: + error_count = int(install_result.diagnostics.error_count) + except (TypeError, ValueError): + error_count = 0 + + if changed: + pkg_noun = "package" if len(changed) == 1 else "packages" + if error_count > 0: + logger.warning(f"Updated {len(changed)} {pkg_noun} with {error_count} error(s).") + else: + logger.success(f"Updated {len(changed)} {pkg_noun}:") + for key, old_sha, new_sha, ref in changed: + ref_str = f" ({ref})" if ref else "" + click.echo(f" {key}{ref_str}: {old_sha} -> {new_sha}") + elif error_count > 0: + logger.error(f"Update failed with {error_count} error(s).") + elif not had_baseline: + logger.success("Update complete.") + else: + logger.success("All packages already at latest refs.") diff --git a/src/apm_cli/commands/deps/cli.py b/src/apm_cli/commands/deps/cli.py index 34d702375..836b81ddd 100644 --- a/src/apm_cli/commands/deps/cli.py +++ b/src/apm_cli/commands/deps/cli.py @@ -12,11 +12,10 @@ from ...core.target_detection import TargetParamType from ...models.apm_package import APMPackage from .._helpers import ( - UnknownPackageError, _expand_with_ancestors, _standalone_installed_packages, - resolve_requested_packages, ) +from ._cli_ops import _show_scope_deps, _update_impl from ._utils import ( _count_primitives, _get_package_display_info, @@ -267,119 +266,6 @@ def deps(): deps.add_command(_why_cmd) -def _show_scope_deps(scope_label, apm_dir, logger, console, has_rich, insecure_only=False): - """Display dependencies for a single scope (Project or Global).""" - installed_packages, orphaned_packages = _resolve_scope_deps(apm_dir, logger, insecure_only) - - if installed_packages is None: - logger.progress(f"No APM dependencies installed ({scope_label} scope)") - logger.verbose_detail("Run 'apm install' to install dependencies from apm.yml") - return - - if not installed_packages: - if insecure_only: - logger.progress(f"No insecure APM dependencies installed ({scope_label} scope)") - else: - logger.progress( - f"apm_modules/ directory exists but contains no valid packages ({scope_label} scope)" - ) - return - - # Display packages in table format - if has_rich: - from rich.table import Table - - table = Table( - title=( - f" Insecure APM Dependencies ({scope_label})" - if insecure_only - else f" APM Dependencies ({scope_label})" - ), - show_header=True, - header_style="bold cyan", - ) - table.add_column("Package", style="bold white") - table.add_column("Version", style="yellow") - table.add_column("Source", style="blue") - if insecure_only: - table.add_column("Origin", style="bold red") - table.add_column("Prompts", style="magenta", justify="center") - table.add_column("Instructions", style="green", justify="center") - table.add_column("Agents", style="cyan", justify="center") - table.add_column("Skills", style="yellow", justify="center") - table.add_column("Hooks", style="red", justify="center") - - for pkg in installed_packages: - p = pkg["primitives"] - table.add_row( - pkg["name"], - pkg["version"], - pkg["source"], - *([pkg["insecure_via"]] if insecure_only else []), - str(p.get("prompts", 0)) if p.get("prompts", 0) > 0 else "-", - str(p.get("instructions", 0)) if p.get("instructions", 0) > 0 else "-", - str(p.get("agents", 0)) if p.get("agents", 0) > 0 else "-", - str(p.get("skills", 0)) if p.get("skills", 0) > 0 else "-", - str(p.get("hooks", 0)) if p.get("hooks", 0) > 0 else "-", - ) - - console.print(table) - - # Show orphaned packages warning -- routed through CommandLogger - # so output goes through the central STATUS_SYMBOLS prefix path - # (no raw `[!]` literal that Rich would parse as markup) and so - # behaviour is consistent with prune.py. - if orphaned_packages: - logger.warning(f"{len(orphaned_packages)} orphaned package(s) found (not in apm.yml):") - for pkg in orphaned_packages: - logger.warning(f" - {pkg}") - logger.info("Run 'apm prune' to remove orphaned packages") - else: - # Fallback text table - if insecure_only: - click.echo(f" Insecure APM Dependencies ({scope_label}):") - click.echo( - f"{'Package':<30} {'Version':<10} {'Source':<12} {'Origin':<18} " - f"{'Prompts':>7} {'Instr':>7} {'Agents':>7} {'Skills':>7} {'Hooks':>7}" - ) - click.echo("-" * 117) - else: - click.echo(f" APM Dependencies ({scope_label}):") - click.echo( - f"{'Package':<30} {'Version':<10} {'Source':<12} {'Prompts':>7} {'Instr':>7} {'Agents':>7} {'Skills':>7} {'Hooks':>7}" - ) - click.echo("-" * 98) - - for pkg in installed_packages: - p = pkg["primitives"] - name = pkg["name"][:28] - version = pkg["version"][:8] - source = pkg["source"][:10] - insecure_via = pkg["insecure_via"][:16] - prompts = str(p.get("prompts", 0)) if p.get("prompts", 0) > 0 else "-" - instructions = str(p.get("instructions", 0)) if p.get("instructions", 0) > 0 else "-" - agents = str(p.get("agents", 0)) if p.get("agents", 0) > 0 else "-" - skills = str(p.get("skills", 0)) if p.get("skills", 0) > 0 else "-" - hooks = str(p.get("hooks", 0)) if p.get("hooks", 0) > 0 else "-" - if insecure_only: - click.echo( - f"{name:<30} {version:<10} {source:<12} {insecure_via:<18} " - f"{prompts:>7} {instructions:>7} {agents:>7} {skills:>7} {hooks:>7}" - ) - else: - click.echo( - f"{name:<30} {version:<10} {source:<12} {prompts:>7} {instructions:>7} {agents:>7} {skills:>7} {hooks:>7}" - ) - - # Show orphaned packages warning -- route through CommandLogger - # for consistency with the rich branch above and with prune.py. - if orphaned_packages: - logger.warning(f"{len(orphaned_packages)} orphaned package(s) found (not in apm.yml):") - for pkg in orphaned_packages: - logger.warning(f" - {pkg}") - logger.info("Run 'apm prune' to remove orphaned packages") - - @deps.command(name="list", help="List installed APM dependencies") @click.option( "--global", @@ -764,145 +650,7 @@ def update(packages, verbose, force, target, parallel_downloads, global_, legacy apm deps update org/a org/b # Update specific packages apm deps update --verbose # Show detailed progress """ - from ...core.auth import AuthResolver - from ...core.command_logger import InstallLogger - from ...utils.console import _rich_warning - from ..install import ( - _APM_IMPORT_ERROR, - APM_DEPS_AVAILABLE, - _install_apm_dependencies, - ) - - # Soft-deprecation (issue #1525): `apm update` is now a strict superset - # of this command. Kept working for one release; removed in the next - # breaking release. - _rich_warning( - "'apm deps update' is deprecated; use 'apm update' instead. " - "'apm update' now supports -g/--global, [PACKAGES]..., --force, and " - "--parallel-downloads, plus an interactive plan, --dry-run, and --yes.", - symbol="warning", - ) - - logger = InstallLogger(verbose=verbose, partial=bool(packages)) - - if not APM_DEPS_AVAILABLE: - logger.error("APM dependency system not available") - if _APM_IMPORT_ERROR: - logger.progress(f"Import error: {_APM_IMPORT_ERROR}") - sys.exit(1) - - from ...core.scope import InstallScope, get_apm_dir - - scope = InstallScope.USER if global_ else InstallScope.PROJECT - project_root = get_apm_dir(scope) - apm_yml_path = project_root / APM_YML_FILENAME - - if not apm_yml_path.exists(): - scope_hint = "~/.apm/" if global_ else "current directory" - logger.error(f"No {APM_YML_FILENAME} found in {scope_hint}") - sys.exit(1) - - try: - apm_package = APMPackage.from_apm_yml(apm_yml_path) - except Exception as e: - logger.error(f"Failed to parse {APM_YML_FILENAME}: {e}") - sys.exit(1) - - all_deps = apm_package.get_apm_dependencies() + apm_package.get_dev_apm_dependencies() - if not all_deps: - logger.progress("No APM dependencies defined in apm.yml") - return - - # Validate and normalize requested packages to canonical dependency keys. - # Shared with `apm update` (see commands/_helpers.py) so the two update - # surfaces resolve short names identically. - try: - only_pkgs = resolve_requested_packages(packages, all_deps) - except UnknownPackageError as e: - logger.error(f"Package '{e.token}' not found in {APM_YML_FILENAME}") - logger.progress(f"Available: {', '.join(e.available)}") - sys.exit(1) - - # Migrate legacy lockfile first, then snapshot SHAs for before/after diff - from ...deps.lockfile import LockFile, get_lockfile_path, migrate_lockfile_if_needed - - lockfile_path = get_lockfile_path(project_root) - migrate_lockfile_if_needed(project_root) - - old_lockfile = LockFile.read(lockfile_path) - had_baseline = old_lockfile is not None - old_shas: dict = {} - if old_lockfile: - for key, dep in old_lockfile.dependencies.items(): - old_shas[key] = dep.resolved_commit - - auth_resolver = AuthResolver() - - noun = f"{len(packages)} package(s)" if packages else f"all {len(all_deps)} dependencies" - # Resolve --legacy-skill-paths: CLI flag wins, then env var fallback. - if not legacy_skill_paths: - from ...integration.targets import should_use_legacy_skill_paths - - legacy_skill_paths = should_use_legacy_skill_paths() - - logger.start(f"Updating {noun}...") - - try: - install_result = _install_apm_dependencies( - apm_package, - update_refs=True, - verbose=verbose, - only_packages=only_pkgs, - force=force, - parallel_downloads=parallel_downloads, - logger=logger, - auth_resolver=auth_resolver, - target=target, - scope=scope, - legacy_skill_paths=legacy_skill_paths, - ) - except Exception as e: - logger.error(f"Update failed: {e}") - if not verbose: - logger.progress("Run with --verbose for detailed diagnostics") - sys.exit(1) - - # Show diagnostics if any - if install_result.diagnostics and install_result.diagnostics.has_diagnostics: - install_result.diagnostics.render_summary() - - # Compare old vs new lockfile SHAs to show what changed - new_lockfile = LockFile.read(lockfile_path) - changed: list = [] - if new_lockfile: - for key, dep in new_lockfile.dependencies.items(): - old_sha = old_shas.get(key) - new_sha = dep.resolved_commit - if old_sha and new_sha and old_sha != new_sha: - changed.append((key, old_sha[:8], new_sha[:8], dep.resolved_ref or "")) - - error_count = 0 - if install_result.diagnostics: - try: - error_count = int(install_result.diagnostics.error_count) - except (TypeError, ValueError): - error_count = 0 - - if changed: - pkg_noun = "package" if len(changed) == 1 else "packages" - if error_count > 0: - logger.warning(f"Updated {len(changed)} {pkg_noun} with {error_count} error(s).") - else: - logger.success(f"Updated {len(changed)} {pkg_noun}:") - for key, old_sha, new_sha, ref in changed: - ref_str = f" ({ref})" if ref else "" - click.echo(f" {key}{ref_str}: {old_sha} -> {new_sha}") - elif error_count > 0: - logger.error(f"Update failed with {error_count} error(s).") - elif not had_baseline: - logger.success("Update complete.") - else: - logger.success("All packages already at latest refs.") + _update_impl(packages, verbose, force, target, parallel_downloads, global_, legacy_skill_paths) @deps.command(help="Show detailed package information") diff --git a/src/apm_cli/commands/find.py b/src/apm_cli/commands/find.py index 89ac965fa..b018eaedb 100644 --- a/src/apm_cli/commands/find.py +++ b/src/apm_cli/commands/find.py @@ -62,6 +62,11 @@ def build_reverse_index(lockfile: LockFile) -> dict[str, list[str]]: # --------------------------------------------------------------------------- +def _ref_with_url(repo_url: str, ref: str) -> str: + """Return ``repo_url@ref`` when *repo_url* is non-empty, else *ref*.""" + return f"{repo_url}@{ref}" if repo_url else ref + + def _format_origin(dep: LockedDependency) -> str: """Return a human-readable ASCII origin string for *dep*. @@ -78,20 +83,11 @@ def _format_origin(dep: LockedDependency) -> str: if dep.source == "local" and dep.local_path: return dep.local_path if dep.resolved_ref: - ref_part = dep.resolved_ref - if dep.repo_url: - return f"{dep.repo_url}@{ref_part}" - return ref_part + return _ref_with_url(dep.repo_url, dep.resolved_ref) if dep.resolved_tag: - tag_part = dep.resolved_tag - if dep.repo_url: - return f"{dep.repo_url}@{tag_part}" - return tag_part + return _ref_with_url(dep.repo_url, dep.resolved_tag) if dep.resolved_commit: - commit = dep.resolved_commit[:12] - if dep.repo_url: - return f"{dep.repo_url}@{commit}" - return commit + return _ref_with_url(dep.repo_url, dep.resolved_commit[:12]) return dep.repo_url diff --git a/src/apm_cli/commands/marketplace/__init__.py b/src/apm_cli/commands/marketplace/__init__.py index 764e223c2..b00e2aed2 100644 --- a/src/apm_cli/commands/marketplace/__init__.py +++ b/src/apm_cli/commands/marketplace/__init__.py @@ -7,17 +7,13 @@ from __future__ import annotations import builtins -import json import logging import re import sys -import traceback from pathlib import Path import click -import yaml -from ...core.command_logger import CommandLogger from ...marketplace.builder import BuildOptions, BuildReport, MarketplaceBuilder, ResolvedPackage from ...marketplace.errors import ( BuildError, @@ -48,7 +44,8 @@ from ...marketplace.semver import SemVer, parse_semver, satisfies_range from ...marketplace.yml_schema import load_marketplace_yml from ...utils.path_security import PathTraversalError, validate_path_segments -from .._helpers import _get_console, _is_interactive +from .._helpers import _get_console as _get_console +from .._helpers import _is_interactive as _is_interactive logger = logging.getLogger(__name__) @@ -215,41 +212,6 @@ def marketplace(ctx): marketplace.add_command(package) -def _check_gitignore_for_marketplace_json(logger): - """Warn if .gitignore contains a rule that would ignore marketplace outputs.""" - gitignore_path = Path.cwd() / ".gitignore" - if not gitignore_path.exists(): - return - - try: - lines = gitignore_path.read_text(encoding="utf-8").splitlines() - except OSError: - return - - patterns = { - "marketplace.json", - "**/marketplace.json", - "/marketplace.json", - ".claude-plugin/marketplace.json", - ".agents/plugins/marketplace.json", - "*.json", - } - for line in lines: - stripped = line.strip() - # Skip blank and commented lines - if not stripped or stripped.startswith("#"): - continue - if stripped in patterns: - logger.warning( - "Your .gitignore ignores marketplace.json. " - "Track apm.yml plus generated marketplace files such as " - ".claude-plugin/marketplace.json and .agents/plugins/marketplace.json. " - "Remove the .gitignore rule or add explicit unignore entries.", - symbol="warning", - ) - return - - def _parse_marketplace_source(source: str, host_flag: str | None) -> tuple[str, str, str | None]: """Parse a marketplace source argument into ``(url, kind, embedded_host)``. @@ -449,1091 +411,38 @@ def _expand_local_path(raw: str) -> str: return _osp.abspath(_osp.expanduser(raw)) -# Host-trust classification is owned by AuthResolver.classify_host (see -# core/auth.py). The marketplace command layer routes through it so that the -# credential-leakage guard at registration time uses the same single source of -# truth as the fetch-time guard in marketplace/client.py. Adding a second -# implementation here would create silent drift on a security-critical path. -_TRUSTED_MARKETPLACE_HOST_KINDS = ("github", "ghe_cloud", "ghes", "gitlab") - - -def _marketplace_add_unsupported_host_error( - resolved_host: str, - quoted_repo: str, - quoted_host: str, - host_kind: str, -) -> str: - """User-facing error when ``apm marketplace add`` rejects the resolved host. - - *quoted_repo* and *quoted_host* must already be ``shlex.quote``-safe for shell - copy-paste (see call sites). - """ - if host_kind == "ado": - return ( - f"Host '{resolved_host}' is not supported for marketplace registration.\n" - "APM marketplaces must be hosted on GitHub, GitHub Enterprise, or GitLab." - ) - return ( - f"Host '{resolved_host}' is not supported.\n" - "Supported marketplace hosts: github.com, *.ghe.com, " - "GitHub Enterprise Server (configure GITHUB_HOST), " - "and GitLab (gitlab.com or self-managed via GITLAB_HOST or APM_GITLAB_HOSTS).\n\n" - "To use GitHub Enterprise Server on this host:\n" - f" export GITHUB_HOST={quoted_host}\n" - "Then re-run:\n" - f" apm marketplace add {quoted_repo}\n\n" - "To use self-managed GitLab on this host:\n" - f" export GITLAB_HOST={quoted_host}\n" - "(or list the host in APM_GITLAB_HOSTS for multiple instances.)\n" - "Then re-run:\n" - f" apm marketplace add {quoted_repo}\n" - ) - - -_ADD_EPILOG = """ -\b -Examples: - apm marketplace add owner/repo - apm marketplace add github.com/owner/repo - apm marketplace add https://gitlab.com/group/repo - apm marketplace add https://dev.azure.com/org/proj/_git/repo --name apm-mkt - apm marketplace add git@gitea.example.com:org/repo.git --name custom - apm marketplace add /srv/marketplaces/agent-forge --name agent-forge -""" - - -@marketplace.command(help="Register a marketplace", epilog=_ADD_EPILOG) -@click.argument("source", metavar="SOURCE", required=True) -@click.option("--name", "-n", default=None, help="Display name (defaults to repo name)") -@click.option("--ref", "-r", default=None, help="Branch, tag, or commit to use (default: main)") -@click.option("--branch", "-b", default=None, help="Deprecated alias for --ref", hidden=True) -@click.option("--host", default=None, help="Git host FQDN (default: github.com)") -@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") -def add(source, name, ref, branch, host, verbose): - """Register a marketplace. - - SOURCE accepts: OWNER/REPO shorthand, HOST/OWNER/REPO shorthand, a full - HTTPS URL (GitHub, GitLab, Azure DevOps, Gitea, Bitbucket Server, or - any self-hosted git server), an SSH URL (``git@host:org/repo.git``), - a local filesystem path, or a ``file://`` URI. - """ - logger = CommandLogger("marketplace-add", verbose=verbose) - try: - from ...marketplace.client import _auto_detect_path, fetch_marketplace - from ...marketplace.models import MarketplaceSource - from ...marketplace.registry import add_marketplace - from ...utils.github_host import is_valid_fqdn - - # --ref / --branch reconciliation. --branch stays as a hidden alias - # for one release so legacy invocations keep working; passing both - # is a hard error so we never silently pick one. - if ref is not None and branch is not None: - logger.error( - "--ref and --branch are mutually exclusive. Use --ref (--branch is a deprecated alias).", - symbol="error", - ) - sys.exit(1) - effective_ref = ref if ref is not None else (branch if branch is not None else "main") - - try: - url, kind, resolved_host = _parse_marketplace_source(source, host) - except PathTraversalError: - logger.error( - f"Invalid source '{source}': contains a path-traversal sequence. " - f"Remove '..', '.', or '~' from each path segment." - ) - sys.exit(1) - except ValueError as exc: - logger.error(str(exc)) - sys.exit(1) - - if host is not None and not is_valid_fqdn(host.strip().lower()): - logger.error( - f"Invalid host: '{host}'. Expected a valid host FQDN (for example, 'github.com').", - symbol="error", - ) - sys.exit(1) - - # --host is meaningful only for shorthand OWNER/REPO inputs. For URL - # / SSH / local-path inputs the host is already embedded; warn that - # --host is being ignored rather than silently overriding. - if host is not None and kind == "local": - logger.warning( - "--host is ignored when SOURCE is a local filesystem path.", - symbol="warning", - ) - elif ( - host is not None - and host.strip().lower() != (resolved_host or "").lower() - and kind in ("git", "github", "gitlab") - and (source.startswith(("https://", "git@", "file://"))) - ): - logger.warning( - "--host is ignored when SOURCE is a full URL.", - symbol="warning", - ) - - # Trust gate is now scoped to kinds that would forward an APM token - # via header injection. The subprocess git path (kind == "git") - # never forwards GITHUB_APM_PAT / GITLAB_APM_TOKEN -- AuthResolver - # only emits credentials matching the classified host. Local-kind - # fetches use no credentials at all. - if kind in ("github", "gitlab"): - from ...core.auth import AuthResolver - - host_info = AuthResolver.classify_host(resolved_host or "") - if host_info.kind not in _TRUSTED_MARKETPLACE_HOST_KINDS: - # Should not happen because _host_kind_to_fetcher_kind already - # mapped non-trusted kinds to "git", but defend in depth. - import shlex as _shlex - - quoted_repo = _shlex.quote(source) - quoted_host = _shlex.quote(resolved_host or "") - logger.error( - _marketplace_add_unsupported_host_error( - resolved_host or "", quoted_repo, quoted_host, host_info.kind - ) - ) - sys.exit(1) - - if name is not None and not _is_valid_alias(name): - logger.error( - f"Invalid marketplace name: '{name}'. " - f"Names must only contain letters, digits, '.', '_', and '-' " - f"(required for 'apm install plugin@marketplace' syntax).", - symbol="error", - ) - sys.exit(1) - - # Surface progress before the slow probe + fetch (5-30s for generic-git) - # so the user sees activity instead of staring at a blank terminal. - provisional_label = name or _default_alias_from_url(url) - logger.start(f"Registering marketplace '{provisional_label}'...", symbol="gear") - - # Probe for marketplace.json location. The probe source's name is a - # placeholder -- _auto_detect_path only consults url/ref/path/kind. - probe_name = provisional_label - probe_source = MarketplaceSource( - name=probe_name, - url=url, - ref=effective_ref, - ) - detected_path = _auto_detect_path(probe_source) - - if detected_path is None: - logger.error( - f"No marketplace.json found in '{probe_source.display_source}'. " - f"Checked: marketplace.json, .github/plugin/marketplace.json, " - f".claude-plugin/marketplace.json", - symbol="error", - ) - sys.exit(1) - - fetch_source = MarketplaceSource( - name=probe_name, - url=url, - ref=effective_ref, - path=detected_path, - ) - manifest = fetch_marketplace(fetch_source, force_refresh=True) - plugin_count = len(manifest.plugins) - - manifest_name = (manifest.name or "").strip() - if name is not None: - display_name = name - alias_source = "--name flag" - elif manifest_name and _is_valid_alias(manifest_name): - display_name = manifest_name - alias_source = f"manifest.name ('{manifest_name}')" - else: - display_name = probe_name - if manifest_name and not _is_valid_alias(manifest_name): - logger.warning( - f"Manifest declares name '{manifest_name}' which is not a " - f"valid alias (must match [a-zA-Z0-9._-]+). " - f"Falling back to repo name.", - symbol="warning", - ) - alias_source = f"derived name (manifest.name '{manifest_name}' invalid)" - else: - alias_source = "derived name (manifest.name missing)" - - assert _is_valid_alias(display_name), ( # noqa: S101 - f"Resolved marketplace alias '{display_name}' failed validation" - ) - - logger.verbose_detail(f" Source: {fetch_source.display_source}") - logger.verbose_detail(f" Kind: {kind}") - logger.verbose_detail(f" Ref: {effective_ref}") - logger.verbose_detail(f" Detected path: {detected_path}") - logger.verbose_detail(f" Alias source: {alias_source}") - - final_source = MarketplaceSource( - name=display_name, - url=url, - ref=effective_ref, - path=detected_path, - ) - add_marketplace(final_source) - - logger.success( - f"Marketplace '{display_name}' registered ({plugin_count} plugins)", - symbol="check", - ) - if manifest.description: - logger.verbose_detail(f" {manifest.description}") - - if name is None and display_name != probe_name: - logger.progress( - f"Install plugins with: apm install @{display_name}", - symbol="info", - ) - - except Exception as e: - logger.error(f"Failed to register marketplace: {e}") - if verbose: - logger.progress(traceback.format_exc(), symbol="info") - sys.exit(1) - - -def _default_alias_from_url(url: str) -> str: - """Derive a default marketplace alias from a parsed URL. - - Strips ``.git`` suffix, trailing slashes, and uses the last - path-segment. For ``file://`` URLs the alias falls back to the - final filesystem segment. - """ - from urllib.parse import urlparse - - parsed = urlparse(url) if "://" in url else None - if parsed and parsed.path: - tail = parsed.path.rstrip("/").rsplit("/", 1)[-1] - else: - tail = url.rstrip("/").rsplit("/", 1)[-1] - if tail.endswith(".git"): - tail = tail[:-4] - # Defensive: alias regex disallows '.' at end + arbitrary characters, - # but it tolerates dots and dashes inside which covers normal repo names. - return tail or "marketplace" - - -@marketplace.command(name="list", help="List registered marketplaces") -@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") -def list_cmd(verbose): - """Show all registered marketplaces.""" - logger = CommandLogger("marketplace-list", verbose=verbose) - try: - from ...marketplace.registry import get_registered_marketplaces - - sources = get_registered_marketplaces() - - if not sources: - logger.progress( - "No marketplaces registered. Use 'apm marketplace add SOURCE' to register one " - "(OWNER/REPO, HTTPS URL, SSH URL, or local path).", - symbol="info", - ) - return - - console = _get_console() - if not console: - # Colorama fallback - logger.progress(f"{len(sources)} marketplace(s) registered:", symbol="info") - for s in sources: - logger.tree_item(f" {s.name} ({s.display_source})") - return - - from rich.table import Table - - table = Table( - title="Registered Marketplaces", - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Name", style="bold white", no_wrap=True) - table.add_column("Source", style="white") - table.add_column("Ref", style="cyan") - table.add_column("Path", style="dim") - - for s in sources: - table.add_row(s.name, s.display_source, s.ref, s.path) - - console.print() - console.print(table) - logger.progress( - "Use 'apm marketplace browse ' to see plugins", - symbol="info", - ) - - except Exception as e: - logger.error(f"Failed to list marketplaces: {e}") - if verbose: - logger.progress(traceback.format_exc(), symbol="info") - sys.exit(1) - - -@marketplace.command(help="Browse plugins in a marketplace") -@click.argument("name", required=True) -@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") -def browse(name, verbose): - """Show available plugins in a marketplace.""" - logger = CommandLogger("marketplace-browse", verbose=verbose) - try: - from ...marketplace.client import fetch_marketplace - from ...marketplace.registry import get_marketplace_by_name - - source = get_marketplace_by_name(name) - logger.start(f"Fetching plugins from '{name}'...", symbol="search") - - manifest = fetch_marketplace(source, force_refresh=True) - - if not manifest.plugins: - logger.warning(f"Marketplace '{name}' has no plugins") - return - - console = _get_console() - if not console: - # Colorama fallback - logger.success(f"{len(manifest.plugins)} plugin(s) in '{name}':", symbol="check") - for p in manifest.plugins: - desc = f" -- {p.description}" if p.description else "" - logger.tree_item(f" {p.name}{desc}") - logger.progress(f"Install: apm install @{name}", symbol="info") - return - - from rich.table import Table - - table = Table( - title=f"Plugins in '{name}'", - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Plugin", style="bold white", no_wrap=True) - table.add_column("Description", style="white", ratio=1) - table.add_column("Version", style="cyan", justify="center") - table.add_column("Install", style="green") - - for p in manifest.plugins: - desc = p.description or "--" - ver = p.version or "--" - table.add_row(p.name, desc, ver, f"{p.name}@{name}") - - console.print() - console.print(table) - logger.progress( - f"Install a plugin: apm install @{name}", - symbol="info", - ) - - except Exception as e: - logger.error(f"Failed to browse marketplace: {e}") - if verbose: - logger.progress(traceback.format_exc(), symbol="info") - sys.exit(1) - - -@marketplace.command(help="Refresh marketplace cache") -@click.argument("name", required=False, default=None) -@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") -def update(name, verbose): - """Refresh cached marketplace data (one or all).""" - logger = CommandLogger("marketplace-update", verbose=verbose) - try: - from ...marketplace.client import clear_marketplace_cache, fetch_marketplace - from ...marketplace.registry import ( - get_marketplace_by_name, - get_registered_marketplaces, - ) - - if name: - source = get_marketplace_by_name(name) - logger.start(f"Refreshing marketplace '{name}'...", symbol="gear") - clear_marketplace_cache(name, host=source.host) - manifest = fetch_marketplace(source, force_refresh=True) - logger.success( - f"Marketplace '{name}' updated ({len(manifest.plugins)} plugins)", - symbol="check", - ) - else: - sources = get_registered_marketplaces() - if not sources: - logger.progress("No marketplaces registered.", symbol="info") - return - logger.start(f"Refreshing {len(sources)} marketplace(s)...", symbol="gear") - for s in sources: - try: - clear_marketplace_cache(s.name, host=s.host) - manifest = fetch_marketplace(s, force_refresh=True) - logger.tree_item(f" {s.name} ({len(manifest.plugins)} plugins)") - except Exception as exc: - logger.warning(f" {s.name}: {exc}") - if verbose: - logger.progress(traceback.format_exc(), symbol="info") - logger.success("Marketplace cache refreshed", symbol="check") - - except Exception as e: - logger.error(f"Failed to update marketplace: {e}") - if verbose: - logger.progress(traceback.format_exc(), symbol="info") - sys.exit(1) - - -@marketplace.command(help="Remove a registered marketplace") -@click.argument("name", required=True) -@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt") -@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") -def remove(name, yes, verbose): - """Unregister a marketplace.""" - logger = CommandLogger("marketplace-remove", verbose=verbose) - try: - from ...marketplace.client import clear_marketplace_cache - from ...marketplace.registry import get_marketplace_by_name, remove_marketplace - - # Verify it exists first - source = get_marketplace_by_name(name) - - if not yes: - if not _is_interactive(): - logger.error( - "Use --yes to skip confirmation in non-interactive mode", - symbol="error", - ) - sys.exit(1) - confirmed = click.confirm( - f"Remove marketplace '{source.name}' ({source.display_source})?", - default=False, - ) - if not confirmed: - logger.progress("Cancelled", symbol="info") - return - - remove_marketplace(name) - clear_marketplace_cache(name, host=source.host) - logger.success(f"Marketplace '{name}' removed", symbol="check") - - except Exception as e: - logger.error(f"Failed to remove marketplace: {e}") - if verbose: - logger.progress(traceback.format_exc(), symbol="info") - sys.exit(1) - - -def _render_build_error(logger, exc): - """Render a BuildError with actionable hints.""" - if isinstance(exc, GitLsRemoteError): - logger.error(exc.summary_text, symbol="error") - if exc.hint: - logger.progress(f"Hint: {exc.hint}", symbol="info") - elif isinstance(exc, NoMatchingVersionError): - logger.error(str(exc), symbol="error") - logger.progress( - "Check that your version range matches published tags.", - symbol="info", - ) - elif isinstance(exc, RefNotFoundError): - logger.error(str(exc), symbol="error") - logger.progress( - "Verify the ref is spelled correctly and the remote is reachable.", - symbol="info", - ) - elif isinstance(exc, HeadNotAllowedError): - logger.error(str(exc), symbol="error") - elif isinstance(exc, OfflineMissError): - logger.error(str(exc), symbol="error") - logger.progress( - "Run a build online first to populate the cache.", - symbol="info", - ) - else: - logger.error(f"Build failed: {exc}", symbol="error") - - -def _render_build_table(logger, report): - """Render the resolved-packages table (Rich with colorama fallback).""" - console = _get_console() - if not console: - # Colorama fallback - for pkg in report.resolved: - sha_short = pkg.sha[:8] if pkg.sha else "--" - ref_kind = "tag" if not pkg.ref.startswith("refs/heads/") else "branch" - logger.tree_item(f" [+] {pkg.name} {pkg.ref} {sha_short} ({ref_kind})") - return - - from rich.table import Table - from rich.text import Text - - table = Table( - title="Resolved Packages", - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Status", style="green", no_wrap=True, width=6) - table.add_column("Package", style="bold white", no_wrap=True) - table.add_column("Version", style="cyan") - table.add_column("Commit", style="dim") - table.add_column("Ref Kind", style="white") - - for pkg in report.resolved: - sha_short = pkg.sha[:8] if pkg.sha else "--" - # Determine ref kind - ref_kind = "tag" - if pkg.ref and not parse_semver(pkg.ref.lstrip("vV")): - ref_kind = "ref" - table.add_row(Text("[+]"), pkg.name, pkg.ref, sha_short, ref_kind) - - console.print() - console.print(table) - - -class _OutdatedRow: - """Simple container for outdated table row data.""" - - __slots__ = ( - "current", - "latest_in_range", - "latest_overall", - "name", - "note", - "range_spec", - "status", - ) - - def __init__(self, name, current, range_spec, latest_in_range, latest_overall, status, note): - self.name = name - self.current = current - self.range_spec = range_spec - self.latest_in_range = latest_in_range - self.latest_overall = latest_overall - self.status = status - self.note = note - - -def _load_current_versions(): - """Load current ref versions from marketplace.json if present.""" - mkt_path = Path.cwd() / "marketplace.json" - if not mkt_path.exists(): - return {} - try: - data = json.loads(mkt_path.read_text(encoding="utf-8")) - result = {} - for plugin in data.get("plugins", []): - name = plugin.get("name", "") - src = plugin.get("source", {}) - if isinstance(src, dict): - result[name] = src.get("ref", "--") - return result - except (json.JSONDecodeError, OSError): - return {} - - -def _extract_tag_versions(refs, entry, yml, include_prerelease): - """Extract (SemVer, tag_name) pairs from remote refs for a package entry.""" - from ...marketplace._shared import iter_semver_tags - from ...marketplace.tag_pattern import ( - build_tag_regex, - infer_tag_pattern_from_refs, - ) - - def _collect(pattern: str) -> list: - tag_rx = ( - build_tag_regex(pattern, name=entry.name) - if "{name}" in pattern - else build_tag_regex(pattern) - ) - collected = [] - for sv, tag_name, _ in iter_semver_tags(refs, tag_rx): - if sv.is_prerelease and not (include_prerelease or entry.include_prerelease): - continue - collected.append((sv, tag_name)) - return collected - - pattern = entry.tag_pattern or yml.build.tag_pattern - results = _collect(pattern) - if not results: - inferred = infer_tag_pattern_from_refs(refs, entry.name) - if inferred and inferred != pattern: - logger.debug( - "Configured tag pattern %r matched no tags for %s; inferred %r", - pattern, - entry.name, - inferred, - ) - results = _collect(inferred) - return results - - -def _render_outdated_table(logger, rows): - """Render the outdated-packages table.""" - console = _get_console() - if not console: - for row in rows: - note = f" ({row.note})" if row.note else "" - logger.tree_item( - f" {row.status} {row.name} current={row.current} " - f"latest-in-range={row.latest_in_range} " - f"latest={row.latest_overall}{note}" - ) - return - - from rich.table import Table - from rich.text import Text - - table = Table( - title="Package Version Status", - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Status", style="green", no_wrap=True, width=6) - table.add_column("Package", style="bold white", no_wrap=True) - table.add_column("Current", style="white") - table.add_column("Range", style="dim") - table.add_column("Latest in Range", style="cyan") - table.add_column("Latest Overall", style="yellow") - - for row in rows: - note = "" - if row.note: - note = f" ({row.note})" - table.add_row( - Text(row.status), - row.name, - row.current, - row.range_spec, - row.latest_in_range + note, - row.latest_overall, - ) - - console.print() - console.print(table) - - -class _CheckResult: - """Container for per-entry check results.""" - - __slots__ = ("error", "name", "reachable", "ref_ok", "version_found") - - def __init__(self, name, reachable, version_found, ref_ok, error): - self.name = name - self.reachable = reachable - self.version_found = version_found - self.ref_ok = ref_ok - self.error = error - - -def _render_check_table(logger, results): - """Render the check-results table.""" - console = _get_console() - if not console: - for r in results: - icon = "[+]" if r.ref_ok else "[x]" - detail = r.error if r.error else "OK" - logger.tree_item(f" {icon} {r.name}: {detail}") - return - - from rich.table import Table - from rich.text import Text - - table = Table( - title="Entry Health Check", - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Status", no_wrap=True, width=6) - table.add_column("Package", style="bold white", no_wrap=True) - table.add_column("Reachable", style="white", justify="center") - table.add_column("Version Found", style="white", justify="center") - table.add_column("Ref OK", style="white", justify="center") - table.add_column("Detail", style="dim") - - for r in results: - reach = "[+]" if r.reachable else "[x]" - ver = "[+]" if r.version_found else "[x]" - ref = "[+]" if r.ref_ok else "[x]" - detail = r.error if r.error else "OK" - table.add_row( - Text("[+]" if r.ref_ok else "[x]"), - r.name, - Text(reach), - Text(ver), - Text(ref), - detail, - ) - - console.print() - console.print(table) - - -class _DoctorCheck: - """Container for a single doctor check result.""" - - __slots__ = ("detail", "informational", "name", "passed") - - def __init__(self, name, passed, detail, informational=False): - self.name = name - self.passed = passed - self.detail = detail - self.informational = informational - - -def _render_doctor_table(logger, checks): - """Render the doctor results table.""" - console = _get_console() - if not console: - for c in checks: - if c.informational: - icon = "[i]" - elif c.passed: - icon = "[+]" - else: - icon = "[x]" - logger.tree_item(f" {icon} {c.name}: {c.detail}") - return - - from rich.table import Table - from rich.text import Text - - table = Table( - title="Environment Diagnostics", - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Check", style="bold white", no_wrap=True) - table.add_column("Status", no_wrap=True, width=6) - table.add_column("Detail", style="white") - - for c in checks: - if c.informational: - icon = "[i]" - elif c.passed: - icon = "[+]" - else: - icon = "[x]" - table.add_row(c.name, Text(icon), c.detail) - - console.print() - console.print(table) - - -def _load_targets_file(path): - """Load and validate a consumer-targets YAML file. - - Returns a list of ``ConsumerTarget`` instances. - - Raises ``SystemExit`` on validation failures. - """ - try: - raw = yaml.safe_load(path.read_text(encoding="utf-8")) - except yaml.YAMLError as exc: - return None, f"Invalid YAML in targets file: {exc}" - except OSError as exc: - return None, f"Cannot read targets file: {exc}" - - if not isinstance(raw, dict) or "targets" not in raw: - return None, "Targets file must contain a 'targets' key." - - raw_targets = raw["targets"] - if not isinstance(raw_targets, list) or not raw_targets: - return None, "Targets file must contain a non-empty 'targets' list." - - targets = [] - for idx, entry in enumerate(raw_targets): - if not isinstance(entry, dict): - return None, f"targets[{idx}] must be a mapping." - - repo = entry.get("repo") - if not repo or not isinstance(repo, str): - return None, f"targets[{idx}]: 'repo' is required (owner/name)." - - # Validate repo format: owner/name - parts = repo.split("/") - if len(parts) != 2 or not parts[0] or not parts[1]: - return None, f"targets[{idx}]: 'repo' must be 'owner/name', got '{repo}'." - - branch = entry.get("branch") - if not branch or not isinstance(branch, str): - return None, f"targets[{idx}]: 'branch' is required." - - path_in_repo = entry.get("path_in_repo", "apm.yml") - if not isinstance(path_in_repo, str) or not path_in_repo.strip(): - return None, f"targets[{idx}]: 'path_in_repo' must be a non-empty string." - - # Path safety check - try: - validate_path_segments( - path_in_repo, - context=f"targets[{idx}].path_in_repo", - ) - except PathTraversalError as exc: - return None, str(exc) - - targets.append( - ConsumerTarget( - repo=repo.strip(), - branch=branch.strip(), - path_in_repo=path_in_repo.strip(), - ) - ) - - return targets, None - - -def _render_publish_plan(logger, plan): - """Render the publish plan as a Rich panel + target table.""" - console = _get_console() - - plan_text = ( - f"Marketplace: {plan.marketplace_name}\n" - f"New version: {plan.marketplace_version}\n" - f"New ref: {plan.new_ref}\n" - f"Branch: {plan.branch_name}\n" - f"Targets: {len(plan.targets)}" - ) - - if not console: - logger.progress("Publish plan:", symbol="info") - for line in plan_text.splitlines(): - logger.tree_item(f" {line}") - click.echo() - for t in plan.targets: - logger.tree_item(f" [*] {t.repo} branch={t.branch} path={t.path_in_repo}") - return - - from rich.panel import Panel - from rich.table import Table - from rich.text import Text - - console.print() - console.print( - Panel( - plan_text, - title="Publish plan", - border_style="cyan", - ) - ) - - table = Table( - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Repo", style="bold white", no_wrap=True) - table.add_column("Branch", style="cyan") - table.add_column("Path", style="dim") - table.add_column("Status", no_wrap=True, width=10) - - for t in plan.targets: - table.add_row(t.repo, t.branch, t.path_in_repo, Text("[*]")) - - console.print(table) - console.print() - - -def _render_publish_summary(logger, results, pr_results, no_pr, dry_run): - """Render the final publish summary table.""" - console = _get_console() - - # Build lookup for PR results by repo - pr_by_repo = {} - for pr_r in pr_results: - pr_by_repo[pr_r.target.repo] = pr_r - - updated_count = sum(1 for r in results if r.outcome == PublishOutcome.UPDATED) - failed_count = sum(1 for r in results if r.outcome == PublishOutcome.FAILED) - total = len(results) - - if not console: - click.echo() - for r in results: - icon = _outcome_symbol(r.outcome) - pr_info = "" - if not no_pr: - pr_r = pr_by_repo.get(r.target.repo) - if pr_r: - pr_info = f" PR: {pr_r.state.value}" - if pr_r.pr_number: - pr_info += f" #{pr_r.pr_number}" - logger.tree_item(f" {icon} {r.target.repo}: {r.outcome.value}{pr_info} -- {r.message}") - click.echo() - _render_publish_footer(logger, updated_count, failed_count, total, dry_run) - return - - from rich.table import Table - from rich.text import Text - - table = Table( - title="Publish Results", - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Status", no_wrap=True, width=6) - table.add_column("Repo", style="bold white", no_wrap=True) - table.add_column("Outcome", style="white") - - if not no_pr: - table.add_column("PR State", style="white") - table.add_column("PR #", style="cyan", justify="right") - table.add_column("PR URL", style="dim") - - table.add_column("Message", style="dim", ratio=1) - - for r in results: - icon = _outcome_symbol(r.outcome) - row = [Text(icon), r.target.repo, r.outcome.value] - - if not no_pr: - pr_r = pr_by_repo.get(r.target.repo) - if pr_r: - row.append(pr_r.state.value) - row.append(str(pr_r.pr_number) if pr_r.pr_number else "--") - row.append(pr_r.pr_url or "--") - else: - row.extend(["--", "--", "--"]) - - row.append(r.message) - table.add_row(*row) - - console.print() - console.print(table) - console.print() - - _render_publish_footer(logger, updated_count, failed_count, total, dry_run) - - -def _outcome_symbol(outcome): - """Map a ``PublishOutcome`` to a bracket symbol.""" - if outcome == PublishOutcome.UPDATED: - return "[+]" - elif outcome == PublishOutcome.FAILED: - return "[x]" - elif outcome in ( - PublishOutcome.SKIPPED_DOWNGRADE, - PublishOutcome.SKIPPED_REF_CHANGE, - ): - return "[!]" - elif outcome == PublishOutcome.NO_CHANGE: - return "[*]" - return "[*]" - - -def _render_publish_footer(logger, updated, failed, total, dry_run): - """Render the footer success/warning line.""" - suffix = " (dry-run)" if dry_run else "" - if failed == 0: - logger.success( - f"Published {updated}/{total} targets{suffix}", - symbol="check", - ) - else: - logger.warning( - f"Published {updated}/{total} targets, {failed} failed{suffix}", - symbol="warning", - ) - +# --------------------------------------------------------------------------- +# Re-exports from siblings (Rule A: keep names patchable on this module) +# --------------------------------------------------------------------------- -@click.command( - name="search", - help="Search plugins in a marketplace (QUERY@MARKETPLACE)", +from ._publish_ops import _load_targets_file as _load_targets_file # noqa: E402 +from ._publish_ops import _outcome_symbol as _outcome_symbol # noqa: E402 +from ._publish_ops import _render_publish_footer as _render_publish_footer # noqa: E402 +from ._publish_ops import _render_publish_plan as _render_publish_plan # noqa: E402 +from ._publish_ops import _render_publish_summary as _render_publish_summary # noqa: E402 +from ._registry_cmds import ( # noqa: E402 + _check_gitignore_for_marketplace_json as _check_gitignore_for_marketplace_json, ) -@click.argument("expression", required=True, metavar="QUERY@MARKETPLACE") -@click.option("--limit", default=20, show_default=True, help="Max results to show") -@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") -def search(expression, limit, verbose): - """Search for plugins in a specific marketplace. - - Use QUERY@MARKETPLACE format, e.g.: apm marketplace search security@skills - """ - logger = CommandLogger("marketplace-search", verbose=verbose) - try: - from ...marketplace.client import search_marketplace - from ...marketplace.registry import get_marketplace_by_name - - if "@" not in expression: - logger.error( - f"Invalid format: '{expression}'. " - "Use QUERY@MARKETPLACE, e.g.: apm marketplace search security@skills" - ) - sys.exit(1) - - query, marketplace_name = expression.rsplit("@", 1) - if not query or not marketplace_name: - logger.error( - "Both QUERY and MARKETPLACE are required. " - "Use QUERY@MARKETPLACE, e.g.: apm marketplace search security@skills" - ) - sys.exit(1) - - try: - source = get_marketplace_by_name(marketplace_name) - except MarketplaceNotFoundError: - logger.error( - f"Marketplace '{marketplace_name}' is not registered. " - "Use 'apm marketplace list' to see registered marketplaces." - ) - sys.exit(1) - - logger.start(f"Searching '{marketplace_name}' for '{query}'...", symbol="search") - results = search_marketplace(query, source)[:limit] - - if not results: - logger.warning( - f"No plugins found matching '{query}' in '{marketplace_name}'. " - f"Try 'apm marketplace browse {marketplace_name}' to see all plugins." - ) - return - - console = _get_console() - if not console: - # Colorama fallback - logger.success(f"Found {len(results)} plugin(s):", symbol="check") - for p in results: - desc = f" -- {p.description}" if p.description else "" - logger.tree_item(f" {p.name}@{marketplace_name}{desc}") - logger.progress( - f"Install: apm install @{marketplace_name}", - symbol="info", - ) - return - - from rich.table import Table - - table = Table( - title=f"Search Results: '{query}' in {marketplace_name}", - show_header=True, - header_style="bold cyan", - border_style="cyan", - ) - table.add_column("Plugin", style="bold white", no_wrap=True) - table.add_column("Description", style="white", ratio=1) - table.add_column("Install", style="green") - - for p in results: - desc = p.description or "--" - if len(desc) > 60: - desc = desc[:57] + "..." - table.add_row(p.name, desc, f"{p.name}@{marketplace_name}") - - console.print() - console.print(table) - logger.progress( - f"Install: apm install @{marketplace_name}", - symbol="info", - ) - - except SystemExit: - raise - except Exception as e: - logger.error(f"Search failed: {e}") - logger.verbose_detail(traceback.format_exc()) - sys.exit(1) - - +from ._registry_cmds import _default_alias_from_url as _default_alias_from_url # noqa: E402 +from ._registry_cmds import ( # noqa: E402 + _marketplace_add_unsupported_host_error as _marketplace_add_unsupported_host_error, +) +from ._registry_cmds import add as add # noqa: E402 +from ._registry_cmds import browse as browse # noqa: E402 +from ._registry_cmds import list_cmd as list_cmd # noqa: E402 +from ._registry_cmds import remove as remove # noqa: E402 +from ._registry_cmds import update as update # noqa: E402 +from ._search_cmd import search as search # noqa: E402 +from ._table_ops import _CheckResult as _CheckResult # noqa: E402 +from ._table_ops import _DoctorCheck as _DoctorCheck # noqa: E402 +from ._table_ops import _extract_tag_versions as _extract_tag_versions # noqa: E402 +from ._table_ops import _load_current_versions as _load_current_versions # noqa: E402 +from ._table_ops import _OutdatedRow as _OutdatedRow # noqa: E402 +from ._table_ops import _render_build_error as _render_build_error # noqa: E402 +from ._table_ops import _render_build_table as _render_build_table # noqa: E402 +from ._table_ops import _render_check_table as _render_check_table # noqa: E402 +from ._table_ops import _render_doctor_table as _render_doctor_table # noqa: E402 +from ._table_ops import _render_outdated_table as _render_outdated_table # noqa: E402 from .audit import audit # noqa: E402 from .check import check # noqa: E402 from .doctor import doctor # noqa: E402 @@ -1574,6 +483,31 @@ def search(expression, limit, verbose): "ResolvedPackage", "SemVer", "TargetResult", + "_CheckResult", + "_DoctorCheck", + "_OutdatedRow", + "_check_gitignore_for_marketplace_json", + "_default_alias_from_url", + "_extract_tag_versions", + "_find_duplicate_names", + "_is_valid_alias", + "_load_config_or_exit", + "_load_current_versions", + "_load_targets_file", + "_load_yml_or_exit", + "_marketplace_add_unsupported_host_error", + "_outcome_symbol", + "_parse_marketplace_repo", + "_parse_marketplace_source", + "_render_build_error", + "_render_build_table", + "_render_check_table", + "_render_doctor_table", + "_render_outdated_table", + "_render_publish_footer", + "_render_publish_plan", + "_render_publish_summary", + "_warn_duplicate_names", "add", "audit", "browse", diff --git a/src/apm_cli/commands/marketplace/_publish_ops.py b/src/apm_cli/commands/marketplace/_publish_ops.py new file mode 100644 index 000000000..5f671c25c --- /dev/null +++ b/src/apm_cli/commands/marketplace/_publish_ops.py @@ -0,0 +1,261 @@ +"""Publishing-related helper functions for the marketplace commands. + +Extracted from ``marketplace/__init__.py`` to keep that module under 800 lines. +All names are re-exported from the package ``__init__`` so existing import +paths keep working. +""" + +from __future__ import annotations + +import click +import yaml + +from ...marketplace.publisher import PublishOutcome +from ...utils.path_security import PathTraversalError, validate_path_segments + + +def _mkt_get_console(): + """Route to marketplace._get_console so test patches apply.""" + from apm_cli.commands import marketplace as _m + + return _m._get_console() + + +# --------------------------------------------------------------------------- +# Targets-file loading +# --------------------------------------------------------------------------- + + +def _validate_target_entry(idx: int, entry: object) -> str | None: + """Validate a single targets-file entry dict. + + Returns an error string on failure or ``None`` when the entry is valid. + """ + from ...marketplace.publisher import ConsumerTarget # noqa: F401 (used by caller) + + if not isinstance(entry, dict): + return f"targets[{idx}] must be a mapping." + + repo = entry.get("repo") + if not repo or not isinstance(repo, str): + return f"targets[{idx}]: 'repo' is required (owner/name)." + + parts = repo.split("/") + if len(parts) != 2 or not parts[0] or not parts[1]: + return f"targets[{idx}]: 'repo' must be 'owner/name', got '{repo}'." + + branch = entry.get("branch") + if not branch or not isinstance(branch, str): + return f"targets[{idx}]: 'branch' is required." + + path_in_repo = entry.get("path_in_repo", "apm.yml") + if not isinstance(path_in_repo, str) or not path_in_repo.strip(): + return f"targets[{idx}]: 'path_in_repo' must be a non-empty string." + + try: + validate_path_segments( + path_in_repo, + context=f"targets[{idx}].path_in_repo", + ) + except PathTraversalError as exc: + return str(exc) + + return None + + +def _load_targets_file(path): + """Load and validate a consumer-targets YAML file. + + Returns a list of ``ConsumerTarget`` instances. + + Raises ``SystemExit`` on validation failures. + """ + from ...marketplace.publisher import ConsumerTarget + + try: + raw = yaml.safe_load(path.read_text(encoding="utf-8")) + except yaml.YAMLError as exc: + return None, f"Invalid YAML in targets file: {exc}" + except OSError as exc: + return None, f"Cannot read targets file: {exc}" + + if not isinstance(raw, dict) or "targets" not in raw: + return None, "Targets file must contain a 'targets' key." + + raw_targets = raw["targets"] + if not isinstance(raw_targets, list) or not raw_targets: + return None, "Targets file must contain a non-empty 'targets' list." + + targets = [] + for idx, entry in enumerate(raw_targets): + err = _validate_target_entry(idx, entry) + if err: + return None, err + + # entry is a valid dict after passing validation above + targets.append( + ConsumerTarget( + repo=entry["repo"].strip(), + branch=entry["branch"].strip(), + path_in_repo=(entry.get("path_in_repo", "apm.yml") or "apm.yml").strip(), + ) + ) + + return targets, None + + +# --------------------------------------------------------------------------- +# Publish plan / summary rendering +# --------------------------------------------------------------------------- + + +def _render_publish_plan(log, plan): + """Render the publish plan as a Rich panel + target table.""" + console = _mkt_get_console() + + plan_text = ( + f"Marketplace: {plan.marketplace_name}\n" + f"New version: {plan.marketplace_version}\n" + f"New ref: {plan.new_ref}\n" + f"Branch: {plan.branch_name}\n" + f"Targets: {len(plan.targets)}" + ) + + if not console: + log.progress("Publish plan:", symbol="info") + for line in plan_text.splitlines(): + log.tree_item(f" {line}") + click.echo() + for t in plan.targets: + log.tree_item(f" [*] {t.repo} branch={t.branch} path={t.path_in_repo}") + return + + from rich.panel import Panel + from rich.table import Table + from rich.text import Text + + console.print() + console.print( + Panel( + plan_text, + title="Publish plan", + border_style="cyan", + ) + ) + + table = Table( + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Repo", style="bold white", no_wrap=True) + table.add_column("Branch", style="cyan") + table.add_column("Path", style="dim") + table.add_column("Status", no_wrap=True, width=10) + + for t in plan.targets: + table.add_row(t.repo, t.branch, t.path_in_repo, Text("[*]")) + + console.print(table) + console.print() + + +def _render_publish_summary(log, results, pr_results, no_pr, dry_run): + """Render the final publish summary table.""" + console = _mkt_get_console() + + # Build lookup for PR results by repo + pr_by_repo = {} + for pr_r in pr_results: + pr_by_repo[pr_r.target.repo] = pr_r + + updated_count = sum(1 for r in results if r.outcome == PublishOutcome.UPDATED) + failed_count = sum(1 for r in results if r.outcome == PublishOutcome.FAILED) + total = len(results) + + if not console: + click.echo() + for r in results: + icon = _outcome_symbol(r.outcome) + pr_info = "" + if not no_pr: + pr_r = pr_by_repo.get(r.target.repo) + if pr_r: + pr_info = f" PR: {pr_r.state.value}" + if pr_r.pr_number: + pr_info += f" #{pr_r.pr_number}" + log.tree_item(f" {icon} {r.target.repo}: {r.outcome.value}{pr_info} -- {r.message}") + click.echo() + _render_publish_footer(log, updated_count, failed_count, total, dry_run) + return + + from rich.table import Table + from rich.text import Text + + table = Table( + title="Publish Results", + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Status", no_wrap=True, width=6) + table.add_column("Repo", style="bold white", no_wrap=True) + table.add_column("Outcome", style="white") + + if not no_pr: + table.add_column("PR State", style="white") + table.add_column("PR #", style="cyan", justify="right") + table.add_column("PR URL", style="dim") + + table.add_column("Message", style="dim", ratio=1) + + for r in results: + icon = _outcome_symbol(r.outcome) + row = [Text(icon), r.target.repo, r.outcome.value] + + if not no_pr: + pr_r = pr_by_repo.get(r.target.repo) + if pr_r: + row.append(pr_r.state.value) + row.append(str(pr_r.pr_number) if pr_r.pr_number else "--") + row.append(pr_r.pr_url or "--") + else: + row.extend(["--", "--", "--"]) + + row.append(r.message) + table.add_row(*row) + + console.print() + console.print(table) + console.print() + + _render_publish_footer(log, updated_count, failed_count, total, dry_run) + + +def _outcome_symbol(outcome): + """Map a ``PublishOutcome`` to a bracket symbol.""" + if outcome == PublishOutcome.UPDATED: + return "[+]" + if outcome == PublishOutcome.FAILED: + return "[x]" + if outcome in ( + PublishOutcome.SKIPPED_DOWNGRADE, + PublishOutcome.SKIPPED_REF_CHANGE, + ): + return "[!]" + return "[*]" + + +def _render_publish_footer(log, updated, failed, total, dry_run): + """Render the footer success/warning line.""" + suffix = " (dry-run)" if dry_run else "" + if failed == 0: + log.success( + f"Published {updated}/{total} targets{suffix}", + symbol="check", + ) + else: + log.warning( + f"Published {updated}/{total} targets, {failed} failed{suffix}", + symbol="warning", + ) diff --git a/src/apm_cli/commands/marketplace/_registry_cmds.py b/src/apm_cli/commands/marketplace/_registry_cmds.py new file mode 100644 index 000000000..f252c1d37 --- /dev/null +++ b/src/apm_cli/commands/marketplace/_registry_cmds.py @@ -0,0 +1,543 @@ +"""Registry-management Click commands for the marketplace group. + +Extracted from ``marketplace/__init__.py`` to keep that module under 800 lines. +Contains: ``add``, ``list_cmd``, ``browse``, ``update``, ``remove`` plus their +private helpers. All names are re-exported from the package ``__init__`` so +existing import paths keep working. + +These commands are imported at the *bottom* of ``__init__.py`` (after +``marketplace``, ``_parse_marketplace_source``, and ``_is_valid_alias`` are +defined), so module-scope ``from . import ...`` is safe - the same pattern used +by the existing ``check``, ``outdated``, and ``publish`` sibling modules. +""" + +from __future__ import annotations + +import sys +import traceback +from pathlib import Path + +import click + +from ...utils.path_security import PathTraversalError +from . import ( + _is_valid_alias, + _parse_marketplace_source, + marketplace, +) + +# --------------------------------------------------------------------------- +# Constants and helpers used only by the registry commands +# --------------------------------------------------------------------------- + +# Host-trust classification is owned by AuthResolver.classify_host (see +# core/auth.py). The marketplace command layer routes through it so that the +# credential-leakage guard at registration time uses the same single source of +# truth as the fetch-time guard in marketplace/client.py. + + +def _mkt_get_console(): + """Route to marketplace._get_console so test patches apply.""" + from apm_cli.commands import marketplace as _m + + return _m._get_console() + + +def _mkt_is_interactive(): + """Route to ``marketplace._is_interactive`` so test patches apply.""" + from apm_cli.commands import marketplace as _m + + return _m._is_interactive() + + +_TRUSTED_MARKETPLACE_HOST_KINDS = ("github", "ghe_cloud", "ghes", "gitlab") + + +def _check_gitignore_for_marketplace_json(log): + """Warn if .gitignore contains a rule that would ignore marketplace outputs.""" + gitignore_path = Path.cwd() / ".gitignore" + if not gitignore_path.exists(): + return + + try: + lines = gitignore_path.read_text(encoding="utf-8").splitlines() + except OSError: + return + + patterns = { + "marketplace.json", + "**/marketplace.json", + "/marketplace.json", + ".claude-plugin/marketplace.json", + ".agents/plugins/marketplace.json", + "*.json", + } + for line in lines: + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if stripped in patterns: + log.warning( + "Your .gitignore ignores marketplace.json. " + "Track apm.yml plus generated marketplace files such as " + ".claude-plugin/marketplace.json and .agents/plugins/marketplace.json. " + "Remove the .gitignore rule or add explicit unignore entries.", + symbol="warning", + ) + return + + +def _marketplace_add_unsupported_host_error( + resolved_host: str, + quoted_repo: str, + quoted_host: str, + host_kind: str, +) -> str: + """User-facing error when ``apm marketplace add`` rejects the resolved host. + + *quoted_repo* and *quoted_host* must already be ``shlex.quote``-safe for + shell copy-paste (see call sites). + """ + if host_kind == "ado": + return ( + f"Host '{resolved_host}' is not supported for marketplace registration.\n" + "APM marketplaces must be hosted on GitHub, GitHub Enterprise, or GitLab." + ) + return ( + f"Host '{resolved_host}' is not supported.\n" + "Supported marketplace hosts: github.com, *.ghe.com, " + "GitHub Enterprise Server (configure GITHUB_HOST), " + "and GitLab (gitlab.com or self-managed via GITLAB_HOST or APM_GITLAB_HOSTS).\n\n" + "To use GitHub Enterprise Server on this host:\n" + f" export GITHUB_HOST={quoted_host}\n" + "Then re-run:\n" + f" apm marketplace add {quoted_repo}\n\n" + "To use self-managed GitLab on this host:\n" + f" export GITLAB_HOST={quoted_host}\n" + "(or list the host in APM_GITLAB_HOSTS for multiple instances.)\n" + "Then re-run:\n" + f" apm marketplace add {quoted_repo}\n" + ) + + +def _default_alias_from_url(url: str) -> str: + """Derive a default marketplace alias from a parsed URL. + + Strips ``.git`` suffix, trailing slashes, and uses the last + path-segment. For ``file://`` URLs the alias falls back to the + final filesystem segment. + """ + from urllib.parse import urlparse + + parsed = urlparse(url) if "://" in url else None + if parsed and parsed.path: + tail = parsed.path.rstrip("/").rsplit("/", 1)[-1] + else: + tail = url.rstrip("/").rsplit("/", 1)[-1] + if tail.endswith(".git"): + tail = tail[:-4] + return tail or "marketplace" + + +# --------------------------------------------------------------------------- +# Click commands +# --------------------------------------------------------------------------- + +_ADD_EPILOG = """ +\b +Examples: + apm marketplace add owner/repo + apm marketplace add github.com/owner/repo + apm marketplace add https://gitlab.com/group/repo + apm marketplace add https://dev.azure.com/org/proj/_git/repo --name apm-mkt + apm marketplace add git@gitea.example.com:org/repo.git --name custom + apm marketplace add /srv/marketplaces/agent-forge --name agent-forge +""" + + +@marketplace.command(help="Register a marketplace", epilog=_ADD_EPILOG) +@click.argument("source", metavar="SOURCE", required=True) +@click.option("--name", "-n", default=None, help="Display name (defaults to repo name)") +@click.option("--ref", "-r", default=None, help="Branch, tag, or commit to use (default: main)") +@click.option("--branch", "-b", default=None, help="Deprecated alias for --ref", hidden=True) +@click.option("--host", default=None, help="Git host FQDN (default: github.com)") +@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") +def add(source, name, ref, branch, host, verbose): + """Register a marketplace. + + SOURCE accepts: OWNER/REPO shorthand, HOST/OWNER/REPO shorthand, a full + HTTPS URL (GitHub, GitLab, Azure DevOps, Gitea, Bitbucket Server, or + any self-hosted git server), an SSH URL (``git@host:org/repo.git``), + a local filesystem path, or a ``file://`` URI. + """ + from ...core.command_logger import CommandLogger + + log = CommandLogger("marketplace-add", verbose=verbose) + try: + from ...marketplace.client import _auto_detect_path, fetch_marketplace + from ...marketplace.models import MarketplaceSource + from ...marketplace.registry import add_marketplace + from ...utils.github_host import is_valid_fqdn + + if ref is not None and branch is not None: + log.error( + "--ref and --branch are mutually exclusive. " + "Use --ref (--branch is a deprecated alias).", + symbol="error", + ) + sys.exit(1) + effective_ref = ref if ref is not None else (branch if branch is not None else "main") + + try: + url, kind, resolved_host = _parse_marketplace_source(source, host) + except PathTraversalError: + log.error( + f"Invalid source '{source}': contains a path-traversal sequence. " + f"Remove '..', '.', or '~' from each path segment." + ) + sys.exit(1) + except ValueError as exc: + log.error(str(exc)) + sys.exit(1) + + if host is not None and not is_valid_fqdn(host.strip().lower()): + log.error( + f"Invalid host: '{host}'. Expected a valid host FQDN (for example, 'github.com').", + symbol="error", + ) + sys.exit(1) + + if host is not None and kind == "local": + log.warning( + "--host is ignored when SOURCE is a local filesystem path.", + symbol="warning", + ) + elif ( + host is not None + and host.strip().lower() != (resolved_host or "").lower() + and kind in ("git", "github", "gitlab") + and (source.startswith(("https://", "git@", "file://"))) + ): + log.warning( + "--host is ignored when SOURCE is a full URL.", + symbol="warning", + ) + + if kind in ("github", "gitlab"): + from ...core.auth import AuthResolver + + host_info = AuthResolver.classify_host(resolved_host or "") + if host_info.kind not in _TRUSTED_MARKETPLACE_HOST_KINDS: + import shlex as _shlex + + quoted_repo = _shlex.quote(source) + quoted_host = _shlex.quote(resolved_host or "") + log.error( + _marketplace_add_unsupported_host_error( + resolved_host or "", quoted_repo, quoted_host, host_info.kind + ) + ) + sys.exit(1) + + if name is not None and not _is_valid_alias(name): + log.error( + f"Invalid marketplace name: '{name}'. " + f"Names must only contain letters, digits, '.', '_', and '-' " + f"(required for 'apm install plugin@marketplace' syntax).", + symbol="error", + ) + sys.exit(1) + + provisional_label = name or _default_alias_from_url(url) + log.start(f"Registering marketplace '{provisional_label}'...", symbol="gear") + + probe_name = provisional_label + probe_source = MarketplaceSource( + name=probe_name, + url=url, + ref=effective_ref, + ) + detected_path = _auto_detect_path(probe_source) + + if detected_path is None: + log.error( + f"No marketplace.json found in '{probe_source.display_source}'. " + f"Checked: marketplace.json, .github/plugin/marketplace.json, " + f".claude-plugin/marketplace.json", + symbol="error", + ) + sys.exit(1) + + fetch_source = MarketplaceSource( + name=probe_name, + url=url, + ref=effective_ref, + path=detected_path, + ) + manifest = fetch_marketplace(fetch_source, force_refresh=True) + plugin_count = len(manifest.plugins) + + manifest_name = (manifest.name or "").strip() + if name is not None: + display_name = name + alias_source = "--name flag" + elif manifest_name and _is_valid_alias(manifest_name): + display_name = manifest_name + alias_source = f"manifest.name ('{manifest_name}')" + else: + display_name = probe_name + if manifest_name and not _is_valid_alias(manifest_name): + log.warning( + f"Manifest declares name '{manifest_name}' which is not a " + f"valid alias (must match [a-zA-Z0-9._-]+). " + f"Falling back to repo name.", + symbol="warning", + ) + alias_source = f"derived name (manifest.name '{manifest_name}' invalid)" + else: + alias_source = "derived name (manifest.name missing)" + + assert _is_valid_alias(display_name), ( # noqa: S101 + f"Resolved marketplace alias '{display_name}' failed validation" + ) + + log.verbose_detail(f" Source: {fetch_source.display_source}") + log.verbose_detail(f" Kind: {kind}") + log.verbose_detail(f" Ref: {effective_ref}") + log.verbose_detail(f" Detected path: {detected_path}") + log.verbose_detail(f" Alias source: {alias_source}") + + final_source = MarketplaceSource( + name=display_name, + url=url, + ref=effective_ref, + path=detected_path, + ) + add_marketplace(final_source) + + log.success( + f"Marketplace '{display_name}' registered ({plugin_count} plugins)", + symbol="check", + ) + if manifest.description: + log.verbose_detail(f" {manifest.description}") + + if name is None and display_name != probe_name: + log.progress( + f"Install plugins with: apm install @{display_name}", + symbol="info", + ) + + except Exception as e: + log.error(f"Failed to register marketplace: {e}") + if verbose: + log.progress(traceback.format_exc(), symbol="info") + sys.exit(1) + + +@marketplace.command(name="list", help="List registered marketplaces") +@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") +def list_cmd(verbose): + """Show all registered marketplaces.""" + from ...core.command_logger import CommandLogger + + log = CommandLogger("marketplace-list", verbose=verbose) + try: + from ...marketplace.registry import get_registered_marketplaces + + sources = get_registered_marketplaces() + + if not sources: + log.progress( + "No marketplaces registered. Use 'apm marketplace add SOURCE' to register one " + "(OWNER/REPO, HTTPS URL, SSH URL, or local path).", + symbol="info", + ) + return + + console = _mkt_get_console() + if not console: + log.progress(f"{len(sources)} marketplace(s) registered:", symbol="info") + for s in sources: + log.tree_item(f" {s.name} ({s.display_source})") + return + + from rich.table import Table + + table = Table( + title="Registered Marketplaces", + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Name", style="bold white", no_wrap=True) + table.add_column("Source", style="white") + table.add_column("Ref", style="cyan") + table.add_column("Path", style="dim") + + for s in sources: + table.add_row(s.name, s.display_source, s.ref, s.path) + + console.print() + console.print(table) + log.progress( + "Use 'apm marketplace browse ' to see plugins", + symbol="info", + ) + + except Exception as e: + log.error(f"Failed to list marketplaces: {e}") + if verbose: + log.progress(traceback.format_exc(), symbol="info") + sys.exit(1) + + +@marketplace.command(help="Browse plugins in a marketplace") +@click.argument("name", required=True) +@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") +def browse(name, verbose): + """Show available plugins in a marketplace.""" + from ...core.command_logger import CommandLogger + + log = CommandLogger("marketplace-browse", verbose=verbose) + try: + from ...marketplace.client import fetch_marketplace + from ...marketplace.registry import get_marketplace_by_name + + source = get_marketplace_by_name(name) + log.start(f"Fetching plugins from '{name}'...", symbol="search") + + manifest = fetch_marketplace(source, force_refresh=True) + + if not manifest.plugins: + log.warning(f"Marketplace '{name}' has no plugins") + return + + console = _mkt_get_console() + if not console: + log.success(f"{len(manifest.plugins)} plugin(s) in '{name}':", symbol="check") + for p in manifest.plugins: + desc = f" -- {p.description}" if p.description else "" + log.tree_item(f" {p.name}{desc}") + log.progress(f"Install: apm install @{name}", symbol="info") + return + + from rich.table import Table + + table = Table( + title=f"Plugins in '{name}'", + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Plugin", style="bold white", no_wrap=True) + table.add_column("Description", style="white", ratio=1) + table.add_column("Version", style="cyan", justify="center") + table.add_column("Install", style="green") + + for p in manifest.plugins: + desc = p.description or "--" + ver = p.version or "--" + table.add_row(p.name, desc, ver, f"{p.name}@{name}") + + console.print() + console.print(table) + log.progress( + f"Install a plugin: apm install @{name}", + symbol="info", + ) + + except Exception as e: + log.error(f"Failed to browse marketplace: {e}") + if verbose: + log.progress(traceback.format_exc(), symbol="info") + sys.exit(1) + + +@marketplace.command(help="Refresh marketplace cache") +@click.argument("name", required=False, default=None) +@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") +def update(name, verbose): + """Refresh cached marketplace data (one or all).""" + from ...core.command_logger import CommandLogger + + log = CommandLogger("marketplace-update", verbose=verbose) + try: + from ...marketplace.client import clear_marketplace_cache, fetch_marketplace + from ...marketplace.registry import ( + get_marketplace_by_name, + get_registered_marketplaces, + ) + + if name: + source = get_marketplace_by_name(name) + log.start(f"Refreshing marketplace '{name}'...", symbol="gear") + clear_marketplace_cache(name, host=source.host) + manifest = fetch_marketplace(source, force_refresh=True) + log.success( + f"Marketplace '{name}' updated ({len(manifest.plugins)} plugins)", + symbol="check", + ) + else: + sources = get_registered_marketplaces() + if not sources: + log.progress("No marketplaces registered.", symbol="info") + return + log.start(f"Refreshing {len(sources)} marketplace(s)...", symbol="gear") + for s in sources: + try: + clear_marketplace_cache(s.name, host=s.host) + manifest = fetch_marketplace(s, force_refresh=True) + log.tree_item(f" {s.name} ({len(manifest.plugins)} plugins)") + except Exception as exc: + log.warning(f" {s.name}: {exc}") + if verbose: + log.progress(traceback.format_exc(), symbol="info") + log.success("Marketplace cache refreshed", symbol="check") + + except Exception as e: + log.error(f"Failed to update marketplace: {e}") + if verbose: + log.progress(traceback.format_exc(), symbol="info") + sys.exit(1) + + +@marketplace.command(help="Remove a registered marketplace") +@click.argument("name", required=True) +@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt") +@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") +def remove(name, yes, verbose): + """Unregister a marketplace.""" + from ...core.command_logger import CommandLogger + + log = CommandLogger("marketplace-remove", verbose=verbose) + try: + from ...marketplace.client import clear_marketplace_cache + from ...marketplace.registry import get_marketplace_by_name, remove_marketplace + + source = get_marketplace_by_name(name) + + if not yes: + if not _mkt_is_interactive(): + log.error( + "Use --yes to skip confirmation in non-interactive mode", + symbol="error", + ) + sys.exit(1) + confirmed = click.confirm( + f"Remove marketplace '{source.name}' ({source.display_source})?", + default=False, + ) + if not confirmed: + log.progress("Cancelled", symbol="info") + return + + remove_marketplace(name) + clear_marketplace_cache(name, host=source.host) + log.success(f"Marketplace '{name}' removed", symbol="check") + + except Exception as e: + log.error(f"Failed to remove marketplace: {e}") + if verbose: + log.progress(traceback.format_exc(), symbol="info") + sys.exit(1) diff --git a/src/apm_cli/commands/marketplace/_search_cmd.py b/src/apm_cli/commands/marketplace/_search_cmd.py new file mode 100644 index 000000000..490f78b42 --- /dev/null +++ b/src/apm_cli/commands/marketplace/_search_cmd.py @@ -0,0 +1,119 @@ +"""``apm search`` (standalone) command implementation. + +Extracted from ``marketplace/__init__.py`` to keep that module under 800 lines. +Re-exported from the package ``__init__`` so ``from apm_cli.commands.marketplace +import search`` keeps working. +""" + +from __future__ import annotations + +import sys +import traceback + +import click + +from ...core.command_logger import CommandLogger +from ...marketplace.errors import MarketplaceNotFoundError + + +def _mkt_get_console(): + """Route to ``marketplace._get_console`` so test patches apply.""" + from apm_cli.commands import marketplace as _m + + return _m._get_console() + + +@click.command( + name="search", + help="Search plugins in a marketplace (QUERY@MARKETPLACE)", +) +@click.argument("expression", required=True, metavar="QUERY@MARKETPLACE") +@click.option("--limit", default=20, show_default=True, help="Max results to show") +@click.option("--verbose", "-v", is_flag=True, help="Show detailed output") +def search(expression, limit, verbose): + """Search for plugins in a specific marketplace. + + Use QUERY@MARKETPLACE format, e.g.: apm marketplace search security@skills + """ + logger = CommandLogger("marketplace-search", verbose=verbose) + try: + from ...marketplace.client import search_marketplace + from ...marketplace.registry import get_marketplace_by_name + + if "@" not in expression: + logger.error( + f"Invalid format: '{expression}'. " + "Use QUERY@MARKETPLACE, e.g.: apm marketplace search security@skills" + ) + sys.exit(1) + + query, marketplace_name = expression.rsplit("@", 1) + if not query or not marketplace_name: + logger.error( + "Both QUERY and MARKETPLACE are required. " + "Use QUERY@MARKETPLACE, e.g.: apm marketplace search security@skills" + ) + sys.exit(1) + + try: + source = get_marketplace_by_name(marketplace_name) + except MarketplaceNotFoundError: + logger.error( + f"Marketplace '{marketplace_name}' is not registered. " + "Use 'apm marketplace list' to see registered marketplaces." + ) + sys.exit(1) + + logger.start(f"Searching '{marketplace_name}' for '{query}'...", symbol="search") + results = search_marketplace(query, source)[:limit] + + if not results: + logger.warning( + f"No plugins found matching '{query}' in '{marketplace_name}'. " + f"Try 'apm marketplace browse {marketplace_name}' to see all plugins." + ) + return + + console = _mkt_get_console() + if not console: + logger.success(f"Found {len(results)} plugin(s):", symbol="check") + for p in results: + desc = f" -- {p.description}" if p.description else "" + logger.tree_item(f" {p.name}@{marketplace_name}{desc}") + logger.progress( + f"Install: apm install @{marketplace_name}", + symbol="info", + ) + return + + from rich.table import Table + + table = Table( + title=f"Search Results: '{query}' in {marketplace_name}", + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Plugin", style="bold white", no_wrap=True) + table.add_column("Description", style="white", ratio=1) + table.add_column("Install", style="green") + + for p in results: + desc = p.description or "--" + if len(desc) > 60: + desc = desc[:57] + "..." + table.add_row(p.name, desc, f"{p.name}@{marketplace_name}") + + console.print() + console.print(table) + logger.progress( + f"Install: apm install @{marketplace_name}", + symbol="info", + ) + + except SystemExit: + raise + except Exception as e: + logger.error(f"Search failed: {e}") + logger.verbose_detail(traceback.format_exc()) + sys.exit(1) diff --git a/src/apm_cli/commands/marketplace/_table_ops.py b/src/apm_cli/commands/marketplace/_table_ops.py new file mode 100644 index 000000000..8c29e5ac9 --- /dev/null +++ b/src/apm_cli/commands/marketplace/_table_ops.py @@ -0,0 +1,350 @@ +"""Table-rendering helpers and data containers for the marketplace commands. + +Extracted from ``marketplace/__init__.py`` to keep that module under 800 lines. +All names are re-exported from the package ``__init__`` so existing import +paths (``from apm_cli.commands.marketplace import _CheckResult``, etc.) keep +working. +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def _mkt_get_console(): + """Route to ``marketplace._get_console`` so test patches apply.""" + from apm_cli.commands import marketplace as _m + + return _m._get_console() + + +# --------------------------------------------------------------------------- +# Build-error rendering +# --------------------------------------------------------------------------- + + +def _render_build_error(log, exc): + """Render a BuildError with actionable hints.""" + from ...marketplace.errors import ( + GitLsRemoteError, + HeadNotAllowedError, + NoMatchingVersionError, + OfflineMissError, + RefNotFoundError, + ) + + if isinstance(exc, GitLsRemoteError): + log.error(exc.summary_text, symbol="error") + if exc.hint: + log.progress(f"Hint: {exc.hint}", symbol="info") + elif isinstance(exc, NoMatchingVersionError): + log.error(str(exc), symbol="error") + log.progress( + "Check that your version range matches published tags.", + symbol="info", + ) + elif isinstance(exc, RefNotFoundError): + log.error(str(exc), symbol="error") + log.progress( + "Verify the ref is spelled correctly and the remote is reachable.", + symbol="info", + ) + elif isinstance(exc, HeadNotAllowedError): + log.error(str(exc), symbol="error") + elif isinstance(exc, OfflineMissError): + log.error(str(exc), symbol="error") + log.progress( + "Run a build online first to populate the cache.", + symbol="info", + ) + else: + log.error(f"Build failed: {exc}", symbol="error") + + +def _render_build_table(log, report): + """Render the resolved-packages table (Rich with colorama fallback).""" + from ...marketplace.semver import parse_semver + + console = _mkt_get_console() + if not console: + for pkg in report.resolved: + sha_short = pkg.sha[:8] if pkg.sha else "--" + ref_kind = "tag" if not pkg.ref.startswith("refs/heads/") else "branch" + log.tree_item(f" [+] {pkg.name} {pkg.ref} {sha_short} ({ref_kind})") + return + + from rich.table import Table + from rich.text import Text + + table = Table( + title="Resolved Packages", + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Status", style="green", no_wrap=True, width=6) + table.add_column("Package", style="bold white", no_wrap=True) + table.add_column("Version", style="cyan") + table.add_column("Commit", style="dim") + table.add_column("Ref Kind", style="white") + + for pkg in report.resolved: + sha_short = pkg.sha[:8] if pkg.sha else "--" + ref_kind = "tag" + if pkg.ref and not parse_semver(pkg.ref.lstrip("vV")): + ref_kind = "ref" + table.add_row(Text("[+]"), pkg.name, pkg.ref, sha_short, ref_kind) + + console.print() + console.print(table) + + +# --------------------------------------------------------------------------- +# Outdated-packages helpers +# --------------------------------------------------------------------------- + + +class _OutdatedRow: + """Simple container for outdated table row data.""" + + __slots__ = ( + "current", + "latest_in_range", + "latest_overall", + "name", + "note", + "range_spec", + "status", + ) + + def __init__(self, name, current, range_spec, latest_in_range, latest_overall, status, note): + self.name = name + self.current = current + self.range_spec = range_spec + self.latest_in_range = latest_in_range + self.latest_overall = latest_overall + self.status = status + self.note = note + + +def _load_current_versions(): + """Load current ref versions from marketplace.json if present.""" + mkt_path = Path.cwd() / "marketplace.json" + if not mkt_path.exists(): + return {} + try: + data = json.loads(mkt_path.read_text(encoding="utf-8")) + result = {} + for plugin in data.get("plugins", []): + name = plugin.get("name", "") + src = plugin.get("source", {}) + if isinstance(src, dict): + result[name] = src.get("ref", "--") + return result + except (json.JSONDecodeError, OSError): + return {} + + +def _extract_tag_versions(refs, entry, yml, include_prerelease): + """Extract (SemVer, tag_name) pairs from remote refs for a package entry.""" + from ...marketplace._shared import iter_semver_tags + from ...marketplace.tag_pattern import ( + build_tag_regex, + infer_tag_pattern_from_refs, + ) + + def _collect(pattern: str) -> list: + tag_rx = ( + build_tag_regex(pattern, name=entry.name) + if "{name}" in pattern + else build_tag_regex(pattern) + ) + collected = [] + for sv, tag_name, _ in iter_semver_tags(refs, tag_rx): + if sv.is_prerelease and not (include_prerelease or entry.include_prerelease): + continue + collected.append((sv, tag_name)) + return collected + + pattern = entry.tag_pattern or yml.build.tag_pattern + results = _collect(pattern) + if not results: + inferred = infer_tag_pattern_from_refs(refs, entry.name) + if inferred and inferred != pattern: + logger.debug( + "Configured tag pattern %r matched no tags for %s; inferred %r", + pattern, + entry.name, + inferred, + ) + results = _collect(inferred) + return results + + +def _render_outdated_table(log, rows): + """Render the outdated-packages table.""" + console = _mkt_get_console() + if not console: + for row in rows: + note = f" ({row.note})" if row.note else "" + log.tree_item( + f" {row.status} {row.name} current={row.current} " + f"latest-in-range={row.latest_in_range} " + f"latest={row.latest_overall}{note}" + ) + return + + from rich.table import Table + from rich.text import Text + + table = Table( + title="Package Version Status", + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Status", style="green", no_wrap=True, width=6) + table.add_column("Package", style="bold white", no_wrap=True) + table.add_column("Current", style="white") + table.add_column("Range", style="dim") + table.add_column("Latest in Range", style="cyan") + table.add_column("Latest Overall", style="yellow") + + for row in rows: + note = "" + if row.note: + note = f" ({row.note})" + table.add_row( + Text(row.status), + row.name, + row.current, + row.range_spec, + row.latest_in_range + note, + row.latest_overall, + ) + + console.print() + console.print(table) + + +# --------------------------------------------------------------------------- +# Check-results helpers +# --------------------------------------------------------------------------- + + +class _CheckResult: + """Container for per-entry check results.""" + + __slots__ = ("error", "name", "reachable", "ref_ok", "version_found") + + def __init__(self, name, reachable, version_found, ref_ok, error): + self.name = name + self.reachable = reachable + self.version_found = version_found + self.ref_ok = ref_ok + self.error = error + + +def _render_check_table(log, results): + """Render the check-results table.""" + console = _mkt_get_console() + if not console: + for r in results: + icon = "[+]" if r.ref_ok else "[x]" + detail = r.error if r.error else "OK" + log.tree_item(f" {icon} {r.name}: {detail}") + return + + from rich.table import Table + from rich.text import Text + + table = Table( + title="Entry Health Check", + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Status", no_wrap=True, width=6) + table.add_column("Package", style="bold white", no_wrap=True) + table.add_column("Reachable", style="white", justify="center") + table.add_column("Version Found", style="white", justify="center") + table.add_column("Ref OK", style="white", justify="center") + table.add_column("Detail", style="dim") + + for r in results: + reach = "[+]" if r.reachable else "[x]" + ver = "[+]" if r.version_found else "[x]" + ref = "[+]" if r.ref_ok else "[x]" + detail = r.error if r.error else "OK" + table.add_row( + Text("[+]" if r.ref_ok else "[x]"), + r.name, + Text(reach), + Text(ver), + Text(ref), + detail, + ) + + console.print() + console.print(table) + + +# --------------------------------------------------------------------------- +# Doctor-check helpers +# --------------------------------------------------------------------------- + + +class _DoctorCheck: + """Container for a single doctor check result.""" + + __slots__ = ("detail", "informational", "name", "passed") + + def __init__(self, name, passed, detail, informational=False): + self.name = name + self.passed = passed + self.detail = detail + self.informational = informational + + +def _render_doctor_table(log, checks): + """Render the doctor results table.""" + console = _mkt_get_console() + if not console: + for c in checks: + if c.informational: + icon = "[i]" + elif c.passed: + icon = "[+]" + else: + icon = "[x]" + log.tree_item(f" {icon} {c.name}: {c.detail}") + return + + from rich.table import Table + from rich.text import Text + + table = Table( + title="Environment Diagnostics", + show_header=True, + header_style="bold cyan", + border_style="cyan", + ) + table.add_column("Check", style="bold white", no_wrap=True) + table.add_column("Status", no_wrap=True, width=6) + table.add_column("Detail", style="white") + + for c in checks: + if c.informational: + icon = "[i]" + elif c.passed: + icon = "[+]" + else: + icon = "[x]" + table.add_row(c.name, Text(icon), c.detail) + + console.print() + console.print(table) diff --git a/src/apm_cli/commands/marketplace/doctor.py b/src/apm_cli/commands/marketplace/doctor.py index 918e70bd4..58e1f4913 100644 --- a/src/apm_cli/commands/marketplace/doctor.py +++ b/src/apm_cli/commands/marketplace/doctor.py @@ -24,18 +24,13 @@ marketplace, ) +# --------------------------------------------------------------------------- +# Individual check helpers +# --------------------------------------------------------------------------- -def run_doctor(verbose: bool, *, logger_name: str = "doctor") -> int: - """Execute the doctor diagnostics and return an exit code. - - Shared between the top-level ``apm doctor`` command and the legacy - ``apm marketplace doctor`` alias so both surfaces produce identical - output. Returns ``0`` if all critical checks pass, ``1`` otherwise. - """ - logger = CommandLogger(logger_name, verbose=verbose) - checks = [] - # Check 1: git on PATH +def _check_git() -> _DoctorCheck: + """Check 1: git is available on PATH.""" git_ok = False git_detail = "" try: @@ -56,16 +51,11 @@ def run_doctor(verbose: bool, *, logger_name: str = "doctor") -> int: git_detail = "git --version timed out" except (subprocess.SubprocessError, OSError) as exc: git_detail = str(exc)[:60] + return _DoctorCheck(name="git", passed=git_ok, detail=git_detail) - checks.append( - _DoctorCheck( - name="git", - passed=git_ok, - detail=git_detail, - ) - ) - # Check 2: network reachability +def _check_network() -> _DoctorCheck: + """Check 2: github.com is reachable via git ls-remote.""" net_ok = False net_detail = "" try: @@ -92,36 +82,30 @@ def run_doctor(verbose: bool, *, logger_name: str = "doctor") -> int: net_detail = "git not found; cannot test network" except (subprocess.SubprocessError, OSError) as exc: net_detail = str(exc)[:60] + return _DoctorCheck(name="network", passed=net_ok, detail=net_detail) - checks.append( - _DoctorCheck( - name="network", - passed=net_ok, - detail=net_detail, - ) - ) - # Check 3: auth tokens (delegate to AuthResolver for full coverage) +def _check_auth() -> _DoctorCheck: + """Check 3: auth tokens (informational).""" try: from ...core.auth import AuthResolver resolver = AuthResolver() - # Try to get a token for github.com as a representative check token = resolver.resolve("github.com").token has_token = bool(token) except Exception: has_token = False auth_detail = "Token detected" if has_token else "No token; unauthenticated rate limits apply" - checks.append( - _DoctorCheck( - name="auth", - passed=True, # informational; never fails - detail=auth_detail, - informational=True, - ) + return _DoctorCheck( + name="auth", + passed=True, # informational; never fails + detail=auth_detail, + informational=True, ) - # Check 4: gh CLI availability (informational; only needed for publish) + +def _check_gh_cli() -> _DoctorCheck: + """Check 4: gh CLI availability (informational; only needed for publish).""" gh_ok = False gh_detail = "" try: @@ -142,18 +126,20 @@ def run_doctor(verbose: bool, *, logger_name: str = "doctor") -> int: gh_detail = "gh --version timed out" except (subprocess.SubprocessError, OSError) as exc: gh_detail = str(exc)[:60] - - checks.append( - _DoctorCheck( - name="gh CLI", - passed=gh_ok, - detail=gh_detail, - informational=True, - ) + return _DoctorCheck( + name="gh CLI", + passed=gh_ok, + detail=gh_detail, + informational=True, ) - # Check 5: marketplace config presence + parsability - project_root = Path.cwd() + +def _check_marketplace_config(project_root: Path) -> tuple[_DoctorCheck, object]: + """Check 5: marketplace config presence + parsability. + + Returns ``(_DoctorCheck, yml_obj)``; ``yml_obj`` is ``None`` when no + config is found or on parse errors. + """ apm_path = project_root / "apm.yml" legacy_path = project_root / "marketplace.yml" yml_obj = None @@ -185,93 +171,117 @@ def run_doctor(verbose: bool, *, logger_name: str = "doctor") -> int: config_passed = False config_detail = str(exc)[:120] - checks.append( - _DoctorCheck( - name="marketplace config", - passed=config_passed, - detail=config_detail, - informational=True, + check = _DoctorCheck( + name="marketplace config", + passed=config_passed, + detail=config_detail, + informational=True, + ) + return check, yml_obj + + +def _check_format_coverage(yml_obj: object) -> _DoctorCheck: + """Check 6: format coverage (informational).""" + configured = frozenset(getattr(yml_obj, "outputs", ()) or ()) + supported = known_output_names() + missing = sorted(supported - configured) + configured_sorted = sorted(configured) + if not missing: + fc_detail = f"Publishing for all known formats: {', '.join(configured_sorted)}." + fc_passed = True + else: + fc_detail = ( + f"Configured: {', '.join(configured_sorted) or '(none)'}. " + f"Also supported: {', '.join(missing)}. " + f"Add e.g. '{missing[0]}: {{}}' under 'marketplace.outputs' " + "in apm.yml to publish for more consumers." ) + fc_passed = True # informational; never fails + return _DoctorCheck( + name="format coverage", + passed=fc_passed, + detail=fc_detail, + informational=True, ) - # Check 6: format coverage (informational; only when config is present) - if yml_obj is not None: - configured = frozenset(getattr(yml_obj, "outputs", ()) or ()) - supported = known_output_names() - missing = sorted(supported - configured) - configured_sorted = sorted(configured) - if not missing: - fc_detail = f"Publishing for all known formats: {', '.join(configured_sorted)}." - fc_passed = True - else: - fc_detail = ( - f"Configured: {', '.join(configured_sorted) or '(none)'}. " - f"Also supported: {', '.join(missing)}. " - f"Add e.g. '{missing[0]}: {{}}' under 'marketplace.outputs' " - "in apm.yml to publish for more consumers." - ) - fc_passed = True # informational; never fails - checks.append( - _DoctorCheck( - name="format coverage", - passed=fc_passed, - detail=fc_detail, - informational=True, - ) + +def _check_duplicate_names(yml_obj: object) -> _DoctorCheck: + """Check 7: duplicate package names (informational).""" + dup_detail = _find_duplicate_names(yml_obj) + if dup_detail: + return _DoctorCheck( + name="duplicate names", + passed=False, + detail=dup_detail, + informational=True, ) + return _DoctorCheck( + name="duplicate names", + passed=True, + detail="No duplicate package names", + informational=True, + ) - # Check 7: duplicate package names (defence-in-depth) - if yml_obj is not None: - dup_detail = _find_duplicate_names(yml_obj) - if dup_detail: - checks.append( - _DoctorCheck( - name="duplicate names", - passed=False, - detail=dup_detail, - informational=True, - ) - ) - else: - checks.append( - _DoctorCheck( - name="duplicate names", - passed=True, - detail="No duplicate package names", - informational=True, - ) - ) - # Check 8: version alignment (informational; only when config is present) - if yml_obj is not None and hasattr(yml_obj, "versioning"): - from ...marketplace.version_check import check_version_alignment - - va_report = check_version_alignment(yml_obj, Path.cwd()) - total = len(va_report.packages) - aligned = sum(1 for p in va_report.packages if p.ok) - if total == 0: - va_detail = f"strategy={va_report.strategy}, no local packages to align" - va_passed = True - elif va_report.ok: - va_detail = f"strategy={va_report.strategy}, {aligned}/{total} packages aligned" - va_passed = True - else: - misaligned = [p.path for p in va_report.packages if not p.ok] - misaligned_count = len(misaligned) - va_detail = ( - f"strategy={va_report.strategy}, " - f"{misaligned_count}/{total} packages misaligned: " - f"{misaligned[0]}" - ) - va_passed = False - checks.append( - _DoctorCheck( - name="version alignment", - passed=va_passed, - detail=va_detail, - informational=True, - ) +def _check_version_alignment(yml_obj: object) -> _DoctorCheck: + """Check 8: version alignment (informational).""" + from ...marketplace.version_check import check_version_alignment + + va_report = check_version_alignment(yml_obj, Path.cwd()) + total = len(va_report.packages) + aligned = sum(1 for p in va_report.packages if p.ok) + if total == 0: + va_detail = f"strategy={va_report.strategy}, no local packages to align" + va_passed = True + elif va_report.ok: + va_detail = f"strategy={va_report.strategy}, {aligned}/{total} packages aligned" + va_passed = True + else: + misaligned = [p.path for p in va_report.packages if not p.ok] + misaligned_count = len(misaligned) + va_detail = ( + f"strategy={va_report.strategy}, " + f"{misaligned_count}/{total} packages misaligned: " + f"{misaligned[0]}" ) + va_passed = False + return _DoctorCheck( + name="version alignment", + passed=va_passed, + detail=va_detail, + informational=True, + ) + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + + +def run_doctor(verbose: bool, *, logger_name: str = "doctor") -> int: + """Execute the doctor diagnostics and return an exit code. + + Shared between the top-level ``apm doctor`` command and the legacy + ``apm marketplace doctor`` alias so both surfaces produce identical + output. Returns ``0`` if all critical checks pass, ``1`` otherwise. + """ + logger = CommandLogger(logger_name, verbose=verbose) + checks = [] + + checks.append(_check_git()) + checks.append(_check_network()) + checks.append(_check_auth()) + checks.append(_check_gh_cli()) + + project_root = Path.cwd() + config_check, yml_obj = _check_marketplace_config(project_root) + checks.append(config_check) + + if yml_obj is not None: + checks.append(_check_format_coverage(yml_obj)) + checks.append(_check_duplicate_names(yml_obj)) + if hasattr(yml_obj, "versioning"): + checks.append(_check_version_alignment(yml_obj)) _render_doctor_table(logger, checks) diff --git a/src/apm_cli/commands/outdated.py b/src/apm_cli/commands/outdated.py index 7c31604a8..95bd1edde 100644 --- a/src/apm_cli/commands/outdated.py +++ b/src/apm_cli/commands/outdated.py @@ -156,19 +156,11 @@ def _check_marketplace_ref(dep, verbose): return None plugin = manifest.find_plugin(dep.marketplace_plugin_name) - if not plugin: - return None - - # Determine marketplace entry's current ref mkt_ref = None - mkt_version = plugin.version or "" - if isinstance(plugin.source, dict): + mkt_version = plugin.version or "" if plugin else "" + if plugin and isinstance(plugin.source, dict): mkt_ref = plugin.source.get("ref", "") - else: - # String sources are relative paths, not refs -- skip - return None - - if not mkt_ref: + if not plugin or not mkt_ref: return None # Determine installed ref @@ -200,6 +192,91 @@ def _check_marketplace_ref(dep, verbose): ) +def _check_tag_result( + package_name, current_ref, package_basename, tag_pattern, remote_refs, verbose +): + """Return an OutdatedRow for a tag-pinned dependency.""" + from ..models.dependency.types import GitReferenceType + from ..utils.version_checker import is_newer_version + + tag_refs = [r for r in remote_refs if r.ref_type == GitReferenceType.TAG] + if not tag_refs: + return OutdatedRow( + package=package_name, + current=current_ref, + latest="-", + status="unknown", + source="git tags", + ) + + from ..marketplace.tag_pattern import parse_tag_version + + candidates = _semver_tag_candidates(tag_refs, tag_pattern, package_basename) + if not candidates: + return OutdatedRow( + package=package_name, + current=current_ref, + latest="-", + status="unknown", + source="git tags", + ) + + _, latest_tag = candidates[0] + current_ver = parse_tag_version(current_ref, tag_pattern, name=package_basename) or _strip_v( + current_ref + ) + latest_ver = parse_tag_version(latest_tag, tag_pattern, name=package_basename) or _strip_v( + latest_tag + ) + + if is_newer_version(current_ver, latest_ver): + extra = [name for _, name in candidates[:10]] if verbose else [] + return OutdatedRow( + package=package_name, + current=current_ref, + latest=latest_tag, + status="outdated", + extra_tags=extra, + source="git tags", + ) + return OutdatedRow( + package=package_name, + current=current_ref, + latest=latest_tag, + status="up-to-date", + source="git tags", + ) + + +def _check_branch_result(package_name, current_ref, locked_sha, remote_refs): + """Return an OutdatedRow for a branch-pinned dependency.""" + remote_tip_sha = _find_remote_tip(current_ref, remote_refs) + if not remote_tip_sha: + return OutdatedRow( + package=package_name, + current=current_ref or "(none)", + latest="-", + status="unknown", + source="git branch", + ) + display_ref = current_ref or "(default)" + if locked_sha and locked_sha != remote_tip_sha: + return OutdatedRow( + package=package_name, + current=display_ref, + latest=remote_tip_sha[:8], + status="outdated", + source="git branch", + ) + return OutdatedRow( + package=package_name, + current=display_ref, + latest=remote_tip_sha[:8], + status="up-to-date", + source="git branch", + ) + + def _check_one_dep(dep, downloader, verbose, registry_ctx=None): """Check a single dependency against remote refs. @@ -218,8 +295,6 @@ def _check_one_dep(dep, downloader, verbose, registry_ctx=None): return marketplace_result from ..models.dependency.reference import DependencyReference - from ..models.dependency.types import GitReferenceType - from ..utils.version_checker import is_newer_version current_ref = dep.resolved_ref or "" locked_sha = dep.resolved_commit or "" @@ -227,7 +302,6 @@ def _check_one_dep(dep, downloader, verbose, registry_ctx=None): # Build a DependencyReference to query remote refs try: - # Use parse() to correctly handle all host types (GitHub, ADO, etc.) full_url = f"{dep.host}/{dep.repo_url}" if dep.host else dep.repo_url dep_ref = DependencyReference.parse(full_url) except Exception: @@ -245,87 +319,12 @@ def _check_one_dep(dep, downloader, verbose, registry_ctx=None): package_basename = _package_basename(dep) tag_pattern = _resolve_tag_pattern(current_ref, package_basename) - is_tag = tag_pattern is not None - - if is_tag: - tag_refs = [r for r in remote_refs if r.ref_type == GitReferenceType.TAG] - if not tag_refs: - return OutdatedRow( - package=package_name, - current=current_ref, - latest="-", - status="unknown", - source="git tags", - ) - from ..marketplace.tag_pattern import parse_tag_version - - candidates = _semver_tag_candidates(tag_refs, tag_pattern, package_basename) - if not candidates: - return OutdatedRow( - package=package_name, - current=current_ref, - latest="-", - status="unknown", - source="git tags", - ) - - _, latest_tag = candidates[0] - current_ver = parse_tag_version( - current_ref, tag_pattern, name=package_basename - ) or _strip_v(current_ref) - latest_ver = parse_tag_version(latest_tag, tag_pattern, name=package_basename) or _strip_v( - latest_tag + if tag_pattern is not None: + return _check_tag_result( + package_name, current_ref, package_basename, tag_pattern, remote_refs, verbose ) - - if is_newer_version(current_ver, latest_ver): - extra = [name for _, name in candidates[:10]] if verbose else [] - return OutdatedRow( - package=package_name, - current=current_ref, - latest=latest_tag, - status="outdated", - extra_tags=extra, - source="git tags", - ) - else: - return OutdatedRow( - package=package_name, - current=current_ref, - latest=latest_tag, - status="up-to-date", - source="git tags", - ) - else: - remote_tip_sha = _find_remote_tip(current_ref, remote_refs) - - if not remote_tip_sha: - return OutdatedRow( - package=package_name, - current=current_ref or "(none)", - latest="-", - status="unknown", - source="git branch", - ) - - display_ref = current_ref or "(default)" - if locked_sha and locked_sha != remote_tip_sha: - latest_display = remote_tip_sha[:8] - return OutdatedRow( - package=package_name, - current=display_ref, - latest=latest_display, - status="outdated", - source="git branch", - ) - else: - return OutdatedRow( - package=package_name, - current=display_ref, - latest=remote_tip_sha[:8], - status="up-to-date", - source="git branch", - ) + return _check_branch_result(package_name, current_ref, locked_sha, remote_refs) @click.command(name="outdated", help="Show outdated locked dependencies") diff --git a/src/apm_cli/commands/pack.py b/src/apm_cli/commands/pack.py index 47d8663d6..48cdac56f 100644 --- a/src/apm_cli/commands/pack.py +++ b/src/apm_cli/commands/pack.py @@ -16,6 +16,8 @@ from ..core.command_logger import CommandLogger from ..core.target_detection import TargetParamType from ..utils.console import set_console_stderr +from ._pack_ops import _emit_drift_recipe as _emit_drift_recipe +from ._pack_ops import _run_release_gates MARKETPLACE_DOCS_URL = ( "https://microsoft.github.io/apm/producer/publish-to-a-marketplace/#consume-from-any-assistant" @@ -348,110 +350,18 @@ def pack_cmd( # noqa: PLR0913 -- Click handler, one param per CLI option drift_gate_failed = False if check_versions or check_clean: - from ..marketplace.builder import BuildOptions as MktBuildOptions - from ..marketplace.builder import MarketplaceBuilder - from ..marketplace.drift_check import check_marketplace_drift, render_diff_lines - from ..marketplace.migration import ( - ConfigSource, - detect_config_source, + ( + version_gate_failed, + drift_gate_failed, + version_alignment_payload, + drift_payload, + gate_errors, + ) = _run_release_gates( + ctx, options, check_versions, check_clean, json_output, logger, project_root ) - from ..marketplace.version_check import check_version_alignment - from ..marketplace.yml_schema import MarketplaceYmlError - - # Try to load the marketplace config; if absent, skip both gates with [i]. - gate_config = None - try: - source = detect_config_source(project_root) - if source != ConfigSource.NONE: - from ..marketplace.migration import load_marketplace_config - - gate_config = load_marketplace_config(project_root) - except MarketplaceYmlError as exc: - _emit_json_error_or_raise(ctx, json_output, "build_error", str(exc)) - return - - if gate_config is None: - if check_versions: - logger.info( - "Version alignment check skipped: no marketplace block; nothing to check." - ) - if check_clean: - logger.info( - "Marketplace drift check skipped: no marketplace block; nothing to check." - ) - else: - if check_versions: - v_report = check_version_alignment(gate_config, project_root) - version_alignment_payload = v_report.to_json_dict() - if v_report.ok: - if not json_output: - if v_report.expected is not None: - logger.success( - f"Version alignment OK [strategy={v_report.strategy}, " - f"expected={v_report.expected}]" - ) - else: - logger.success(f"Version alignment OK [strategy={v_report.strategy}]") - for row in v_report.packages: - tag_str = f" -> tag {row.rendered_tag}" if row.rendered_tag else "" - logger.info(f" {row.path} {row.version}{tag_str} [{row.reason}]") - else: - version_gate_failed = True - if not json_output: - if v_report.expected is not None: - logger.error( - f"Version alignment failed [strategy={v_report.strategy}, " - f"expected={v_report.expected}]" - ) - else: - logger.error(f"Version alignment failed [strategy={v_report.strategy}]") - for row in v_report.packages: - tag_str = f" -> tag {row.rendered_tag}" if row.rendered_tag else "" - version_str = row.version if row.version is not None else "" - logger.info(f" {row.path} {version_str}{tag_str} [{row.reason}]") - for msg in v_report.error_messages(): - gate_errors.append({"code": "version_misaligned", "message": msg}) - - if check_clean: - # Use a builder with dry_run=True so the gate itself - # never mutates the working tree. - mkt_opts = MktBuildOptions( - dry_run=True, - offline=options.marketplace_offline, - include_prerelease=options.marketplace_include_prerelease, - ) - drift_builder = MarketplaceBuilder.from_config( - gate_config, project_root=project_root, options=mkt_opts - ) - d_report = check_marketplace_drift(drift_builder, gate_config, project_root) - drift_payload = d_report.to_json_dict() - if d_report.ok: - if not json_output: - formats = ", ".join(o.format for o in d_report.outputs) - logger.success(f"Marketplace working tree clean [outputs={formats}]") - for out in d_report.outputs: - logger.info(f" {out.path} [unchanged]") - else: - drift_gate_failed = True - if not json_output: - dirty_formats = ", ".join( - o.format for o in d_report.outputs if o.status != "unchanged" - ) - logger.error(f"Marketplace working tree dirty [outputs={dirty_formats}]") - for out in d_report.outputs: - if out.status == "unchanged": - logger.info(f" {out.path} [unchanged]") - elif out.status == "missing": - logger.info(f" {out.path} [missing on disk; would be created]") - _emit_drift_recipe(logger, out.path) - else: - count = len(out.differences) - logger.info(f" {out.path} [drift: {count} differences]") - for line in render_diff_lines(out): - logger.info(line) - _emit_drift_recipe(logger, out.path) - for msg in d_report.error_messages(): - gate_errors.append({"code": "marketplace_drift", "message": msg}) + if not version_gate_failed and not drift_gate_failed and not gate_errors: + # _run_release_gates may have called ctx.exit() on schema errors + pass # -- JSON output mode: consistent envelope -- if json_output: @@ -496,30 +406,6 @@ def pack_cmd( # noqa: PLR0913 -- Click handler, one param per CLI option ctx.exit(4) -def _emit_drift_recipe(logger, out_path: str) -> None: - """Emit the canonical recovery recipe when marketplace.json drift is detected. - - Teaches producers the amend+force-with-lease pattern so they can fix the - drift without a noisy follow-up commit. - """ - logger.info("") - logger.info(" To recover cleanly (fold into the current commit):") - logger.info("") - logger.info(" apm pack # regenerate locally") - logger.info(f" git add -- {out_path}") - logger.info(" git commit --amend --no-edit # fold into the current commit") - logger.info(" git push --force-with-lease # safe re-push") - logger.info("") - logger.info(" Or as a follow-up commit:") - logger.info("") - logger.info(f" apm pack && git add -- {out_path}") - logger.info(" git commit -m 'chore(marketplace): regen'") - logger.info("") - logger.info(" Why this exists: marketplace.json is checked in (lockfile pattern)") - logger.info(" so consumers can resolve packages without running 'apm pack'. CI") - logger.info(" enforces that the checked-in copy matches the apm.yml source of truth.") - - def _render_bundle_result(logger, pack_result, fmt, target, dry_run): """Mirror the legacy ``apm pack`` output for the bundle producer.""" if pack_result is None: diff --git a/src/apm_cli/commands/uninstall/cli.py b/src/apm_cli/commands/uninstall/cli.py index cff62d270..d8418de8e 100644 --- a/src/apm_cli/commands/uninstall/cli.py +++ b/src/apm_cli/commands/uninstall/cli.py @@ -20,6 +20,58 @@ ) +def _collect_deployed_files(packages_to_remove, actual_orphans, lockfile): + """Collect deployed files for removed packages before lockfile mutation.""" + from ...integration.base_integrator import BaseIntegrator + + removed_keys = builtins.set() + for pkg in packages_to_remove: + try: + ref = _parse_dependency_entry(pkg) + removed_keys.add(ref.get_unique_key()) + except (ValueError, TypeError, AttributeError, KeyError): + removed_keys.add(pkg) + removed_keys.update(actual_orphans) + all_deployed_files = builtins.set() + if lockfile: + for dep_key, dep in lockfile.dependencies.items(): + if dep_key in removed_keys: + all_deployed_files.update(dep.deployed_files) + return BaseIntegrator.normalize_managed_files(all_deployed_files) or builtins.set() + + +def _update_lockfile_after_remove( + lockfile, packages_to_remove, actual_orphans, lockfile_path, logger +): + """Update or delete the lockfile after package removal.""" + if not lockfile: + return + lockfile_updated = False + for pkg in packages_to_remove: + try: + ref = _parse_dependency_entry(pkg) + key = ref.get_unique_key() + except (ValueError, TypeError, AttributeError, KeyError): + key = pkg + if key in lockfile.dependencies: + del lockfile.dependencies[key] + lockfile_updated = True + for orphan_key in actual_orphans: + if orphan_key in lockfile.dependencies: + del lockfile.dependencies[orphan_key] + lockfile_updated = True + if lockfile_updated: + try: + if lockfile.dependencies: + lockfile.write(lockfile_path) + else: + lockfile_path.unlink(missing_ok=True) + except Exception: + logger.warning( + "Failed to update lockfile -- it may be out of sync with uninstalled packages." + ) + + @click.command(help="Remove APM packages, their integrated files, and apm.yml entries") @click.argument("packages", nargs=-1, required=True) @click.option("--dry-run", is_flag=True, help="Show what would be removed without removing") @@ -176,51 +228,12 @@ def uninstall(ctx, packages, dry_run, verbose, global_): removed_from_modules += orphan_removed # Step 7: Collect deployed files for removed packages (before lockfile mutation) - from ...integration.base_integrator import BaseIntegrator - - removed_keys = builtins.set() - for pkg in packages_to_remove: - try: - ref = _parse_dependency_entry(pkg) - removed_keys.add(ref.get_unique_key()) - except (ValueError, TypeError, AttributeError, KeyError): - removed_keys.add(pkg) - removed_keys.update(actual_orphans) - all_deployed_files = builtins.set() - if lockfile: - for dep_key, dep in lockfile.dependencies.items(): - if dep_key in removed_keys: - all_deployed_files.update(dep.deployed_files) - all_deployed_files = ( - BaseIntegrator.normalize_managed_files(all_deployed_files) or builtins.set() - ) + all_deployed_files = _collect_deployed_files(packages_to_remove, actual_orphans, lockfile) # Step 8: Update lockfile - if lockfile: - lockfile_updated = False - for pkg in packages_to_remove: - try: - ref = _parse_dependency_entry(pkg) - key = ref.get_unique_key() - except (ValueError, TypeError, AttributeError, KeyError): - key = pkg - if key in lockfile.dependencies: - del lockfile.dependencies[key] - lockfile_updated = True - for orphan_key in actual_orphans: - if orphan_key in lockfile.dependencies: - del lockfile.dependencies[orphan_key] - lockfile_updated = True - if lockfile_updated: - try: - if lockfile.dependencies: - lockfile.write(lockfile_path) - else: - lockfile_path.unlink(missing_ok=True) - except Exception: - logger.warning( - "Failed to update lockfile -- it may be out of sync with uninstalled packages." - ) + _update_lockfile_after_remove( + lockfile, packages_to_remove, actual_orphans, lockfile_path, logger + ) # Step 9: Sync integrations cleaned = { From 1a7b6eefeb7b1139b4df05cec968f0f47252d061 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 06:57:14 +0200 Subject: [PATCH 17/21] refactor(models,compilation): split under 800-line guardrail (#1078) Strangler Stage 2, Commit 5 of 8. Drive the models/dependency and compilation subsystems under the 800-line file guardrail and clear their second-tier complexity offenders, preserving behaviour and the monkeypatch/import surface via reusable mixin classes. models/dependency/reference.py 1985 -> 702 (largest file in the repo): - DependencyReference now composes three mixins -- _ReferenceParseMixin (_reference_parse.py), _ReferenceUrlMixin (_reference_url.py), _ReferenceShorthandMixin (_reference_shorthand.py). The dataclass and composed symbol stay defined in reference.py so the patch surface is unchanged. - Leaf helpers/constants moved to _reference_util.py and re-exported from reference.py; this breaks the import cycle (mixins import the util, never the parent at module scope; back-references are TYPE_CHECKING-only). - parse_from_dict (C901) decomposed into shape-routed sub-parsers with two shared validators that genuinely de-duplicate the repeated alias/ref grammar. - models/validation.py: validate_apm_package PLR0911 9 -> 4 returns via a type-dispatch helper. compilation: - agents_compiler.py 1498 -> 785 (+_agents_emit, _agents_output mixins) - context_optimizer.py 1328 -> 548 (+_placement_solver, _pattern_matcher) - distributed_compiler.py 804 -> 658 (+_distributed_orphans) - link_resolver.py PLR0911 10 -> 8 returns via merged guards - moved methods that reference patched module globals (resolve_markdown_links, _logger, Path) route back through the origin module at call time so existing monkeypatches still apply. File-length backlog 19 -> 15. Complexity gate clean at final Stage-2 thresholds; thresholds flip in the final enforcement commit. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/compilation/_agents_emit.py | 534 +++++++ src/apm_cli/compilation/_agents_output.py | 261 ++++ .../compilation/_distributed_orphans.py | 164 ++ src/apm_cli/compilation/_pattern_matcher.py | 357 +++++ src/apm_cli/compilation/_placement_solver.py | 482 ++++++ src/apm_cli/compilation/agents_compiler.py | 719 +-------- src/apm_cli/compilation/context_optimizer.py | 786 +--------- .../compilation/distributed_compiler.py | 152 +- src/apm_cli/compilation/link_resolver.py | 8 +- .../models/dependency/_reference_parse.py | 504 ++++++ .../models/dependency/_reference_shorthand.py | 375 +++++ .../models/dependency/_reference_url.py | 550 +++++++ .../models/dependency/_reference_util.py | 53 + src/apm_cli/models/dependency/reference.py | 1377 +---------------- src/apm_cli/models/validation.py | 58 +- 15 files changed, 3373 insertions(+), 3007 deletions(-) create mode 100644 src/apm_cli/compilation/_agents_emit.py create mode 100644 src/apm_cli/compilation/_agents_output.py create mode 100644 src/apm_cli/compilation/_distributed_orphans.py create mode 100644 src/apm_cli/compilation/_pattern_matcher.py create mode 100644 src/apm_cli/compilation/_placement_solver.py create mode 100644 src/apm_cli/models/dependency/_reference_parse.py create mode 100644 src/apm_cli/models/dependency/_reference_shorthand.py create mode 100644 src/apm_cli/models/dependency/_reference_url.py create mode 100644 src/apm_cli/models/dependency/_reference_util.py diff --git a/src/apm_cli/compilation/_agents_emit.py b/src/apm_cli/compilation/_agents_emit.py new file mode 100644 index 000000000..b3baed4be --- /dev/null +++ b/src/apm_cli/compilation/_agents_emit.py @@ -0,0 +1,534 @@ +"""Mixin: emit methods for CLAUDE.md, GEMINI.md, and Copilot root instructions. + +Extracted from agents_compiler.AgentsCompiler to stay under the 800-line +guardrail (Strangler Stage 2 / issue #1078). + +Rule B routing +-------------- +Three module-level names in agents_compiler are patched by tests: + ``resolve_markdown_links``, ``discover_primitives``, ``_logger``. +Any moved method that references them does so via a **function-level** late +import so the mock installed by the test suite is picked up at call time: + + from apm_cli.compilation import agents_compiler as _ac + _ac.resolve_markdown_links(...) + _ac._logger.debug(...) +""" + +from __future__ import annotations + +import hashlib +from typing import TYPE_CHECKING + +from ..primitives.models import PrimitiveCollection +from ..utils.paths import portable_relpath +from ..version import get_version +from .constants import BUILD_ID_PLACEHOLDER + +if TYPE_CHECKING: + from .agents_compiler import CompilationConfig, CompilationResult + + +class _AgentsEmitMixin: + """Mixin: CLAUDE.md / GEMINI.md / copilot-instructions.md emit methods.""" + + # ------------------------------------------------------------------ # + # CLAUDE.md compilation # + # ------------------------------------------------------------------ # + + def _compile_claude_md( + self, + config: CompilationConfig, + primitives: PrimitiveCollection, + ) -> CompilationResult: + """Compile CLAUDE.md files (Claude Code target).""" + from apm_cli.compilation.agents_compiler import CompilationResult + + errors = self.validate_primitives(primitives) + self.errors.extend(errors) + + # Create Claude formatter + from .claude_formatter import ClaudeFormatter + + claude_formatter = ClaudeFormatter(str(self.base_dir), source_dir=str(self.source_dir)) + + # Honor compilation.strategy=single-file (and the --single-agents flag) + # by collapsing all instructions into a single root CLAUDE.md, mirroring + # the gate in _compile_agents_md. Without this, single-file mode is + # silently ignored for the Claude target and per-subdirectory CLAUDE.md + # files are emitted via the distributed placement path (issue #1445). + # + # DistributedAgentsCompiler is only constructed on the distributed + # branch -- single-file mode does not use its placement analysis and + # the later display block guards on `distributed_compiler is not None`. + distributed_compiler = None + if config.strategy != "distributed" or config.single_agents: + placement_map = {self.base_dir: list(primitives.instructions)} + else: + from .distributed_compiler import DistributedAgentsCompiler + + distributed_compiler = DistributedAgentsCompiler( + str(self.base_dir), + exclude_patterns=config.exclude, + source_dir=str(self.source_dir), + ) + # Analyze directory structure and determine placement + directory_map = distributed_compiler.analyze_directory_structure( + primitives.instructions + ) + placement_map = distributed_compiler.determine_agents_placement( + primitives.instructions, + directory_map, + min_instructions=config.min_instructions_per_file, + debug=config.debug, + ) + + # Skip instructions in CLAUDE.md when they are already deployed to + # .claude/rules/ by `apm install` (avoids duplicate context in Claude Code). + # --no-dedup / --force-instructions lets users opt out of this behaviour. + from .agents_compiler import _detect_deployed_instructions + + if config.no_dedup: + skip_instructions = False + self._log( + "progress", + "Including instructions in CLAUDE.md (--no-dedup overrides deduplication)", + symbol="info", + ) + else: + skip_instructions = _detect_deployed_instructions( + self.base_dir / ".claude" / "rules", + self.base_dir, + lambda msg: self._log("warning", msg), + ) + if skip_instructions: + self._log( + "progress", + "Instructions already in .claude/rules/ -- omitting from CLAUDE.md" + " to avoid duplicate context", + symbol="info", + ) + + # Format CLAUDE.md files + claude_config = { + "source_attribution": config.source_attribution, + "debug": config.debug, + "skip_instructions": skip_instructions, + } + claude_result = claude_formatter.format_distributed( + primitives, placement_map, claude_config + ) + + # NOTE: Claude commands are now generated at install time via CommandIntegrator, + # not at compile time. This keeps behavior consistent with VSCode prompt integration. + + # Merge warnings and errors (no command result anymore) + all_warnings = self.warnings + claude_result.warnings + all_errors = self.errors + claude_result.errors + + # Handle dry-run mode + if config.dry_run: + # Generate preview summary + count = len(claude_result.placements) + preview_lines = [ + f"CLAUDE.md Preview: Would generate {count} {'file' if count == 1 else 'files'}" + ] + # Surface the deduplication skip so dry-run is self-explanatory + # for scripted consumers (otherwise "Would generate 0 files" + # looks like a no-op or a bug). The same skip appears in the + # non-dry-run path via the dedicated INFO log line. + if skip_instructions: + preview_lines.append( + " (instructions section skipped: .claude/rules/ already " + "populated -- avoids duplicate content in Claude Code's " + "context window)" + ) + for claude_path in claude_result.content_map.keys(): # noqa: SIM118 + rel_path = portable_relpath(claude_path, self.base_dir) + preview_lines.append(f" {rel_path}") + + return CompilationResult( + success=len(all_errors) == 0, + output_path="Preview mode - CLAUDE.md", + content="\n".join(preview_lines), + warnings=all_warnings, + errors=all_errors, + stats=claude_result.stats, + ) + + # Write CLAUDE.md files + files_written = 0 + critical_security_found = False + # Rule B: _logger is patched at agents_compiler._logger in tests + from apm_cli.compilation import agents_compiler as _ac + + from ..security.gate import WARN_POLICY, SecurityGate + from .output_writer import CompiledOutputWriter + + writer = CompiledOutputWriter() + for claude_path, content in claude_result.content_map.items(): + try: + # Handle constitution injection if enabled + final_content = content + if config.with_constitution: + try: + from .injector import ConstitutionInjector + + injector = ConstitutionInjector(str(claude_path.parent)) + final_content, _, _ = injector.inject( + content, with_constitution=True, output_path=claude_path + ) + except Exception as exc: + _ac._logger.debug( + "Constitution injection failed for %s: %s", claude_path, exc + ) + + # Defense-in-depth: scan compiled output before writing + verdict = SecurityGate.scan_text( + final_content, str(claude_path), policy=WARN_POLICY + ) + actionable = verdict.critical_count + verdict.warning_count + if actionable: + if verdict.has_critical: + critical_security_found = True + all_warnings.append( + f"CLAUDE.md contains {actionable} hidden character(s) " + f"— run 'apm audit --file {claude_path}' to inspect" + ) + + writer.write(claude_path, final_content) + files_written += 1 + except OSError as e: + all_errors.append(f"Failed to write {claude_path}: {e!s}") + + # Update stats + stats = claude_result.stats.copy() + stats["claude_files_written"] = files_written + + if files_written == 0 and skip_instructions: + self._log( + "progress", + "CLAUDE.md not generated -- Claude Code reads .claude/rules/ directly," + " no further action needed", + symbol="info", + ) + elif distributed_compiler is None and files_written > 0 and not config.dry_run: + # Single-file strategy bypasses the distributed display formatter + # (which has no analysis to render). Emit a minimal progress line + # so users get a confirmation that single-file mode took effect. + noun = "file" if files_written == 1 else "files" + self._log( + "progress", + f"CLAUDE.md compiled ({files_written} {noun})", + symbol="success", + ) + + # Display CLAUDE.md compilation output using standard formatter + # Get proper compilation results from distributed compiler (has optimization decisions) + # Skip formatter output when deduplication filtered out all placements to + # avoid contradicting the "not generated" log message above. + from ..output.formatters import CompilationFormatter + from ..output.models import CompilationResults + + compilation_results = ( + distributed_compiler.get_compilation_results_for_display(is_dry_run=config.dry_run) + if distributed_compiler is not None + else None + ) + if compilation_results and not (skip_instructions and files_written == 0): + # Update target name for CLAUDE.md output + formatter_results = CompilationResults( + project_analysis=compilation_results.project_analysis, + optimization_decisions=compilation_results.optimization_decisions, + placement_summaries=compilation_results.placement_summaries, + optimization_stats=compilation_results.optimization_stats, + warnings=all_warnings, + errors=all_errors, + is_dry_run=config.dry_run, + target_name="CLAUDE.md", + ) + + # Use the same formatter as AGENTS.md + formatter = CompilationFormatter(use_color=True) + if config.debug or config.trace: + output = formatter.format_verbose(formatter_results) + elif config.dry_run: + output = formatter.format_dry_run(formatter_results) + else: + output = formatter.format_default(formatter_results) + self._log("progress", output) + + # Generate summary content for result object + summary_lines = [ + f"# CLAUDE.md Compilation Summary", # noqa: F541 + f"", # noqa: F541 + f"Generated {files_written} CLAUDE.md files:", + ] + for placement in claude_result.placements: + rel_path = portable_relpath(placement.claude_path, self.base_dir) + summary_lines.append(f"- {rel_path} ({len(placement.instructions)} instructions)") + + return CompilationResult( + success=len(all_errors) == 0, + output_path=f"CLAUDE.md: {files_written} files", + content="\n".join(summary_lines), + warnings=all_warnings, + errors=all_errors, + stats=stats, + has_critical_security=critical_security_found, + ) + + # ------------------------------------------------------------------ # + # GEMINI.md compilation # + # ------------------------------------------------------------------ # + + def _compile_gemini_md( + self, + config: CompilationConfig, + primitives: PrimitiveCollection, + ) -> CompilationResult: + """Compile GEMINI.md stub that imports AGENTS.md.""" + from apm_cli.compilation.agents_compiler import CompilationResult + + from .gemini_formatter import GeminiFormatter + + gemini_formatter = GeminiFormatter(str(self.base_dir)) + gemini_result = gemini_formatter.format_distributed(primitives) + + all_warnings = self.warnings + gemini_result.warnings + all_errors = self.errors + gemini_result.errors + + if config.dry_run: + return CompilationResult( + success=len(all_errors) == 0, + output_path="Preview mode - GEMINI.md", + content="GEMINI.md Preview: Would generate stub importing AGENTS.md", + warnings=all_warnings, + errors=all_errors, + stats=gemini_result.stats, + ) + + files_written = 0 + from .output_writer import CompiledOutputWriter + + writer = CompiledOutputWriter() + for gemini_path, content in gemini_result.content_map.items(): + try: + writer.write(gemini_path, content) + files_written += 1 + except OSError as e: + all_errors.append(f"Failed to write {gemini_path}: {e!s}") + + stats = gemini_result.stats.copy() + stats["gemini_files_written"] = files_written + + self._log("progress", "Generated GEMINI.md (imports AGENTS.md)") + + return CompilationResult( + success=len(all_errors) == 0, + output_path=f"GEMINI.md: {files_written} files", + content=f"Generated {files_written} GEMINI.md stub importing AGENTS.md", + warnings=all_warnings, + errors=all_errors, + stats=stats, + ) + + # ------------------------------------------------------------------ # + # Copilot root-instructions emit / cleanup # + # ------------------------------------------------------------------ # + + def _maybe_emit_copilot_root_instructions( + self, + config: CompilationConfig, + primitives: PrimitiveCollection, + result: CompilationResult, + ) -> CompilationResult: + """Generate .github/copilot-instructions.md for Copilot-capable targets.""" + from ..core.target_detection import should_compile_copilot_instructions_md + from .agents_compiler import _COPILOT_ROOT_GENERATED_MARKER, _VSCODE_TARGET_ALIASES + + routing_target = "vscode" if config.target in _VSCODE_TARGET_ALIASES else config.target + output_path = self.base_dir / ".github" / "copilot-instructions.md" + if not should_compile_copilot_instructions_md(routing_target): + if not config.dry_run: + self._cleanup_copilot_root_instructions(output_path, result) + result.stats.setdefault("copilot_root_instructions_generated", 0) + result.stats.setdefault("copilot_root_instructions_written", 0) + result.stats.setdefault("copilot_root_instructions_unchanged", 0) + result.stats.setdefault("copilot_root_instructions_skipped", 0) + result.stats.setdefault("copilot_root_instructions_removed", 0) + return result + + global_instructions = sorted( + [instruction for instruction in primitives.instructions if not instruction.apply_to], + key=lambda instruction: portable_relpath(instruction.file_path, self.base_dir), + ) + if not global_instructions: + if not config.dry_run: + self._cleanup_copilot_root_instructions(output_path, result) + result.stats.setdefault("copilot_root_instructions_generated", 0) + result.stats.setdefault("copilot_root_instructions_written", 0) + result.stats.setdefault("copilot_root_instructions_unchanged", 0) + result.stats.setdefault("copilot_root_instructions_skipped", 0) + result.stats.setdefault("copilot_root_instructions_removed", 0) + return result + + content = self._generate_copilot_root_instructions_content(global_instructions, config) + + result.stats["copilot_root_instructions_generated"] = 1 + result.stats.setdefault("copilot_root_instructions_skipped", 0) + result.stats.setdefault("copilot_root_instructions_removed", 0) + result.stats.setdefault("copilot_root_instructions_written", 0) + result.stats.setdefault("copilot_root_instructions_unchanged", 0) + + # Inspect any existing file BEFORE the dry-run early-exit so that + # `--dry-run` faithfully reports what a real run would do (skip vs + # write vs unchanged). Reading the file here is safe in dry-run mode + # because we never mutate it. + try: + existing = output_path.read_text(encoding="utf-8") if output_path.exists() else None + except OSError as exc: + message = f"Failed to read {output_path}: {exc}" + self.errors.append(message) + result.errors.append(message) + result.success = False + return result + + if existing is not None and _COPILOT_ROOT_GENERATED_MARKER not in existing: + rel_path = portable_relpath(output_path, self.base_dir) + result.warnings.append( + f"Skipped {rel_path}: hand-authored file will not be overwritten. " + "To regenerate, either delete or rename it, or prepend the line " + f"'{_COPILOT_ROOT_GENERATED_MARKER}' to the top of the file. " + "Then re-run 'apm compile'." + ) + # The file was never compared to new content; record as + # 'skipped', not 'unchanged'. Also reset 'generated' since no + # output was actually emitted (or would be, on a real run). + result.stats["copilot_root_instructions_generated"] = 0 + result.stats["copilot_root_instructions_written"] = 0 + result.stats["copilot_root_instructions_skipped"] = 1 + result.stats["copilot_root_instructions_unchanged"] = 0 + return result + + if existing == content: + result.stats["copilot_root_instructions_written"] = 0 + result.stats["copilot_root_instructions_unchanged"] = 1 + return result + + if config.dry_run: + return result + + from ..security.gate import WARN_POLICY, SecurityGate + + verdict = SecurityGate.scan_text(content, str(output_path), policy=WARN_POLICY) + actionable = verdict.critical_count + verdict.warning_count + if actionable: + if verdict.has_critical: + result.has_critical_security = True + result.warnings.append( + f"copilot-instructions.md contains {actionable} hidden character(s) " + f"-- run 'apm audit --file {output_path}' to inspect" + ) + + try: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content, encoding="utf-8") + result.stats["copilot_root_instructions_written"] = 1 + result.stats["copilot_root_instructions_unchanged"] = 0 + return result + except OSError as exc: + message = f"Failed to write {output_path}: {exc}" + self.errors.append(message) + result.errors.append(message) + result.success = False + result.stats["copilot_root_instructions_written"] = 0 + result.stats.setdefault("copilot_root_instructions_unchanged", 0) + return result + + def _generate_copilot_root_instructions_content( + self, + instructions, + config: CompilationConfig, + ) -> str: + """Generate root Copilot instructions content from global instruction primitives.""" + from .agents_compiler import _COPILOT_ROOT_GENERATED_MARKER + + # Functional marker and Build ID are always present (injection/drift/cleanup coupling). + sections = [ + _COPILOT_ROOT_GENERATED_MARKER, + BUILD_ID_PLACEHOLDER, + ] + if config.source_attribution: + sections.append(f"") + sections.append("") + + for instruction in instructions: + # instruction.file_path is a source-tree file; relativise it + # against source_dir so `apm compile --root` never leaks + # `../../` or absolute deploy-relative paths into the + # `` provenance comments (sources stay in $PWD + # while writes redirect to base_dir). + rel_path = portable_relpath(instruction.file_path, self.source_dir) + if config.source_attribution: + sections.append(f"") + sections.append(instruction.content.strip()) + if config.source_attribution: + sections.append(f"") + sections.append("") + + if config.source_attribution: + sections.append("---") + sections.append("*This file was generated by APM CLI. Do not edit manually.*") + sections.append("*To regenerate: `apm compile`*") + sections.append("") + + content = "\n".join(sections) + if config.resolve_links: + # Rule B: resolve_markdown_links is patched at agents_compiler in tests + from apm_cli.compilation import agents_compiler as _ac + + content = _ac.resolve_markdown_links(content, self.base_dir) + return self._finalize_build_id(content) + + def _finalize_build_id(self, content: str) -> str: + """Replace the build-id placeholder with a deterministic content hash.""" + lines = content.splitlines() + try: + idx = lines.index(BUILD_ID_PLACEHOLDER) + except ValueError: + return content + + hash_input_lines = [line for i, line in enumerate(lines) if i != idx] + build_id = hashlib.sha256("\n".join(hash_input_lines).encode("utf-8")).hexdigest()[:12] + lines[idx] = f"" + return "\n".join(lines) + ("\n" if content.endswith("\n") else "") + + def _cleanup_copilot_root_instructions( + self, + output_path, + result: CompilationResult, + ) -> CompilationResult: + """Remove stale generated Copilot root instructions when no longer applicable.""" + from .agents_compiler import _COPILOT_ROOT_GENERATED_MARKER + + if not output_path.exists(): + result.stats.setdefault("copilot_root_instructions_removed", 0) + return result + + try: + existing = output_path.read_text(encoding="utf-8") + if _COPILOT_ROOT_GENERATED_MARKER not in existing: + result.stats.setdefault("copilot_root_instructions_removed", 0) + return result + + output_path.unlink() + result.stats["copilot_root_instructions_removed"] = 1 + return result + except OSError as exc: + message = f"Failed to remove stale {output_path}: {exc}" + self.errors.append(message) + result.errors.append(message) + result.success = False + result.stats.setdefault("copilot_root_instructions_removed", 0) + return result diff --git a/src/apm_cli/compilation/_agents_output.py b/src/apm_cli/compilation/_agents_output.py new file mode 100644 index 000000000..8d24f2416 --- /dev/null +++ b/src/apm_cli/compilation/_agents_output.py @@ -0,0 +1,261 @@ +"""Mixin: output-writing and display methods for AgentsCompiler. + +Extracted from agents_compiler.AgentsCompiler to stay under the 800-line +guardrail (Strangler Stage 2 / issue #1078). + +Rule B routing +-------------- +``_logger`` is patched at ``apm_cli.compilation.agents_compiler._logger`` in +tests. Any moved method that calls ``_logger`` does so via a function-level +late import: + + from apm_cli.compilation import agents_compiler as _ac + _ac._logger.debug(...) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .agents_compiler import CompilationConfig + +from ..primitives.models import PrimitiveCollection +from ..utils.paths import portable_relpath +from .constants import BUILD_ID_PLACEHOLDER # noqa: F401 (re-export convenience) +from .template_builder import TemplateData + + +class _AgentsOutputMixin: + """Mixin: file-write, stats, display, and summary methods.""" + + def _write_output_file(self, output_path: str, content: str) -> None: + """Write the generated content to the output file (full-file mode). + + Args: + output_path (str): Path to write the output. + content (str): Content to write. + """ + from .output_writer import CompiledOutputWriter + + try: + CompiledOutputWriter().write(Path(output_path), content) + except OSError as e: + self.errors.append(f"Failed to write output file {output_path}: {e!s}") + + def _write_output_file_with_config( + self, + output_path: str, + content: str, + config: CompilationConfig, + ) -> None: + """Write generated content, honouring agents_md_mode from config. + + In ``full`` mode (default) the entire file is replaced. + In ``managed_section`` mode only the text between the configured + start/end markers is replaced; everything else is preserved. + + Args: + output_path (str): Path to write the output. + content (str): Generated content for this compilation. + config (CompilationConfig): Compilation configuration. + """ + from .managed_section import ManagedSectionError, apply_managed_section + from .output_writer import CompiledOutputWriter + + if config.agents_md_mode == "managed_section": + target = Path(output_path) + if not target.is_file(): + raise ManagedSectionError( + f"{target} does not exist yet. " + "Create it with the managed-section markers first, " + "or set agents_md.mode: full in apm.yml for initial generation." + ) + existing = target.read_text(encoding="utf-8") + try: + content = apply_managed_section( + existing, + content, + config.agents_md_start_marker, + config.agents_md_end_marker, + ) + except ManagedSectionError as exc: + raise ManagedSectionError(f"[{target}] {exc}") from exc + elif config.agents_md_mode != "full": + raise ValueError( + f"Unknown agents_md.mode {config.agents_md_mode!r}. " + "Supported values: 'full', 'managed_section'." + ) + + try: + CompiledOutputWriter().write(Path(output_path), content) + except OSError as e: + self.errors.append(f"Failed to write output file {output_path}: {e!s}") + + def _compile_stats( + self, primitives: PrimitiveCollection, template_data: TemplateData + ) -> dict[str, Any]: + """Compile statistics about the compilation. + + Args: + primitives (PrimitiveCollection): Discovered primitives. + template_data (TemplateData): Generated template data. + + Returns: + Dict[str, Any]: Compilation statistics. + """ + return { + "primitives_found": primitives.count(), + "chatmodes": len(primitives.chatmodes), + "instructions": len(primitives.instructions), + "contexts": len(primitives.contexts), + "content_length": len(template_data.instructions_content), + # timestamp removed + "version": template_data.version, + } + + def _write_distributed_file( + self, + agents_path: Path, + content: str, + config: CompilationConfig, + ) -> None: + """Write a distributed AGENTS.md file with constitution injection support. + + Args: + agents_path (Path): Path to write the AGENTS.md file. + content (str): Content to write. + config (CompilationConfig): Compilation configuration. + """ + try: + # Handle constitution injection for distributed files + final_content = content + + if config.with_constitution: + # Try to inject constitution if available + try: + from .injector import ConstitutionInjector + + injector = ConstitutionInjector(str(agents_path.parent)) + final_content, c_status, c_hash = injector.inject( # noqa: RUF059 + content, with_constitution=True, output_path=agents_path + ) + except Exception as exc: + # Rule B: _logger is patched at agents_compiler._logger in tests + from apm_cli.compilation import agents_compiler as _ac + + _ac._logger.debug("Constitution injection failed for %s: %s", agents_path, exc) + + from .output_writer import CompiledOutputWriter + + CompiledOutputWriter().write(agents_path, final_content) + + except OSError as e: + raise OSError(f"Failed to write distributed AGENTS.md file {agents_path}: {e!s}") # noqa: B904 + + def _display_placement_preview(self, distributed_result) -> None: + """Display placement preview for --show-placement mode. + + Args: + distributed_result: Result from distributed compilation. + """ + self._log("progress", "Distributed AGENTS.md Placement Preview:") + self._log("progress", "") + + for placement in distributed_result.placements: + rel_path = portable_relpath(placement.agents_path, self.base_dir) + self._log("verbose_detail", f"{rel_path}") + self._log("verbose_detail", f" Instructions: {len(placement.instructions)}") + self._log( + "verbose_detail", f" Patterns: {', '.join(sorted(placement.coverage_patterns))}" + ) + if placement.source_attribution: + sources = set(placement.source_attribution.values()) + self._log("verbose_detail", f" Sources: {', '.join(sorted(sources))}") + self._log("verbose_detail", "") + + def _display_trace_info(self, distributed_result, primitives: PrimitiveCollection) -> None: + """Display detailed trace information for --trace mode. + + Args: + distributed_result: Result from distributed compilation. + primitives (PrimitiveCollection): Full primitive collection. + """ + self._log("progress", "Distributed Compilation Trace:") + self._log("progress", "") + + for placement in distributed_result.placements: + rel_path = portable_relpath(placement.agents_path, self.base_dir) + self._log("verbose_detail", f"{rel_path}") + + for instruction in placement.instructions: + source = getattr(instruction, "source", "local") + # instruction.file_path is a source-tree file; relativise + # against source_dir so `apm compile --root` produces + # human-readable paths in verbose output. + inst_path = portable_relpath(instruction.file_path, self.source_dir) + + self._log( + "verbose_detail", + f" * {instruction.apply_to or 'no pattern'} <- {source} {inst_path}", + ) + self._log("verbose_detail", "") + + def _generate_placement_summary(self, distributed_result) -> str: + """Generate a text summary of placement results. + + Args: + distributed_result: Result from distributed compilation. + + Returns: + str: Text summary of placements. + """ + lines = ["Distributed AGENTS.md Placement Summary:", ""] + + for placement in distributed_result.placements: + rel_path = portable_relpath(placement.agents_path, self.base_dir) + lines.append(f"{rel_path}") + lines.append(f" Instructions: {len(placement.instructions)}") + lines.append(f" Patterns: {', '.join(sorted(placement.coverage_patterns))}") + lines.append("") + + lines.append(f"Total AGENTS.md files: {len(distributed_result.placements)}") + return "\n".join(lines) + + def _generate_distributed_summary( + self, + distributed_result, + config: CompilationConfig, + ) -> str: + """Generate a summary of distributed compilation results. + + Args: + distributed_result: Result from distributed compilation. + config (CompilationConfig): Compilation configuration. + + Returns: + str: Summary content. + """ + lines = [ + "# Distributed AGENTS.md Compilation Summary", + "", + f"Generated {len(distributed_result.placements)} AGENTS.md files:", + "", + ] + + for placement in distributed_result.placements: + rel_path = portable_relpath(placement.agents_path, self.base_dir) + lines.append(f"- {rel_path} ({len(placement.instructions)} instructions)") + + lines.extend( + [ + "", + f"Total instructions: {distributed_result.stats.get('total_instructions_placed', 0)}", + f"Total patterns: {distributed_result.stats.get('total_patterns_covered', 0)}", + "", + "Use 'apm compile --single-agents' for traditional single-file compilation.", + ] + ) + + return "\n".join(lines) diff --git a/src/apm_cli/compilation/_distributed_orphans.py b/src/apm_cli/compilation/_distributed_orphans.py new file mode 100644 index 000000000..c8a5c4075 --- /dev/null +++ b/src/apm_cli/compilation/_distributed_orphans.py @@ -0,0 +1,164 @@ +"""Mixin: orphan-file handling methods for DistributedAgentsCompiler. + +Extracted from distributed_compiler.DistributedAgentsCompiler to stay under +the 800-line guardrail (Strangler Stage 2 / issue #1078). + +No Rule B routing is required: none of the methods here reference the +module-level names patched by tests (build_attributed_instructions, +UnifiedLinkResolver, ContextOptimizer, CompilationFormatter, CompilationResults). +""" + +from __future__ import annotations + +import builtins + +from ..utils.paths import portable_relpath + + +class _DistributedOrphansMixin: + """Mixin: orphan AGENTS.md detection/cleanup and coverage validation.""" + + def _find_orphaned_agents_files(self, generated_paths: builtins.list) -> builtins.list: + """Find existing AGENTS.md files that weren't generated in the current compilation. + + Args: + generated_paths (List[Path]): List of AGENTS.md files generated in current run. + + Returns: + List[Path]: List of orphaned AGENTS.md files that should be cleaned up. + """ + orphaned_files = [] + generated_set = builtins.set(generated_paths) + + # Find all existing AGENTS.md files in the project + for agents_file in self.base_dir.rglob("AGENTS.md"): + # Skip files that are outside our project or in special directories + try: + relative_path = agents_file.resolve().relative_to(self.base_dir.resolve()) + + # Skip files in certain directories that shouldn't be cleaned + skip_dirs = { + ".git", + ".apm", + "node_modules", + "__pycache__", + ".pytest_cache", + "apm_modules", + } + if any(part in skip_dirs for part in relative_path.parts): + continue + + # If this existing file wasn't generated in current run, it's orphaned + if agents_file not in generated_set: + orphaned_files.append(agents_file) + + except ValueError: + # File is outside base_dir, skip it + continue + + return orphaned_files + + def _generate_orphan_warnings(self, orphaned_files: builtins.list) -> builtins.list: + """Generate warning messages for orphaned AGENTS.md files. + + Args: + orphaned_files (List[Path]): List of orphaned files to warn about. + + Returns: + List[str]: List of warning messages. + """ + warning_messages = [] + + if not orphaned_files: + return warning_messages + + # Professional warning format with readable list for multiple files + if len(orphaned_files) == 1: + rel_path = portable_relpath(orphaned_files[0], self.base_dir) + warning_messages.append( + f"Orphaned AGENTS.md found: {rel_path} - run 'apm compile --clean' to remove" + ) + else: + # For multiple files, create a single multi-line warning message + file_list = [] + for file_path in orphaned_files[:5]: # Show first 5 + rel_path = portable_relpath(file_path, self.base_dir) + file_list.append(f" * {rel_path}") + if len(orphaned_files) > 5: + file_list.append(f" * ...and {len(orphaned_files) - 5} more") + + # Create one cohesive warning message + files_text = "\n".join(file_list) + warning_messages.append( + f"Found {len(orphaned_files)} orphaned AGENTS.md files:\n{files_text}\n Run 'apm compile --clean' to remove orphaned files" + ) + + return warning_messages + + def _cleanup_orphaned_files( + self, orphaned_files: builtins.list, dry_run: bool = False + ) -> builtins.list: + """Actually remove orphaned AGENTS.md files. + + Args: + orphaned_files (List[Path]): List of orphaned files to remove. + dry_run (bool): If True, don't actually remove files, just report what would be removed. + + Returns: + List[str]: List of cleanup status messages. + """ + cleanup_messages = [] + + if not orphaned_files: + return cleanup_messages + + if dry_run: + # In dry-run mode, just report what would be cleaned + cleanup_messages.append( + f"Would clean up {len(orphaned_files)} orphaned AGENTS.md files" + ) + for file_path in orphaned_files: + rel_path = portable_relpath(file_path, self.base_dir) + cleanup_messages.append(f" * {rel_path}") + else: + # Actually perform the cleanup + cleanup_messages.append(f"Cleaning up {len(orphaned_files)} orphaned AGENTS.md files") + for file_path in orphaned_files: + try: + rel_path = portable_relpath(file_path, self.base_dir) + file_path.unlink() + cleanup_messages.append(f" + Removed {rel_path}") + except Exception as e: + cleanup_messages.append(f" x Failed to remove {rel_path}: {e!s}") + + return cleanup_messages + + def _validate_coverage( + self, + placements: builtins.list, + all_instructions: builtins.list, + ) -> builtins.list: + """Validate that all instructions are covered by placements. + + Args: + placements (List[PlacementResult]): Generated placements. + all_instructions (List[Instruction]): All available instructions. + + Returns: + List[str]: List of coverage warnings. + """ + warnings = [] + placed_instructions = builtins.set() + + for placement in placements: + placed_instructions.update(str(inst.file_path) for inst in placement.instructions) + + all_instruction_paths = builtins.set(str(inst.file_path) for inst in all_instructions) + + missing_instructions = all_instruction_paths - placed_instructions + if missing_instructions: + warnings.append( + f"Instructions not placed in any AGENTS.md: {', '.join(missing_instructions)}" + ) + + return warnings diff --git a/src/apm_cli/compilation/_pattern_matcher.py b/src/apm_cli/compilation/_pattern_matcher.py new file mode 100644 index 000000000..8d283d5c6 --- /dev/null +++ b/src/apm_cli/compilation/_pattern_matcher.py @@ -0,0 +1,357 @@ +"""Mixin: pattern-matching and helper methods for ContextOptimizer. + +Extracted from context_optimizer.ContextOptimizer to stay under the 800-line +guardrail (Strangler Stage 2 / issue #1078). + +Rule B routing +-------------- +``Path`` is patched at ``apm_cli.compilation.context_optimizer.Path`` in tests +(specifically ``Path.resolve``). ``_file_matches_pattern`` constructs +``Path(match)`` objects; it does so via a function-level late import: + + from apm_cli.compilation import context_optimizer as _co + _co.Path(...) +""" + +from __future__ import annotations + +import builtins +import fnmatch +import os + +from ..utils.paths import portable_relpath +from ..utils.patterns import has_top_level_comma, parse_apply_to + + +class _PatternMatcherMixin: + """Mixin: pattern-matching, inheritance-chain, and distribution-score helpers.""" + + def _extract_intended_directory_from_pattern(self, pattern: str): + """Extract the intended directory from a pattern like 'docs/**/*.md' -> 'docs'. + + Args: + pattern (str): File pattern (may be a comma-separated list). + + Returns: + Optional[Path]: Intended directory path, or None if pattern is global. + """ + # For comma-lists, only the first segment is consulted - the + # placement still flows into a single directory. + if has_top_level_comma(pattern): + segments = parse_apply_to(pattern) + if not segments: + return None + pattern = segments[0] + + if not pattern or pattern.startswith("**/"): + return None # Global pattern + + if "/" in pattern: + # Extract the first directory component + parts = pattern.split("/") + first_part = parts[0] + + # Skip if it's a wildcard + if "*" not in first_part and first_part: + intended_dir = self.base_dir / first_part + if intended_dir.exists() and intended_dir.is_dir(): + return intended_dir + + return None + + def _expand_glob_pattern(self, pattern: str) -> builtins.list: + """Expand glob pattern with brace expansion, supporting multiple brace groups. + + Args: + pattern (str): Pattern like '**/*.{css,scss}' or '**/*.{test,spec}.{ts,js}' + + Returns: + List[str]: Expanded patterns like ['**/*.css', '**/*.scss'] + or ['**/*.test.ts', '**/*.test.js', '**/*.spec.ts', '**/*.spec.js'] + """ + import re + + # Handle brace expansion like {css,scss} + brace_match = re.search(r"\{([^}]+)\}", pattern) + if brace_match: + alternatives = brace_match.group(1).split(",") + prefix = pattern[: brace_match.start()] + suffix = pattern[brace_match.end() :] + # Recursively expand remaining brace groups in each result + expanded = [] + for alt in alternatives: + expanded.extend(self._expand_glob_pattern(prefix + alt + suffix)) + return expanded + + return [pattern] + + def _file_matches_pattern(self, file_path, pattern: str) -> bool: + """Check if a file matches a given pattern with optimized performance. + + Args: + file_path (Path): File path to check + pattern (str): Glob pattern or comma-separated list of globs. + + Returns: + bool: True if file matches pattern (or any segment of a list). + """ + # applyTo accepts a comma-separated list of globs; treat any + # segment match as a hit so list patterns mirror per-glob semantics. + # Only split on top-level commas - commas inside brace alternation + # (e.g. ``**/*.{css,scss}``) must stay attached for brace expansion. + if has_top_level_comma(pattern): + segments = parse_apply_to(pattern) + return any(self._file_matches_pattern(file_path, seg) for seg in segments) + + # Expand any brace patterns + expanded_patterns = self._expand_glob_pattern(pattern) + + for expanded_pattern in expanded_patterns: + # For patterns with **, use cached glob results + if "**" in expanded_pattern: + try: + # Resolve both paths to handle symlinks and path inconsistencies + resolved_file = file_path.resolve() + rel_path = resolved_file.relative_to(self.base_dir.resolve()) + + # Use cached glob results instead of repeated glob calls + matches = self._cached_glob(expanded_pattern) + # Use cached Set[Path] to avoid recreating on every call + if expanded_pattern not in self._glob_set_cache: + # Rule B: Path is patched at context_optimizer.Path in tests + from apm_cli.compilation import context_optimizer as _co + + self._glob_set_cache[expanded_pattern] = { + _co.Path(match) for match in matches + } + if rel_path in self._glob_set_cache[expanded_pattern]: + return True + except (ValueError, OSError): + pass + else: + # For non-recursive patterns, use fnmatch as before + try: + rel_str = portable_relpath(file_path, self.base_dir) + if fnmatch.fnmatch(rel_str, expanded_pattern): + return True + except ValueError: + pass + + # Only use filename match for patterns without directory structure + # This prevents "docs/**/*.md" from matching any "*.md" file anywhere + if "/" not in expanded_pattern: + if fnmatch.fnmatch(file_path.name, expanded_pattern): + return True + + return False + + def _find_matching_directories(self, pattern: str) -> builtins.set: + """Find directories that contain files matching the pattern. + + Args: + pattern (str): File pattern to match. + + Returns: + Set[Path]: Set of directories with matching files. + """ + # Use cached result if available + if pattern in self._pattern_cache: + return self._pattern_cache[pattern] + + matching_dirs: builtins.set = builtins.set() + + # Use the reliable approach for all patterns + for directory, analysis in sorted(self._directory_cache.items()): + try: + files = [ + f for f in directory.iterdir() if f.is_file() and not f.name.startswith(".") + ] + + match_count = 0 + for file_path in files: + if self._file_matches_pattern(file_path, pattern): + match_count += 1 + matching_dirs.add(directory) + + if match_count > 0: + analysis.pattern_matches[pattern] = match_count + except (OSError, PermissionError): + continue + + self._pattern_cache[pattern] = matching_dirs + return matching_dirs + + def _calculate_inheritance_pollution(self, directory, pattern: str) -> float: + """Calculate inheritance pollution score for placing instruction at directory. + + Args: + directory (Path): Candidate placement directory. + pattern (str): Instruction pattern. + + Returns: + float: Pollution score (higher = more pollution). + """ + pollution_score = 0.0 + + # Optimization: Only check direct children instead of all directories + # This prevents O(n2) complexity with unlimited depth analysis + try: + direct_children = [ + child + for child in directory.iterdir() + if child.is_dir() and child in self._directory_cache + ] + + # Check only direct child directories for pollution + for child_dir in direct_children: + analysis = self._directory_cache[child_dir] + + # If child has no matching files, this creates pollution + child_relevance = analysis.get_relevance_score(pattern) + if child_relevance == 0.0: + pollution_score += 0.5 # Strong pollution penalty + elif child_relevance < 0.1: # Weak relevance threshold + pollution_score += 0.2 # Weak pollution penalty + except (OSError, PermissionError): + # Skip directories we can't read + pass + + return pollution_score + + def _calculate_distribution_score(self, matching_directories: builtins.set) -> float: + """Calculate distribution score with diversity factor. + + Args: + matching_directories: Set of directories with pattern matches. + + Returns: + float: Distribution score accounting for spread and depth diversity. + """ + total_dirs_with_files = len( + [d for d in self._directory_cache.values() if d.total_files > 0] + ) + if total_dirs_with_files == 0: + return 0.0 + + base_ratio = len(matching_directories) / total_dirs_with_files + + # Calculate diversity factor based on depth distribution + depths = [self._directory_cache[d].depth for d in matching_directories] + if not depths: + return base_ratio + + depth_variance = sum((d - sum(depths) / len(depths)) ** 2 for d in depths) / len(depths) + diversity_factor = 1.0 + (depth_variance * self.DIVERSITY_FACTOR_BASE) + + return base_ratio * diversity_factor + + def _get_inheritance_chain(self, working_directory) -> builtins.list: + """Get inheritance chain from working directory to root. + + Args: + working_directory (Path): Starting directory. + + Returns: + List[Path]: Inheritance chain (most specific to root). + """ + cached = self._inheritance_cache.get(working_directory) + if cached is not None: + return cached + + chain = [] + # Resolve the starting directory to ensure consistent path comparison + try: + current = working_directory.resolve() + except (OSError, ValueError): + current = working_directory.absolute() + + seen_paths = builtins.set() # Track visited paths to prevent infinite loops + + # Build chain from working directory up to (and including) base_dir + while current not in seen_paths: + seen_paths.add(current) + chain.append(current) + + # Stop at base_dir + if current == self.base_dir: + break + + # Stop if we can't go higher or hit filesystem root + try: + parent = current.parent + if parent == current: # We've hit filesystem root + break + current = parent + except (OSError, ValueError): + break + + self._inheritance_cache[working_directory] = chain + return chain + + def _is_child_directory(self, child, parent) -> bool: + """Check if child is a subdirectory of parent. + + Args: + child (Path): Potential child directory. + parent (Path): Potential parent directory. + + Returns: + bool: True if child is subdirectory of parent. + """ + try: + child.resolve().relative_to(parent.resolve()) + return child.resolve() != parent.resolve() + except ValueError: + return False + + def _is_instruction_relevant(self, instruction, working_directory) -> bool: + """Check if instruction is relevant for the working directory. + + Args: + instruction (Instruction): Instruction to check. + working_directory (Path): Directory where agent is working. + + Returns: + bool: True if instruction is relevant. + """ + if not instruction.apply_to: + return True # Global instructions are always relevant + + pattern = instruction.apply_to + + # Resolve working directory to handle path inconsistencies + try: + resolved_working_dir = working_directory.resolve() + except (OSError, ValueError): + resolved_working_dir = working_directory.absolute() + + # Check if working directory has files matching the pattern + analysis = self._directory_cache.get(resolved_working_dir) + if not analysis: + return False + + # If pattern already analyzed, use cached result + if pattern in analysis.pattern_matches: + return analysis.pattern_matches[pattern] > 0 + + # Otherwise, analyze this specific directory for the pattern + # Only check direct files in this directory (not subdirectories for simplicity) + matching_files = 0 + + try: + for file in os.listdir(resolved_working_dir): + if file.startswith("."): + continue + + file_path = resolved_working_dir / file + if file_path.is_file(): + if self._file_matches_pattern(file_path, pattern): + matching_files += 1 + except (OSError, PermissionError): + # Handle case where directory doesn't exist or can't be read + pass + + # Cache the result + analysis.pattern_matches[pattern] = matching_files + + return matching_files > 0 diff --git a/src/apm_cli/compilation/_placement_solver.py b/src/apm_cli/compilation/_placement_solver.py new file mode 100644 index 000000000..dda648bee --- /dev/null +++ b/src/apm_cli/compilation/_placement_solver.py @@ -0,0 +1,482 @@ +"""Mixin: placement-optimisation solver methods for ContextOptimizer. + +Extracted from context_optimizer.ContextOptimizer to stay under the 800-line +guardrail (Strangler Stage 2 / issue #1078). + +Rule B routing +-------------- +``Path`` is patched at ``apm_cli.compilation.context_optimizer.Path`` in tests +(specifically ``Path.resolve``). Any moved method that constructs ``Path(...)`` +does so via a function-level late import so the mock is picked up at call time: + + from apm_cli.compilation import context_optimizer as _co + _co.Path(...) +""" + +from __future__ import annotations + +import builtins + +from ..output.models import OptimizationDecision, PlacementStrategy +from ..primitives.models import Instruction +from ..utils.paths import portable_relpath + + +class _PlacementSolverMixin: + """Mixin: mathematical placement-optimisation solver for ContextOptimizer.""" + + def _find_optimal_placements( + self, instruction: Instruction, verbose: bool = False + ) -> builtins.list: + """Find optimal placement(s) for an instruction using mathematical optimization. + + This implements constraint satisfaction optimization that guarantees every + instruction gets placed at its mathematically optimal location(s). + + Args: + instruction (Instruction): Instruction to place. + verbose (bool): Collect verbose analysis data. + + Returns: + List[Path]: List of optimal directory placements. + """ + return self._solve_placement_optimization(instruction, verbose) + + def _solve_placement_optimization( + self, instruction: Instruction, verbose: bool = False + ) -> builtins.list: + """Mathematical optimization solver for instruction placement. + + Implements the mathematician's objective function: + minimize: sum(context_pollution x directory_weight) + subject to: for_all instruction -> exists placement + + Args: + instruction (Instruction): Instruction to optimize placement for. + verbose (bool): Collect verbose analysis data. + + Returns: + List[Path]: Mathematically optimal placement(s). + """ + pattern = instruction.apply_to + + # Find all directories with matching files + matching_directories = self._find_matching_directories(pattern) + + if not matching_directories: + # Smart fallback: Try to place in semantically appropriate directory + intended_dir = self._extract_intended_directory_from_pattern(pattern) + name = getattr(instruction, "name", None) or instruction.file_path.stem + + if intended_dir: + # Place in the intended directory (e.g., docs/ for docs/**/*.md) + placement = intended_dir + reasoning = f"No matching files found, placed in intended directory '{portable_relpath(intended_dir, self.base_dir)}'" + self._warnings.append( + f"applyTo for '{name}' matched no files - placing in '{portable_relpath(intended_dir, self.base_dir)}'" + ) + else: + # Fallback to root for global patterns + placement = self.base_dir + reasoning = "No matching files found, fallback to root placement" + self._warnings.append( + f"applyTo for '{name}' matched no files - placing at project root" + ) + + # Calculate relevance score for the fallback placement + relevance_score = 0.0 # No matches means no relevance + if placement in self._directory_cache: + relevance_score = self._calculate_coverage_efficiency(placement, pattern) + + decision = OptimizationDecision( + instruction=instruction, + pattern=pattern, + matching_directories=0, + total_directories=len(self._directory_cache), + distribution_score=0.0, + strategy=PlacementStrategy.DISTRIBUTED, + placement_directories=[placement], + reasoning=reasoning, + relevance_score=relevance_score, + ) + self._optimization_decisions.append(decision) + + return [placement] + + # Calculate distribution score with diversity factor + distribution_score = self._calculate_distribution_score(matching_directories) + + # Apply three-tier placement strategy based on mathematical analysis + if distribution_score < self.LOW_DISTRIBUTION_THRESHOLD: + # Low distribution: Single Point Placement + strategy = PlacementStrategy.SINGLE_POINT + placements = self._optimize_single_point_placement( + matching_directories, instruction, verbose + ) + reasoning = "Low distribution pattern optimized for minimal pollution" + elif distribution_score > self.HIGH_DISTRIBUTION_THRESHOLD: + # High distribution: Distributed Placement + strategy = PlacementStrategy.DISTRIBUTED + placements = self._optimize_distributed_placement( + matching_directories, instruction, verbose + ) + reasoning = "High distribution pattern placed at root to minimize duplication" + else: + # Medium distribution: Selective Multi-Placement + strategy = PlacementStrategy.SELECTIVE_MULTI + placements = self._optimize_selective_placement( + matching_directories, instruction, verbose + ) + reasoning = "Medium distribution pattern with selective high-relevance placement" + + # Calculate relevance score for the primary placement directory + relevance_score = 0.0 + if placements: + primary_placement = placements[0] # Use first placement as representative + if primary_placement in self._directory_cache: + relevance_score = self._calculate_coverage_efficiency(primary_placement, pattern) + + # Record optimization decision + decision = OptimizationDecision( + instruction=instruction, + pattern=pattern, + matching_directories=len(matching_directories), + total_directories=len(self._directory_cache), + distribution_score=distribution_score, + strategy=strategy, + placement_directories=placements, + reasoning=reasoning, + relevance_score=relevance_score, + ) + self._optimization_decisions.append(decision) + + return placements + + def _optimize_single_point_placement( + self, + matching_directories: builtins.set, + instruction: Instruction, + verbose: bool = False, + ) -> builtins.list: + """Optimize placement for low distribution patterns (< 0.3 ratio). + + Strategy: Ensure mandatory coverage constraint first, then optimize for minimal pollution. + Coverage guarantee takes priority over efficiency optimization. + """ + candidates = self._generate_all_candidates(matching_directories, instruction) + + if not candidates: + return [self.base_dir] + + # CRITICAL: Mandatory coverage constraint - filter candidates that provide complete coverage + coverage_candidates = [] + for candidate in candidates: + # Verify this placement can provide hierarchical coverage for ALL matching directories + covered_directories = self._calculate_hierarchical_coverage( + [candidate.directory], matching_directories + ) + if covered_directories == matching_directories: + # This candidate satisfies the mandatory coverage constraint + coverage_candidates.append(candidate) + + # If no single candidate provides complete coverage, find minimal coverage placement + if not coverage_candidates: + minimal_coverage = self._find_minimal_coverage_placement(matching_directories) + if minimal_coverage: + return [minimal_coverage] + else: + # Ultimate fallback to root to guarantee coverage + return [self.base_dir] + + # Among coverage-compliant candidates, select the one with best efficiency/pollution ratio + best_candidate = max( + coverage_candidates, key=lambda c: c.coverage_efficiency - c.pollution_score + ) + + return [best_candidate.directory] + + def _optimize_distributed_placement( + self, + matching_directories: builtins.set, + instruction: Instruction, + verbose: bool = False, + ) -> builtins.list: + """Optimize placement for high distribution patterns (> 0.7 ratio). + + Strategy: Place at root to minimize duplication while maintaining accessibility. + """ + return [self.base_dir] + + def _optimize_selective_placement( + self, + matching_directories: builtins.set, + instruction: Instruction, + verbose: bool = False, + ) -> builtins.list: + """Optimize placement for medium distribution patterns (0.3-0.7 ratio). + + Strategy: Ensure hierarchical coverage - all matching files must be able + to inherit the instruction through the hierarchical AGENTS.md system. + """ + # First check if we can achieve complete coverage with a single high-level placement + coverage_placement = self._find_minimal_coverage_placement(matching_directories) + if coverage_placement: + return [coverage_placement] + + # If single placement doesn't work, use multi-placement strategy + candidates = self._generate_all_candidates(matching_directories, instruction) + + if not candidates: + return [self.base_dir] + + # Filter for high-relevance candidates (top 20% or relevance > 0.8) + high_relevance_threshold = max( + 0.8, + sorted([c.coverage_efficiency for c in candidates], reverse=True)[ + max(0, len(candidates) // 5) + ], + ) + + high_relevance_candidates = [ + c for c in candidates if c.coverage_efficiency >= high_relevance_threshold + ] + + if not high_relevance_candidates: + # Fallback: use best candidate + high_relevance_candidates = [max(candidates, key=lambda c: c.total_score)] + + optimal_placements = [c.directory for c in high_relevance_candidates] + + # CRITICAL: Verify hierarchical coverage + covered_directories = self._calculate_hierarchical_coverage( + optimal_placements, matching_directories + ) + uncovered_directories = matching_directories - covered_directories + + if uncovered_directories: + # Coverage violation! Find minimal placement that covers everything + minimal_coverage = self._find_minimal_coverage_placement(matching_directories) + if minimal_coverage: + return [minimal_coverage] + else: + # Fallback to root to ensure no coverage gaps + return [self.base_dir] + + return optimal_placements + + def _generate_all_candidates( + self, matching_directories: builtins.set, instruction: Instruction + ) -> builtins.list: + """Generate all placement candidates with optimization scores. + + This includes both matching directories AND their common ancestors to ensure + the mandatory coverage constraint can be satisfied. + """ + candidates = [] + pattern = instruction.apply_to + + # Collect all potential placement directories: + # 1. The matching directories themselves + # 2. Their common ancestors (for coverage guarantee) + potential_directories = builtins.set(matching_directories) + + # Add common ancestor directories to ensure coverage options exist + if len(matching_directories) > 1: + # Find common ancestors that could provide coverage + common_ancestor = self._find_minimal_coverage_placement(matching_directories) + if common_ancestor: + potential_directories.add(common_ancestor) + + # Also add any intermediate directories in the inheritance chains + for directory in matching_directories: + chain = self._get_inheritance_chain(directory) + # Add intermediate directories that could provide coverage + for intermediate in chain: + if intermediate != directory and intermediate in self._directory_cache: + potential_directories.add(intermediate) + + # Generate candidates for all potential directories + for directory in sorted(potential_directories): + if directory not in self._directory_cache: + continue + + analysis = self._directory_cache[directory] + + # Calculate the three optimization objectives + coverage_efficiency = self._calculate_coverage_efficiency(directory, pattern) + pollution_score = self._calculate_pollution_minimization(directory, pattern) + maintenance_locality = self._calculate_maintenance_locality(directory, pattern) + + # Apply depth penalty for excessive nesting + depth_penalty = max(0, (analysis.depth - 3) * self.DEPTH_PENALTY_FACTOR) + + # Calculate total objective function score + total_score = ( + coverage_efficiency * self.COVERAGE_EFFICIENCY_WEIGHT + + (1.0 - pollution_score) * self.POLLUTION_MINIMIZATION_WEIGHT + + maintenance_locality * self.MAINTENANCE_LOCALITY_WEIGHT + - depth_penalty + ) + + # PlacementCandidate lives in context_optimizer; import lazily to avoid cycle + from .context_optimizer import PlacementCandidate + + candidate = PlacementCandidate( + instruction=instruction, + directory=directory, + direct_relevance=coverage_efficiency, # Legacy field + inheritance_pollution=pollution_score, # Legacy field + depth_specificity=analysis.depth * 0.1, # Legacy field + total_score=0.0, # Temporary value, will be overwritten + ) + + # Add new optimization fields + candidate.coverage_efficiency = coverage_efficiency + candidate.pollution_score = pollution_score + candidate.maintenance_locality = maintenance_locality + + # Set the mathematical optimization score (after __post_init__ has run) + candidate.total_score = total_score + + candidates.append(candidate) + + return candidates + + def _find_minimal_coverage_placement(self, matching_directories: builtins.set): + """Find the highest directory that can provide hierarchical coverage for all matching directories. + + Args: + matching_directories: Directories that contain files matching the pattern + + Returns: + Path to the minimal covering directory, or None if no single placement works + """ + if not matching_directories: + return None + + # Convert to relative paths for easier analysis + relative_dirs = [ + d.resolve().relative_to(self.base_dir.resolve()) for d in matching_directories + ] + + # Find the lowest common ancestor that covers all directories + if len(relative_dirs) == 1: + # Single directory - we can place instruction in that directory or any parent + return next(iter(matching_directories)) + + # Find common path prefix for all directories + common_parts = [] + min_depth = min(len(d.parts) for d in relative_dirs) + + for i in range(min_depth): + parts_at_level = [d.parts[i] for d in relative_dirs] + if len(builtins.set(parts_at_level)) == 1: + # All directories share this path component + common_parts.append(parts_at_level[0]) + else: + break + + if common_parts: + # Found common ancestor. + # Rule B: Path is patched at context_optimizer.Path in tests. + from apm_cli.compilation import context_optimizer as _co + + common_ancestor = self.base_dir / _co.Path(*common_parts) + return common_ancestor + else: + # No common ancestor beyond root - place at root + return self.base_dir + + def _calculate_hierarchical_coverage( + self, placements: builtins.list, target_directories: builtins.set + ) -> builtins.set: + """Calculate which target directories are covered by the given placements through hierarchical inheritance. + + Args: + placements: List of directories where AGENTS.md files will be placed + target_directories: Directories that need to be covered + + Returns: + Set of target directories that are covered by the placements + """ + covered = builtins.set() + + for target in target_directories: + for placement in placements: + if self._is_hierarchically_covered(target, placement): + covered.add(target) + break + + return covered + + def _is_hierarchically_covered(self, target_dir, placement_dir) -> bool: + """Check if target_dir can inherit instructions from placement_dir through hierarchy. + + This is true if placement_dir is target_dir itself or any parent of target_dir. + """ + try: + # Check if target is the same as placement or is a subdirectory of placement + target_dir.resolve().relative_to(placement_dir.resolve()) + return True + except ValueError: + # target_dir is not under placement_dir + return False + + def _calculate_coverage_efficiency(self, directory, pattern: str) -> float: + """Calculate how well placement covers actual usage.""" + analysis = self._directory_cache[directory] + return analysis.get_relevance_score(pattern) + + def _calculate_pollution_minimization(self, directory, pattern: str) -> float: + """Calculate pollution score (higher = more pollution).""" + return self._calculate_inheritance_pollution(directory, pattern) + + def _calculate_maintenance_locality(self, directory, pattern: str) -> float: + """Calculate maintenance locality score.""" + # Simple heuristic: prefer directories with more related files + analysis = self._directory_cache[directory] + pattern_matches = analysis.pattern_matches.get(pattern, 0) + + if analysis.total_files == 0: + return 0.0 + + return min(1.0, pattern_matches / analysis.total_files) + + def _select_clean_separation_placements( + self, candidates: builtins.list, pattern: str + ) -> builtins.list: + """Select placements that provide clean separation of concerns. + + Args: + candidates (List[PlacementCandidate]): Sorted placement candidates. + pattern (str): Instruction pattern. + + Returns: + List[Path]: List of directories for clean separation. + """ + # Look for distinct clusters of files + clusters = [] + + for candidate in candidates: + # Check if this directory is isolated (not a parent/child of others) + is_isolated = True + + for other in candidates: + if candidate.directory == other.directory: + continue + + if self._is_child_directory( + candidate.directory, other.directory + ) or self._is_child_directory(other.directory, candidate.directory): + is_isolated = False + break + + if is_isolated and candidate.direct_relevance >= 0.1: # Use fixed threshold + clusters.append(candidate.directory) + + # If we found clean clusters, use them + if len(clusters) > 1: + return clusters + + # Otherwise, return single best placement + return [] diff --git a/src/apm_cli/compilation/agents_compiler.py b/src/apm_cli/compilation/agents_compiler.py index 085a86c87..a74ce6e74 100644 --- a/src/apm_cli/compilation/agents_compiler.py +++ b/src/apm_cli/compilation/agents_compiler.py @@ -5,7 +5,6 @@ primitives & constitution are unchanged. """ -import hashlib import logging from dataclasses import dataclass from pathlib import Path @@ -15,15 +14,14 @@ CompileTargetType, should_compile_agents_md, should_compile_claude_md, - should_compile_copilot_instructions_md, should_compile_gemini_md, ) from ..primitives.discovery import discover_primitives from ..primitives.models import PrimitiveCollection from ..utils.paths import portable_relpath from ..version import get_version -from .claude_formatter import ClaudeFormatter -from .constants import BUILD_ID_PLACEHOLDER +from ._agents_emit import _AgentsEmitMixin +from ._agents_output import _AgentsOutputMixin from .link_resolver import resolve_markdown_links, validate_link_targets from .template_builder import ( TemplateData, @@ -236,7 +234,7 @@ class CompilationResult: has_critical_security: bool = False -class AgentsCompiler: +class AgentsCompiler(_AgentsEmitMixin, _AgentsOutputMixin): """Main compiler for generating AGENTS.md files.""" def __init__(self, base_dir: str = ".", source_dir: str | None = None): @@ -597,307 +595,6 @@ def _compile_single_file( stats=stats, ) - def _compile_claude_md( - self, config: CompilationConfig, primitives: PrimitiveCollection - ) -> CompilationResult: - """Compile CLAUDE.md files (Claude Code target). - - Uses ClaudeFormatter to generate CLAUDE.md files following Claude's - Memory format with @import syntax, grouped project standards, and - workflows section for agents/roles. - - Args: - config (CompilationConfig): Compilation configuration. - primitives (PrimitiveCollection): Primitives to compile. - - Returns: - CompilationResult: Result of the CLAUDE.md compilation. - """ - errors = self.validate_primitives(primitives) - self.errors.extend(errors) - - # Create Claude formatter - claude_formatter = ClaudeFormatter(str(self.base_dir), source_dir=str(self.source_dir)) - - # Honor compilation.strategy=single-file (and the --single-agents flag) - # by collapsing all instructions into a single root CLAUDE.md, mirroring - # the gate in _compile_agents_md. Without this, single-file mode is - # silently ignored for the Claude target and per-subdirectory CLAUDE.md - # files are emitted via the distributed placement path (issue #1445). - # - # DistributedAgentsCompiler is only constructed on the distributed - # branch -- single-file mode does not use its placement analysis and - # the later display block guards on `distributed_compiler is not None`. - distributed_compiler = None - if config.strategy != "distributed" or config.single_agents: - placement_map = {self.base_dir: list(primitives.instructions)} - else: - from .distributed_compiler import DistributedAgentsCompiler - - distributed_compiler = DistributedAgentsCompiler( - str(self.base_dir), - exclude_patterns=config.exclude, - source_dir=str(self.source_dir), - ) - # Analyze directory structure and determine placement - directory_map = distributed_compiler.analyze_directory_structure( - primitives.instructions - ) - placement_map = distributed_compiler.determine_agents_placement( - primitives.instructions, - directory_map, - min_instructions=config.min_instructions_per_file, - debug=config.debug, - ) - - # Skip instructions in CLAUDE.md when they are already deployed to - # .claude/rules/ by `apm install` (avoids duplicate context in Claude Code). - # --no-dedup / --force-instructions lets users opt out of this behaviour. - if config.no_dedup: - skip_instructions = False - self._log( - "progress", - "Including instructions in CLAUDE.md (--no-dedup overrides deduplication)", - symbol="info", - ) - else: - skip_instructions = _detect_deployed_instructions( - self.base_dir / ".claude" / "rules", - self.base_dir, - lambda msg: self._log("warning", msg), - ) - if skip_instructions: - self._log( - "progress", - "Instructions already in .claude/rules/ -- omitting from CLAUDE.md" - " to avoid duplicate context", - symbol="info", - ) - - # Format CLAUDE.md files - claude_config = { - "source_attribution": config.source_attribution, - "debug": config.debug, - "skip_instructions": skip_instructions, - } - claude_result = claude_formatter.format_distributed( - primitives, placement_map, claude_config - ) - - # NOTE: Claude commands are now generated at install time via CommandIntegrator, - # not at compile time. This keeps behavior consistent with VSCode prompt integration. - - # Merge warnings and errors (no command result anymore) - all_warnings = self.warnings + claude_result.warnings - all_errors = self.errors + claude_result.errors - - # Handle dry-run mode - if config.dry_run: - # Generate preview summary - count = len(claude_result.placements) - preview_lines = [ - f"CLAUDE.md Preview: Would generate {count} {'file' if count == 1 else 'files'}" - ] - # Surface the deduplication skip so dry-run is self-explanatory - # for scripted consumers (otherwise "Would generate 0 files" - # looks like a no-op or a bug). The same skip appears in the - # non-dry-run path via the dedicated INFO log line. - if skip_instructions: - preview_lines.append( - " (instructions section skipped: .claude/rules/ already " - "populated -- avoids duplicate content in Claude Code's " - "context window)" - ) - for claude_path in claude_result.content_map.keys(): # noqa: SIM118 - rel_path = portable_relpath(claude_path, self.base_dir) - preview_lines.append(f" {rel_path}") - - return CompilationResult( - success=len(all_errors) == 0, - output_path="Preview mode - CLAUDE.md", - content="\n".join(preview_lines), - warnings=all_warnings, - errors=all_errors, - stats=claude_result.stats, - ) - - # Write CLAUDE.md files - files_written = 0 - critical_security_found = False - from ..security.gate import WARN_POLICY, SecurityGate - from .output_writer import CompiledOutputWriter - - writer = CompiledOutputWriter() - for claude_path, content in claude_result.content_map.items(): - try: - # Handle constitution injection if enabled - final_content = content - if config.with_constitution: - try: - from .injector import ConstitutionInjector - - injector = ConstitutionInjector(str(claude_path.parent)) - final_content, _, _ = injector.inject( - content, with_constitution=True, output_path=claude_path - ) - except Exception as exc: - _logger.debug("Constitution injection failed for %s: %s", claude_path, exc) - - # Defense-in-depth: scan compiled output before writing - verdict = SecurityGate.scan_text( - final_content, str(claude_path), policy=WARN_POLICY - ) - actionable = verdict.critical_count + verdict.warning_count - if actionable: - if verdict.has_critical: - critical_security_found = True - all_warnings.append( - f"CLAUDE.md contains {actionable} hidden character(s) " - f"— run 'apm audit --file {claude_path}' to inspect" - ) - - writer.write(claude_path, final_content) - files_written += 1 - except OSError as e: - all_errors.append(f"Failed to write {claude_path}: {e!s}") - - # Update stats - stats = claude_result.stats.copy() - stats["claude_files_written"] = files_written - - if files_written == 0 and skip_instructions: - self._log( - "progress", - "CLAUDE.md not generated -- Claude Code reads .claude/rules/ directly," - " no further action needed", - symbol="info", - ) - elif distributed_compiler is None and files_written > 0 and not config.dry_run: - # Single-file strategy bypasses the distributed display formatter - # (which has no analysis to render). Emit a minimal progress line - # so users get a confirmation that single-file mode took effect. - noun = "file" if files_written == 1 else "files" - self._log( - "progress", - f"CLAUDE.md compiled ({files_written} {noun})", - symbol="success", - ) - - # Display CLAUDE.md compilation output using standard formatter - # Get proper compilation results from distributed compiler (has optimization decisions) - # Skip formatter output when deduplication filtered out all placements to - # avoid contradicting the "not generated" log message above. - from ..output.formatters import CompilationFormatter - from ..output.models import CompilationResults - - compilation_results = ( - distributed_compiler.get_compilation_results_for_display(is_dry_run=config.dry_run) - if distributed_compiler is not None - else None - ) - if compilation_results and not (skip_instructions and files_written == 0): - # Update target name for CLAUDE.md output - formatter_results = CompilationResults( - project_analysis=compilation_results.project_analysis, - optimization_decisions=compilation_results.optimization_decisions, - placement_summaries=compilation_results.placement_summaries, - optimization_stats=compilation_results.optimization_stats, - warnings=all_warnings, - errors=all_errors, - is_dry_run=config.dry_run, - target_name="CLAUDE.md", - ) - - # Use the same formatter as AGENTS.md - formatter = CompilationFormatter(use_color=True) - if config.debug or config.trace: - output = formatter.format_verbose(formatter_results) - elif config.dry_run: - output = formatter.format_dry_run(formatter_results) - else: - output = formatter.format_default(formatter_results) - self._log("progress", output) - - # Generate summary content for result object - summary_lines = [ - f"# CLAUDE.md Compilation Summary", # noqa: F541 - f"", # noqa: F541 - f"Generated {files_written} CLAUDE.md files:", - ] - for placement in claude_result.placements: - rel_path = portable_relpath(placement.claude_path, self.base_dir) - summary_lines.append(f"- {rel_path} ({len(placement.instructions)} instructions)") - - return CompilationResult( - success=len(all_errors) == 0, - output_path=f"CLAUDE.md: {files_written} files", - content="\n".join(summary_lines), - warnings=all_warnings, - errors=all_errors, - stats=stats, - has_critical_security=critical_security_found, - ) - - def _compile_gemini_md( - self, config: CompilationConfig, primitives: PrimitiveCollection - ) -> CompilationResult: - """Compile GEMINI.md stub that imports AGENTS.md. - - Gemini CLI supports ``@./path`` import syntax, so GEMINI.md is a - thin wrapper that pulls in AGENTS.md at load time. The actual - instruction roll-up is handled by the AGENTS.md pipeline (which - is always compiled alongside via ``should_compile_agents_md``). - - Args: - config: Compilation configuration. - primitives: Primitives to compile. - - Returns: - CompilationResult for the GEMINI.md compilation. - """ - from .gemini_formatter import GeminiFormatter - - gemini_formatter = GeminiFormatter(str(self.base_dir)) - gemini_result = gemini_formatter.format_distributed(primitives) - - all_warnings = self.warnings + gemini_result.warnings - all_errors = self.errors + gemini_result.errors - - if config.dry_run: - return CompilationResult( - success=len(all_errors) == 0, - output_path="Preview mode - GEMINI.md", - content="GEMINI.md Preview: Would generate stub importing AGENTS.md", - warnings=all_warnings, - errors=all_errors, - stats=gemini_result.stats, - ) - - files_written = 0 - from .output_writer import CompiledOutputWriter - - writer = CompiledOutputWriter() - for gemini_path, content in gemini_result.content_map.items(): - try: - writer.write(gemini_path, content) - files_written += 1 - except OSError as e: - all_errors.append(f"Failed to write {gemini_path}: {e!s}") - - stats = gemini_result.stats.copy() - stats["gemini_files_written"] = files_written - - self._log("progress", "Generated GEMINI.md (imports AGENTS.md)") - - return CompilationResult( - success=len(all_errors) == 0, - output_path=f"GEMINI.md: {files_written} files", - content=f"Generated {files_written} GEMINI.md stub importing AGENTS.md", - warnings=all_warnings, - errors=all_errors, - stats=stats, - ) - def _merge_results(self, results: list[CompilationResult]) -> CompilationResult: """Merge multiple compilation results into a single result. @@ -1050,416 +747,6 @@ def _generate_template_data( chatmode_content=chatmode_content, ) - def _maybe_emit_copilot_root_instructions( - self, - config: CompilationConfig, - primitives: PrimitiveCollection, - result: CompilationResult, - ) -> CompilationResult: - """Generate .github/copilot-instructions.md for Copilot-capable targets. - - Skip semantics: if the file already exists without the APM-generated - marker, it is treated as hand-authored and left untouched. The - ``copilot_root_instructions_skipped`` stat captures this case - explicitly so callers can distinguish it from a genuine no-op - (``copilot_root_instructions_unchanged``) or an unrouted target. - """ - routing_target = "vscode" if config.target in _VSCODE_TARGET_ALIASES else config.target - output_path = self.base_dir / ".github" / "copilot-instructions.md" - if not should_compile_copilot_instructions_md(routing_target): - if not config.dry_run: - self._cleanup_copilot_root_instructions(output_path, result) - result.stats.setdefault("copilot_root_instructions_generated", 0) - result.stats.setdefault("copilot_root_instructions_written", 0) - result.stats.setdefault("copilot_root_instructions_unchanged", 0) - result.stats.setdefault("copilot_root_instructions_skipped", 0) - result.stats.setdefault("copilot_root_instructions_removed", 0) - return result - - global_instructions = sorted( - [instruction for instruction in primitives.instructions if not instruction.apply_to], - key=lambda instruction: portable_relpath(instruction.file_path, self.base_dir), - ) - if not global_instructions: - if not config.dry_run: - self._cleanup_copilot_root_instructions(output_path, result) - result.stats.setdefault("copilot_root_instructions_generated", 0) - result.stats.setdefault("copilot_root_instructions_written", 0) - result.stats.setdefault("copilot_root_instructions_unchanged", 0) - result.stats.setdefault("copilot_root_instructions_skipped", 0) - result.stats.setdefault("copilot_root_instructions_removed", 0) - return result - - content = self._generate_copilot_root_instructions_content(global_instructions, config) - - result.stats["copilot_root_instructions_generated"] = 1 - result.stats.setdefault("copilot_root_instructions_skipped", 0) - result.stats.setdefault("copilot_root_instructions_removed", 0) - result.stats.setdefault("copilot_root_instructions_written", 0) - result.stats.setdefault("copilot_root_instructions_unchanged", 0) - - # Inspect any existing file BEFORE the dry-run early-exit so that - # `--dry-run` faithfully reports what a real run would do (skip vs - # write vs unchanged). Reading the file here is safe in dry-run mode - # because we never mutate it. - try: - existing = output_path.read_text(encoding="utf-8") if output_path.exists() else None - except OSError as exc: - message = f"Failed to read {output_path}: {exc}" - self.errors.append(message) - result.errors.append(message) - result.success = False - return result - - if existing is not None and _COPILOT_ROOT_GENERATED_MARKER not in existing: - rel_path = portable_relpath(output_path, self.base_dir) - result.warnings.append( - f"Skipped {rel_path}: hand-authored file will not be overwritten. " - "To regenerate, either delete or rename it, or prepend the line " - f"'{_COPILOT_ROOT_GENERATED_MARKER}' to the top of the file. " - "Then re-run 'apm compile'." - ) - # The file was never compared to new content; record as - # 'skipped', not 'unchanged'. Also reset 'generated' since no - # output was actually emitted (or would be, on a real run). - result.stats["copilot_root_instructions_generated"] = 0 - result.stats["copilot_root_instructions_written"] = 0 - result.stats["copilot_root_instructions_skipped"] = 1 - result.stats["copilot_root_instructions_unchanged"] = 0 - return result - - if existing == content: - result.stats["copilot_root_instructions_written"] = 0 - result.stats["copilot_root_instructions_unchanged"] = 1 - return result - - if config.dry_run: - return result - - from ..security.gate import WARN_POLICY, SecurityGate - - verdict = SecurityGate.scan_text(content, str(output_path), policy=WARN_POLICY) - actionable = verdict.critical_count + verdict.warning_count - if actionable: - if verdict.has_critical: - result.has_critical_security = True - result.warnings.append( - f"copilot-instructions.md contains {actionable} hidden character(s) " - f"-- run 'apm audit --file {output_path}' to inspect" - ) - - try: - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(content, encoding="utf-8") - result.stats["copilot_root_instructions_written"] = 1 - result.stats["copilot_root_instructions_unchanged"] = 0 - return result - except OSError as exc: - message = f"Failed to write {output_path}: {exc}" - self.errors.append(message) - result.errors.append(message) - result.success = False - result.stats["copilot_root_instructions_written"] = 0 - result.stats.setdefault("copilot_root_instructions_unchanged", 0) - return result - - def _generate_copilot_root_instructions_content( - self, - instructions, - config: CompilationConfig, - ) -> str: - """Generate root Copilot instructions content from global instruction primitives.""" - # Functional marker and Build ID are always present (injection/drift/cleanup coupling). - sections = [ - _COPILOT_ROOT_GENERATED_MARKER, - BUILD_ID_PLACEHOLDER, - ] - if config.source_attribution: - sections.append(f"") - sections.append("") - - for instruction in instructions: - # instruction.file_path is a source-tree file; relativise it - # against source_dir so `apm compile --root` never leaks - # `../../` or absolute deploy-relative paths into the - # `` provenance comments (sources stay in $PWD - # while writes redirect to base_dir). - rel_path = portable_relpath(instruction.file_path, self.source_dir) - if config.source_attribution: - sections.append(f"") - sections.append(instruction.content.strip()) - if config.source_attribution: - sections.append(f"") - sections.append("") - - if config.source_attribution: - sections.append("---") - sections.append("*This file was generated by APM CLI. Do not edit manually.*") - sections.append("*To regenerate: `apm compile`*") - sections.append("") - - content = "\n".join(sections) - if config.resolve_links: - content = resolve_markdown_links(content, self.base_dir) - return self._finalize_build_id(content) - - def _finalize_build_id(self, content: str) -> str: - """Replace the build-id placeholder with a deterministic content hash.""" - lines = content.splitlines() - try: - idx = lines.index(BUILD_ID_PLACEHOLDER) - except ValueError: - return content - - hash_input_lines = [line for i, line in enumerate(lines) if i != idx] - build_id = hashlib.sha256("\n".join(hash_input_lines).encode("utf-8")).hexdigest()[:12] - lines[idx] = f"" - return "\n".join(lines) + ("\n" if content.endswith("\n") else "") - - def _cleanup_copilot_root_instructions( - self, - output_path: Path, - result: CompilationResult, - ) -> CompilationResult: - """Remove stale generated Copilot root instructions when no longer applicable.""" - if not output_path.exists(): - result.stats.setdefault("copilot_root_instructions_removed", 0) - return result - - try: - existing = output_path.read_text(encoding="utf-8") - if _COPILOT_ROOT_GENERATED_MARKER not in existing: - result.stats.setdefault("copilot_root_instructions_removed", 0) - return result - - output_path.unlink() - result.stats["copilot_root_instructions_removed"] = 1 - return result - except OSError as exc: - message = f"Failed to remove stale {output_path}: {exc}" - self.errors.append(message) - result.errors.append(message) - result.success = False - result.stats.setdefault("copilot_root_instructions_removed", 0) - return result - - def _write_output_file(self, output_path: str, content: str) -> None: - """Write the generated content to the output file (full-file mode). - - Args: - output_path (str): Path to write the output. - content (str): Content to write. - """ - from .output_writer import CompiledOutputWriter - - try: - CompiledOutputWriter().write(Path(output_path), content) - except OSError as e: - self.errors.append(f"Failed to write output file {output_path}: {e!s}") - - def _write_output_file_with_config( - self, output_path: str, content: str, config: "CompilationConfig" - ) -> None: - """Write generated content, honouring agents_md_mode from config. - - In ``full`` mode (default) the entire file is replaced. - In ``managed_section`` mode only the text between the configured - start/end markers is replaced; everything else is preserved. - - Args: - output_path (str): Path to write the output. - content (str): Generated content for this compilation. - config (CompilationConfig): Compilation configuration. - """ - from .managed_section import ManagedSectionError, apply_managed_section - from .output_writer import CompiledOutputWriter - - if config.agents_md_mode == "managed_section": - target = Path(output_path) - if not target.is_file(): - raise ManagedSectionError( - f"{target} does not exist yet. " - "Create it with the managed-section markers first, " - "or set agents_md.mode: full in apm.yml for initial generation." - ) - existing = target.read_text(encoding="utf-8") - try: - content = apply_managed_section( - existing, - content, - config.agents_md_start_marker, - config.agents_md_end_marker, - ) - except ManagedSectionError as exc: - raise ManagedSectionError(f"[{target}] {exc}") from exc - elif config.agents_md_mode != "full": - raise ValueError( - f"Unknown agents_md.mode {config.agents_md_mode!r}. " - "Supported values: 'full', 'managed_section'." - ) - - try: - CompiledOutputWriter().write(Path(output_path), content) - except OSError as e: - self.errors.append(f"Failed to write output file {output_path}: {e!s}") - - def _compile_stats( - self, primitives: PrimitiveCollection, template_data: TemplateData - ) -> dict[str, Any]: - """Compile statistics about the compilation. - - Args: - primitives (PrimitiveCollection): Discovered primitives. - template_data (TemplateData): Generated template data. - - Returns: - Dict[str, Any]: Compilation statistics. - """ - return { - "primitives_found": primitives.count(), - "chatmodes": len(primitives.chatmodes), - "instructions": len(primitives.instructions), - "contexts": len(primitives.contexts), - "content_length": len(template_data.instructions_content), - # timestamp removed - "version": template_data.version, - } - - def _write_distributed_file( - self, agents_path: Path, content: str, config: CompilationConfig - ) -> None: - """Write a distributed AGENTS.md file with constitution injection support. - - Args: - agents_path (Path): Path to write the AGENTS.md file. - content (str): Content to write. - config (CompilationConfig): Compilation configuration. - """ - try: - # Handle constitution injection for distributed files - final_content = content - - if config.with_constitution: - # Try to inject constitution if available - try: - from .injector import ConstitutionInjector - - injector = ConstitutionInjector(str(agents_path.parent)) - final_content, c_status, c_hash = injector.inject( # noqa: RUF059 - content, with_constitution=True, output_path=agents_path - ) - except Exception as exc: - _logger.debug("Constitution injection failed for %s: %s", agents_path, exc) - - from .output_writer import CompiledOutputWriter - - CompiledOutputWriter().write(agents_path, final_content) - - except OSError as e: - raise OSError(f"Failed to write distributed AGENTS.md file {agents_path}: {e!s}") # noqa: B904 - - def _display_placement_preview(self, distributed_result) -> None: - """Display placement preview for --show-placement mode. - - Args: - distributed_result: Result from distributed compilation. - """ - self._log("progress", "Distributed AGENTS.md Placement Preview:") - self._log("progress", "") - - for placement in distributed_result.placements: - rel_path = portable_relpath(placement.agents_path, self.base_dir) - self._log("verbose_detail", f"{rel_path}") - self._log("verbose_detail", f" Instructions: {len(placement.instructions)}") - self._log( - "verbose_detail", f" Patterns: {', '.join(sorted(placement.coverage_patterns))}" - ) - if placement.source_attribution: - sources = set(placement.source_attribution.values()) - self._log("verbose_detail", f" Sources: {', '.join(sorted(sources))}") - self._log("verbose_detail", "") - - def _display_trace_info(self, distributed_result, primitives: PrimitiveCollection) -> None: - """Display detailed trace information for --trace mode. - - Args: - distributed_result: Result from distributed compilation. - primitives (PrimitiveCollection): Full primitive collection. - """ - self._log("progress", "Distributed Compilation Trace:") - self._log("progress", "") - - for placement in distributed_result.placements: - rel_path = portable_relpath(placement.agents_path, self.base_dir) - self._log("verbose_detail", f"{rel_path}") - - for instruction in placement.instructions: - source = getattr(instruction, "source", "local") - # instruction.file_path is a source-tree file; relativise - # against source_dir so `apm compile --root` produces - # human-readable paths in verbose output. - inst_path = portable_relpath(instruction.file_path, self.source_dir) - - self._log( - "verbose_detail", - f" * {instruction.apply_to or 'no pattern'} <- {source} {inst_path}", - ) - self._log("verbose_detail", "") - - def _generate_placement_summary(self, distributed_result) -> str: - """Generate a text summary of placement results. - - Args: - distributed_result: Result from distributed compilation. - - Returns: - str: Text summary of placements. - """ - lines = ["Distributed AGENTS.md Placement Summary:", ""] - - for placement in distributed_result.placements: - rel_path = portable_relpath(placement.agents_path, self.base_dir) - lines.append(f"{rel_path}") - lines.append(f" Instructions: {len(placement.instructions)}") - lines.append(f" Patterns: {', '.join(sorted(placement.coverage_patterns))}") - lines.append("") - - lines.append(f"Total AGENTS.md files: {len(distributed_result.placements)}") - return "\n".join(lines) - - def _generate_distributed_summary(self, distributed_result, config: CompilationConfig) -> str: - """Generate a summary of distributed compilation results. - - Args: - distributed_result: Result from distributed compilation. - config (CompilationConfig): Compilation configuration. - - Returns: - str: Summary content. - """ - lines = [ - "# Distributed AGENTS.md Compilation Summary", - "", - f"Generated {len(distributed_result.placements)} AGENTS.md files:", - "", - ] - - for placement in distributed_result.placements: - rel_path = portable_relpath(placement.agents_path, self.base_dir) - lines.append(f"- {rel_path} ({len(placement.instructions)} instructions)") - - lines.extend( - [ - "", - f"Total instructions: {distributed_result.stats.get('total_instructions_placed', 0)}", - f"Total patterns: {distributed_result.stats.get('total_patterns_covered', 0)}", - "", - "Use 'apm compile --single-agents' for traditional single-file compilation.", - ] - ) - - return "\n".join(lines) - def compile_agents_md( primitives: PrimitiveCollection | None = None, diff --git a/src/apm_cli/compilation/context_optimizer.py b/src/apm_cli/compilation/context_optimizer.py index ec36f9355..0430ba84e 100644 --- a/src/apm_cli/compilation/context_optimizer.py +++ b/src/apm_cli/compilation/context_optimizer.py @@ -6,7 +6,6 @@ """ import builtins -import fnmatch import os import time from collections import defaultdict @@ -24,7 +23,8 @@ from ..primitives.models import Instruction from ..utils.exclude import matches_glob, should_exclude, validate_exclude_patterns from ..utils.paths import portable_relpath -from ..utils.patterns import has_top_level_comma, parse_apply_to +from ._pattern_matcher import _PatternMatcherMixin +from ._placement_solver import _PlacementSolverMixin # CRITICAL: Shadow Click commands to prevent namespace collision # When this module is imported during 'apm compile', Click's active context @@ -102,7 +102,7 @@ def __post_init__(self): ) -class ContextOptimizer: +class ContextOptimizer(_PlacementSolverMixin, _PatternMatcherMixin): """Context Optimization Engine for distributed AGENTS.md placement.""" # Mathematical optimization parameters @@ -546,783 +546,3 @@ def _should_exclude_path(self, path: Path) -> bool: True if path should be excluded, False otherwise """ return should_exclude(path, self.base_dir, self._exclude_patterns) - - def _find_optimal_placements( - self, instruction: Instruction, verbose: bool = False - ) -> builtins.list[Path]: - """Find optimal placement(s) for an instruction using mathematical optimization. - - This implements constraint satisfaction optimization that guarantees every - instruction gets placed at its mathematically optimal location(s). - - Args: - instruction (Instruction): Instruction to place. - verbose (bool): Collect verbose analysis data. - - Returns: - List[Path]: List of optimal directory placements. - """ - return self._solve_placement_optimization(instruction, verbose) - - def _solve_placement_optimization( - self, instruction: Instruction, verbose: bool = False - ) -> builtins.list[Path]: - """Mathematical optimization solver for instruction placement. - - Implements the mathematician's objective function: - minimize: sum(context_pollution x directory_weight) - subject to: for_all instruction -> exists placement - - Args: - instruction (Instruction): Instruction to optimize placement for. - verbose (bool): Collect verbose analysis data. - - Returns: - List[Path]: Mathematically optimal placement(s). - """ - pattern = instruction.apply_to - - # Find all directories with matching files - matching_directories = self._find_matching_directories(pattern) - - if not matching_directories: - # Smart fallback: Try to place in semantically appropriate directory - intended_dir = self._extract_intended_directory_from_pattern(pattern) - name = getattr(instruction, "name", None) or instruction.file_path.stem - - if intended_dir: - # Place in the intended directory (e.g., docs/ for docs/**/*.md) - placement = intended_dir - reasoning = f"No matching files found, placed in intended directory '{portable_relpath(intended_dir, self.base_dir)}'" - self._warnings.append( - f"applyTo for '{name}' matched no files - placing in '{portable_relpath(intended_dir, self.base_dir)}'" - ) - else: - # Fallback to root for global patterns - placement = self.base_dir - reasoning = "No matching files found, fallback to root placement" - self._warnings.append( - f"applyTo for '{name}' matched no files - placing at project root" - ) - - # Calculate relevance score for the fallback placement - relevance_score = 0.0 # No matches means no relevance - if placement in self._directory_cache: - relevance_score = self._calculate_coverage_efficiency(placement, pattern) - - decision = OptimizationDecision( - instruction=instruction, - pattern=pattern, - matching_directories=0, - total_directories=len(self._directory_cache), - distribution_score=0.0, - strategy=PlacementStrategy.DISTRIBUTED, - placement_directories=[placement], - reasoning=reasoning, - relevance_score=relevance_score, - ) - self._optimization_decisions.append(decision) - - return [placement] - - # Calculate distribution score with diversity factor - distribution_score = self._calculate_distribution_score(matching_directories) - - # Apply three-tier placement strategy based on mathematical analysis - if distribution_score < self.LOW_DISTRIBUTION_THRESHOLD: - # Low distribution: Single Point Placement - strategy = PlacementStrategy.SINGLE_POINT - placements = self._optimize_single_point_placement( - matching_directories, instruction, verbose - ) - reasoning = "Low distribution pattern optimized for minimal pollution" - elif distribution_score > self.HIGH_DISTRIBUTION_THRESHOLD: - # High distribution: Distributed Placement - strategy = PlacementStrategy.DISTRIBUTED - placements = self._optimize_distributed_placement( - matching_directories, instruction, verbose - ) - reasoning = "High distribution pattern placed at root to minimize duplication" - else: - # Medium distribution: Selective Multi-Placement - strategy = PlacementStrategy.SELECTIVE_MULTI - placements = self._optimize_selective_placement( - matching_directories, instruction, verbose - ) - reasoning = "Medium distribution pattern with selective high-relevance placement" - - # Calculate relevance score for the primary placement directory - relevance_score = 0.0 - if placements: - primary_placement = placements[0] # Use first placement as representative - if primary_placement in self._directory_cache: - relevance_score = self._calculate_coverage_efficiency(primary_placement, pattern) - - # Record optimization decision - decision = OptimizationDecision( - instruction=instruction, - pattern=pattern, - matching_directories=len(matching_directories), - total_directories=len(self._directory_cache), - distribution_score=distribution_score, - strategy=strategy, - placement_directories=placements, - reasoning=reasoning, - relevance_score=relevance_score, - ) - self._optimization_decisions.append(decision) - - return placements - - def _extract_intended_directory_from_pattern(self, pattern: str) -> Path | None: - """Extract the intended directory from a pattern like 'docs/**/*.md' -> 'docs'. - - Args: - pattern (str): File pattern (may be a comma-separated list). - - Returns: - Optional[Path]: Intended directory path, or None if pattern is global. - """ - # For comma-lists, only the first segment is consulted - the - # placement still flows into a single directory. - if has_top_level_comma(pattern): - segments = parse_apply_to(pattern) - if not segments: - return None - pattern = segments[0] - - if not pattern or pattern.startswith("**/"): - return None # Global pattern - - if "/" in pattern: - # Extract the first directory component - parts = pattern.split("/") - first_part = parts[0] - - # Skip if it's a wildcard - if "*" not in first_part and first_part: - intended_dir = self.base_dir / first_part - if intended_dir.exists() and intended_dir.is_dir(): - return intended_dir - - return None - - def _expand_glob_pattern(self, pattern: str) -> builtins.list[str]: - """Expand glob pattern with brace expansion, supporting multiple brace groups. - - Args: - pattern (str): Pattern like '**/*.{css,scss}' or '**/*.{test,spec}.{ts,js}' - - Returns: - List[str]: Expanded patterns like ['**/*.css', '**/*.scss'] - or ['**/*.test.ts', '**/*.test.js', '**/*.spec.ts', '**/*.spec.js'] - """ - import re - - # Handle brace expansion like {css,scss} - brace_match = re.search(r"\{([^}]+)\}", pattern) - if brace_match: - alternatives = brace_match.group(1).split(",") - prefix = pattern[: brace_match.start()] - suffix = pattern[brace_match.end() :] - # Recursively expand remaining brace groups in each result - expanded = [] - for alt in alternatives: - expanded.extend(self._expand_glob_pattern(prefix + alt + suffix)) - return expanded - - return [pattern] - - def _file_matches_pattern(self, file_path: Path, pattern: str) -> bool: - """Check if a file matches a given pattern with optimized performance. - - Args: - file_path (Path): File path to check - pattern (str): Glob pattern or comma-separated list of globs. - - Returns: - bool: True if file matches pattern (or any segment of a list). - """ - # applyTo accepts a comma-separated list of globs; treat any - # segment match as a hit so list patterns mirror per-glob semantics. - # Only split on top-level commas - commas inside brace alternation - # (e.g. ``**/*.{css,scss}``) must stay attached for brace expansion. - if has_top_level_comma(pattern): - segments = parse_apply_to(pattern) - return any(self._file_matches_pattern(file_path, seg) for seg in segments) - - # Expand any brace patterns - expanded_patterns = self._expand_glob_pattern(pattern) - - for expanded_pattern in expanded_patterns: - # For patterns with **, use cached glob results - if "**" in expanded_pattern: - try: - # Resolve both paths to handle symlinks and path inconsistencies - resolved_file = file_path.resolve() - rel_path = resolved_file.relative_to(self.base_dir.resolve()) - - # Use cached glob results instead of repeated glob calls - matches = self._cached_glob(expanded_pattern) - # Use cached Set[Path] to avoid recreating on every call - if expanded_pattern not in self._glob_set_cache: - self._glob_set_cache[expanded_pattern] = {Path(match) for match in matches} - if rel_path in self._glob_set_cache[expanded_pattern]: - return True - except (ValueError, OSError): - pass - else: - # For non-recursive patterns, use fnmatch as before - try: - rel_str = portable_relpath(file_path, self.base_dir) - if fnmatch.fnmatch(rel_str, expanded_pattern): - return True - except ValueError: - pass - - # Only use filename match for patterns without directory structure - # This prevents "docs/**/*.md" from matching any "*.md" file anywhere - if "/" not in expanded_pattern: - if fnmatch.fnmatch(file_path.name, expanded_pattern): - return True - - return False - - def _find_matching_directories(self, pattern: str) -> builtins.set[Path]: - """Find directories that contain files matching the pattern. - - Args: - pattern (str): File pattern to match. - - Returns: - Set[Path]: Set of directories with matching files. - """ - # Use cached result if available - if pattern in self._pattern_cache: - return self._pattern_cache[pattern] - - matching_dirs: builtins.set[Path] = set() - - # Use the reliable approach for all patterns - for directory, analysis in sorted(self._directory_cache.items()): - try: - files = [ - f for f in directory.iterdir() if f.is_file() and not f.name.startswith(".") - ] - - match_count = 0 - for file_path in files: - if self._file_matches_pattern(file_path, pattern): - match_count += 1 - matching_dirs.add(directory) - - if match_count > 0: - analysis.pattern_matches[pattern] = match_count - except (OSError, PermissionError): - continue - - self._pattern_cache[pattern] = matching_dirs - return matching_dirs - - def _calculate_inheritance_pollution(self, directory: Path, pattern: str) -> float: - """Calculate inheritance pollution score for placing instruction at directory. - - Args: - directory (Path): Candidate placement directory. - pattern (str): Instruction pattern. - - Returns: - float: Pollution score (higher = more pollution). - """ - pollution_score = 0.0 - - # Optimization: Only check direct children instead of all directories - # This prevents O(n2) complexity with unlimited depth analysis - try: - direct_children = [ - child - for child in directory.iterdir() - if child.is_dir() and child in self._directory_cache - ] - - # Check only direct child directories for pollution - for child_dir in direct_children: - analysis = self._directory_cache[child_dir] - - # If child has no matching files, this creates pollution - child_relevance = analysis.get_relevance_score(pattern) - if child_relevance == 0.0: - pollution_score += 0.5 # Strong pollution penalty - elif child_relevance < 0.1: # Weak relevance threshold - pollution_score += 0.2 # Weak pollution penalty - except (OSError, PermissionError): - # Skip directories we can't read - pass - - return pollution_score - - def _calculate_distribution_score(self, matching_directories: builtins.set[Path]) -> float: - """Calculate distribution score with diversity factor. - - Args: - matching_directories: Set of directories with pattern matches. - - Returns: - float: Distribution score accounting for spread and depth diversity. - """ - total_dirs_with_files = len( - [d for d in self._directory_cache.values() if d.total_files > 0] - ) - if total_dirs_with_files == 0: - return 0.0 - - base_ratio = len(matching_directories) / total_dirs_with_files - - # Calculate diversity factor based on depth distribution - depths = [self._directory_cache[d].depth for d in matching_directories] - if not depths: - return base_ratio - - depth_variance = sum((d - sum(depths) / len(depths)) ** 2 for d in depths) / len(depths) - diversity_factor = 1.0 + (depth_variance * self.DIVERSITY_FACTOR_BASE) - - return base_ratio * diversity_factor - - def _optimize_single_point_placement( - self, - matching_directories: builtins.set[Path], - instruction: Instruction, - verbose: bool = False, - ) -> builtins.list[Path]: - """Optimize placement for low distribution patterns (< 0.3 ratio). - - Strategy: Ensure mandatory coverage constraint first, then optimize for minimal pollution. - Coverage guarantee takes priority over efficiency optimization. - """ - candidates = self._generate_all_candidates(matching_directories, instruction) - - if not candidates: - return [self.base_dir] - - # CRITICAL: Mandatory coverage constraint - filter candidates that provide complete coverage - coverage_candidates = [] - for candidate in candidates: - # Verify this placement can provide hierarchical coverage for ALL matching directories - covered_directories = self._calculate_hierarchical_coverage( - [candidate.directory], matching_directories - ) - if covered_directories == matching_directories: - # This candidate satisfies the mandatory coverage constraint - coverage_candidates.append(candidate) - - # If no single candidate provides complete coverage, find minimal coverage placement - if not coverage_candidates: - minimal_coverage = self._find_minimal_coverage_placement(matching_directories) - if minimal_coverage: - return [minimal_coverage] - else: - # Ultimate fallback to root to guarantee coverage - return [self.base_dir] - - # Among coverage-compliant candidates, select the one with best efficiency/pollution ratio - best_candidate = max( - coverage_candidates, key=lambda c: c.coverage_efficiency - c.pollution_score - ) - - return [best_candidate.directory] - - def _optimize_distributed_placement( - self, - matching_directories: builtins.set[Path], - instruction: Instruction, - verbose: bool = False, - ) -> builtins.list[Path]: - """Optimize placement for high distribution patterns (> 0.7 ratio). - - Strategy: Place at root to minimize duplication while maintaining accessibility. - """ - return [self.base_dir] - - def _optimize_selective_placement( - self, - matching_directories: builtins.set[Path], - instruction: Instruction, - verbose: bool = False, - ) -> builtins.list[Path]: - """Optimize placement for medium distribution patterns (0.3-0.7 ratio). - - Strategy: Ensure hierarchical coverage - all matching files must be able - to inherit the instruction through the hierarchical AGENTS.md system. - """ - # First check if we can achieve complete coverage with a single high-level placement - coverage_placement = self._find_minimal_coverage_placement(matching_directories) - if coverage_placement: - return [coverage_placement] - - # If single placement doesn't work, use multi-placement strategy - candidates = self._generate_all_candidates(matching_directories, instruction) - - if not candidates: - return [self.base_dir] - - # Filter for high-relevance candidates (top 20% or relevance > 0.8) - high_relevance_threshold = max( - 0.8, - sorted([c.coverage_efficiency for c in candidates], reverse=True)[ - max(0, len(candidates) // 5) - ], - ) - - high_relevance_candidates = [ - c for c in candidates if c.coverage_efficiency >= high_relevance_threshold - ] - - if not high_relevance_candidates: - # Fallback: use best candidate - high_relevance_candidates = [max(candidates, key=lambda c: c.total_score)] - - optimal_placements = [c.directory for c in high_relevance_candidates] - - # CRITICAL: Verify hierarchical coverage - covered_directories = self._calculate_hierarchical_coverage( - optimal_placements, matching_directories - ) - uncovered_directories = matching_directories - covered_directories - - if uncovered_directories: - # Coverage violation! Find minimal placement that covers everything - minimal_coverage = self._find_minimal_coverage_placement(matching_directories) - if minimal_coverage: - return [minimal_coverage] - else: - # Fallback to root to ensure no coverage gaps - return [self.base_dir] - - return optimal_placements - - def _generate_all_candidates( - self, matching_directories: builtins.set[Path], instruction: Instruction - ) -> builtins.list[PlacementCandidate]: - """Generate all placement candidates with optimization scores. - - This includes both matching directories AND their common ancestors to ensure - the mandatory coverage constraint can be satisfied. - """ - candidates = [] - pattern = instruction.apply_to - - # Collect all potential placement directories: - # 1. The matching directories themselves - # 2. Their common ancestors (for coverage guarantee) - potential_directories = set(matching_directories) - - # Add common ancestor directories to ensure coverage options exist - if len(matching_directories) > 1: - # Find common ancestors that could provide coverage - common_ancestor = self._find_minimal_coverage_placement(matching_directories) - if common_ancestor: - potential_directories.add(common_ancestor) - - # Also add any intermediate directories in the inheritance chains - for directory in matching_directories: - chain = self._get_inheritance_chain(directory) - # Add intermediate directories that could provide coverage - for intermediate in chain: - if intermediate != directory and intermediate in self._directory_cache: - potential_directories.add(intermediate) - - # Generate candidates for all potential directories - for directory in sorted(potential_directories): - if directory not in self._directory_cache: - continue - - analysis = self._directory_cache[directory] - - # Calculate the three optimization objectives - coverage_efficiency = self._calculate_coverage_efficiency(directory, pattern) - pollution_score = self._calculate_pollution_minimization(directory, pattern) - maintenance_locality = self._calculate_maintenance_locality(directory, pattern) - - # Apply depth penalty for excessive nesting - depth_penalty = max(0, (analysis.depth - 3) * self.DEPTH_PENALTY_FACTOR) - - # Calculate total objective function score - total_score = ( - coverage_efficiency * self.COVERAGE_EFFICIENCY_WEIGHT - + (1.0 - pollution_score) * self.POLLUTION_MINIMIZATION_WEIGHT - + maintenance_locality * self.MAINTENANCE_LOCALITY_WEIGHT - - depth_penalty - ) - - candidate = PlacementCandidate( - instruction=instruction, - directory=directory, - direct_relevance=coverage_efficiency, # Legacy field - inheritance_pollution=pollution_score, # Legacy field - depth_specificity=analysis.depth * 0.1, # Legacy field - total_score=0.0, # Temporary value, will be overwritten - ) - - # Add new optimization fields - candidate.coverage_efficiency = coverage_efficiency - candidate.pollution_score = pollution_score - candidate.maintenance_locality = maintenance_locality - - # Set the mathematical optimization score (after __post_init__ has run) - candidate.total_score = total_score - - candidates.append(candidate) - - return candidates - - def _find_minimal_coverage_placement( - self, matching_directories: builtins.set[Path] - ) -> Path | None: - """Find the highest directory that can provide hierarchical coverage for all matching directories. - - Args: - matching_directories: Directories that contain files matching the pattern - - Returns: - Path to the minimal covering directory, or None if no single placement works - """ - if not matching_directories: - return None - - # Convert to relative paths for easier analysis - relative_dirs = [ - d.resolve().relative_to(self.base_dir.resolve()) for d in matching_directories - ] - - # Find the lowest common ancestor that covers all directories - if len(relative_dirs) == 1: - # Single directory - we can place instruction in that directory or any parent - return list(matching_directories)[0] - - # Find common path prefix for all directories - common_parts = [] - min_depth = min(len(d.parts) for d in relative_dirs) - - for i in range(min_depth): - parts_at_level = [d.parts[i] for d in relative_dirs] - if len(set(parts_at_level)) == 1: - # All directories share this path component - common_parts.append(parts_at_level[0]) - else: - break - - if common_parts: - # Found common ancestor - common_ancestor = self.base_dir / Path(*common_parts) - return common_ancestor - else: - # No common ancestor beyond root - place at root - return self.base_dir - - def _calculate_hierarchical_coverage( - self, placements: builtins.list[Path], target_directories: builtins.set[Path] - ) -> builtins.set[Path]: - """Calculate which target directories are covered by the given placements through hierarchical inheritance. - - Args: - placements: List of directories where AGENTS.md files will be placed - target_directories: Directories that need to be covered - - Returns: - Set of target directories that are covered by the placements - """ - covered = set() - - for target in target_directories: - for placement in placements: - if self._is_hierarchically_covered(target, placement): - covered.add(target) - break - - return covered - - def _is_hierarchically_covered(self, target_dir: Path, placement_dir: Path) -> bool: - """Check if target_dir can inherit instructions from placement_dir through hierarchy. - - This is true if placement_dir is target_dir itself or any parent of target_dir. - """ - try: - # Check if target is the same as placement or is a subdirectory of placement - target_dir.resolve().relative_to(placement_dir.resolve()) - return True - except ValueError: - # target_dir is not under placement_dir - return False - - def _calculate_coverage_efficiency(self, directory: Path, pattern: str) -> float: - """Calculate how well placement covers actual usage.""" - analysis = self._directory_cache[directory] - return analysis.get_relevance_score(pattern) - - def _calculate_pollution_minimization(self, directory: Path, pattern: str) -> float: - """Calculate pollution score (higher = more pollution).""" - return self._calculate_inheritance_pollution(directory, pattern) - - def _calculate_maintenance_locality(self, directory: Path, pattern: str) -> float: - """Calculate maintenance locality score.""" - # Simple heuristic: prefer directories with more related files - analysis = self._directory_cache[directory] - pattern_matches = analysis.pattern_matches.get(pattern, 0) - - if analysis.total_files == 0: - return 0.0 - - return min(1.0, pattern_matches / analysis.total_files) - - def _select_clean_separation_placements( - self, candidates: builtins.list[PlacementCandidate], pattern: str - ) -> builtins.list[Path]: - """Select placements that provide clean separation of concerns. - - Args: - candidates (List[PlacementCandidate]): Sorted placement candidates. - pattern (str): Instruction pattern. - - Returns: - List[Path]: List of directories for clean separation. - """ - # Look for distinct clusters of files - clusters = [] - - for candidate in candidates: - # Check if this directory is isolated (not a parent/child of others) - is_isolated = True - - for other in candidates: - if candidate.directory == other.directory: - continue - - if self._is_child_directory( - candidate.directory, other.directory - ) or self._is_child_directory(other.directory, candidate.directory): - is_isolated = False - break - - if is_isolated and candidate.direct_relevance >= 0.1: # Use fixed threshold - clusters.append(candidate.directory) - - # If we found clean clusters, use them - if len(clusters) > 1: - return clusters - - # Otherwise, return single best placement - return [] - - def _get_inheritance_chain(self, working_directory: Path) -> builtins.list[Path]: - """Get inheritance chain from working directory to root. - - Args: - working_directory (Path): Starting directory. - - Returns: - List[Path]: Inheritance chain (most specific to root). - """ - cached = self._inheritance_cache.get(working_directory) - if cached is not None: - return cached - - chain = [] - # Resolve the starting directory to ensure consistent path comparison - try: - current = working_directory.resolve() - except (OSError, ValueError): - current = working_directory.absolute() - - seen_paths = set() # Track visited paths to prevent infinite loops - - # Build chain from working directory up to (and including) base_dir - while current not in seen_paths: - seen_paths.add(current) - chain.append(current) - - # Stop at base_dir - if current == self.base_dir: - break - - # Stop if we can't go higher or hit filesystem root - try: - parent = current.parent - if parent == current: # We've hit filesystem root - break - current = parent - except (OSError, ValueError): - break - - self._inheritance_cache[working_directory] = chain - return chain - - def _is_child_directory(self, child: Path, parent: Path) -> bool: - """Check if child is a subdirectory of parent. - - Args: - child (Path): Potential child directory. - parent (Path): Potential parent directory. - - Returns: - bool: True if child is subdirectory of parent. - """ - try: - child.resolve().relative_to(parent.resolve()) - return child.resolve() != parent.resolve() - except ValueError: - return False - - def _is_instruction_relevant(self, instruction: Instruction, working_directory: Path) -> bool: - """Check if instruction is relevant for the working directory. - - Args: - instruction (Instruction): Instruction to check. - working_directory (Path): Directory where agent is working. - - Returns: - bool: True if instruction is relevant. - """ - if not instruction.apply_to: - return True # Global instructions are always relevant - - pattern = instruction.apply_to - - # Resolve working directory to handle path inconsistencies - try: - resolved_working_dir = working_directory.resolve() - except (OSError, ValueError): - resolved_working_dir = working_directory.absolute() - - # Check if working directory has files matching the pattern - analysis = self._directory_cache.get(resolved_working_dir) - if not analysis: - return False - - # If pattern already analyzed, use cached result - if pattern in analysis.pattern_matches: - return analysis.pattern_matches[pattern] > 0 - - # Otherwise, analyze this specific directory for the pattern - # Only check direct files in this directory (not subdirectories for simplicity) - matching_files = 0 - - try: - for file in os.listdir(resolved_working_dir): - if file.startswith("."): - continue - - file_path = resolved_working_dir / file - if file_path.is_file(): - if self._file_matches_pattern(file_path, pattern): - matching_files += 1 - except (OSError, PermissionError): - # Handle case where directory doesn't exist or can't be read - pass - - # Cache the result - analysis.pattern_matches[pattern] = matching_files - - return matching_files > 0 - - # Debug print methods removed - replaced by structured data collection - # for professional output formatting via CompilationResults diff --git a/src/apm_cli/compilation/distributed_compiler.py b/src/apm_cli/compilation/distributed_compiler.py index 886efad4a..196a1386b 100644 --- a/src/apm_cli/compilation/distributed_compiler.py +++ b/src/apm_cli/compilation/distributed_compiler.py @@ -13,8 +13,9 @@ from ..output.formatters import CompilationFormatter from ..output.models import CompilationResults from ..primitives.models import Instruction, PrimitiveCollection -from ..utils.paths import portable_relpath, resolve_base_and_source_dirs +from ..utils.paths import resolve_base_and_source_dirs from ..version import get_version +from ._distributed_orphans import _DistributedOrphansMixin from .constants import BUILD_ID_PLACEHOLDER from .context_optimizer import ContextOptimizer from .link_resolver import UnifiedLinkResolver @@ -68,7 +69,7 @@ class CompilationResult: stats: builtins.dict[str, float] = field(default_factory=dict) # Support optimization metrics -class DistributedAgentsCompiler: +class DistributedAgentsCompiler(_DistributedOrphansMixin): """Main compiler for generating distributed AGENTS.md files.""" def __init__( @@ -610,153 +611,6 @@ def _generate_agents_content( return content - def _validate_coverage( - self, - placements: builtins.list[PlacementResult], - all_instructions: builtins.list[Instruction], - ) -> builtins.list[str]: - """Validate that all instructions are covered by placements. - - Args: - placements (List[PlacementResult]): Generated placements. - all_instructions (List[Instruction]): All available instructions. - - Returns: - List[str]: List of coverage warnings. - """ - warnings = [] - placed_instructions = set() - - for placement in placements: - placed_instructions.update(str(inst.file_path) for inst in placement.instructions) - - all_instruction_paths = set(str(inst.file_path) for inst in all_instructions) - - missing_instructions = all_instruction_paths - placed_instructions - if missing_instructions: - warnings.append( - f"Instructions not placed in any AGENTS.md: {', '.join(missing_instructions)}" - ) - - return warnings - - def _find_orphaned_agents_files( - self, generated_paths: builtins.list[Path] - ) -> builtins.list[Path]: - """Find existing AGENTS.md files that weren't generated in the current compilation. - - Args: - generated_paths (List[Path]): List of AGENTS.md files generated in current run. - - Returns: - List[Path]: List of orphaned AGENTS.md files that should be cleaned up. - """ - orphaned_files = [] - generated_set = set(generated_paths) - - # Find all existing AGENTS.md files in the project - for agents_file in self.base_dir.rglob("AGENTS.md"): - # Skip files that are outside our project or in special directories - try: - relative_path = agents_file.resolve().relative_to(self.base_dir.resolve()) - - # Skip files in certain directories that shouldn't be cleaned - skip_dirs = { - ".git", - ".apm", - "node_modules", - "__pycache__", - ".pytest_cache", - "apm_modules", - } - if any(part in skip_dirs for part in relative_path.parts): - continue - - # If this existing file wasn't generated in current run, it's orphaned - if agents_file not in generated_set: - orphaned_files.append(agents_file) - - except ValueError: - # File is outside base_dir, skip it - continue - - return orphaned_files - - def _generate_orphan_warnings(self, orphaned_files: builtins.list[Path]) -> builtins.list[str]: - """Generate warning messages for orphaned AGENTS.md files. - - Args: - orphaned_files (List[Path]): List of orphaned files to warn about. - - Returns: - List[str]: List of warning messages. - """ - warning_messages = [] - - if not orphaned_files: - return warning_messages - - # Professional warning format with readable list for multiple files - if len(orphaned_files) == 1: - rel_path = portable_relpath(orphaned_files[0], self.base_dir) - warning_messages.append( - f"Orphaned AGENTS.md found: {rel_path} - run 'apm compile --clean' to remove" - ) - else: - # For multiple files, create a single multi-line warning message - file_list = [] - for file_path in orphaned_files[:5]: # Show first 5 - rel_path = portable_relpath(file_path, self.base_dir) - file_list.append(f" * {rel_path}") - if len(orphaned_files) > 5: - file_list.append(f" * ...and {len(orphaned_files) - 5} more") - - # Create one cohesive warning message - files_text = "\n".join(file_list) - warning_messages.append( - f"Found {len(orphaned_files)} orphaned AGENTS.md files:\n{files_text}\n Run 'apm compile --clean' to remove orphaned files" - ) - - return warning_messages - - def _cleanup_orphaned_files( - self, orphaned_files: builtins.list[Path], dry_run: bool = False - ) -> builtins.list[str]: - """Actually remove orphaned AGENTS.md files. - - Args: - orphaned_files (List[Path]): List of orphaned files to remove. - dry_run (bool): If True, don't actually remove files, just report what would be removed. - - Returns: - List[str]: List of cleanup status messages. - """ - cleanup_messages = [] - - if not orphaned_files: - return cleanup_messages - - if dry_run: - # In dry-run mode, just report what would be cleaned - cleanup_messages.append( - f"Would clean up {len(orphaned_files)} orphaned AGENTS.md files" - ) - for file_path in orphaned_files: - rel_path = portable_relpath(file_path, self.base_dir) - cleanup_messages.append(f" * {rel_path}") - else: - # Actually perform the cleanup - cleanup_messages.append(f"Cleaning up {len(orphaned_files)} orphaned AGENTS.md files") - for file_path in orphaned_files: - try: - rel_path = portable_relpath(file_path, self.base_dir) - file_path.unlink() - cleanup_messages.append(f" + Removed {rel_path}") - except Exception as e: - cleanup_messages.append(f" x Failed to remove {rel_path}: {e!s}") - - return cleanup_messages - def _compile_distributed_stats( self, placements: builtins.list[PlacementResult], primitives: PrimitiveCollection ) -> builtins.dict[str, float]: diff --git a/src/apm_cli/compilation/link_resolver.py b/src/apm_cli/compilation/link_resolver.py index 6e57b8281..ff21647f2 100644 --- a/src/apm_cli/compilation/link_resolver.py +++ b/src/apm_cli/compilation/link_resolver.py @@ -479,9 +479,7 @@ def _resolve_in_package_asset_link( ``..`` chains, etc.). * Path computation raises (broken filesystem, encoding, ...). """ - if ctx.package_root is None: - return None - if not ctx.package_root.is_dir(): + if ctx.package_root is None or not ctx.package_root.is_dir(): return None path_part, suffix = self._split_link_target(link_path) @@ -492,10 +490,6 @@ def _resolve_in_package_asset_link( source_dir = ( ctx.source_file.parent if ctx.source_file.is_file() else ctx.source_location ) - except OSError: - return None - - try: candidate = (source_dir / path_part).resolve() except (OSError, ValueError): return None diff --git a/src/apm_cli/models/dependency/_reference_parse.py b/src/apm_cli/models/dependency/_reference_parse.py new file mode 100644 index 000000000..ea1b14e0c --- /dev/null +++ b/src/apm_cli/models/dependency/_reference_parse.py @@ -0,0 +1,504 @@ +"""Object-entry + top-level string parsing mixin for ``DependencyReference``. + +Composed onto :class:`~apm_cli.models.dependency.reference.DependencyReference` +via mixin inheritance; ``cls`` binds to ``DependencyReference`` at call time and +cross-method calls (``cls.parse``, ``cls._detect_virtual_package``, ...) resolve +through the MRO. Nothing here imports the composed class, so the package stays +free of import cycles. +""" + +import re +import urllib.parse +from pathlib import Path +from typing import TYPE_CHECKING + +from ...utils.github_host import ( + default_host, + is_azure_devops_hostname, + maybe_raise_bare_fqdn_github_gitlab_conflict, + unsupported_host_error, +) +from ...utils.path_security import validate_path_segments + +if TYPE_CHECKING: + from .reference import DependencyReference + +_MARKETPLACE_KEYS = {"name", "marketplace", "version"} + + +class _ReferenceParseMixin: + """``parse`` / ``parse_from_dict`` and their object-entry sub-parsers.""" + + @staticmethod + def _validate_object_alias(alias_override: object) -> str: + """Strip and validate an object-form ``alias`` value. + + Shared by every object entry shape (git, parent, registry) so the + alias grammar stays defined in exactly one place. + """ + if not isinstance(alias_override, str) or not alias_override.strip(): + raise ValueError("'alias' field must be a non-empty string") + alias_override = alias_override.strip() + if not re.match(r"^[a-zA-Z0-9._-]+$", alias_override): + raise ValueError( + f"Invalid alias: {alias_override}. Aliases can only contain letters, " + f"numbers, dots, underscores, and hyphens" + ) + return alias_override + + @staticmethod + def _validate_object_ref(ref_override: object) -> str: + """Strip and validate an object-form ``ref`` value.""" + if not isinstance(ref_override, str) or not ref_override.strip(): + raise ValueError("'ref' field must be a non-empty string") + return ref_override.strip() + + @classmethod + def parse_from_dict(cls, entry: dict) -> "DependencyReference": + """Parse an object-style dependency entry from apm.yml. + + Supports the Cargo-inspired object format: + + - git: https://gitlab.com/acme/coding-standards.git + path: instructions/security + ref: v2.0 + + - git: git@bitbucket.org:team/rules.git + path: prompts/review.prompt.md + + Also supports local path entries: + + - path: ./packages/my-shared-skills + + And marketplace dependency entries: + + - name: gopls-lsp + marketplace: claude-plugins-official + + - name: secrets-vault + marketplace: acme-tools + version: "~2.1.0" + + Args: + entry: Dictionary with 'git', 'path', or 'marketplace' key. + Marketplace entries support 'name', 'marketplace', and + optional 'version' (semver range) fields. + + Returns: + DependencyReference: Parsed dependency reference + + Raises: + ValueError: If the entry is missing required fields or has invalid format + """ + # Support marketplace dependencies: { name: X, marketplace: Y, version: Z } + if "marketplace" in entry: + return cls._parse_marketplace_object_entry(entry) + + # Object-form registry package — design §3.2. + # Discriminated by the ``registry:`` or ``id:`` key (``registry:`` is + # optional when a ``registries.default:`` is configured). Mutually + # exclusive with ``git:``. + if "registry" in entry or "id" in entry: + if "git" in entry: + raise ValueError( + "Object-style dependency cannot mix 'registry:'/'id:' and 'git:' " + "keys — choose one resolver." + ) + return cls._parse_registry_object_entry(entry) + + # Support dict-form local path: { path: ./local/dir } + if "path" in entry and "git" not in entry: + return cls._parse_local_path_object_entry(entry) + + if "git" not in entry: + raise ValueError( + "Object-style dependency must have a 'git', 'path', or 'registry' field" + ) + + git_url = entry["git"] + if not isinstance(git_url, str) or not git_url.strip(): + raise ValueError("'git' field must be a non-empty string") + + # Monorepo parent inheritance (literal ``git: parent`` only; resolver expands) + if git_url == "parent": + return cls._parse_parent_inheritance_entry(entry) + + return cls._parse_git_object_entry(entry, git_url) + + @classmethod + def _parse_marketplace_object_entry(cls, entry: dict) -> "DependencyReference": + """Parse a ``{ name, marketplace, version }`` marketplace entry.""" + source_keys = {"git", "path", "registry", "id"}.intersection(entry) + if source_keys: + joined = "', '".join(sorted(source_keys)) + raise ValueError( + f"Ambiguous dependency: 'marketplace' cannot be combined with '{joined}'" + ) + unknown = set(entry.keys()) - _MARKETPLACE_KEYS + if unknown: + raise ValueError( + f"Unknown keys in marketplace dependency: {sorted(unknown)}. " + f"Allowed keys: {sorted(_MARKETPLACE_KEYS)}" + ) + name = entry.get("name") + marketplace = entry["marketplace"] + if not isinstance(name, str) or not name.strip(): + raise ValueError("Marketplace dependency must have a non-empty 'name' field") + if not isinstance(marketplace, str) or not marketplace.strip(): + raise ValueError("'marketplace' field must be a non-empty string") + name = name.strip() + marketplace = marketplace.strip() + if not re.match(r"^[a-zA-Z0-9._-]+$", name): + raise ValueError( + f"Invalid marketplace plugin name: '{name}'. " + "Names can only contain letters, numbers, dots, underscores, and hyphens" + ) + if not re.match(r"^[a-zA-Z0-9._-]+$", marketplace): + raise ValueError( + f"Invalid marketplace name: '{marketplace}'. " + "Names can only contain letters, numbers, dots, underscores, and hyphens" + ) + version_spec = entry.get("version") + if version_spec is not None: + if not isinstance(version_spec, str) or not version_spec.strip(): + raise ValueError("'version' field must be a non-empty string") + version_spec = version_spec.strip() + return cls( + repo_url=f"_marketplace/{marketplace}/{name}", + is_marketplace=True, + marketplace_name=marketplace, + marketplace_plugin_name=name, + marketplace_version_spec=version_spec, + ) + + @classmethod + def _parse_local_path_object_entry(cls, entry: dict) -> "DependencyReference": + """Parse a ``{ path: ./local/dir }`` dict-form local path entry.""" + local = entry["path"] + if not isinstance(local, str) or not local.strip(): + raise ValueError("'path' field must be a non-empty string") + local = local.strip() + if not cls.is_local_path(local): + raise ValueError( + "Object-style dependency must have a 'git' field, " + "or 'path' must be a local filesystem path " + "(starting with './', '../', '/', or '~')" + ) + return cls.parse(local) + + @classmethod + def _parse_parent_inheritance_entry(cls, entry: dict) -> "DependencyReference": + """Parse a ``git: parent`` monorepo-inheritance object entry.""" + path_raw = entry.get("path") + if path_raw is None: + raise ValueError("Object-style dependency with git: 'parent' requires a 'path' field") + if not isinstance(path_raw, str) or not path_raw.strip(): + raise ValueError("'path' field must be a non-empty string") + normalized_path = cls._normalize_parent_repo_decl_path(path_raw) + + ref_override = entry.get("ref") + reference: str | None = None + if ref_override is not None: + reference = cls._validate_object_ref(ref_override) + + alias_override = entry.get("alias") + alias_val: str | None = None + if alias_override is not None: + alias_val = cls._validate_object_alias(alias_override) + + return cls( + repo_url="_parent", + host=None, + reference=reference, + alias=alias_val, + virtual_path=normalized_path, + is_virtual=True, + is_parent_repo_inheritance=True, + ) + + @classmethod + def _parse_git_object_entry(cls, entry: dict, git_url: str) -> "DependencyReference": + """Parse a standard ``git:`` object entry and apply its overrides.""" + sub_path = entry.get("path") + allow_insecure = entry.get("allow_insecure", False) + if not isinstance(allow_insecure, bool): + raise ValueError("'allow_insecure' field must be a boolean") + + # Validate sub_path if provided + if sub_path is not None: + if not isinstance(sub_path, str) or not sub_path.strip(): + raise ValueError("'path' field must be a non-empty string") + sub_path = sub_path.strip().strip("/") + # Normalize backslashes to forward slashes for cross-platform safety + sub_path = sub_path.replace("\\", "/").strip().strip("/") + # Security: reject path traversal + validate_path_segments(sub_path, context="path") + + # Parse the git URL using the standard parser + dep = cls.parse(git_url) + dep.allow_insecure = allow_insecure + # Object-form ``- git:`` is an explicit Git resolver pin, even when + # a top-level ``registries.default`` is set. Mark source so the + # default-routing pass in apm_package.py leaves it alone. + dep.source = "git" + + # Apply overrides from the object fields + ref_override = entry.get("ref") + if ref_override is not None: + dep.reference = cls._validate_object_ref(ref_override) + + alias_override = entry.get("alias") + if alias_override is not None: + dep.alias = cls._validate_object_alias(alias_override) + + # Apply sub-path as virtual package + if sub_path: + dep.virtual_path = sub_path + dep.is_virtual = True + + # Parse skills: field (SKILL_BUNDLE subset selection) + skills_raw = entry.get("skills") + if skills_raw is not None: + dep.skill_subset = cls._parse_skill_subset(skills_raw) + + return dep + + @staticmethod + def _parse_skill_subset(skills_raw: object) -> list[str]: + """Validate and de-duplicate the ``skills:`` subset list.""" + if not isinstance(skills_raw, (list,)): + raise ValueError("'skills' field must be a list of skill names") + if len(skills_raw) == 0: + raise ValueError( + "skills: must contain at least one name; " + "remove the field to install all skills in the bundle." + ) + seen: set = set() + validated: list = [] + for name in skills_raw: + if not isinstance(name, str) or not name.strip(): + raise ValueError("Each entry in 'skills' must be a non-empty string") + name = name.strip() + # Path safety: reject traversal sequences + validate_path_segments(name, context="skills/") + if name not in seen: + seen.add(name) + validated.append(name) + return sorted(validated) + + @classmethod + def _parse_registry_object_entry(cls, entry: dict) -> "DependencyReference": + """Parse the object-form registry entry per §3.2. + + Required keys: + id: / # package identity at the registry + version: # opaque version string; registry resolves it + + Optional: + registry: # routes to named registry; omit to use default + path: prompts/foo.md # virtual sub-path; omit to install the whole package + alias: # same meaning as in other object forms + """ + from ...deps.registry.feature_gate import require_package_registry_enabled + + require_package_registry_enabled("Object-form registry dependencies") + + _registry_raw = entry.get("registry") + registry_name: str | None = None + if _registry_raw is not None: + if not isinstance(_registry_raw, str) or not _registry_raw.strip(): + raise ValueError( + "Object-form registry entry: 'registry' must be a non-empty " + "string (the name of an entry in the apm.yml registries: block)" + ) + registry_name = _registry_raw.strip() + + pkg_id = entry.get("id") + if not isinstance(pkg_id, str) or not pkg_id.strip(): + raise ValueError( + "Object-form registry entry: 'id' is required and must be a " + "non-empty 'owner/repo' string" + ) + pkg_id = pkg_id.strip() + if "/" not in pkg_id: + raise ValueError( + f"Object-form registry entry: 'id' must be 'owner/repo', got {pkg_id!r}" + ) + + raw_path = entry.get("path") + sub_path: str | None = None + if raw_path is not None: + if not isinstance(raw_path, str) or not raw_path.strip(): + raise ValueError( + "Object-form registry entry: 'path' must be a non-empty string " + "when provided (e.g. 'prompts/review.prompt.md')" + ) + sub_path = raw_path.strip().strip("/").replace("\\", "/").strip("/") + validate_path_segments(sub_path, context="path") + + version = entry.get("version") + if not isinstance(version, str) or not version.strip(): + raise ValueError("Object-form registry entry: 'version' is required") + version = version.strip() + + alias = entry.get("alias") + if alias is not None: + alias = cls._validate_object_alias(alias) + + # Reject any unknown keys to catch typos early. + known = {"registry", "id", "path", "version", "alias"} + unknown = set(entry.keys()) - known + if unknown: + raise ValueError( + f"Object-form registry entry has unknown fields: " + f"{sorted(unknown)}. Known fields: {sorted(known)}" + ) + + owner_segments = pkg_id.split("/") + validate_path_segments(pkg_id, context="registry id") + for seg in owner_segments: + if not re.match(r"^[a-zA-Z0-9._-]+$", seg): + raise ValueError(f"Invalid registry id segment: {seg!r} in {pkg_id!r}") + + return cls( + repo_url=pkg_id, + host=default_host(), + reference=version, + virtual_path=sub_path, + is_virtual=sub_path is not None, + alias=alias, + source="registry", + registry_name=registry_name, + ) + + @classmethod + def parse(cls, dependency_str: str) -> "DependencyReference": + """Parse a dependency string into a DependencyReference. + + Supports formats: + - user/repo + - user/repo#branch + - user/repo#v1.0.0 + - user/repo#commit_sha + - github.com/user/repo#ref + - user/repo/path/to/file.prompt.md (virtual file package) + - user/repo/skills/foo (virtual subdirectory package) + - user/repo/collections/foo (virtual subdirectory package) + - https://gitlab.com/owner/repo.git (generic HTTPS git URL) + - git@gitlab.com:owner/repo.git (SSH git URL) + - ssh://git@gitlab.com/owner/repo.git (SSH protocol URL) + + Ambiguous GitLab nested-group shorthand cannot cover every depth; use + object form (``git:`` + ``path:`` in ``apm.yml``) as the supported + escape hatch. + + - ./local/path (local filesystem path) + - /absolute/path (local filesystem path) + - ../relative/path (local filesystem path) + + Any valid FQDN is accepted as a git host (GitHub, GitLab, Bitbucket, + self-hosted instances, etc.). + + Args: + dependency_str: The dependency string to parse + + Returns: + DependencyReference: Parsed dependency reference + + Raises: + ValueError: If the dependency string format is invalid + """ + if not dependency_str.strip(): + raise ValueError("Empty dependency string") + + dependency_str = urllib.parse.unquote(dependency_str) + + if any(ord(c) < 32 for c in dependency_str): + raise ValueError("Dependency string contains invalid control characters") + + # --- Local path detection (must run before URL/host parsing) --- + if cls.is_local_path(dependency_str): + local = dependency_str.strip() + pkg_name = Path(local).name + if not pkg_name or pkg_name in (".", ".."): + raise ValueError( + f"Local path '{local}' does not resolve to a named directory. " + f"Use a path that ends with a directory name " + f"(e.g., './my-package' instead of './')." + ) + return cls( + repo_url=f"_local/{pkg_name}", + is_local=True, + local_path=local, + source="local", + ) + + if dependency_str.startswith("//"): + raise ValueError( + unsupported_host_error("//...", context="Protocol-relative URLs are not supported") + ) + + cls._reject_shorthand_alias(dependency_str) + + maybe_raise_bare_fqdn_github_gitlab_conflict(dependency_str) + + # Phase 1: detect virtual packages + is_virtual_package, virtual_path, validated_host = cls._detect_virtual_package( + dependency_str + ) + + # Phase 2: parse SSH (ssh:// URL first -- it preserves port; then SCP + # shorthand), otherwise fall back to HTTPS/shorthand parsing. + explicit_scheme: str | None = None + ssh_user: str | None = None + ssh_proto_result = cls._parse_ssh_protocol_url(dependency_str) + if ssh_proto_result: + host, port, repo_url, reference, alias, ssh_user = ssh_proto_result + explicit_scheme = "ssh" + else: + scp_result = cls._parse_ssh_url(dependency_str) + if scp_result: + host, port, repo_url, reference, alias, ssh_user = scp_result + explicit_scheme = "ssh" + else: + host, port, repo_url, reference, alias, is_virtual_package, virtual_path = ( + cls._parse_standard_url( + dependency_str, is_virtual_package, virtual_path, validated_host + ) + ) + _stripped = dependency_str.strip().lower() + if _stripped.startswith("https://"): + explicit_scheme = "https" + elif _stripped.startswith("http://"): + explicit_scheme = "http" + + # Phase 3: final validation and ADO field extraction + ado_organization, ado_project, ado_repo = cls._validate_final_repo_fields(host, repo_url) + + if alias and not re.match(r"^[a-zA-Z0-9._-]+$", alias): + raise ValueError( + f"Invalid alias: {alias}. Aliases can only contain letters, numbers, dots, underscores, and hyphens" + ) + + # Extract Artifactory prefix from the original path if applicable + is_ado_final = host and is_azure_devops_hostname(host) + artifactory_prefix = None + if host and not is_ado_final: + artifactory_prefix = cls._extract_artifactory_prefix(dependency_str, host) + + return cls( + repo_url=repo_url, + host=host, + port=port, + explicit_scheme=explicit_scheme, + reference=reference, + alias=alias, + virtual_path=virtual_path, + is_virtual=is_virtual_package, + ado_organization=ado_organization, + ado_project=ado_project, + ado_repo=ado_repo, + artifactory_prefix=artifactory_prefix, + is_insecure=urllib.parse.urlparse(dependency_str).scheme.lower() == "http", + ssh_user=ssh_user, + ) diff --git a/src/apm_cli/models/dependency/_reference_shorthand.py b/src/apm_cli/models/dependency/_reference_shorthand.py new file mode 100644 index 000000000..d81434531 --- /dev/null +++ b/src/apm_cli/models/dependency/_reference_shorthand.py @@ -0,0 +1,375 @@ +"""Shorthand / virtual-package detection mixin for ``DependencyReference``. + +These classmethods/staticmethods are composed onto +:class:`~apm_cli.models.dependency.reference.DependencyReference` via mixin +inheritance, so ``cls`` binds to ``DependencyReference`` at call time and +cross-method calls resolve through the MRO. Nothing here imports the composed +class, so the package stays free of import cycles. +""" + +import urllib.parse +from typing import TYPE_CHECKING + +from ...utils.github_host import ( + is_artifactory_path, + is_azure_devops_hostname, + is_github_hostname, + is_gitlab_hostname, + is_supported_git_host, + maybe_raise_bare_fqdn_github_gitlab_conflict, + unsupported_host_error, +) +from ...utils.path_security import ( + PathTraversalError, + validate_path_segments, +) +from ..validation import InvalidVirtualPackageExtensionError + +if TYPE_CHECKING: + from .reference import DependencyReference + + +class _ReferenceShorthandMixin: + """Virtual-package detection + GitLab/Artifactory boundary heuristics.""" + + @classmethod + def virtual_suffix_is_installable_shape(cls, virtual_path: str) -> bool: + """Return whether *virtual_path* matches APM virtual package shape rules. + + Used for GitLab direct host/path shorthand: a repo boundary is accepted + only when the remaining suffix would be a valid virtual path (file, + collection, or extension-less subdirectory), matching the rules applied + in :meth:`_detect_virtual_package` for the tail segments. + """ + if not virtual_path or not virtual_path.strip(): + return False + v = virtual_path.strip().strip("/") + try: + validate_path_segments(v, context="virtual path") + except PathTraversalError: + return False + if "/collections/" in v or v.startswith("collections/"): + return True + if any(v.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): + return True + last = v.split("/")[-1] + return "." not in last + + @classmethod + def split_gitlab_direct_shorthand_parts( + cls, package: str + ) -> tuple[str, list[str], str | None] | None: + """If *package* is bare host/path shorthand, return (host, path_segments, ref_str). + + Returns ``None`` for ``https://``, ``git@``, or non-GitLab-class hosts. + """ + s = package.strip() + ref_out: str | None = None + if "#" in s: + s, r = s.rsplit("#", 1) + s = s.strip() + r = r.strip() + ref_out = r if r else None + maybe_raise_bare_fqdn_github_gitlab_conflict(package) + if s.startswith(("git@", "https://", "http://", "ssh://", "//")): + return None + if "/" not in s: + return None + parts = s.split("/") + host_cand = parts[0] + if "." not in host_cand: + return None + segs = [p for p in parts[1:] if p] + if len(segs) < 1: + return None + if not is_supported_git_host(host_cand) or not is_gitlab_hostname(host_cand): + return None + return (host_cand, segs, ref_out) + + @classmethod + def needs_gitlab_direct_shorthand_probing( + cls, package: str, dep_ref: "DependencyReference" + ) -> bool: + """True when install should probe left-to-right repo boundaries (GitLab only).""" + if dep_ref.is_local: + return False + if dep_ref.is_virtual: + return False + sp = cls.split_gitlab_direct_shorthand_parts(package) + if not sp: + return False + _host, segs, _ref = sp + return len(segs) >= 3 + + @classmethod + def iter_gitlab_direct_shorthand_boundary_candidates(cls, path_segments: list[str]): + """Yield (repo_url, virtual_suffix) for k=2..n-1 (earliest k first).""" + n = len(path_segments) + if n < 3: + return + for k in range(2, n): + repo = "/".join(path_segments[:k]) + suffix = "/".join(path_segments[k:]) + if cls.virtual_suffix_is_installable_shape(suffix): + yield repo, suffix + + @classmethod + def from_gitlab_shorthand_probe( + cls, + host: str, + repo_url: str, + virtual_path: str, + reference: str | None, + ) -> "DependencyReference": + """Build a virtual dependency ref for a resolved GitLab shorthand probe.""" + return cls( + repo_url=repo_url, + host=host, + reference=reference, + virtual_path=virtual_path, + is_virtual=True, + ) + + @classmethod + def from_artifactory_boundary_probe( + cls, + host: str, + prefix: str, + owner: str, + repo: str, + virtual_path: str | None, + reference: str | None, + ) -> "DependencyReference": + """Build a dependency ref for a resolved Artifactory boundary probe.""" + return cls( + repo_url=f"{owner}/{repo}", + host=host, + reference=reference, + virtual_path=virtual_path, + is_virtual=bool(virtual_path), + artifactory_prefix=prefix, + ) + + @classmethod + def _gitlab_shorthand_repo_segment_count( + cls, + path_segments: list[str], + has_virtual_ext: bool, + has_collection: bool, + ) -> int: + """Return how many segments after the host belong to the GitLab project path. + + GitLab allows nested groups; unlike GitHub's fixed ``owner/repo``, the + project slug may span 3+ segments. Virtual package shorthand must not + chop a nested group path after two segments. + + Shorthand cannot disambiguate every deep namespace; ambiguous cases use + object form with ``git:`` + ``path:`` in ``apm.yml``. + + This does **not** split extension-less paths (e.g. ``.../registry/pkg``) + into repo + virtual: that would mis-parse valid 5+ segment project + paths; use ``parse_from_dict`` with an explicit ``path`` for those. + """ + n = len(path_segments) + if n < 2: + return n + + if has_collection and "collections" in path_segments: + coll_idx = path_segments.index("collections") + if coll_idx >= 2: + return coll_idx + return n + + if has_virtual_ext: + for idx, seg in enumerate(path_segments): + if idx >= 2 and seg in cls._GITLAB_VIRTUAL_ROOT_SEGMENTS: + return idx + # 3-segment paths keep owner/repo; 4+ segment paths reserve the + # first three for the (possibly nested-group) project slug. + return 3 if n >= 4 else 2 + + return n + + @classmethod + def _bare_shorthand_repo_segment_count(cls, path_segments: list[str]) -> int: + """Return how many leading segments belong to the repo path for bare shorthand. + + For ``owner/repo[/...]`` shorthand without an FQDN, the default is 2 + segments (GitHub convention). When registry-only mode is active, the + proxy may be fronting a host that allows nested namespaces (GitLab + subgroups) -- parse defaults to **all-as-repo** so the deterministic + boundary probe in :mod:`apm_cli.install.artifactory_resolver` can + rebuild the dependency reference at the proxy-verified split. + + The only parse-time inference kept is **structural**: a path whose + last segment ends in a virtual file extension + (``.prompt.md``/``.instructions.md``/``.chatmode.md``/``.agent.md``) + is by shape a virtual file dep -- the file is the last segment and + the repo is everything before it. This is not a directory-marker + heuristic; the file extension is the type. The shallower boundary + (when the file lives under a known directory like ``prompts/``) is + settled by the probe, not by a marker list. + """ + n = len(path_segments) + if n < 3: + return 2 + + from ...deps.registry_proxy import is_enforce_only + + if not is_enforce_only(): + return 2 + + if any(path_segments[-1].endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): + return n - 1 + return n + + @classmethod + def _detect_virtual_package(cls, dependency_str: str): + """Detect whether *dependency_str* refers to a virtual package. + + Returns: + (is_virtual_package, virtual_path, validated_host) + """ + # Temporarily remove reference for path segment counting + temp_str = dependency_str + if "#" in temp_str: + temp_str = temp_str.rsplit("#", 1)[0] + + is_virtual_package = False + virtual_path = None + validated_host = None + + if temp_str.lower().startswith(("git@", "https://", "http://", "ssh://")): + return is_virtual_package, virtual_path, validated_host + + check_str = temp_str + + if "/" in check_str: + first_segment = check_str.split("/")[0] + + if "." in first_segment: + test_url = f"https://{check_str}" + try: + parsed = urllib.parse.urlparse(test_url) + hostname = parsed.hostname + + if hostname and is_supported_git_host(hostname): + validated_host = hostname + path_parts = parsed.path.lstrip("/").split("/") + if len(path_parts) >= 2: + check_str = "/".join(check_str.split("/")[1:]) + else: + raise ValueError(unsupported_host_error(hostname or first_segment)) + except (ValueError, AttributeError) as e: + if isinstance(e, ValueError) and "Invalid Git host" in str(e): + raise + raise ValueError(unsupported_host_error(first_segment)) from e + elif check_str.startswith("gh/"): + check_str = "/".join(check_str.split("/")[1:]) + + path_segments = [seg for seg in check_str.split("/") if seg] + + # Azure DevOps ``_git`` segment is a URL marker, not part of the + # org/project/repo path -- strip it before counting and slicing so + # both the base-segment count and the virtual suffix are computed + # against the real path. + is_ado = validated_host is not None and is_azure_devops_hostname(validated_host) + if is_ado and "_git" in path_segments: + git_idx = path_segments.index("_git") + path_segments = path_segments[:git_idx] + path_segments[git_idx + 1 :] + + min_base_segments = cls._virtual_min_base_segments(path_segments, validated_host) + min_virtual_segments = min_base_segments + 1 + + if len(path_segments) >= min_virtual_segments: + is_virtual_package = True + virtual_path = "/".join(path_segments[min_base_segments:]) + cls._validate_detected_virtual_path(virtual_path) + + return is_virtual_package, virtual_path, validated_host + + @classmethod + def _virtual_min_base_segments( + cls, path_segments: list[str], validated_host: str | None + ) -> int: + """Return the count of leading segments forming the base repo path. + + Encapsulates the per-host-class boundary rules (ADO / Artifactory / + GitLab / generic FQDN / bare shorthand) used by + :meth:`_detect_virtual_package`. ``path_segments`` must already have any + Azure DevOps ``_git`` marker stripped by the caller. + """ + is_ado = validated_host is not None and is_azure_devops_hostname(validated_host) + is_generic_host = ( + validated_host is not None + and not is_github_hostname(validated_host) + and not is_azure_devops_hostname(validated_host) + ) + is_gitlab_host = validated_host is not None and is_gitlab_hostname(validated_host) + + # Detect Artifactory VCS paths (artifactory/{repo-key}/{owner}/{repo}) + is_artifactory = is_generic_host and is_artifactory_path(path_segments) + + if is_ado: + from ...utils.github_host import is_visualstudio_legacy_hostname + + # *.visualstudio.com encodes org in the subdomain; path is proj/repo (2 parts). + # dev.azure.com encodes org as the first path segment; path is org/proj/repo (3 parts). + if validated_host and is_visualstudio_legacy_hostname(validated_host): + return 2 + return 3 + if is_artifactory: + # Artifactory: artifactory/{repo-key}/{owner}/{repo} + return 4 + if is_generic_host: + has_virtual_ext = any( + any(seg.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS) + for seg in path_segments + ) + has_collection = "collections" in path_segments + if is_gitlab_host: + return cls._gitlab_shorthand_repo_segment_count( + path_segments, has_virtual_ext, has_collection + ) + if has_virtual_ext or has_collection: + return 2 + return len(path_segments) + + # Bare shorthand (no FQDN). Default GitHub-style: owner/repo plus + # any tail is treated as a virtual sub-path. But when registry-only + # mode is active, the proxy may be fronting a GitLab instance where + # the project lives at an arbitrary subgroup depth -- fold non-marker + # segments into the repo path instead of mis-classifying them as + # virtual sub-paths (see issue: nested GitLab subgroup support). + return cls._bare_shorthand_repo_segment_count(path_segments) + + @classmethod + def _validate_detected_virtual_path(cls, virtual_path: str) -> None: + """Validate a detected virtual sub-path's safety and extension shape.""" + # Security: reject path traversal in virtual path + validate_path_segments(virtual_path, context="virtual path") + + # Reject removed `.collection.yml` extensions with a clear + # migration message (#1094). Curated dependency aggregators + # are now expressed as `apm.yml` with a `dependencies` block. + if any(virtual_path.endswith(ext) for ext in cls.REMOVED_COLLECTION_EXTENSIONS): + raise ValueError( + f".collection.yml is no longer supported. " + f"Convert '{virtual_path}' to an apm.yml with a " + f"'dependencies' section. " + f"See: https://microsoft.github.io/apm/guides/dependencies/" + ) + + # Accept any path ending in a recognised virtual file + # extension. Reject other dotted final segments so typos like + # `prompts/file.txt` fail fast instead of silently + # mis-classifying as a subdirectory. + if any(virtual_path.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): + return + last_segment = virtual_path.split("/")[-1] + if "." in last_segment: + raise InvalidVirtualPackageExtensionError( + f"Invalid virtual package path '{virtual_path}'. " + f"Individual files must end with one of: {', '.join(cls.VIRTUAL_FILE_EXTENSIONS)}. " + f"For subdirectory packages, the path should not have a file extension." + ) diff --git a/src/apm_cli/models/dependency/_reference_url.py b/src/apm_cli/models/dependency/_reference_url.py new file mode 100644 index 000000000..25fe53f64 --- /dev/null +++ b/src/apm_cli/models/dependency/_reference_url.py @@ -0,0 +1,550 @@ +"""URL / SSH / shorthand resolution mixin for ``DependencyReference``. + +Composed onto :class:`~apm_cli.models.dependency.reference.DependencyReference` +via mixin inheritance; ``cls`` binds to ``DependencyReference`` at call time and +cross-method calls resolve through the MRO. Nothing here imports the composed +class, so the package stays free of import cycles. +""" + +import re +import urllib.parse + +from ...cache.url_normalize import SCP_LIKE_RE +from ...utils.github_host import ( + default_host, + is_artifactory_path, + is_azure_devops_hostname, + is_github_hostname, + is_gitlab_hostname, + is_supported_git_host, + is_visualstudio_legacy_hostname, + parse_artifactory_path, + unsupported_host_error, + validate_ssh_user, +) +from ...utils.path_security import validate_path_segments +from ..validation import InvalidVirtualPackageExtensionError +from ._reference_util import ( + _DEFAULT_SCHEME_PORTS, + _NON_ADO_PATH_SEGMENT_RE, + _path_segment_pattern, +) + + +class _ReferenceUrlMixin: + """HTTPS/SSH/shorthand URL parsing, normalisation, and validation.""" + + @staticmethod + def _normalize_parent_repo_decl_path(raw: str) -> str: + """Normalize ``path`` for ``git: parent`` to a single canonical relative path.""" + s = raw.strip().replace("\\", "/").strip() + s = s.strip("/") + segments = [seg for seg in s.split("/") if seg] + if not segments: + raise ValueError("'path' field must be a non-empty string") + normalized = "/".join(segments) + validate_path_segments(normalized, context="path") + return normalized + + @staticmethod + def _parse_ssh_url(dependency_str: str): + """Parse an SCP-shorthand SSH URL (``@host:owner/repo``). + + Accepts any SSH username (not just ``git``), so EMU and custom GHE + SSH accounts (e.g. ``enterprise-user@ghe.corp.com:org/repo``) parse + correctly. SCP shorthand cannot carry a port (``:`` is the path + separator), so the returned port is always ``None``. For custom SSH + ports, use the ``ssh://`` URL form which is handled by + ``_parse_ssh_protocol_url``. + + Returns: + ``(host, port, repo_url, reference, alias)`` or *None* if not an SCP URL. + """ + ssh_match = SCP_LIKE_RE.match(dependency_str) + if not ssh_match: + return None + + user = ssh_match.group("user") + host = ssh_match.group("host") + ssh_repo_part = ssh_match.group("path") + + reference = None + alias = None + + if "@" in ssh_repo_part: + ssh_repo_part, alias = ssh_repo_part.rsplit("@", 1) + alias = alias.strip() + + if "#" in ssh_repo_part: + repo_part, reference = ssh_repo_part.rsplit("#", 1) + reference = reference.strip() + else: + repo_part = ssh_repo_part + + had_git_suffix = repo_part.endswith(".git") + if had_git_suffix: + repo_part = repo_part[:-4] + + repo_url = repo_part.strip() + + # SCP syntax (git@host:path) uses ':' as the path separator, so it + # cannot carry a port. Detect when the first segment is a valid TCP + # port number (1-65535) and raise an actionable error instead of + # silently misparsing the port as part of the repo path. + segments = repo_url.split("/", 1) + first_segment = segments[0] + if re.fullmatch(r"[0-9]+", first_segment): + port_candidate = int(first_segment) + if 1 <= port_candidate <= 65535: + remaining_path = segments[1] if len(segments) > 1 else "" + if remaining_path: + git_suffix = ".git" if had_git_suffix else "" + ref_suffix = f"#{reference}" if reference else "" + alias_suffix = f"@{alias}" if alias else "" + suggested = f"ssh://{user}@{host}:{port_candidate}/{remaining_path}{git_suffix}{ref_suffix}{alias_suffix}" + raise ValueError( + f"It looks like '{first_segment}' in '{user}@{host}:{repo_url}' " + f"is a port number, but SCP-style URLs (@host:path) cannot " + f"carry a port. Use the ssh:// URL form instead:\n" + f" {suggested}" + ) + else: + raise ValueError( + f"It looks like '{first_segment}' in '{user}@{host}:{first_segment}' " + f"is a port number, but no repository path follows it. " + f"SCP-style URLs (@host:path) cannot carry a port. " + f"Use the ssh:// URL form: ssh://{user}@{host}:{port_candidate}//.git" + ) + + # Security: reject traversal sequences in SSH repo paths + validate_path_segments(repo_url, context="SSH repository path", reject_empty=True) + + ssh_user = validate_ssh_user(user) + return host, None, repo_url, reference, alias, ssh_user + + @classmethod + def _resolve_virtual_shorthand_repo(cls, repo_url, validated_host, virtual_path=None): + """Narrow a virtual-package shorthand to just the base repo path. + + When a virtual package is given without a URL scheme + (e.g. ``github.com/owner/repo/path/file.prompt.md``), this strips + the virtual suffix so the downstream shorthand resolver only sees + the ``owner/repo`` (or ``org/project/repo`` for ADO) portion. + + Returns: + ``(host, repo_url)`` where *host* may be ``None``. + """ + parts = repo_url.split("/") + + if "_git" in parts: + git_idx = parts.index("_git") + parts = parts[:git_idx] + parts[git_idx + 1 :] + + host = None + if len(parts) >= 3 and is_supported_git_host(parts[0]): + host = parts[0] + if is_azure_devops_hostname(parts[0]): + if is_visualstudio_legacy_hostname(parts[0]): + # myorg.visualstudio.com/proj/repo/path: org in subdomain, + # need at least host + proj + repo + 1 virtual segment. + if len(parts) < 4: + raise ValueError( + "Invalid Azure DevOps virtual package format: must be " + "myorg.visualstudio.com/project/repo/path" + ) + repo_url = "/".join(parts[1:3]) + else: + # dev.azure.com/org/proj/repo/path: org in path + if len(parts) < 5: + raise ValueError( + "Invalid Azure DevOps virtual package format: must be dev.azure.com/org/project/repo/path" + ) + repo_url = "/".join(parts[1:4]) + elif is_artifactory_path(parts[1:]): + art_result = parse_artifactory_path(parts[1:]) + if art_result: + repo_url = f"{art_result[1]}/{art_result[2]}" + elif is_gitlab_hostname(parts[0]) and virtual_path: + vparts = [p for p in virtual_path.split("/") if p] + tail = len(vparts) + if tail > 0 and len(parts) > 1 + tail: + repo_url = "/".join(parts[1 : len(parts) - tail]) + else: + repo_url = "/".join(parts[1:]) + else: + repo_url = "/".join(parts[1:3]) + elif len(parts) >= 2: + if not host: + host = default_host() + if validated_host and is_azure_devops_hostname(validated_host): + if len(parts) < 4: + raise ValueError( + "Invalid Azure DevOps virtual package format: expected at least org/project/repo/path" + ) + repo_url = "/".join(parts[:3]) + elif validated_host is None and virtual_path: + # Bare shorthand under registry-only mode may carry a nested + # repo path (GitLab subgroup via proxy). Trust the boundary + # already chosen by ``_bare_shorthand_repo_segment_count`` -- + # everything before the virtual tail belongs to the repo. + vparts = [p for p in virtual_path.split("/") if p] + tail = len(vparts) + if tail > 0 and len(parts) > tail + 1: + repo_url = "/".join(parts[: len(parts) - tail]) + else: + repo_url = "/".join(parts[:2]) + else: + repo_url = "/".join(parts[:2]) + + return host, repo_url + + @classmethod + def _resolve_shorthand_to_parsed_url(cls, repo_url, host): + """Resolve a non-URL shorthand path into a ``urllib``-parsed URL. + + Handles ``user/repo``, ``github.com/user/repo``, + ``dev.azure.com/org/project/repo``, and Artifactory VCS paths. + Validates path components before returning. + + Returns: + ``(parsed_url, host)`` + """ + parts = repo_url.split("/") + + if "_git" in parts: + git_idx = parts.index("_git") + parts = parts[:git_idx] + parts[git_idx + 1 :] + + if len(parts) >= 3 and is_supported_git_host(parts[0]): + host = parts[0] + if is_visualstudio_legacy_hostname(host) and len(parts) >= 3: + # *.visualstudio.com/proj/repo: org is in the subdomain, path is proj/repo only + user_repo = "/".join(parts[1:3]) + elif is_azure_devops_hostname(host) and len(parts) >= 4: + # dev.azure.com/org/proj/repo: org is the first path segment + user_repo = "/".join(parts[1:4]) + elif not is_github_hostname(host) and not is_azure_devops_hostname(host): + if is_artifactory_path(parts[1:]): + art_result = parse_artifactory_path(parts[1:]) + if art_result: + user_repo = f"{art_result[1]}/{art_result[2]}" + else: + user_repo = "/".join(parts[1:]) + else: + user_repo = "/".join(parts[1:]) + else: + user_repo = "/".join(parts[1:]) + elif len(parts) >= 2 and "." not in parts[0]: + if not host: + host = default_host() + if is_azure_devops_hostname(host) and len(parts) >= 3: + user_repo = "/".join(parts[:3]) + elif host and not is_github_hostname(host) and not is_azure_devops_hostname(host): + user_repo = "/".join(parts) + elif len(parts) >= 3 and cls._bare_shorthand_repo_segment_count(parts) > 2: + # Registry-only mode allows nested-group repo paths + # (GitLab via proxy). Keep the full multi-segment path. + user_repo = "/".join(parts[: cls._bare_shorthand_repo_segment_count(parts)]) + else: + user_repo = "/".join(parts[:2]) + else: + raise ValueError( + "Use 'user/repo' or 'github.com/user/repo' or 'dev.azure.com/org/project/repo' format" + ) + + if not user_repo or "/" not in user_repo: + raise ValueError( + f"Invalid repository format: {repo_url}. Expected 'user/repo' or 'org/project/repo'" + ) + + uparts = user_repo.split("/") + is_ado_host = host and is_azure_devops_hostname(host) + + if is_ado_host: + # *.visualstudio.com encodes org in subdomain -> proj/repo is sufficient (2 parts). + # dev.azure.com encodes org in path -> org/proj/repo required (3 parts). + min_ado_parts = 2 if is_visualstudio_legacy_hostname(host) else 3 + if len(uparts) < min_ado_parts: + raise ValueError( + f"Invalid Azure DevOps repository format: {repo_url}. Expected 'org/project/repo'" + ) + elif len(uparts) < 2: + raise ValueError(f"Invalid repository format: {repo_url}. Expected 'user/repo'") + + allowed_pattern = _path_segment_pattern(is_ado_host) + validate_path_segments("/".join(uparts), context="repository path") + for part in uparts: + if not re.match(allowed_pattern, part.rstrip(".git")): + raise ValueError(f"Invalid repository path component: {part}") + + quoted_repo = "/".join(urllib.parse.quote(p, safe="") for p in uparts) + github_url = urllib.parse.urljoin(f"https://{host}/", quoted_repo) + parsed_url = urllib.parse.urlparse(github_url) + + return parsed_url, host + + @classmethod + def _validate_url_repo_path(cls, parsed_url) -> tuple[str, str | None]: + """Validate and normalise the repository path from a parsed URL. + + Checks host support, strips ``.git`` suffixes, removes ``_git`` + segments, and validates each path component against the allowed + character set for the detected host type. + + For Azure DevOps URLs with extra path segments beyond + ``org/project/repo`` (e.g. + ``https://dev.azure.com/org/proj/_git/repo/sub/path``), the extra + segments are extracted as a virtual package path and validated with + the same rules as the shorthand virtual-path detector. + + Returns: + ``(repo_url, virtual_path)`` where *repo_url* is the normalised + base repository path (e.g. ``owner/repo`` or + ``org/project/repo``) and *virtual_path* is ``None`` unless + extra ADO sub-path segments were detected. + """ + hostname = parsed_url.hostname or "" + if not is_supported_git_host(hostname): + raise ValueError(unsupported_host_error(hostname or parsed_url.netloc)) + + path = parsed_url.path.strip("/") + if not path: + raise ValueError("Repository path cannot be empty") + + if path.endswith(".git"): + path = path[:-4] + + path_parts = [urllib.parse.unquote(p) for p in path.split("/")] + if "_git" in path_parts: + git_idx = path_parts.index("_git") + path_parts = path_parts[:git_idx] + path_parts[git_idx + 1 :] + + is_ado_host = is_azure_devops_hostname(hostname) + + if is_ado_host: + return cls._validate_ado_url_repo_path(hostname, path, path_parts) + + if len(path_parts) < 2: + raise ValueError( + f"Invalid repository path: expected at least 'user/repo', got '{path}'" + ) + # Strip the Artifactory VCS prefix so ``repo_url`` is the bare + # ``owner/repo`` -- otherwise URL round-trip through + # ``to_github_url`` -> ``parse`` would carry the prefix in the + # repo_url and the orchestrator would double-prefix download URLs. + # The prefix itself is recovered separately via + # :meth:`_extract_artifactory_prefix`. + if is_artifactory_path(path_parts): + path_parts = path_parts[2:] + for pp in path_parts: + if any(pp.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): + raise ValueError( + f"Invalid repository path: '{path}' contains a virtual file extension. " + f"Use the dict format with 'path:' for virtual packages in HTTPS URLs" + ) + + cls._validate_repo_path_segments(path_parts, is_ado_host=False) + return "/".join(path_parts), None + + @classmethod + def _validate_ado_url_repo_path( + cls, hostname: str, path: str, path_parts: list[str] + ) -> tuple[str, str | None]: + """Validate an Azure DevOps URL path, splitting off any virtual sub-path. + + Returns ``(repo_url, virtual_path)`` with the org injected from the + subdomain for ``*.visualstudio.com`` so the result is always + ``org/project/repo``. + """ + # *.visualstudio.com encodes org in the subdomain; URL path is proj/repo (2 parts). + # dev.azure.com encodes org as the first path segment; URL path is org/proj/repo (3 parts). + is_vs_legacy = is_visualstudio_legacy_hostname(hostname) + min_ado_parts = 2 if is_vs_legacy else 3 + if len(path_parts) < min_ado_parts: + raise ValueError( + f"Invalid Azure DevOps repository path: expected 'org/project/repo', got '{path}'" + ) + + url_virtual_path: str | None = None + if len(path_parts) > min_ado_parts: + # Extra segments are a virtual sub-path (e.g. sub/path in + # https://dev.azure.com/org/proj/_git/repo/sub/path or + # https://myorg.visualstudio.com/proj/_git/repo/sub/path). + ado_virtual = "/".join(path_parts[min_ado_parts:]) + cls._validate_ado_virtual_suffix(ado_virtual) + url_virtual_path = ado_virtual + path_parts = path_parts[:min_ado_parts] + + # For *.visualstudio.com, inject the org from the subdomain so that the + # normalised repo_url is always org/project/repo (matching dev.azure.com). + if is_vs_legacy: + vs_org = hostname.split(".")[0] + path_parts = [vs_org, *path_parts] + + cls._validate_repo_path_segments(path_parts, is_ado_host=True) + return "/".join(path_parts), url_virtual_path + + @classmethod + def _validate_ado_virtual_suffix(cls, ado_virtual: str) -> None: + """Validate an ADO URL virtual sub-path (traversal + extension shape).""" + # Security: reject path traversal in virtual path. + validate_path_segments(ado_virtual, context="virtual path") + + # Reject removed .collection.yml extensions. + if any(ado_virtual.endswith(ext) for ext in cls.REMOVED_COLLECTION_EXTENSIONS): + raise ValueError( + f".collection.yml is no longer supported. " + f"Convert '{ado_virtual}' to an apm.yml with a " + f"'dependencies' section. " + f"See: https://microsoft.github.io/apm/guides/dependencies/" + ) + + # Accept any recognised virtual file extension; reject other + # dotted final segments (mirrors shorthand virtual detection). + if any(ado_virtual.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): + return + last_segment = ado_virtual.split("/")[-1] + if "." in last_segment: + raise InvalidVirtualPackageExtensionError( + f"Invalid virtual package path '{ado_virtual}'. " + f"Individual files must end with one of: " + f"{', '.join(cls.VIRTUAL_FILE_EXTENSIONS)}. " + f"For subdirectory packages, the path should not have a file extension." + ) + + @staticmethod + def _validate_repo_path_segments(path_parts: list[str], *, is_ado_host: bool) -> None: + """Validate each repo path segment against the host's allowed char set.""" + allowed_pattern = _path_segment_pattern(is_ado_host) + validate_path_segments( + "/".join(path_parts), + context="repository URL path", + reject_empty=True, + ) + for part in path_parts: + if not re.match(allowed_pattern, part): + raise ValueError(f"Invalid repository path component: {part}") + + @classmethod + def _parse_standard_url( + cls, + dependency_str: str, + is_virtual_package: bool, + virtual_path: str | None, + validated_host: str | None, + ) -> tuple[str, int | None, str, str | None, str | None, bool, str | None]: + """Parse a non-SSH dependency string (HTTPS, FQDN, or shorthand). + + Detects scheme vs shorthand, delegates host-specific resolution to + helpers, then validates the resulting URL path. + + Returns: + ``(host, port, repo_url, reference, alias, effective_is_virtual, + effective_virtual_path)`` -- the last two reflect any ADO sub-path + segments embedded in the URL itself (issue #1128). + """ + host = None + port = None + alias = None + + reference = None + if "#" in dependency_str: + repo_part, reference = dependency_str.rsplit("#", 1) + reference = reference.strip() + else: + repo_part = dependency_str + + repo_url = repo_part.strip() + + # Lowercase copy for scheme detection -- kept from the original + # repo_url so the URL-vs-shorthand check below still works after + # the virtual shorthand resolver has narrowed repo_url. + repo_url_lower = repo_url.lower() + + # For virtual packages without a URL scheme, narrow to just owner/repo + if is_virtual_package and not repo_url_lower.startswith(("https://", "http://")): + host, repo_url = cls._resolve_virtual_shorthand_repo( + repo_url, validated_host, virtual_path + ) + + # Normalize to URL format for secure parsing + if repo_url_lower.startswith(("https://", "http://")): + parsed_url = urllib.parse.urlparse(repo_url) + host = parsed_url.hostname or "" + port = parsed_url.port # capture :PORT from https://host:8443/... + # Normalise default-scheme ports (443 for HTTPS, 80 for HTTP) + # so lockfile keys are consistent regardless of URL spelling. + scheme = (parsed_url.scheme or "").lower() + if port == _DEFAULT_SCHEME_PORTS.get(scheme): + port = None + else: + parsed_url, host = cls._resolve_shorthand_to_parsed_url(repo_url, host) + + repo_url, url_virtual_path = cls._validate_url_repo_path(parsed_url) + + # If URL contained extra ADO sub-path segments, they become the virtual + # path (overriding the _detect_virtual_package result which returns + # early for https:// URLs). + effective_is_virtual = is_virtual_package + effective_virtual_path = virtual_path + if url_virtual_path is not None: + effective_is_virtual = True + effective_virtual_path = url_virtual_path + + if not host: + host = default_host() + + return host, port, repo_url, reference, alias, effective_is_virtual, effective_virtual_path + + @classmethod + def _validate_final_repo_fields(cls, host, repo_url): + """Validate the final repo_url and extract ADO organisation fields. + + Performs character-set and segment-count validation appropriate for + the detected host type (Azure DevOps vs generic git host). + + Returns: + ``(ado_organization, ado_project, ado_repo)`` -- all ``None`` + for non-ADO hosts. + """ + is_ado_final = host and is_azure_devops_hostname(host) + if is_ado_final: + if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._\- ]+/[a-zA-Z0-9._\- ]+$", repo_url): + raise ValueError( + f"Invalid Azure DevOps repository format: {repo_url}. Expected 'org/project/repo'" + ) + ado_parts = repo_url.split("/") + validate_path_segments(repo_url, context="Azure DevOps repository path") + return ado_parts[0], ado_parts[1], ado_parts[2] + + segments = repo_url.split("/") + if len(segments) < 2: + raise ValueError(f"Invalid repository format: {repo_url}. Expected 'user/repo'") + if not all(re.match(_NON_ADO_PATH_SEGMENT_RE, s) for s in segments): + raise ValueError(f"Invalid repository format: {repo_url}. Contains invalid characters") + validate_path_segments(repo_url, context="repository path") + for seg in segments: + if any(seg.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): + raise ValueError( + f"Invalid repository format: '{repo_url}' contains a virtual file extension. " + f"Use the dict format with 'path:' for virtual packages in SSH/HTTPS URLs" + ) + return None, None, None + + @staticmethod + def _extract_artifactory_prefix(dependency_str, host): + """Extract the Artifactory VCS prefix from the original dependency string. + + Returns: + The prefix string (e.g. ``"artifactory/github"``) or ``None``. + """ + _art_str = dependency_str.split("#")[0].split("@")[0] + # Strip scheme if present (e.g., https://host/artifactory/...) + if "://" in _art_str: + _art_str = _art_str.split("://", 1)[1] + _art_segs = _art_str.replace(f"{host}/", "", 1).split("/") + if is_artifactory_path(_art_segs): + art_result = parse_artifactory_path(_art_segs) + if art_result: + return art_result[0] + return None diff --git a/src/apm_cli/models/dependency/_reference_util.py b/src/apm_cli/models/dependency/_reference_util.py new file mode 100644 index 000000000..2ac361a80 --- /dev/null +++ b/src/apm_cli/models/dependency/_reference_util.py @@ -0,0 +1,53 @@ +"""Leaf helpers shared by :mod:`reference` and its parse/url/shorthand mixins. + +These helpers depend on nothing inside ``reference.py``, so importing them +here (instead of from ``reference``) keeps the mixin modules free of a circular +import back to the composed :class:`DependencyReference`. ``reference.py`` +re-exports the public-ish names so existing +``apm_cli.models.dependency.reference.NAME`` references keep resolving. +""" + +import re + +# Default ports per URI scheme -- used to normalise away redundant +# explicit ports (e.g. https://host:443/...) so that lockfile keys +# and error messages stay consistent regardless of how the user +# spelled the URL. +_DEFAULT_SCHEME_PORTS: dict[str, int] = {"https": 443, "http": 80, "ssh": 22} + +# Allowed character set for a single repository path segment. +# +# ADO accepts spaces (project / repo names can contain them) but NOT tilde -- +# tilde has no meaning on Azure DevOps URLs and keeping it out preserves the +# asymmetry that protects the ADO surface from inadvertent regressions. +# +# Non-ADO hosts accept tilde because Bitbucket Data Center / Server (and +# Sourcehut) use ``~username`` path segments for personal repositories +# (e.g. ``/scm/~jdoe/repo.git``). ``~`` is RFC 3986 unreserved, has no +# POSIX path-traversal meaning, and all subprocess calls in APM use +# list-form ``argv`` so there is no shell-expansion vector. +_ADO_PATH_SEGMENT_RE = r"^[a-zA-Z0-9._\- ]+$" +_NON_ADO_PATH_SEGMENT_RE = r"^[a-zA-Z0-9._~-]+$" + +_RANGE_PREFIX_RE = re.compile(r"^(>=|<=|>|<|\^|~|=)") + + +def _path_segment_pattern(is_ado_host: bool) -> str: + """Return the allowed-character regex for a single repo path segment.""" + return _ADO_PATH_SEGMENT_RE if is_ado_host else _NON_ADO_PATH_SEGMENT_RE + + +def _is_valid_registry_semver_range(spec: str) -> bool: + """Defer importing ``deps.registry`` until call time (avoids import cycles).""" + from ...deps.registry.semver import is_semver_range + + return is_semver_range(spec) + + +class InvalidSemverRangeError(ValueError): + """Raised when a ref starts like a semver range but is invalid.""" + + +def _looks_like_invalid_semver_range(spec: str) -> bool: + """Return whether *spec* starts like a semver range but is invalid.""" + return bool(_RANGE_PREFIX_RE.match(spec.strip())) diff --git a/src/apm_cli/models/dependency/reference.py b/src/apm_cli/models/dependency/reference.py index 4c780551a..8a0f2a919 100644 --- a/src/apm_cli/models/dependency/reference.py +++ b/src/apm_cli/models/dependency/reference.py @@ -1,4 +1,17 @@ -"""DependencyReference model -- core dependency representation and parsing.""" +"""DependencyReference model -- core dependency representation and parsing. + +The class body is split across cohesive mixins to keep each module small and +focused while preserving a single patchable ``DependencyReference`` symbol: + +* :class:`_ReferenceParseMixin` -- ``parse`` / ``parse_from_dict`` + object entries +* :class:`_ReferenceUrlMixin` -- HTTPS / SSH / shorthand URL resolution +* :class:`_ReferenceShorthandMixin` -- virtual-package + boundary heuristics + +Mixin classmethods inherit onto the composed class, so ``cls`` binds to +``DependencyReference`` and inter-method calls resolve via the MRO. Leaf +helpers and constants live in :mod:`._reference_util` (re-exported below) so +the mixin modules never import this module back -- no circular import. +""" import re import urllib.parse @@ -8,72 +21,49 @@ from ...cache.url_normalize import SCP_LIKE_RE from ...utils.github_host import ( default_host, - is_artifactory_path, - is_azure_devops_hostname, - is_github_hostname, - is_gitlab_hostname, - is_supported_git_host, - is_visualstudio_legacy_hostname, - maybe_raise_bare_fqdn_github_gitlab_conflict, - parse_artifactory_path, - unsupported_host_error, validate_ssh_user, ) from ...utils.path_security import ( - PathTraversalError, ensure_path_within, validate_path_segments, ) -from ..validation import InvalidVirtualPackageExtensionError +from ._reference_parse import _ReferenceParseMixin +from ._reference_shorthand import _ReferenceShorthandMixin +from ._reference_url import _ReferenceUrlMixin + +# Re-export relocated leaf helpers/constants so that +# ``apm_cli.models.dependency.reference.NAME`` keeps resolving for any code +# (and tests) importing them from here. The helpers themselves live in +# ``_reference_util`` to break the mixin <-> reference import cycle. +from ._reference_util import ( + _ADO_PATH_SEGMENT_RE as _ADO_PATH_SEGMENT_RE, +) +from ._reference_util import ( + _DEFAULT_SCHEME_PORTS as _DEFAULT_SCHEME_PORTS, +) +from ._reference_util import ( + _NON_ADO_PATH_SEGMENT_RE as _NON_ADO_PATH_SEGMENT_RE, +) +from ._reference_util import ( + _RANGE_PREFIX_RE as _RANGE_PREFIX_RE, +) +from ._reference_util import ( + InvalidSemverRangeError as InvalidSemverRangeError, +) +from ._reference_util import ( + _is_valid_registry_semver_range as _is_valid_registry_semver_range, +) +from ._reference_util import ( + _looks_like_invalid_semver_range as _looks_like_invalid_semver_range, +) +from ._reference_util import ( + _path_segment_pattern as _path_segment_pattern, +) from .types import VirtualPackageType -# Default ports per URI scheme -- used to normalise away redundant -# explicit ports (e.g. https://host:443/...) so that lockfile keys -# and error messages stay consistent regardless of how the user -# spelled the URL. -_DEFAULT_SCHEME_PORTS: dict[str, int] = {"https": 443, "http": 80, "ssh": 22} - -# Allowed character set for a single repository path segment. -# -# ADO accepts spaces (project / repo names can contain them) but NOT tilde -- -# tilde has no meaning on Azure DevOps URLs and keeping it out preserves the -# asymmetry that protects the ADO surface from inadvertent regressions. -# -# Non-ADO hosts accept tilde because Bitbucket Data Center / Server (and -# Sourcehut) use ``~username`` path segments for personal repositories -# (e.g. ``/scm/~jdoe/repo.git``). ``~`` is RFC 3986 unreserved, has no -# POSIX path-traversal meaning, and all subprocess calls in APM use -# list-form ``argv`` so there is no shell-expansion vector. -_ADO_PATH_SEGMENT_RE = r"^[a-zA-Z0-9._\- ]+$" -_NON_ADO_PATH_SEGMENT_RE = r"^[a-zA-Z0-9._~-]+$" - - -def _path_segment_pattern(is_ado_host: bool) -> str: - """Return the allowed-character regex for a single repo path segment.""" - return _ADO_PATH_SEGMENT_RE if is_ado_host else _NON_ADO_PATH_SEGMENT_RE - - -def _is_valid_registry_semver_range(spec: str) -> bool: - """Defer importing ``deps.registry`` until call time (avoids import cycles).""" - from ...deps.registry.semver import is_semver_range - - return is_semver_range(spec) - - -_RANGE_PREFIX_RE = re.compile(r"^(>=|<=|>|<|\^|~|=)") - - -class InvalidSemverRangeError(ValueError): - """Raised when a ref starts like a semver range but is invalid.""" - - -def _looks_like_invalid_semver_range(spec: str) -> bool: - """Return whether *spec* starts like a semver range but is invalid.""" - return bool(_RANGE_PREFIX_RE.match(spec.strip())) - @dataclass -class DependencyReference: +class DependencyReference(_ReferenceParseMixin, _ReferenceUrlMixin, _ReferenceShorthandMixin): """Represents a reference to an APM dependency.""" repo_url: str # e.g., "user/repo" for GitHub or "org/project/repo" for Azure DevOps @@ -612,1279 +602,6 @@ def _parse_ssh_protocol_url(url: str): return host, port, repo_url, reference, alias, ssh_user - @staticmethod - def _normalize_parent_repo_decl_path(raw: str) -> str: - """Normalize ``path`` for ``git: parent`` to a single canonical relative path.""" - s = raw.strip().replace("\\", "/").strip() - s = s.strip("/") - segments = [seg for seg in s.split("/") if seg] - if not segments: - raise ValueError("'path' field must be a non-empty string") - normalized = "/".join(segments) - validate_path_segments(normalized, context="path") - return normalized - - @classmethod - def parse_from_dict(cls, entry: dict) -> "DependencyReference": - """Parse an object-style dependency entry from apm.yml. - - Supports the Cargo-inspired object format: - - - git: https://gitlab.com/acme/coding-standards.git - path: instructions/security - ref: v2.0 - - - git: git@bitbucket.org:team/rules.git - path: prompts/review.prompt.md - - Also supports local path entries: - - - path: ./packages/my-shared-skills - - And marketplace dependency entries: - - - name: gopls-lsp - marketplace: claude-plugins-official - - - name: secrets-vault - marketplace: acme-tools - version: "~2.1.0" - - Args: - entry: Dictionary with 'git', 'path', or 'marketplace' key. - Marketplace entries support 'name', 'marketplace', and - optional 'version' (semver range) fields. - - Returns: - DependencyReference: Parsed dependency reference - - Raises: - ValueError: If the entry is missing required fields or has invalid format - """ - # Support marketplace dependencies: { name: X, marketplace: Y, version: Z } - if "marketplace" in entry: - source_keys = {"git", "path", "registry", "id"}.intersection(entry) - if source_keys: - joined = "', '".join(sorted(source_keys)) - raise ValueError( - f"Ambiguous dependency: 'marketplace' cannot be combined with '{joined}'" - ) - _MARKETPLACE_KEYS = {"name", "marketplace", "version"} - unknown = set(entry.keys()) - _MARKETPLACE_KEYS - if unknown: - raise ValueError( - f"Unknown keys in marketplace dependency: {sorted(unknown)}. " - f"Allowed keys: {sorted(_MARKETPLACE_KEYS)}" - ) - name = entry.get("name") - marketplace = entry["marketplace"] - if not isinstance(name, str) or not name.strip(): - raise ValueError("Marketplace dependency must have a non-empty 'name' field") - if not isinstance(marketplace, str) or not marketplace.strip(): - raise ValueError("'marketplace' field must be a non-empty string") - name = name.strip() - marketplace = marketplace.strip() - if not re.match(r"^[a-zA-Z0-9._-]+$", name): - raise ValueError( - f"Invalid marketplace plugin name: '{name}'. " - "Names can only contain letters, numbers, dots, underscores, and hyphens" - ) - if not re.match(r"^[a-zA-Z0-9._-]+$", marketplace): - raise ValueError( - f"Invalid marketplace name: '{marketplace}'. " - "Names can only contain letters, numbers, dots, underscores, and hyphens" - ) - version_spec = entry.get("version") - if version_spec is not None: - if not isinstance(version_spec, str) or not version_spec.strip(): - raise ValueError("'version' field must be a non-empty string") - version_spec = version_spec.strip() - return cls( - repo_url=f"_marketplace/{marketplace}/{name}", - is_marketplace=True, - marketplace_name=marketplace, - marketplace_plugin_name=name, - marketplace_version_spec=version_spec, - ) - - # Object-form registry package — design §3.2. - # Discriminated by the ``registry:`` or ``id:`` key (``registry:`` is - # optional when a ``registries.default:`` is configured). Mutually - # exclusive with ``git:``. - if "registry" in entry or "id" in entry: - if "git" in entry: - raise ValueError( - "Object-style dependency cannot mix 'registry:'/'id:' and 'git:' " - "keys — choose one resolver." - ) - return cls._parse_registry_object_entry(entry) - - # Support dict-form local path: { path: ./local/dir } - if "path" in entry and "git" not in entry: - local = entry["path"] - if not isinstance(local, str) or not local.strip(): - raise ValueError("'path' field must be a non-empty string") - local = local.strip() - if not cls.is_local_path(local): - raise ValueError( - "Object-style dependency must have a 'git' field, " - "or 'path' must be a local filesystem path " - "(starting with './', '../', '/', or '~')" - ) - return cls.parse(local) - - if "git" not in entry: - raise ValueError( - "Object-style dependency must have a 'git', 'path', or 'registry' field" - ) - - git_url = entry["git"] - if not isinstance(git_url, str) or not git_url.strip(): - raise ValueError("'git' field must be a non-empty string") - - # Monorepo parent inheritance (literal ``git: parent`` only; resolver expands) - if git_url == "parent": - path_raw = entry.get("path") - if path_raw is None: - raise ValueError( - "Object-style dependency with git: 'parent' requires a 'path' field" - ) - if not isinstance(path_raw, str) or not path_raw.strip(): - raise ValueError("'path' field must be a non-empty string") - normalized_path = cls._normalize_parent_repo_decl_path(path_raw) - - ref_override = entry.get("ref") - alias_override = entry.get("alias") - reference: str | None = None - if ref_override is not None: - if not isinstance(ref_override, str) or not ref_override.strip(): - raise ValueError("'ref' field must be a non-empty string") - reference = ref_override.strip() - - alias_val: str | None = None - if alias_override is not None: - if not isinstance(alias_override, str) or not alias_override.strip(): - raise ValueError("'alias' field must be a non-empty string") - alias_override = alias_override.strip() - if not re.match(r"^[a-zA-Z0-9._-]+$", alias_override): - raise ValueError( - f"Invalid alias: {alias_override}. Aliases can only contain letters, numbers, dots, underscores, and hyphens" - ) - alias_val = alias_override - - return cls( - repo_url="_parent", - host=None, - reference=reference, - alias=alias_val, - virtual_path=normalized_path, - is_virtual=True, - is_parent_repo_inheritance=True, - ) - - sub_path = entry.get("path") - ref_override = entry.get("ref") - alias_override = entry.get("alias") - allow_insecure = entry.get("allow_insecure", False) - if not isinstance(allow_insecure, bool): - raise ValueError("'allow_insecure' field must be a boolean") - - # Validate sub_path if provided - if sub_path is not None: - if not isinstance(sub_path, str) or not sub_path.strip(): - raise ValueError("'path' field must be a non-empty string") - sub_path = sub_path.strip().strip("/") - # Normalize backslashes to forward slashes for cross-platform safety - sub_path = sub_path.replace("\\", "/").strip().strip("/") - # Security: reject path traversal - validate_path_segments(sub_path, context="path") - - # Parse the git URL using the standard parser - dep = cls.parse(git_url) - dep.allow_insecure = allow_insecure - # Object-form ``- git:`` is an explicit Git resolver pin, even when - # a top-level ``registries.default`` is set. Mark source so the - # default-routing pass in apm_package.py leaves it alone. - dep.source = "git" - - # Apply overrides from the object fields - if ref_override is not None: - if not isinstance(ref_override, str) or not ref_override.strip(): - raise ValueError("'ref' field must be a non-empty string") - dep.reference = ref_override.strip() - - if alias_override is not None: - if not isinstance(alias_override, str) or not alias_override.strip(): - raise ValueError("'alias' field must be a non-empty string") - alias_override = alias_override.strip() - if not re.match(r"^[a-zA-Z0-9._-]+$", alias_override): - raise ValueError( - f"Invalid alias: {alias_override}. Aliases can only contain letters, numbers, dots, underscores, and hyphens" - ) - dep.alias = alias_override - - # Apply sub-path as virtual package - if sub_path: - dep.virtual_path = sub_path - dep.is_virtual = True - - # Parse skills: field (SKILL_BUNDLE subset selection) - skills_raw = entry.get("skills") - if skills_raw is not None: - if not isinstance(skills_raw, (list,)): - raise ValueError("'skills' field must be a list of skill names") - if len(skills_raw) == 0: - raise ValueError( - "skills: must contain at least one name; " - "remove the field to install all skills in the bundle." - ) - seen: set = set() - validated: list = [] - for name in skills_raw: - if not isinstance(name, str) or not name.strip(): - raise ValueError("Each entry in 'skills' must be a non-empty string") - name = name.strip() - # Path safety: reject traversal sequences - validate_path_segments(name, context="skills/") - if name not in seen: - seen.add(name) - validated.append(name) - dep.skill_subset = sorted(validated) - - return dep - - @classmethod - def virtual_suffix_is_installable_shape(cls, virtual_path: str) -> bool: - """Return whether *virtual_path* matches APM virtual package shape rules. - - Used for GitLab direct host/path shorthand: a repo boundary is accepted - only when the remaining suffix would be a valid virtual path (file, - collection, or extension-less subdirectory), matching the rules applied - in :meth:`_detect_virtual_package` for the tail segments. - """ - if not virtual_path or not virtual_path.strip(): - return False - v = virtual_path.strip().strip("/") - try: - validate_path_segments(v, context="virtual path") - except PathTraversalError: - return False - if "/collections/" in v or v.startswith("collections/"): - return True - if any(v.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): - return True - last = v.split("/")[-1] - return "." not in last - - @classmethod - def split_gitlab_direct_shorthand_parts( - cls, package: str - ) -> tuple[str, list[str], str | None] | None: - """If *package* is bare host/path shorthand, return (host, path_segments, ref_str). - - Returns ``None`` for ``https://``, ``git@``, or non-GitLab-class hosts. - """ - s = package.strip() - ref_out: str | None = None - if "#" in s: - s, r = s.rsplit("#", 1) - s = s.strip() - r = r.strip() - ref_out = r if r else None - maybe_raise_bare_fqdn_github_gitlab_conflict(package) - if s.startswith(("git@", "https://", "http://", "ssh://", "//")): - return None - if "/" not in s: - return None - parts = s.split("/") - host_cand = parts[0] - if "." not in host_cand: - return None - segs = [p for p in parts[1:] if p] - if len(segs) < 1: - return None - if not is_supported_git_host(host_cand) or not is_gitlab_hostname(host_cand): - return None - return (host_cand, segs, ref_out) - - @classmethod - def needs_gitlab_direct_shorthand_probing( - cls, package: str, dep_ref: "DependencyReference" - ) -> bool: - """True when install should probe left-to-right repo boundaries (GitLab only).""" - if dep_ref.is_local: - return False - if dep_ref.is_virtual: - return False - sp = cls.split_gitlab_direct_shorthand_parts(package) - if not sp: - return False - _host, segs, _ref = sp - return len(segs) >= 3 - - @classmethod - def iter_gitlab_direct_shorthand_boundary_candidates(cls, path_segments: list[str]): - """Yield (repo_url, virtual_suffix) for k=2..n-1 (earliest k first).""" - n = len(path_segments) - if n < 3: - return - for k in range(2, n): - repo = "/".join(path_segments[:k]) - suffix = "/".join(path_segments[k:]) - if cls.virtual_suffix_is_installable_shape(suffix): - yield repo, suffix - - @classmethod - def from_gitlab_shorthand_probe( - cls, - host: str, - repo_url: str, - virtual_path: str, - reference: str | None, - ) -> "DependencyReference": - """Build a virtual dependency ref for a resolved GitLab shorthand probe.""" - return cls( - repo_url=repo_url, - host=host, - reference=reference, - virtual_path=virtual_path, - is_virtual=True, - ) - - @classmethod - def from_artifactory_boundary_probe( - cls, - host: str, - prefix: str, - owner: str, - repo: str, - virtual_path: str | None, - reference: str | None, - ) -> "DependencyReference": - """Build a dependency ref for a resolved Artifactory boundary probe.""" - return cls( - repo_url=f"{owner}/{repo}", - host=host, - reference=reference, - virtual_path=virtual_path, - is_virtual=bool(virtual_path), - artifactory_prefix=prefix, - ) - - @classmethod - def _gitlab_shorthand_repo_segment_count( - cls, - path_segments: list[str], - has_virtual_ext: bool, - has_collection: bool, - ) -> int: - """Return how many segments after the host belong to the GitLab project path. - - GitLab allows nested groups; unlike GitHub's fixed ``owner/repo``, the - project slug may span 3+ segments. Virtual package shorthand must not - chop a nested group path after two segments. - - Shorthand cannot disambiguate every deep namespace; ambiguous cases use - object form with ``git:`` + ``path:`` in ``apm.yml``. - - This does **not** split extension-less paths (e.g. ``.../registry/pkg``) - into repo + virtual: that would mis-parse valid 5+ segment project - paths; use ``parse_from_dict`` with an explicit ``path`` for those. - """ - n = len(path_segments) - if n < 2: - return n - - if has_collection and "collections" in path_segments: - coll_idx = path_segments.index("collections") - if coll_idx >= 2: - return coll_idx - return n - - if has_virtual_ext: - for idx, seg in enumerate(path_segments): - if idx >= 2 and seg in cls._GITLAB_VIRTUAL_ROOT_SEGMENTS: - return idx - if n == 3: - return 2 - if n == 4: - return 3 - if n >= 5: - return 3 - return 2 - - return n - - @classmethod - def _bare_shorthand_repo_segment_count(cls, path_segments: list[str]) -> int: - """Return how many leading segments belong to the repo path for bare shorthand. - - For ``owner/repo[/...]`` shorthand without an FQDN, the default is 2 - segments (GitHub convention). When registry-only mode is active, the - proxy may be fronting a host that allows nested namespaces (GitLab - subgroups) -- parse defaults to **all-as-repo** so the deterministic - boundary probe in :mod:`apm_cli.install.artifactory_resolver` can - rebuild the dependency reference at the proxy-verified split. - - The only parse-time inference kept is **structural**: a path whose - last segment ends in a virtual file extension - (``.prompt.md``/``.instructions.md``/``.chatmode.md``/``.agent.md``) - is by shape a virtual file dep -- the file is the last segment and - the repo is everything before it. This is not a directory-marker - heuristic; the file extension is the type. The shallower boundary - (when the file lives under a known directory like ``prompts/``) is - settled by the probe, not by a marker list. - """ - n = len(path_segments) - if n < 3: - return 2 - - from ...deps.registry_proxy import is_enforce_only - - if not is_enforce_only(): - return 2 - - if any(path_segments[-1].endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): - return n - 1 - return n - - @classmethod - def _parse_registry_object_entry(cls, entry: dict) -> "DependencyReference": - """Parse the object-form registry entry per §3.2. - - Required keys: - id: / # package identity at the registry - version: # opaque version string; registry resolves it - - Optional: - registry: # routes to named registry; omit to use default - path: prompts/foo.md # virtual sub-path; omit to install the whole package - alias: # same meaning as in other object forms - """ - from ...deps.registry.feature_gate import require_package_registry_enabled - - require_package_registry_enabled("Object-form registry dependencies") - - _registry_raw = entry.get("registry") - registry_name: str | None = None - if _registry_raw is not None: - if not isinstance(_registry_raw, str) or not _registry_raw.strip(): - raise ValueError( - "Object-form registry entry: 'registry' must be a non-empty " - "string (the name of an entry in the apm.yml registries: block)" - ) - registry_name = _registry_raw.strip() - - pkg_id = entry.get("id") - if not isinstance(pkg_id, str) or not pkg_id.strip(): - raise ValueError( - "Object-form registry entry: 'id' is required and must be a " - "non-empty 'owner/repo' string" - ) - pkg_id = pkg_id.strip() - if "/" not in pkg_id: - raise ValueError( - f"Object-form registry entry: 'id' must be 'owner/repo', got {pkg_id!r}" - ) - - raw_path = entry.get("path") - sub_path: str | None = None - if raw_path is not None: - if not isinstance(raw_path, str) or not raw_path.strip(): - raise ValueError( - "Object-form registry entry: 'path' must be a non-empty string " - "when provided (e.g. 'prompts/review.prompt.md')" - ) - sub_path = raw_path.strip().strip("/").replace("\\", "/").strip("/") - validate_path_segments(sub_path, context="path") - - version = entry.get("version") - if not isinstance(version, str) or not version.strip(): - raise ValueError("Object-form registry entry: 'version' is required") - version = version.strip() - - alias = entry.get("alias") - if alias is not None: - if not isinstance(alias, str) or not alias.strip(): - raise ValueError("'alias' field must be a non-empty string") - alias = alias.strip() - if not re.match(r"^[a-zA-Z0-9._-]+$", alias): - raise ValueError( - f"Invalid alias: {alias}. Aliases can only contain " - f"letters, numbers, dots, underscores, and hyphens" - ) - - # Reject any unknown keys to catch typos early. - known = {"registry", "id", "path", "version", "alias"} - unknown = set(entry.keys()) - known - if unknown: - raise ValueError( - f"Object-form registry entry has unknown fields: " - f"{sorted(unknown)}. Known fields: {sorted(known)}" - ) - - owner_segments = pkg_id.split("/") - validate_path_segments(pkg_id, context="registry id") - for seg in owner_segments: - if not re.match(r"^[a-zA-Z0-9._-]+$", seg): - raise ValueError(f"Invalid registry id segment: {seg!r} in {pkg_id!r}") - - return cls( - repo_url=pkg_id, - host=default_host(), - reference=version, - virtual_path=sub_path, - is_virtual=sub_path is not None, - alias=alias, - source="registry", - registry_name=registry_name, - ) - - @classmethod - def _detect_virtual_package(cls, dependency_str: str): - """Detect whether *dependency_str* refers to a virtual package. - - Returns: - (is_virtual_package, virtual_path, validated_host) - """ - # Temporarily remove reference for path segment counting - temp_str = dependency_str - if "#" in temp_str: - temp_str = temp_str.rsplit("#", 1)[0] - - is_virtual_package = False - virtual_path = None - validated_host = None - - if temp_str.lower().startswith(("git@", "https://", "http://", "ssh://")): - return is_virtual_package, virtual_path, validated_host - - check_str = temp_str - - if "/" in check_str: - first_segment = check_str.split("/")[0] - - if "." in first_segment: - test_url = f"https://{check_str}" - try: - parsed = urllib.parse.urlparse(test_url) - hostname = parsed.hostname - - if hostname and is_supported_git_host(hostname): - validated_host = hostname - path_parts = parsed.path.lstrip("/").split("/") - if len(path_parts) >= 2: - check_str = "/".join(check_str.split("/")[1:]) - else: - raise ValueError(unsupported_host_error(hostname or first_segment)) - except (ValueError, AttributeError) as e: - if isinstance(e, ValueError) and "Invalid Git host" in str(e): - raise - raise ValueError(unsupported_host_error(first_segment)) from e - elif check_str.startswith("gh/"): - check_str = "/".join(check_str.split("/")[1:]) - - path_segments = [seg for seg in check_str.split("/") if seg] - - is_ado = validated_host is not None and is_azure_devops_hostname(validated_host) - is_generic_host = ( - validated_host is not None - and not is_github_hostname(validated_host) - and not is_azure_devops_hostname(validated_host) - ) - is_gitlab_host = validated_host is not None and is_gitlab_hostname(validated_host) - - if is_ado and "_git" in path_segments: - git_idx = path_segments.index("_git") - path_segments = path_segments[:git_idx] + path_segments[git_idx + 1 :] - - # Detect Artifactory VCS paths (artifactory/{repo-key}/{owner}/{repo}) - is_artifactory = is_generic_host and is_artifactory_path(path_segments) - - if is_ado: - # *.visualstudio.com encodes org in the subdomain; path is proj/repo (2 parts). - # dev.azure.com encodes org as the first path segment; path is org/proj/repo (3 parts). - if validated_host and is_visualstudio_legacy_hostname(validated_host): - min_base_segments = 2 - else: - min_base_segments = 3 - elif is_artifactory: - # Artifactory: artifactory/{repo-key}/{owner}/{repo} - min_base_segments = 4 - elif is_generic_host: - has_virtual_ext = any( - any(seg.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS) - for seg in path_segments - ) - has_collection = "collections" in path_segments - if is_gitlab_host: - min_base_segments = cls._gitlab_shorthand_repo_segment_count( - path_segments, has_virtual_ext, has_collection - ) - elif has_virtual_ext or has_collection: - min_base_segments = 2 - else: - min_base_segments = len(path_segments) - else: - # Bare shorthand (no FQDN). Default GitHub-style: owner/repo plus - # any tail is treated as a virtual sub-path. But when registry-only - # mode is active, the proxy may be fronting a GitLab instance where - # the project lives at an arbitrary subgroup depth -- fold non-marker - # segments into the repo path instead of mis-classifying them as - # virtual sub-paths (see issue: nested GitLab subgroup support). - min_base_segments = cls._bare_shorthand_repo_segment_count(path_segments) - - min_virtual_segments = min_base_segments + 1 - - if len(path_segments) >= min_virtual_segments: - is_virtual_package = True - virtual_path = "/".join(path_segments[min_base_segments:]) - - # Security: reject path traversal in virtual path - validate_path_segments(virtual_path, context="virtual path") - - # Reject removed `.collection.yml` extensions with a clear - # migration message (#1094). Curated dependency aggregators - # are now expressed as `apm.yml` with a `dependencies` block. - if any(virtual_path.endswith(ext) for ext in cls.REMOVED_COLLECTION_EXTENSIONS): - raise ValueError( - f".collection.yml is no longer supported. " - f"Convert '{virtual_path}' to an apm.yml with a " - f"'dependencies' section. " - f"See: https://microsoft.github.io/apm/guides/dependencies/" - ) - - # Accept any path ending in a recognised virtual file - # extension. Reject other dotted final segments so typos like - # `prompts/file.txt` fail fast instead of silently - # mis-classifying as a subdirectory. - if any(virtual_path.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): - pass - else: - last_segment = virtual_path.split("/")[-1] - if "." in last_segment: - raise InvalidVirtualPackageExtensionError( - f"Invalid virtual package path '{virtual_path}'. " - f"Individual files must end with one of: {', '.join(cls.VIRTUAL_FILE_EXTENSIONS)}. " - f"For subdirectory packages, the path should not have a file extension." - ) - - return is_virtual_package, virtual_path, validated_host - - @staticmethod - def _parse_ssh_url(dependency_str: str): - """Parse an SCP-shorthand SSH URL (``@host:owner/repo``). - - Accepts any SSH username (not just ``git``), so EMU and custom GHE - SSH accounts (e.g. ``enterprise-user@ghe.corp.com:org/repo``) parse - correctly. SCP shorthand cannot carry a port (``:`` is the path - separator), so the returned port is always ``None``. For custom SSH - ports, use the ``ssh://`` URL form which is handled by - ``_parse_ssh_protocol_url``. - - Returns: - ``(host, port, repo_url, reference, alias)`` or *None* if not an SCP URL. - """ - ssh_match = SCP_LIKE_RE.match(dependency_str) - if not ssh_match: - return None - - user = ssh_match.group("user") - host = ssh_match.group("host") - ssh_repo_part = ssh_match.group("path") - - reference = None - alias = None - - if "@" in ssh_repo_part: - ssh_repo_part, alias = ssh_repo_part.rsplit("@", 1) - alias = alias.strip() - - if "#" in ssh_repo_part: - repo_part, reference = ssh_repo_part.rsplit("#", 1) - reference = reference.strip() - else: - repo_part = ssh_repo_part - - had_git_suffix = repo_part.endswith(".git") - if had_git_suffix: - repo_part = repo_part[:-4] - - repo_url = repo_part.strip() - - # SCP syntax (git@host:path) uses ':' as the path separator, so it - # cannot carry a port. Detect when the first segment is a valid TCP - # port number (1-65535) and raise an actionable error instead of - # silently misparsing the port as part of the repo path. - segments = repo_url.split("/", 1) - first_segment = segments[0] - if re.fullmatch(r"[0-9]+", first_segment): - port_candidate = int(first_segment) - if 1 <= port_candidate <= 65535: - remaining_path = segments[1] if len(segments) > 1 else "" - if remaining_path: - git_suffix = ".git" if had_git_suffix else "" - ref_suffix = f"#{reference}" if reference else "" - alias_suffix = f"@{alias}" if alias else "" - suggested = f"ssh://{user}@{host}:{port_candidate}/{remaining_path}{git_suffix}{ref_suffix}{alias_suffix}" - raise ValueError( - f"It looks like '{first_segment}' in '{user}@{host}:{repo_url}' " - f"is a port number, but SCP-style URLs (@host:path) cannot " - f"carry a port. Use the ssh:// URL form instead:\n" - f" {suggested}" - ) - else: - raise ValueError( - f"It looks like '{first_segment}' in '{user}@{host}:{first_segment}' " - f"is a port number, but no repository path follows it. " - f"SCP-style URLs (@host:path) cannot carry a port. " - f"Use the ssh:// URL form: ssh://{user}@{host}:{port_candidate}//.git" - ) - - # Security: reject traversal sequences in SSH repo paths - validate_path_segments(repo_url, context="SSH repository path", reject_empty=True) - - ssh_user = validate_ssh_user(user) - return host, None, repo_url, reference, alias, ssh_user - - @classmethod - def _resolve_virtual_shorthand_repo(cls, repo_url, validated_host, virtual_path=None): - """Narrow a virtual-package shorthand to just the base repo path. - - When a virtual package is given without a URL scheme - (e.g. ``github.com/owner/repo/path/file.prompt.md``), this strips - the virtual suffix so the downstream shorthand resolver only sees - the ``owner/repo`` (or ``org/project/repo`` for ADO) portion. - - Returns: - ``(host, repo_url)`` where *host* may be ``None``. - """ - parts = repo_url.split("/") - - if "_git" in parts: - git_idx = parts.index("_git") - parts = parts[:git_idx] + parts[git_idx + 1 :] - - host = None - if len(parts) >= 3 and is_supported_git_host(parts[0]): - host = parts[0] - if is_azure_devops_hostname(parts[0]): - if is_visualstudio_legacy_hostname(parts[0]): - # myorg.visualstudio.com/proj/repo/path: org in subdomain, - # need at least host + proj + repo + 1 virtual segment. - if len(parts) < 4: - raise ValueError( - "Invalid Azure DevOps virtual package format: must be " - "myorg.visualstudio.com/project/repo/path" - ) - repo_url = "/".join(parts[1:3]) - else: - # dev.azure.com/org/proj/repo/path: org in path - if len(parts) < 5: - raise ValueError( - "Invalid Azure DevOps virtual package format: must be dev.azure.com/org/project/repo/path" - ) - repo_url = "/".join(parts[1:4]) - elif is_artifactory_path(parts[1:]): - art_result = parse_artifactory_path(parts[1:]) - if art_result: - repo_url = f"{art_result[1]}/{art_result[2]}" - elif is_gitlab_hostname(parts[0]) and virtual_path: - vparts = [p for p in virtual_path.split("/") if p] - tail = len(vparts) - if tail > 0 and len(parts) > 1 + tail: - repo_url = "/".join(parts[1 : len(parts) - tail]) - else: - repo_url = "/".join(parts[1:]) - else: - repo_url = "/".join(parts[1:3]) - elif len(parts) >= 2: - if not host: - host = default_host() - if validated_host and is_azure_devops_hostname(validated_host): - if len(parts) < 4: - raise ValueError( - "Invalid Azure DevOps virtual package format: expected at least org/project/repo/path" - ) - repo_url = "/".join(parts[:3]) - elif validated_host is None and virtual_path: - # Bare shorthand under registry-only mode may carry a nested - # repo path (GitLab subgroup via proxy). Trust the boundary - # already chosen by ``_bare_shorthand_repo_segment_count`` -- - # everything before the virtual tail belongs to the repo. - vparts = [p for p in virtual_path.split("/") if p] - tail = len(vparts) - if tail > 0 and len(parts) > tail + 1: - repo_url = "/".join(parts[: len(parts) - tail]) - else: - repo_url = "/".join(parts[:2]) - else: - repo_url = "/".join(parts[:2]) - - return host, repo_url - - @classmethod - def _resolve_shorthand_to_parsed_url(cls, repo_url, host): - """Resolve a non-URL shorthand path into a ``urllib``-parsed URL. - - Handles ``user/repo``, ``github.com/user/repo``, - ``dev.azure.com/org/project/repo``, and Artifactory VCS paths. - Validates path components before returning. - - Returns: - ``(parsed_url, host)`` - """ - parts = repo_url.split("/") - - if "_git" in parts: - git_idx = parts.index("_git") - parts = parts[:git_idx] + parts[git_idx + 1 :] - - if len(parts) >= 3 and is_supported_git_host(parts[0]): - host = parts[0] - if is_visualstudio_legacy_hostname(host) and len(parts) >= 3: - # *.visualstudio.com/proj/repo: org is in the subdomain, path is proj/repo only - user_repo = "/".join(parts[1:3]) - elif is_azure_devops_hostname(host) and len(parts) >= 4: - # dev.azure.com/org/proj/repo: org is the first path segment - user_repo = "/".join(parts[1:4]) - elif not is_github_hostname(host) and not is_azure_devops_hostname(host): - if is_artifactory_path(parts[1:]): - art_result = parse_artifactory_path(parts[1:]) - if art_result: - user_repo = f"{art_result[1]}/{art_result[2]}" - else: - user_repo = "/".join(parts[1:]) - else: - user_repo = "/".join(parts[1:]) - else: - user_repo = "/".join(parts[1:]) - elif len(parts) >= 2 and "." not in parts[0]: - if not host: - host = default_host() - if is_azure_devops_hostname(host) and len(parts) >= 3: - user_repo = "/".join(parts[:3]) - elif host and not is_github_hostname(host) and not is_azure_devops_hostname(host): - user_repo = "/".join(parts) - elif len(parts) >= 3 and cls._bare_shorthand_repo_segment_count(parts) > 2: - # Registry-only mode allows nested-group repo paths - # (GitLab via proxy). Keep the full multi-segment path. - user_repo = "/".join(parts[: cls._bare_shorthand_repo_segment_count(parts)]) - else: - user_repo = "/".join(parts[:2]) - else: - raise ValueError( - "Use 'user/repo' or 'github.com/user/repo' or 'dev.azure.com/org/project/repo' format" - ) - - if not user_repo or "/" not in user_repo: - raise ValueError( - f"Invalid repository format: {repo_url}. Expected 'user/repo' or 'org/project/repo'" - ) - - uparts = user_repo.split("/") - is_ado_host = host and is_azure_devops_hostname(host) - - if is_ado_host: - # *.visualstudio.com encodes org in subdomain -> proj/repo is sufficient (2 parts). - # dev.azure.com encodes org in path -> org/proj/repo required (3 parts). - min_ado_parts = 2 if is_visualstudio_legacy_hostname(host) else 3 - if len(uparts) < min_ado_parts: - raise ValueError( - f"Invalid Azure DevOps repository format: {repo_url}. Expected 'org/project/repo'" - ) - elif len(uparts) < 2: - raise ValueError(f"Invalid repository format: {repo_url}. Expected 'user/repo'") - - allowed_pattern = _path_segment_pattern(is_ado_host) - validate_path_segments("/".join(uparts), context="repository path") - for part in uparts: - if not re.match(allowed_pattern, part.rstrip(".git")): - raise ValueError(f"Invalid repository path component: {part}") - - quoted_repo = "/".join(urllib.parse.quote(p, safe="") for p in uparts) - github_url = urllib.parse.urljoin(f"https://{host}/", quoted_repo) - parsed_url = urllib.parse.urlparse(github_url) - - return parsed_url, host - - @classmethod - def _validate_url_repo_path(cls, parsed_url) -> tuple[str, str | None]: - """Validate and normalise the repository path from a parsed URL. - - Checks host support, strips ``.git`` suffixes, removes ``_git`` - segments, and validates each path component against the allowed - character set for the detected host type. - - For Azure DevOps URLs with extra path segments beyond - ``org/project/repo`` (e.g. - ``https://dev.azure.com/org/proj/_git/repo/sub/path``), the extra - segments are extracted as a virtual package path and validated with - the same rules as the shorthand virtual-path detector. - - Returns: - ``(repo_url, virtual_path)`` where *repo_url* is the normalised - base repository path (e.g. ``owner/repo`` or - ``org/project/repo``) and *virtual_path* is ``None`` unless - extra ADO sub-path segments were detected. - """ - hostname = parsed_url.hostname or "" - if not is_supported_git_host(hostname): - raise ValueError(unsupported_host_error(hostname or parsed_url.netloc)) - - path = parsed_url.path.strip("/") - if not path: - raise ValueError("Repository path cannot be empty") - - if path.endswith(".git"): - path = path[:-4] - - path_parts = [urllib.parse.unquote(p) for p in path.split("/")] - if "_git" in path_parts: - git_idx = path_parts.index("_git") - path_parts = path_parts[:git_idx] + path_parts[git_idx + 1 :] - - is_ado_host = is_azure_devops_hostname(hostname) - - url_virtual_path: str | None = None - - if is_ado_host: - # *.visualstudio.com encodes org in the subdomain; URL path is proj/repo (2 parts). - # dev.azure.com encodes org as the first path segment; URL path is org/proj/repo (3 parts). - is_vs_legacy = is_visualstudio_legacy_hostname(hostname) - min_ado_parts = 2 if is_vs_legacy else 3 - if len(path_parts) < min_ado_parts: - raise ValueError( - f"Invalid Azure DevOps repository path: expected 'org/project/repo', got '{path}'" - ) - if len(path_parts) > min_ado_parts: - # Extra segments are a virtual sub-path (e.g. sub/path in - # https://dev.azure.com/org/proj/_git/repo/sub/path or - # https://myorg.visualstudio.com/proj/_git/repo/sub/path). - ado_virtual = "/".join(path_parts[min_ado_parts:]) - - # Security: reject path traversal in virtual path. - validate_path_segments(ado_virtual, context="virtual path") - - # Reject removed .collection.yml extensions. - if any(ado_virtual.endswith(ext) for ext in cls.REMOVED_COLLECTION_EXTENSIONS): - raise ValueError( - f".collection.yml is no longer supported. " - f"Convert '{ado_virtual}' to an apm.yml with a " - f"'dependencies' section. " - f"See: https://microsoft.github.io/apm/guides/dependencies/" - ) - - # Accept any recognised virtual file extension; reject other - # dotted final segments (mirrors shorthand virtual detection). - if any(ado_virtual.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): - pass - else: - last_segment = ado_virtual.split("/")[-1] - if "." in last_segment: - raise InvalidVirtualPackageExtensionError( - f"Invalid virtual package path '{ado_virtual}'. " - f"Individual files must end with one of: " - f"{', '.join(cls.VIRTUAL_FILE_EXTENSIONS)}. " - f"For subdirectory packages, the path should not have a file extension." - ) - - url_virtual_path = ado_virtual - path_parts = path_parts[:min_ado_parts] - - # For *.visualstudio.com, inject the org from the subdomain so that the - # normalised repo_url is always org/project/repo (matching dev.azure.com). - if is_vs_legacy: - vs_org = hostname.split(".")[0] - path_parts = [vs_org, *path_parts] - else: - if len(path_parts) < 2: - raise ValueError( - f"Invalid repository path: expected at least 'user/repo', got '{path}'" - ) - # Strip the Artifactory VCS prefix so ``repo_url`` is the bare - # ``owner/repo`` -- otherwise URL round-trip through - # ``to_github_url`` -> ``parse`` would carry the prefix in the - # repo_url and the orchestrator would double-prefix download URLs. - # The prefix itself is recovered separately via - # :meth:`_extract_artifactory_prefix`. - if is_artifactory_path(path_parts): - path_parts = path_parts[2:] - for pp in path_parts: - if any(pp.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): - raise ValueError( - f"Invalid repository path: '{path}' contains a virtual file extension. " - f"Use the dict format with 'path:' for virtual packages in HTTPS URLs" - ) - - allowed_pattern = _path_segment_pattern(is_ado_host) - validate_path_segments( - "/".join(path_parts), - context="repository URL path", - reject_empty=True, - ) - for part in path_parts: - if not re.match(allowed_pattern, part): - raise ValueError(f"Invalid repository path component: {part}") - - return "/".join(path_parts), url_virtual_path - - @classmethod - def _parse_standard_url( - cls, - dependency_str: str, - is_virtual_package: bool, - virtual_path: str | None, - validated_host: str | None, - ) -> tuple[str, int | None, str, str | None, str | None, bool, str | None]: - """Parse a non-SSH dependency string (HTTPS, FQDN, or shorthand). - - Detects scheme vs shorthand, delegates host-specific resolution to - helpers, then validates the resulting URL path. - - Returns: - ``(host, port, repo_url, reference, alias, effective_is_virtual, - effective_virtual_path)`` -- the last two reflect any ADO sub-path - segments embedded in the URL itself (issue #1128). - """ - host = None - port = None - alias = None - - reference = None - if "#" in dependency_str: - repo_part, reference = dependency_str.rsplit("#", 1) - reference = reference.strip() - else: - repo_part = dependency_str - - repo_url = repo_part.strip() - - # Lowercase copy for scheme detection -- kept from the original - # repo_url so the URL-vs-shorthand check below still works after - # the virtual shorthand resolver has narrowed repo_url. - repo_url_lower = repo_url.lower() - - # For virtual packages without a URL scheme, narrow to just owner/repo - if is_virtual_package and not repo_url_lower.startswith(("https://", "http://")): - host, repo_url = cls._resolve_virtual_shorthand_repo( - repo_url, validated_host, virtual_path - ) - - # Normalize to URL format for secure parsing - if repo_url_lower.startswith(("https://", "http://")): - parsed_url = urllib.parse.urlparse(repo_url) - host = parsed_url.hostname or "" - port = parsed_url.port # capture :PORT from https://host:8443/... - # Normalise default-scheme ports (443 for HTTPS, 80 for HTTP) - # so lockfile keys are consistent regardless of URL spelling. - scheme = (parsed_url.scheme or "").lower() - if port == _DEFAULT_SCHEME_PORTS.get(scheme): - port = None - else: - parsed_url, host = cls._resolve_shorthand_to_parsed_url(repo_url, host) - - repo_url, url_virtual_path = cls._validate_url_repo_path(parsed_url) - - # If URL contained extra ADO sub-path segments, they become the virtual - # path (overriding the _detect_virtual_package result which returns - # early for https:// URLs). - effective_is_virtual = is_virtual_package - effective_virtual_path = virtual_path - if url_virtual_path is not None: - effective_is_virtual = True - effective_virtual_path = url_virtual_path - - if not host: - host = default_host() - - return host, port, repo_url, reference, alias, effective_is_virtual, effective_virtual_path - - @classmethod - def _validate_final_repo_fields(cls, host, repo_url): - """Validate the final repo_url and extract ADO organisation fields. - - Performs character-set and segment-count validation appropriate for - the detected host type (Azure DevOps vs generic git host). - - Returns: - ``(ado_organization, ado_project, ado_repo)`` -- all ``None`` - for non-ADO hosts. - """ - is_ado_final = host and is_azure_devops_hostname(host) - if is_ado_final: - if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._\- ]+/[a-zA-Z0-9._\- ]+$", repo_url): - raise ValueError( - f"Invalid Azure DevOps repository format: {repo_url}. Expected 'org/project/repo'" - ) - ado_parts = repo_url.split("/") - validate_path_segments(repo_url, context="Azure DevOps repository path") - return ado_parts[0], ado_parts[1], ado_parts[2] - - segments = repo_url.split("/") - if len(segments) < 2: - raise ValueError(f"Invalid repository format: {repo_url}. Expected 'user/repo'") - if not all(re.match(_NON_ADO_PATH_SEGMENT_RE, s) for s in segments): - raise ValueError(f"Invalid repository format: {repo_url}. Contains invalid characters") - validate_path_segments(repo_url, context="repository path") - for seg in segments: - if any(seg.endswith(ext) for ext in cls.VIRTUAL_FILE_EXTENSIONS): - raise ValueError( - f"Invalid repository format: '{repo_url}' contains a virtual file extension. " - f"Use the dict format with 'path:' for virtual packages in SSH/HTTPS URLs" - ) - return None, None, None - - @staticmethod - def _extract_artifactory_prefix(dependency_str, host): - """Extract the Artifactory VCS prefix from the original dependency string. - - Returns: - The prefix string (e.g. ``"artifactory/github"``) or ``None``. - """ - _art_str = dependency_str.split("#")[0].split("@")[0] - # Strip scheme if present (e.g., https://host/artifactory/...) - if "://" in _art_str: - _art_str = _art_str.split("://", 1)[1] - _art_segs = _art_str.replace(f"{host}/", "", 1).split("/") - if is_artifactory_path(_art_segs): - art_result = parse_artifactory_path(_art_segs) - if art_result: - return art_result[0] - return None - - @classmethod - def parse(cls, dependency_str: str) -> "DependencyReference": - """Parse a dependency string into a DependencyReference. - - Supports formats: - - user/repo - - user/repo#branch - - user/repo#v1.0.0 - - user/repo#commit_sha - - github.com/user/repo#ref - - user/repo/path/to/file.prompt.md (virtual file package) - - user/repo/skills/foo (virtual subdirectory package) - - user/repo/collections/foo (virtual subdirectory package) - - https://gitlab.com/owner/repo.git (generic HTTPS git URL) - - git@gitlab.com:owner/repo.git (SSH git URL) - - ssh://git@gitlab.com/owner/repo.git (SSH protocol URL) - - Ambiguous GitLab nested-group shorthand cannot cover every depth; use - object form (``git:`` + ``path:`` in ``apm.yml``) as the supported - escape hatch. - - - ./local/path (local filesystem path) - - /absolute/path (local filesystem path) - - ../relative/path (local filesystem path) - - Any valid FQDN is accepted as a git host (GitHub, GitLab, Bitbucket, - self-hosted instances, etc.). - - Args: - dependency_str: The dependency string to parse - - Returns: - DependencyReference: Parsed dependency reference - - Raises: - ValueError: If the dependency string format is invalid - """ - if not dependency_str.strip(): - raise ValueError("Empty dependency string") - - dependency_str = urllib.parse.unquote(dependency_str) - - if any(ord(c) < 32 for c in dependency_str): - raise ValueError("Dependency string contains invalid control characters") - - # --- Local path detection (must run before URL/host parsing) --- - if cls.is_local_path(dependency_str): - local = dependency_str.strip() - pkg_name = Path(local).name - if not pkg_name or pkg_name in (".", ".."): - raise ValueError( - f"Local path '{local}' does not resolve to a named directory. " - f"Use a path that ends with a directory name " - f"(e.g., './my-package' instead of './')." - ) - return cls( - repo_url=f"_local/{pkg_name}", - is_local=True, - local_path=local, - source="local", - ) - - if dependency_str.startswith("//"): - raise ValueError( - unsupported_host_error("//...", context="Protocol-relative URLs are not supported") - ) - - cls._reject_shorthand_alias(dependency_str) - - maybe_raise_bare_fqdn_github_gitlab_conflict(dependency_str) - - # Phase 1: detect virtual packages - is_virtual_package, virtual_path, validated_host = cls._detect_virtual_package( - dependency_str - ) - - # Phase 2: parse SSH (ssh:// URL first -- it preserves port; then SCP - # shorthand), otherwise fall back to HTTPS/shorthand parsing. - explicit_scheme: str | None = None - ssh_user: str | None = None - ssh_proto_result = cls._parse_ssh_protocol_url(dependency_str) - if ssh_proto_result: - host, port, repo_url, reference, alias, ssh_user = ssh_proto_result - explicit_scheme = "ssh" - else: - scp_result = cls._parse_ssh_url(dependency_str) - if scp_result: - host, port, repo_url, reference, alias, ssh_user = scp_result - explicit_scheme = "ssh" - else: - host, port, repo_url, reference, alias, is_virtual_package, virtual_path = ( - cls._parse_standard_url( - dependency_str, is_virtual_package, virtual_path, validated_host - ) - ) - _stripped = dependency_str.strip().lower() - if _stripped.startswith("https://"): - explicit_scheme = "https" - elif _stripped.startswith("http://"): - explicit_scheme = "http" - - # Phase 3: final validation and ADO field extraction - ado_organization, ado_project, ado_repo = cls._validate_final_repo_fields(host, repo_url) - - if alias and not re.match(r"^[a-zA-Z0-9._-]+$", alias): - raise ValueError( - f"Invalid alias: {alias}. Aliases can only contain letters, numbers, dots, underscores, and hyphens" - ) - - # Extract Artifactory prefix from the original path if applicable - is_ado_final = host and is_azure_devops_hostname(host) - artifactory_prefix = None - if host and not is_ado_final: - artifactory_prefix = cls._extract_artifactory_prefix(dependency_str, host) - - return cls( - repo_url=repo_url, - host=host, - port=port, - explicit_scheme=explicit_scheme, - reference=reference, - alias=alias, - virtual_path=virtual_path, - is_virtual=is_virtual_package, - ado_organization=ado_organization, - ado_project=ado_project, - ado_repo=ado_repo, - artifactory_prefix=artifactory_prefix, - is_insecure=urllib.parse.urlparse(dependency_str).scheme.lower() == "http", - ssh_user=ssh_user, - ) - def to_apm_yml_entry(self): """Return the entry to store in apm.yml. diff --git a/src/apm_cli/models/validation.py b/src/apm_cli/models/validation.py index 668a9710e..d2cb6b5cb 100644 --- a/src/apm_cli/models/validation.py +++ b/src/apm_cli/models/validation.py @@ -336,32 +336,46 @@ def validate_apm_package(package_path: Path) -> ValidationResult: result.package_type = pkg_type if pkg_type == PackageType.INVALID: - # Two sub-cases of INVALID: - # 1. apm.yml present but no .apm/ directory (or .apm is a file) - # 2. Nothing recognizable at all - apm_yml_path = package_path / APM_YML_FILENAME - if apm_yml_path.exists(): - apm_path = package_path / APM_DIR - if apm_path.exists() and not apm_path.is_dir(): - result.add_error(".apm must be a directory") - else: - result.add_error( - f"Not a valid APM package: {package_path.name} has apm.yml but " - "is missing the required .apm/ directory. " - "Add .apm/ with primitives (instructions, skills, etc.), " - "declare dependencies in apm.yml (curated aggregator), " - "or add skills//SKILL.md for a skill bundle." - ) + _add_invalid_package_error(package_path, result) + return result + + return _dispatch_package_validation(package_path, plugin_json_path, result) + + +def _add_invalid_package_error(package_path: Path, result: ValidationResult) -> None: + """Record the appropriate error for an INVALID package directory. + + Two sub-cases of INVALID: + 1. apm.yml present but no .apm/ directory (or .apm is a file) + 2. Nothing recognizable at all + """ + apm_yml_path = package_path / APM_YML_FILENAME + if apm_yml_path.exists(): + apm_path = package_path / APM_DIR + if apm_path.exists() and not apm_path.is_dir(): + result.add_error(".apm must be a directory") else: result.add_error( - f"Not a valid APM package: no apm.yml, SKILL.md, hooks, or " - f"plugin structure found in {package_path.name}. " - "Ensure the package has SKILL.md (skill bundle), " - "apm.yml + .apm/ (APM package), or plugin.json (Claude plugin) " - "at its root." + f"Not a valid APM package: {package_path.name} has apm.yml but " + "is missing the required .apm/ directory. " + "Add .apm/ with primitives (instructions, skills, etc.), " + "declare dependencies in apm.yml (curated aggregator), " + "or add skills//SKILL.md for a skill bundle." ) - return result + else: + result.add_error( + f"Not a valid APM package: no apm.yml, SKILL.md, hooks, or " + f"plugin structure found in {package_path.name}. " + "Ensure the package has SKILL.md (skill bundle), " + "apm.yml + .apm/ (APM package), or plugin.json (Claude plugin) " + "at its root." + ) + +def _dispatch_package_validation( + package_path: Path, plugin_json_path: Path | None, result: ValidationResult +) -> ValidationResult: + """Route a non-INVALID package to its type-specific validator.""" # Handle hook-only packages (no apm.yml or SKILL.md) if result.package_type == PackageType.HOOK_PACKAGE: return _validate_hook_package(package_path, result) From 9eb251b131f77c5f8086203c35ff146e217d18ed Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 07:48:25 +0200 Subject: [PATCH 18/21] refactor(#1078): Stage 2 commit 6 - core/adapters/cache/utils/output splits Strangler Stage 2 second-tier complexity tightening, subsystem 6 of 7. Clears the 8 remaining >800-line offenders in core/, adapters/, cache/, utils/ and output/ via mixin/leaf-module splits, plus in-place complexity fixes - net -144 lines (genuine reduction despite the mandated splits). Length splits (all parents + new siblings <800): - core/auth.py 1142->670 (+_auth_support.py 532 mixin). Auth boundary preserved: get_bearer_provider/ls-remote methods stay in auth.py; the sibling reaches the ADO bearer provider via a new AuthResolver._ado_bearer_provider() accessor so Rule A stays clean. try_with_fallback de-duplicated (PLR0911 12->8, was failing on main). - core/script_runner.py 1135->686 (+_prompt_compiler.py 183, +_runtime_commands.py 302 mixin). - core/command_logger.py 836->221 (+_install_logger.py 599). InstallLogger re-exported lazily via PEP 562 __getattr__ to break the subclass<->base import cycle while preserving the command_logger._rich_*/CommandLogger patch surface. - cache/git_cache.py 830->599 (+_git_cache_bare.py 279 mixin). - utils/github_host.py 806->659 (+_github_host_artifactory.py 173). - adapters/client/base.py 866->434 (+_base_env.py 483 mixin). - adapters/client/copilot.py 1056->697 (+_copilot_env.py 388 mixin); CopilotClientAdapter override MRO preserved (_CopilotEnvMixin first). - output/formatters.py 994->484 (+_formatters_detail.py 465 mixin); removed _format_final_summary duplicate (byte-identical to _format_results_summary) + its duplicate test classes in two test files. Complexity-only (in place, no split): target_detection.py C901 via dict dispatch (+list-arg guards); primitives/parser.py PLR0911 12->3; cache/integrity.py PLR0911 9->8. Rule B routing added for every moved reference to a patched module global. No linter-gaming (no **kwargs, no complexity noqa; one obsolete noqa removed). Verification: ruff + format clean; complexity gate clean at final Stage-2 thresholds; pylint R0801 10.00/10; backlog 15->7; full unit+acceptance 16605 passed; targeted integration 3119 passed; auth-signals lint clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/adapters/client/_base_env.py | 483 +++++++++++++ src/apm_cli/adapters/client/_copilot_env.py | 388 ++++++++++ src/apm_cli/adapters/client/base.py | 520 ++------------ src/apm_cli/adapters/client/copilot.py | 369 +--------- src/apm_cli/cache/_git_cache_bare.py | 279 ++++++++ src/apm_cli/cache/git_cache.py | 235 +------ src/apm_cli/cache/integrity.py | 26 +- src/apm_cli/core/_auth_support.py | 532 ++++++++++++++ src/apm_cli/core/_install_logger.py | 599 ++++++++++++++++ src/apm_cli/core/_prompt_compiler.py | 183 +++++ src/apm_cli/core/_runtime_commands.py | 302 ++++++++ src/apm_cli/core/auth.py | 566 ++------------- src/apm_cli/core/command_logger.py | 663 +----------------- src/apm_cli/core/script_runner.py | 455 +----------- src/apm_cli/core/target_detection.py | 157 ++--- src/apm_cli/output/_formatters_detail.py | 465 ++++++++++++ src/apm_cli/output/formatters.py | 516 +------------- src/apm_cli/primitives/parser.py | 78 +-- src/apm_cli/utils/_github_host_artifactory.py | 173 +++++ src/apm_cli/utils/github_host.py | 173 +---- tests/unit/test_output_formatters_phase3.py | 170 ----- .../unit/test_output_formatters_rendering.py | 170 ----- 22 files changed, 3679 insertions(+), 3823 deletions(-) create mode 100644 src/apm_cli/adapters/client/_base_env.py create mode 100644 src/apm_cli/adapters/client/_copilot_env.py create mode 100644 src/apm_cli/cache/_git_cache_bare.py create mode 100644 src/apm_cli/core/_auth_support.py create mode 100644 src/apm_cli/core/_install_logger.py create mode 100644 src/apm_cli/core/_prompt_compiler.py create mode 100644 src/apm_cli/core/_runtime_commands.py create mode 100644 src/apm_cli/output/_formatters_detail.py create mode 100644 src/apm_cli/utils/_github_host_artifactory.py diff --git a/src/apm_cli/adapters/client/_base_env.py b/src/apm_cli/adapters/client/_base_env.py new file mode 100644 index 000000000..d9c1566c7 --- /dev/null +++ b/src/apm_cli/adapters/client/_base_env.py @@ -0,0 +1,483 @@ +"""Env-resolution mixin and module-level helpers for MCPClientAdapter. + +Contains the pure-helper functions and the _BaseEnvMixin class that +MCPClientAdapter composes in. Kept in a sibling module so base.py stays +under 800 lines while all helpers remain importable from base.py via +re-exports. +""" + +import os +import re +from typing import ClassVar + +_INPUT_VAR_RE = re.compile(r"\$\{input:([^}]+)\}") + +# Matches ${VAR} and ${env:VAR}, capturing VAR. Intentionally does NOT match +# ${input:VAR} (the optional ``env:`` group cannot also satisfy ``input:``), +# nor GitHub Actions ``${{ ... }}`` templates (the second ``{`` fails the +# identifier class). This keeps env-var handling fully disjoint from input +# variable handling, so existing _INPUT_VAR_RE call sites are unaffected. +_ENV_VAR_RE = re.compile(r"\$\{(?:env:)?([A-Za-z_][A-Za-z0-9_]*)\}") + +# Superset of _ENV_VAR_RE that also matches the legacy ```` syntax +# (uppercase identifier only). Used as the single-pass translation target so +# resolved values are NOT re-scanned -- a literal value whose text happens to +# contain ``${...}`` does not get recursively expanded. ``${input:...}`` is +# intentionally not matched here so input-variable handling stays disjoint. +_ENV_PLACEHOLDER_RE = re.compile(r"<([A-Z_][A-Z0-9_]*)>|" + _ENV_VAR_RE.pattern) + +# Detects the legacy ```` placeholder syntax only. Used to aggregate +# deprecation warnings across all servers in a single install run. +_LEGACY_ANGLE_VAR_RE = re.compile(r"<([A-Z_][A-Z0-9_]*)>") + + +def _translate_env_placeholder(value): + """Pure-textual translation of env-var placeholders to the canonical + ``${VAR}`` runtime-substitution syntax. + + Security-critical helper for issue #1152: MUST NOT read ``os.environ`` + and MUST NOT resolve placeholders to literal values. Runtimes that + support runtime substitution (Copilot CLI) resolve ``${VAR}`` from the + host environment at server-start, so APM emits placeholders verbatim + rather than baking secrets to disk. + + Translations: + ``${env:VAR}`` -> ``${VAR}`` (strip ``env:`` prefix) + ``${VAR}`` -> ``${VAR}`` (no-op) + ```` -> ``${VAR}`` (legacy syntax migration) + ``${VAR:-default}``-> passthrough (regex doesn't match) + ``$VAR`` (bare) -> passthrough (regex doesn't match) + ``${input:foo}`` -> passthrough (regex doesn't match) + non-string -> passthrough + + Idempotent: applying twice yields the same result as applying once. + """ + if not isinstance(value, str): + return value + + def _to_brace(match): + # group(1) = legacy ; group(2) = ${VAR} / ${env:VAR} + var_name = match.group(1) or match.group(2) + return "${" + var_name + "}" + + return _ENV_PLACEHOLDER_RE.sub(_to_brace, value) + + +def _extract_legacy_angle_vars(value): + """Return the set of legacy ```` names present in *value*. + + Used to aggregate deprecation warnings across all servers in a single + install run, so authors see one helpful list instead of one warning per + occurrence. + """ + if not isinstance(value, str): + return set() + return set(_LEGACY_ANGLE_VAR_RE.findall(value)) + + +def _has_env_placeholder(value): + """True if *value* is a string containing any recognised env-var + placeholder syntax (``${VAR}``, ``${env:VAR}``, or legacy ````). + + Used to distinguish placeholder-sourced env values (which translate) + from hardcoded literal defaults (which stay literal). + """ + if not isinstance(value, str): + return False + return bool(_ENV_PLACEHOLDER_RE.search(value)) + + +def _stringify_env_literal(value): + """Return MCP env literal values in the manifest ``map`` shape.""" + if isinstance(value, bool): + return str(value).lower() + return str(value) + + +class _BaseEnvMixin: + """Env-resolution logic composed into MCPClientAdapter. + + All methods access instance state (``_last_env_placeholder_keys``, + ``_last_legacy_angle_vars``) and adapter helpers + (``_format_runtime_env_placeholder``, ``_translate_env_placeholder_for_runtime``) + that are defined on ``MCPClientAdapter``. This is the standard mixin + pattern: the mixin trusts the final class to provide those attributes. + """ + + # GitHub MCP server defaults: not secrets, preserved literal in translate + # mode and used as fallbacks in legacy mode. The defaults apply regardless + # of which client CLI runs the server, so they live on the base. + _DEFAULT_GITHUB_ENV: ClassVar[dict[str, str]] = { + "GITHUB_TOOLSETS": "context", + "GITHUB_DYNAMIC_TOOLSETS": "1", + } + + @staticmethod + def _should_skip_env_prompts(env_overrides): + """True when the caller has already collected env vars (managed mode), + when APM_E2E_TESTS is set, or when stdin/stdout is not a TTY. + + Centralising this policy keeps the resolver paths consistent and + avoids subtle drift between ``_resolve_environment_variables`` and + ``_resolve_env_variable``. + """ + import sys + + if env_overrides: + return True + if os.getenv("APM_E2E_TESTS") == "1": + return True + return not (sys.stdin.isatty() and sys.stdout.isatty()) + + def _resolve_environment_variables(self, env_vars, env_overrides=None): + """Resolve (or translate) declared environment variables. + + Behaviour follows ``self._supports_runtime_env_substitution``: + translate-mode (Copilot CLI) emits ``${VAR}`` placeholders verbatim + so the runtime resolves them at server-start (see issue #1152); + legacy-mode resolves placeholders to literal values via env_overrides + -> os.environ -> optional interactive prompt. + + Args: + env_vars: Either a ``dict[name, value-or-placeholder]`` from a + self-defined stdio dep (``_raw_stdio["env"]``), or a + ``list[{name, description, required}]`` from the registry. + env_overrides: Pre-collected env-var overrides (ignored in + translate mode). + + Returns: + dict: ``{name: value}`` -- placeholder string in translate + mode, literal value in legacy mode. + """ + # ---- translate mode, dict shape (self-defined stdio in apm.yml) ---- + if isinstance(env_vars, dict) and self._supports_runtime_env_substitution: + # Value type is intentionally untyped: most entries are translated + # placeholder strings, but non-string values (e.g. an int/bool + # YAML scalar) are passed through verbatim and serialised by the + # adapter's config writer (JSON/TOML). + translated: dict = {} + placeholder_keys: list[str] = [] + for name, raw_value in env_vars.items(): + if not name: + continue + if not isinstance(raw_value, str): + translated[name] = raw_value + continue + if _has_env_placeholder(raw_value): + self._last_legacy_angle_vars.update(_extract_legacy_angle_vars(raw_value)) + translated[name] = self._translate_env_placeholder_for_runtime(raw_value) + placeholder_keys.extend( + m.group(1) for m in _ENV_VAR_RE.finditer(translated[name]) + ) + elif ( + name in self._DEFAULT_GITHUB_ENV and raw_value == self._DEFAULT_GITHUB_ENV[name] + ): + translated[name] = raw_value + else: + # Literal value present in apm.yml -- replace with a + # runtime placeholder so the secret never touches disk. + translated[name] = self._format_runtime_env_placeholder(name) + placeholder_keys.append(name) + self._last_env_placeholder_keys = set(placeholder_keys) + return translated + + # ---- translate mode, registry list shape ---- + if self._supports_runtime_env_substitution: + resolved: dict[str, str] = {} + placeholder_keys: list[str] = [] + for env_var in env_vars: + if not isinstance(env_var, dict): + continue + name = env_var.get("name", "") + if not name: + continue + if name in self._DEFAULT_GITHUB_ENV: + resolved[name] = self._DEFAULT_GITHUB_ENV[name] + else: + resolved[name] = self._format_runtime_env_placeholder(name) + placeholder_keys.append(name) + self._last_env_placeholder_keys = set(placeholder_keys) + return resolved + + # ---- legacy mode, dict shape (self-defined stdio in apm.yml) ---- + # Issue #1266 / #1222: ``_raw_stdio["env"]`` is a plain dict. Each + # value is resolved via the same single-value pipeline used for + # header values so all three placeholder syntaxes (````, + # ``${VAR}``, ``${env:VAR}``) behave consistently across adapters. + # + # Note the deliberate semantic divergence from the legacy-list branch + # below: empty strings authored in apm.yml are preserved as-is and + # ``_DEFAULT_GITHUB_ENV`` fallbacks are NOT applied, because a value + # explicitly written by the user expresses intent, whereas an empty + # value coming from ``env_overrides`` / ``os.environ`` for a + # registry-declared schema entry means "no value supplied, use the + # default if one exists". + if isinstance(env_vars, dict): + resolved = {} + for name, value in env_vars.items(): + if not name: + continue + if isinstance(value, str): + resolved[name] = self._resolve_env_variable( + name, value, env_overrides=env_overrides + ) + elif value is not None: + resolved[name] = str(value) + return resolved + + # ---- legacy mode, registry list shape ---- + from rich.prompt import Prompt + + env_overrides = env_overrides or {} + skip_prompting = self._should_skip_env_prompts(env_overrides) + + # Variables explicitly provided with empty values mean "use the default". + empty_value_vars = {k for k, v in env_overrides.items() if not v or not v.strip()} + + resolved = {} + for env_var in env_vars: + if not isinstance(env_var, dict): + continue + name = env_var.get("name", "") + if not name: + continue + required = env_var.get("required", True) + + value = env_overrides.get(name) or os.getenv(name) + if not value and required and not skip_prompting: + prompt_text = f"Enter value for {name}" + if description := env_var.get("description", ""): + prompt_text += f" ({description})" + value = Prompt.ask( + prompt_text, + password="token" in name.lower() or "key" in name.lower(), + ) + + if value and value.strip(): + resolved[name] = value + elif name in self._DEFAULT_GITHUB_ENV and ( + name in empty_value_vars or not required or skip_prompting + ): + resolved[name] = self._DEFAULT_GITHUB_ENV[name] + + return resolved + + def _resolve_env_variable(self, name, value, env_overrides=None): + """Resolve (or translate) a single env-var value. + + Used for header values and for individual entries in dict-shape + env blocks. The ``name`` parameter is currently unused by the + method body but kept in the signature because every call site + (headers, dict iteration) already has the name in hand, and + passing it preserves call-site symmetry with future hooks that + may want to dispatch on it. + + Args: + name: Env-var name (currently unused, see above). + value: Env-var value possibly containing placeholders. + env_overrides: Pre-collected overrides (ignored in translate mode). + """ + if self._supports_runtime_env_substitution: + legacy_keys = _extract_legacy_angle_vars(value) + self._last_legacy_angle_vars.update(legacy_keys) + self._last_env_placeholder_keys.update(legacy_keys) + for match in _ENV_VAR_RE.finditer(value): + self._last_env_placeholder_keys.add(match.group(1)) + return self._translate_env_placeholder_for_runtime(value) + + from rich.prompt import Prompt + + env_overrides = env_overrides or {} + skip_prompting = self._should_skip_env_prompts(env_overrides) + + # Three accepted placeholder syntaxes resolved against + # env_overrides -> os.environ -> optional interactive prompt. + # Single-pass substitution preserves the legacy ```` semantics: + # resolved values are NOT re-scanned for further expansion. + def _replace(match): + env_name = match.group(1) or match.group(2) + env_value = env_overrides.get(env_name) or os.getenv(env_name) + if not env_value and not skip_prompting: + env_value = Prompt.ask( + f"Enter value for {env_name}", + password="token" in env_name.lower() or "key" in env_name.lower(), + ) + return env_value if env_value else match.group(0) + + return _ENV_PLACEHOLDER_RE.sub(_replace, value) + + def _resolve_variable_placeholders(self, value, resolved_env, runtime_vars): + """Resolve env-var and APM template placeholders in argument strings. + + Translate mode rewrites all three env-var placeholder syntaxes to + ``${VAR}`` (so the runtime can resolve them at server-start); legacy + mode resolves only the legacy ```` form against ``resolved_env`` + and leaves the newer ``${VAR}`` / ``${env:VAR}`` syntaxes untouched + for backward compatibility. APM template variables (``{runtime_var}``) + are always resolved at install time because they are an APM-internal + concept the target runtime cannot interpret. + + Args: + value: String possibly containing placeholders. + resolved_env: Resolved env-var literals (legacy mode) or + placeholder strings (translate mode). + runtime_vars: Resolved APM template variables. + + Returns: + str: ``value`` with placeholders translated or resolved. + """ + if not value: + return value + + processed = str(value) + + if self._supports_runtime_env_substitution: + self._last_legacy_angle_vars.update(_extract_legacy_angle_vars(processed)) + processed = self._translate_env_placeholder_for_runtime(processed) + else: + # Resolve only the legacy ```` form; newer syntaxes are + # preserved verbatim for backward compatibility. + def _replace_legacy_angle(match): + return resolved_env.get(match.group(1), match.group(0)) + + processed = _LEGACY_ANGLE_VAR_RE.sub(_replace_legacy_angle, processed) + + # Resolve APM ``{runtime_var}`` template variables. The negative + # lookbehind on ``$`` ensures we never accidentally match the brace + # of an already-translated ``${VAR}`` env placeholder. + if runtime_vars: + runtime_pattern = re.compile(r"(? dict: + """Resolve *env_vars* from overrides, environment, or interactive prompts. + + Identical logic shared between + :meth:`CopilotClientAdapter._process_environment_variables` and + :meth:`CodexClientAdapter._process_environment_variables`. + + All imports are deferred so that ``rich.prompt`` (an optional + dependency) is never imported at module load time. + + Args: + env_vars: List of env-var descriptor dicts from the registry. + env_overrides: Pre-collected ``{name: value}`` overrides (empty + dict when none). + default_github_env: Mapping of well-known GitHub variable names + to their preferred environment-variable lookup names. + + Returns: + ``resolved`` dict mapping each env-var name to its resolved value + (empty string when unresolvable). + """ + import sys + + # Rule B: route through base module so tests patching + # apm_cli.adapters.client.base._rich_warning are intercepted. + from apm_cli.adapters.client import base as _b + + env_overrides = env_overrides or {} + resolved: dict = {} + + # Determine whether interactive prompting is available. + # If env_overrides is provided the CLI has already collected variables -- never prompt again. + skip_prompting = ( + bool(env_overrides) + or bool(os.getenv("CI")) + or bool(os.getenv("APM_E2E_TESTS")) + or not sys.stdout.isatty() + or not sys.stdin.isatty() + ) + + # First pass: identify variables with empty values to warn the user. + empty_value_vars = [ev for ev in env_vars if ev.get("required") and not ev.get("value")] + if empty_value_vars and skip_prompting: + var_names = [ev.get("name") for ev in empty_value_vars] + _b._rich_warning( + f"Warning: The following required environment variables have no default " + f"value and cannot be prompted in non-interactive mode: {var_names}" + ) + + for env_var in env_vars: + name = env_var.get("name", "") + if not name: + continue + + # Priority 1: caller-supplied override. + # An explicit empty (or whitespace-only) value is treated as + # "user cleared this". For names with a GitHub-style default the + # logic falls through so the literal default wins; for names + # without a default the entry is dropped from the resolved map. + if name in env_overrides: + override_value = env_overrides[name] + if isinstance(override_value, str) and not override_value.strip(): + if name not in default_github_env: + continue + else: + resolved[name] = override_value + continue + + # Priority 2: check GitHub-specific defaults (values are literal defaults, not env-var names) + if name in default_github_env: + resolved[name] = os.getenv(name) or default_github_env[name] + continue + + # Priority 3: environment variable with the same name + env_val = os.getenv(name, "") + if env_val: + resolved[name] = env_val + continue + + # Priority 4: interactive prompt + default_value = env_var.get("value", "") + required = env_var.get("required", False) + + if not skip_prompting: + from rich.prompt import Prompt + + description = env_var.get("description", "") + prompt_text = f"Enter value for {name}" + if description: + prompt_text += f" ({description})" + is_secret = "token" in name.lower() or "key" in name.lower() + user_input = Prompt.ask( + prompt_text, + default=default_value, + password=True # noqa: SIM210 + if is_secret + else False, + ) + resolved[name] = user_input + elif default_value: + resolved[name] = default_value + elif required: + _b._rich_warning( + f"Warning: Required environment variable '{name}' could not be resolved. " + f"The MCP server may not function correctly." + ) + resolved[name] = "" + else: + resolved[name] = default_value + + return resolved diff --git a/src/apm_cli/adapters/client/_copilot_env.py b/src/apm_cli/adapters/client/_copilot_env.py new file mode 100644 index 000000000..d4133c0ef --- /dev/null +++ b/src/apm_cli/adapters/client/_copilot_env.py @@ -0,0 +1,388 @@ +"""Env-resolution and docker-args mixin for CopilotClientAdapter. + +Extracted from copilot.py to keep that module under 800 lines while +preserving full MRO override semantics: ``CopilotClientAdapter`` lists +``_CopilotEnvMixin`` before ``MCPClientAdapter`` so these methods shadow +the base implementations for every ``CopilotClientAdapter`` instance. + +Rule B: none of these methods reference patched module-level names from +copilot.py (``_rich_warning``, ``SimpleRegistryClient``, etc.), so no +function-level late imports are required here. +""" + +import os + +from ._mcp_runtime_args import process_v01_value_hint_arg +from .base import ( + _ENV_PLACEHOLDER_RE, + _ENV_VAR_RE, + _extract_legacy_angle_vars, + _has_env_placeholder, + _stringify_env_literal, +) + + +class _CopilotEnvMixin: + """Env-resolution and docker-args helpers composed into CopilotClientAdapter. + + Overrides the corresponding base-class methods so that Copilot CLI's + translate-mode behaviour (emit ``${VAR}`` placeholders, never read secrets + at install time) takes effect for every ``CopilotClientAdapter`` instance + while sibling adapters (Cursor, Claude, etc.) keep the legacy-resolve path. + """ + + def _resolve_environment_variables(self, env_vars, env_overrides=None): + """Resolve (or translate) declared environment variables. + + Behaviour depends on ``self._supports_runtime_env_substitution``: + + - True (Copilot CLI default): each declared env var ``NAME`` gets a + ``${NAME}`` placeholder that Copilot CLI resolves at server-start + from the host environment. Hardcoded literal defaults + (``GITHUB_TOOLSETS``, ``GITHUB_DYNAMIC_TOOLSETS``) stay literal + because they are not secrets and provide essential server + configuration. The host environment is NOT read; secrets never + touch disk. See issue #1152 for context. + + - False (legacy / sibling-adapter behaviour): resolve each variable + to its literal value via ``env_overrides`` -> ``os.environ`` -> + optional interactive prompt, baking the result into the config. + + Args: + env_vars (list): List of environment variable definitions from + server info (each item is ``{name, description, required}``). + env_overrides (dict, optional): Pre-collected environment + variable overrides. Ignored in translate mode. + + Returns: + dict: ``{name: value}`` -- placeholder string in translate mode, + literal value in legacy mode. + """ + # Hardcoded literal defaults that supply essential server behaviour + # rather than secrets. These stay literal in translate mode so that + # tool-selection still works without a user export step. + default_github_env = {"GITHUB_TOOLSETS": "context", "GITHUB_DYNAMIC_TOOLSETS": "1"} + + # Self-defined stdio deps pass ``env`` as a plain dict + # ({NAME: value-or-placeholder}); registry-sourced deps pass a list + # of {name, description, required} dicts. Translate-mode handling + # for the dict shape: each value is either already a placeholder + # (translate it to the adapter's runtime form) or a literal + # (record the key as a placeholder reference and emit a runtime + # placeholder so the value never lands on disk). See issue #1152. + if isinstance(env_vars, dict) and self._supports_runtime_env_substitution: + translated = {} + placeholder_keys = [] + for name, raw_value in env_vars.items(): + if not name: + continue + if raw_value is None: + continue + if not isinstance(raw_value, str): + translated[name] = _stringify_env_literal(raw_value) + continue + if _has_env_placeholder(raw_value): + self._last_legacy_angle_vars.update(_extract_legacy_angle_vars(raw_value)) + translated[name] = self._translate_env_placeholder_for_runtime(raw_value) + for match in _ENV_VAR_RE.finditer(translated[name]): + placeholder_keys.append(match.group(1)) + elif name in default_github_env and raw_value == default_github_env[name]: + translated[name] = raw_value + else: + # Literal value present in apm.yml -- replace with a + # runtime placeholder so the secret never touches disk. + translated[name] = self._format_runtime_env_placeholder(name) + placeholder_keys.append(name) + self._last_env_placeholder_keys = set(placeholder_keys) + return translated + + if self._supports_runtime_env_substitution: + resolved = {} + placeholder_keys = [] + for env_var in env_vars: + if not isinstance(env_var, dict): + continue + name = env_var.get("name", "") + if not name: + continue + if name in default_github_env: + # Non-secret literal default -- preserve as-is. + resolved[name] = default_github_env[name] + else: + # Emit a runtime-substitution placeholder; APM never reads + # or stores the value. + resolved[name] = self._format_runtime_env_placeholder(name) + placeholder_keys.append(name) + # Record for the post-install summary line and the + # security-improvement notice. + self._last_env_placeholder_keys = set(placeholder_keys) + return resolved + + if isinstance(env_vars, dict): + # Mirror the base-class dict-shape branch but coerce non-string + # scalars through Copilot's hardened ``_stringify_env_literal`` + # helper so booleans/ints land as the strings Copilot CLI expects. + return { + name: ( + self._resolve_env_variable(name, value, env_overrides=env_overrides) + if isinstance(value, str) + else _stringify_env_literal(value) + ) + for name, value in env_vars.items() + if name and value is not None + } + + return self._resolve_env_vars_with_prompting(env_vars, env_overrides, default_github_env) + + def _resolve_env_variable(self, name, value, env_overrides=None): + """Resolve (or translate) a single environment variable value. + + Behaviour depends on ``self._supports_runtime_env_substitution``: + + - True (Copilot CLI default): translate placeholders to Copilot CLI's + native runtime substitution syntax (``${VAR}``). The host + environment is NOT read; the secret never touches disk. See issue + #1152 for context. Legacy ```` offenders are tracked for the + aggregated deprecation warning emitted by + ``configure_mcp_server``. + + - False (legacy / sibling-adapter behaviour): resolve placeholders + to literal values via ``env_overrides`` -> ``os.environ`` -> + optional interactive prompt, baking the result into the config. + + Args: + name (str): Environment variable name. + value (str): Environment variable value or placeholder. + env_overrides (dict, optional): Pre-collected environment + variable overrides. Ignored in translate mode. + + Returns: + str: Translated placeholder (translate mode) or resolved + literal value (legacy mode). + """ + if self._supports_runtime_env_substitution: + # Track legacy offenders for the aggregated deprecation + # warning. Translation itself is a pure-textual rewrite. + self._last_legacy_angle_vars.update(_extract_legacy_angle_vars(value)) + # Track env-var names referenced via this header/value so the + # security-upgrade detector and per-server summary can see + # them (the env-block path tracks via _resolve_environment_variables). + for match in _ENV_VAR_RE.finditer(value): + self._last_env_placeholder_keys.add(match.group(1)) + return self._translate_env_placeholder_for_runtime(value) + + import sys + + from rich.prompt import Prompt + + env_overrides = env_overrides or {} + # If env_overrides is provided, it means we're in managed environment collection mode + skip_prompting = bool(env_overrides) + + # Check for CI/automated environment via APM_E2E_TESTS flag (more reliable than TTY detection) + if os.getenv("APM_E2E_TESTS") == "1": + skip_prompting = True + + # Also skip prompting if we're in a non-interactive environment (fallback) + is_interactive = sys.stdin.isatty() and sys.stdout.isatty() + if not is_interactive: + skip_prompting = True + + # Three accepted placeholder syntaxes (see _COPILOT_ENV_RE at module + # top), all resolved against env_overrides -> os.environ -> optional + # interactive prompt. Single-pass substitution preserves the legacy + # ```` semantics: resolved values are not re-scanned for further + # placeholder expansion. + def _replace(match): + # Group 1 = legacy ; group 2 = ${VAR} / ${env:VAR}. + env_name = match.group(1) or match.group(2) + env_value = env_overrides.get(env_name) or os.getenv(env_name) + if not env_value and not skip_prompting: + prompt_text = f"Enter value for {env_name}" + env_value = Prompt.ask( + prompt_text, + password=True # noqa: SIM210 + if "token" in env_name.lower() or "key" in env_name.lower() + else False, + ) + return env_value if env_value else match.group(0) + + return _ENV_PLACEHOLDER_RE.sub(_replace, value) + + def _inject_env_vars_into_docker_args(self, docker_args, env_vars): + """Inject environment variables into Docker arguments following registry template. + + The registry provides a complete Docker command template in runtime_arguments. + We need to inject actual environment variable values while respecting the template structure. + Also ensures required Docker flags (-i, --rm) are present. + + Args: + docker_args (list): Docker arguments from registry runtime_arguments. + env_vars (dict): Resolved environment variables. + + Returns: + list: Docker arguments with environment variables properly injected and required flags. + """ + if not env_vars: + env_vars = {} + + result = [] + i = 0 + has_interactive = False + has_rm = False + + # Check for existing -i and --rm flags + for arg in docker_args: + if arg == "-i" or arg == "--interactive": # noqa: PLR1714 + has_interactive = True + elif arg == "--rm": + has_rm = True + + while i < len(docker_args): + arg = docker_args[i] + result.append(arg) + + # When we encounter "run", inject required flags first + if arg == "run": + # Add -i flag if not present + if not has_interactive: + result.append("-i") + + # Add --rm flag if not present + if not has_rm: + result.append("--rm") + + # If this is an environment variable name placeholder, replace with actual env var + if arg in env_vars: + # This is an environment variable name that should be replaced with -e VAR=value + result.pop() # Remove the env var name + result.extend(["-e", f"{arg}={env_vars[arg]}"]) + elif arg == "-e" and i + 1 < len(docker_args): + # Handle -e flag followed by env var name + next_arg = docker_args[i + 1] + if next_arg in env_vars: + result.append(f"{next_arg}={env_vars[next_arg]}") + i += 1 # Skip the next argument as we've processed it + else: + # Keep the original argument structure + result.append(next_arg) + i += 1 + + i += 1 + + # Add any remaining environment variables that weren't in the template + template_env_vars = set() + for arg in docker_args: + if arg in env_vars: + template_env_vars.add(arg) + + for env_name, env_value in env_vars.items(): + if env_name not in template_env_vars: + # Find a good place to insert additional env vars (after "run" but before image name) + insert_pos = len(result) + for idx, arg in enumerate(result): + if arg == "run": + # Insert after run command but before image name (usually last arg) + insert_pos = min(len(result) - 1, idx + 1) + break + + result.insert(insert_pos, "-e") + result.insert(insert_pos + 1, f"{env_name}={env_value}") + + # Add default GitHub MCP server environment variables if not already present + # Only add defaults for variables that were NOT explicitly provided (even if empty) + default_github_env = {"GITHUB_TOOLSETS": "context", "GITHUB_DYNAMIC_TOOLSETS": "1"} # noqa: F841 + + existing_env_vars = set() + for i, arg in enumerate(result): + if arg == "-e" and i + 1 < len(result): + env_spec = result[i + 1] + if "=" in env_spec: + env_name = env_spec.split("=", 1)[0] + existing_env_vars.add(env_name) + + # For Copilot, defaults are already added during environment resolution + # This section is kept for compatibility but shouldn't add duplicates + + return result + + def _inject_docker_env_vars(self, args, env_vars): + """Inject environment variables into Docker arguments. + + Args: + args (list): Original Docker arguments. + env_vars (dict): Environment variables to inject. + + Returns: + list: Updated arguments with environment variables injected. + """ + result = [] + + for arg in args: + result.append(arg) + # If this is a docker run command, inject environment variables after "run" + if arg == "run" and env_vars: + for env_name, env_value in env_vars.items(): + result.extend(["-e", f"{env_name}={env_value}"]) + + return result + + def _process_arguments(self, arguments, resolved_env=None, runtime_vars=None): + """Process argument objects to extract simple string values with environment and runtime variable resolution. + + Args: + arguments (list): List of argument objects from registry. + resolved_env (dict): Resolved environment variables. + runtime_vars (dict): Resolved runtime variables. + + Returns: + list: List of processed argument strings. + """ + if resolved_env is None: + resolved_env = {} + if runtime_vars is None: + runtime_vars = {} + + processed = [] + + for arg in arguments: + if isinstance(arg, dict): + # Extract value from argument object + arg_type = arg.get("type", "") + if arg_type == "positional": + value = arg.get("value", arg.get("default", "")) + if value: + # Resolve both environment and runtime variable placeholders with actual values + processed_value = self._resolve_variable_placeholders( + str(value), resolved_env, runtime_vars + ) + processed.append(processed_value) + elif arg_type == "named": + name = arg.get("name", "") + value = arg.get("value", arg.get("default", "")) + if name: + processed.append(name) + # For named arguments, only add value if it's different from the flag name + # and not empty + if value and value != name and not value.startswith("-"): + processed_value = self._resolve_variable_placeholders( + str(value), resolved_env, runtime_vars + ) + processed.append(processed_value) + elif not arg_type and "value_hint" in arg: + # v0.1 registry format: shared helper handles is_required + # guard and {var_name} placeholder substitution. + value = process_v01_value_hint_arg(arg, runtime_vars) + if value: + processed_value = self._resolve_variable_placeholders( + value, resolved_env, runtime_vars + ) + processed.append(processed_value) + elif isinstance(arg, str): + # Already a string, use as-is but resolve variable placeholders + processed_value = self._resolve_variable_placeholders( + arg, resolved_env, runtime_vars + ) + processed.append(processed_value) + + return processed diff --git a/src/apm_cli/adapters/client/base.py b/src/apm_cli/adapters/client/base.py index bad48e433..935722be9 100644 --- a/src/apm_cli/adapters/client/base.py +++ b/src/apm_cli/adapters/client/base.py @@ -1,98 +1,37 @@ """Base adapter interface for MCP clients.""" import os -import re from abc import ABC, abstractmethod from pathlib import Path -from typing import ClassVar from ...utils.console import _rich_error, _rich_warning - -_INPUT_VAR_RE = re.compile(r"\$\{input:([^}]+)\}") - -# Matches ${VAR} and ${env:VAR}, capturing VAR. Intentionally does NOT match -# ${input:VAR} (the optional ``env:`` group cannot also satisfy ``input:``), -# nor GitHub Actions ``${{ ... }}`` templates (the second ``{`` fails the -# identifier class). This keeps env-var handling fully disjoint from input -# variable handling, so existing _INPUT_VAR_RE call sites are unaffected. -_ENV_VAR_RE = re.compile(r"\$\{(?:env:)?([A-Za-z_][A-Za-z0-9_]*)\}") - -# Superset of _ENV_VAR_RE that also matches the legacy ```` syntax -# (uppercase identifier only). Used as the single-pass translation target so -# resolved values are NOT re-scanned -- a literal value whose text happens to -# contain ``${...}`` does not get recursively expanded. ``${input:...}`` is -# intentionally not matched here so input-variable handling stays disjoint. -_ENV_PLACEHOLDER_RE = re.compile(r"<([A-Z_][A-Z0-9_]*)>|" + _ENV_VAR_RE.pattern) - -# Detects the legacy ```` placeholder syntax only. Used to aggregate -# deprecation warnings across all servers in a single install run. -_LEGACY_ANGLE_VAR_RE = re.compile(r"<([A-Z_][A-Z0-9_]*)>") - - -def _translate_env_placeholder(value): - """Pure-textual translation of env-var placeholders to the canonical - ``${VAR}`` runtime-substitution syntax. - - Security-critical helper for issue #1152: MUST NOT read ``os.environ`` - and MUST NOT resolve placeholders to literal values. Runtimes that - support runtime substitution (Copilot CLI) resolve ``${VAR}`` from the - host environment at server-start, so APM emits placeholders verbatim - rather than baking secrets to disk. - - Translations: - ``${env:VAR}`` -> ``${VAR}`` (strip ``env:`` prefix) - ``${VAR}`` -> ``${VAR}`` (no-op) - ```` -> ``${VAR}`` (legacy syntax migration) - ``${VAR:-default}``-> passthrough (regex doesn't match) - ``$VAR`` (bare) -> passthrough (regex doesn't match) - ``${input:foo}`` -> passthrough (regex doesn't match) - non-string -> passthrough - - Idempotent: applying twice yields the same result as applying once. - """ - if not isinstance(value, str): - return value - - def _to_brace(match): - # group(1) = legacy ; group(2) = ${VAR} / ${env:VAR} - var_name = match.group(1) or match.group(2) - return "${" + var_name + "}" - - return _ENV_PLACEHOLDER_RE.sub(_to_brace, value) - - -def _extract_legacy_angle_vars(value): - """Return the set of legacy ```` names present in *value*. - - Used to aggregate deprecation warnings across all servers in a single - install run, so authors see one helpful list instead of one warning per - occurrence. - """ - if not isinstance(value, str): - return set() - return set(_LEGACY_ANGLE_VAR_RE.findall(value)) - - -def _has_env_placeholder(value): - """True if *value* is a string containing any recognised env-var - placeholder syntax (``${VAR}``, ``${env:VAR}``, or legacy ````). - - Used to distinguish placeholder-sourced env values (which translate) - from hardcoded literal defaults (which stay literal). - """ - if not isinstance(value, str): - return False - return bool(_ENV_PLACEHOLDER_RE.search(value)) - - -def _stringify_env_literal(value): - """Return MCP env literal values in the manifest ``map`` shape.""" - if isinstance(value, bool): - return str(value).lower() - return str(value) - - -class MCPClientAdapter(ABC): +from ._base_env import ( + _ENV_PLACEHOLDER_RE, + _ENV_VAR_RE, + _INPUT_VAR_RE, + _LEGACY_ANGLE_VAR_RE, + _BaseEnvMixin, + _extract_legacy_angle_vars, + _has_env_placeholder, + _stringify_env_literal, + _translate_env_placeholder, +) + +# Re-export so existing ``from .base import _translate_env_placeholder`` etc. +# in sibling modules keep working unchanged. +__all__ = [ + "_ENV_PLACEHOLDER_RE", + "_ENV_VAR_RE", + "_INPUT_VAR_RE", + "_LEGACY_ANGLE_VAR_RE", + "_extract_legacy_angle_vars", + "_has_env_placeholder", + "_stringify_env_literal", + "_translate_env_placeholder", +] + + +class MCPClientAdapter(_BaseEnvMixin, ABC): """Base adapter for MCP clients.""" # Identifier matching the corresponding ``KNOWN_TARGETS`` entry name. @@ -240,21 +179,25 @@ def _infer_registry_name(package): name = package.get("name", "") runtime_hint = package.get("runtime_hint", "") - # Infer from runtime_hint - if runtime_hint in ("npx", "npm"): - return "npm" - if runtime_hint in ("uvx", "pip", "pipx"): - return "pypi" - if runtime_hint == "docker": - return "docker" - if runtime_hint in ("dotnet", "dnx"): - return "nuget" + # Lookup tables replace per-value if/return chains. + _hint_map = { + "npx": "npm", + "npm": "npm", + "uvx": "pypi", + "pip": "pypi", + "pipx": "pypi", + "docker": "docker", + "dotnet": "nuget", + "dnx": "nuget", + } + if runtime_hint in _hint_map: + return _hint_map[runtime_hint] # Infer from package name patterns - if name.startswith("@") and "/" in name: - return "npm" # scoped npm package, e.g. @azure/mcp - if name.startswith(("ghcr.io/", "mcr.microsoft.com/", "docker.io/")): - return "docker" + if (name.startswith("@") and "/" in name) or name.startswith( + ("ghcr.io/", "mcr.microsoft.com/", "docker.io/") + ): + return "npm" if name.startswith("@") else "docker" if name.startswith("https://") and name.endswith(".mcpb"): return "mcpb" # PascalCase with dots usually means nuget (e.g. Azure.Mcp) @@ -343,268 +286,6 @@ def normalize_project_arg(self, value): return "." return value - # -- Env-var placeholder resolution ------------------------------------- - # GitHub MCP server defaults: not secrets, preserved literal in translate - # mode and used as fallbacks in legacy mode. The defaults apply regardless - # of which client CLI runs the server, so they live on the base. - _DEFAULT_GITHUB_ENV: ClassVar[dict[str, str]] = { - "GITHUB_TOOLSETS": "context", - "GITHUB_DYNAMIC_TOOLSETS": "1", - } - - @staticmethod - def _should_skip_env_prompts(env_overrides): - """True when the caller has already collected env vars (managed mode), - when APM_E2E_TESTS is set, or when stdin/stdout is not a TTY. - - Centralising this policy keeps the resolver paths consistent and - avoids subtle drift between ``_resolve_environment_variables`` and - ``_resolve_env_variable``. - """ - import sys - - if env_overrides: - return True - if os.getenv("APM_E2E_TESTS") == "1": - return True - return not (sys.stdin.isatty() and sys.stdout.isatty()) - - def _resolve_environment_variables(self, env_vars, env_overrides=None): - """Resolve (or translate) declared environment variables. - - Behaviour follows ``self._supports_runtime_env_substitution``: - translate-mode (Copilot CLI) emits ``${VAR}`` placeholders verbatim - so the runtime resolves them at server-start (see issue #1152); - legacy-mode resolves placeholders to literal values via env_overrides - -> os.environ -> optional interactive prompt. - - Args: - env_vars: Either a ``dict[name, value-or-placeholder]`` from a - self-defined stdio dep (``_raw_stdio["env"]``), or a - ``list[{name, description, required}]`` from the registry. - env_overrides: Pre-collected env-var overrides (ignored in - translate mode). - - Returns: - dict: ``{name: value}`` -- placeholder string in translate - mode, literal value in legacy mode. - """ - # ---- translate mode, dict shape (self-defined stdio in apm.yml) ---- - if isinstance(env_vars, dict) and self._supports_runtime_env_substitution: - # Value type is intentionally untyped: most entries are translated - # placeholder strings, but non-string values (e.g. an int/bool - # YAML scalar) are passed through verbatim and serialised by the - # adapter's config writer (JSON/TOML). - translated: dict = {} - placeholder_keys: list[str] = [] - for name, raw_value in env_vars.items(): - if not name: - continue - if not isinstance(raw_value, str): - translated[name] = raw_value - continue - if _has_env_placeholder(raw_value): - self._last_legacy_angle_vars.update(_extract_legacy_angle_vars(raw_value)) - translated[name] = self._translate_env_placeholder_for_runtime(raw_value) - placeholder_keys.extend( - m.group(1) for m in _ENV_VAR_RE.finditer(translated[name]) - ) - elif ( - name in self._DEFAULT_GITHUB_ENV and raw_value == self._DEFAULT_GITHUB_ENV[name] - ): - translated[name] = raw_value - else: - # Literal value present in apm.yml -- replace with a - # runtime placeholder so the secret never touches disk. - translated[name] = self._format_runtime_env_placeholder(name) - placeholder_keys.append(name) - self._last_env_placeholder_keys = set(placeholder_keys) - return translated - - # ---- translate mode, registry list shape ---- - if self._supports_runtime_env_substitution: - resolved: dict[str, str] = {} - placeholder_keys: list[str] = [] - for env_var in env_vars: - if not isinstance(env_var, dict): - continue - name = env_var.get("name", "") - if not name: - continue - if name in self._DEFAULT_GITHUB_ENV: - resolved[name] = self._DEFAULT_GITHUB_ENV[name] - else: - resolved[name] = self._format_runtime_env_placeholder(name) - placeholder_keys.append(name) - self._last_env_placeholder_keys = set(placeholder_keys) - return resolved - - # ---- legacy mode, dict shape (self-defined stdio in apm.yml) ---- - # Issue #1266 / #1222: ``_raw_stdio["env"]`` is a plain dict. Each - # value is resolved via the same single-value pipeline used for - # header values so all three placeholder syntaxes (````, - # ``${VAR}``, ``${env:VAR}``) behave consistently across adapters. - # - # Note the deliberate semantic divergence from the legacy-list branch - # below: empty strings authored in apm.yml are preserved as-is and - # ``_DEFAULT_GITHUB_ENV`` fallbacks are NOT applied, because a value - # explicitly written by the user expresses intent, whereas an empty - # value coming from ``env_overrides`` / ``os.environ`` for a - # registry-declared schema entry means "no value supplied, use the - # default if one exists". - if isinstance(env_vars, dict): - resolved = {} - for name, value in env_vars.items(): - if not name: - continue - if isinstance(value, str): - resolved[name] = self._resolve_env_variable( - name, value, env_overrides=env_overrides - ) - elif value is not None: - resolved[name] = str(value) - return resolved - - # ---- legacy mode, registry list shape ---- - from rich.prompt import Prompt - - env_overrides = env_overrides or {} - skip_prompting = self._should_skip_env_prompts(env_overrides) - - # Variables explicitly provided with empty values mean "use the default". - empty_value_vars = {k for k, v in env_overrides.items() if not v or not v.strip()} - - resolved = {} - for env_var in env_vars: - if not isinstance(env_var, dict): - continue - name = env_var.get("name", "") - if not name: - continue - required = env_var.get("required", True) - - value = env_overrides.get(name) or os.getenv(name) - if not value and required and not skip_prompting: - prompt_text = f"Enter value for {name}" - if description := env_var.get("description", ""): - prompt_text += f" ({description})" - value = Prompt.ask( - prompt_text, - password="token" in name.lower() or "key" in name.lower(), - ) - - if value and value.strip(): - resolved[name] = value - elif name in self._DEFAULT_GITHUB_ENV and ( - name in empty_value_vars or not required or skip_prompting - ): - resolved[name] = self._DEFAULT_GITHUB_ENV[name] - - return resolved - - def _resolve_env_variable(self, name, value, env_overrides=None): - """Resolve (or translate) a single env-var value. - - Used for header values and for individual entries in dict-shape - env blocks. The ``name`` parameter is currently unused by the - method body but kept in the signature because every call site - (headers, dict iteration) already has the name in hand, and - passing it preserves call-site symmetry with future hooks that - may want to dispatch on it. - - Args: - name: Env-var name (currently unused, see above). - value: Env-var value possibly containing placeholders. - env_overrides: Pre-collected overrides (ignored in translate mode). - """ - if self._supports_runtime_env_substitution: - legacy_keys = _extract_legacy_angle_vars(value) - self._last_legacy_angle_vars.update(legacy_keys) - self._last_env_placeholder_keys.update(legacy_keys) - for match in _ENV_VAR_RE.finditer(value): - self._last_env_placeholder_keys.add(match.group(1)) - return self._translate_env_placeholder_for_runtime(value) - - from rich.prompt import Prompt - - env_overrides = env_overrides or {} - skip_prompting = self._should_skip_env_prompts(env_overrides) - - # Three accepted placeholder syntaxes resolved against - # env_overrides -> os.environ -> optional interactive prompt. - # Single-pass substitution preserves the legacy ```` semantics: - # resolved values are NOT re-scanned for further expansion. - def _replace(match): - env_name = match.group(1) or match.group(2) - env_value = env_overrides.get(env_name) or os.getenv(env_name) - if not env_value and not skip_prompting: - env_value = Prompt.ask( - f"Enter value for {env_name}", - password="token" in env_name.lower() or "key" in env_name.lower(), - ) - return env_value if env_value else match.group(0) - - return _ENV_PLACEHOLDER_RE.sub(_replace, value) - - def _resolve_variable_placeholders(self, value, resolved_env, runtime_vars): - """Resolve env-var and APM template placeholders in argument strings. - - Translate mode rewrites all three env-var placeholder syntaxes to - ``${VAR}`` (so the runtime can resolve them at server-start); legacy - mode resolves only the legacy ```` form against ``resolved_env`` - and leaves the newer ``${VAR}`` / ``${env:VAR}`` syntaxes untouched - for backward compatibility. APM template variables (``{runtime_var}``) - are always resolved at install time because they are an APM-internal - concept the target runtime cannot interpret. - - Args: - value: String possibly containing placeholders. - resolved_env: Resolved env-var literals (legacy mode) or - placeholder strings (translate mode). - runtime_vars: Resolved APM template variables. - - Returns: - str: ``value`` with placeholders translated or resolved. - """ - if not value: - return value - - processed = str(value) - - if self._supports_runtime_env_substitution: - self._last_legacy_angle_vars.update(_extract_legacy_angle_vars(processed)) - processed = self._translate_env_placeholder_for_runtime(processed) - else: - # Resolve only the legacy ```` form; newer syntaxes are - # preserved verbatim for backward compatibility. - def _replace_legacy_angle(match): - return resolved_env.get(match.group(1), match.group(0)) - - processed = _LEGACY_ANGLE_VAR_RE.sub(_replace_legacy_angle, processed) - - # Resolve APM ``{runtime_var}`` template variables. The negative - # lookbehind on ``$`` ensures we never accidentally match the brace - # of an already-translated ``${VAR}`` env placeholder. - if runtime_vars: - runtime_pattern = re.compile(r"(? dict: - """Resolve *env_vars* from overrides, environment, or interactive prompts. - - Identical logic shared between - :meth:`CopilotClientAdapter._process_environment_variables` and - :meth:`CodexClientAdapter._process_environment_variables`. - - All imports are deferred so that ``rich.prompt`` (an optional - dependency) is never imported at module load time. - - Args: - env_vars: List of env-var descriptor dicts from the registry. - env_overrides: Pre-collected ``{name: value}`` overrides (empty - dict when none). - default_github_env: Mapping of well-known GitHub variable names - to their preferred environment-variable lookup names. - - Returns: - ``resolved`` dict mapping each env-var name to its resolved value - (empty string when unresolvable). - """ - import sys - - env_overrides = env_overrides or {} - resolved: dict = {} - - # Determine whether interactive prompting is available. - # If env_overrides is provided the CLI has already collected variables -- never prompt again. - skip_prompting = ( - bool(env_overrides) - or bool(os.getenv("CI")) - or bool(os.getenv("APM_E2E_TESTS")) - or not sys.stdout.isatty() - or not sys.stdin.isatty() - ) - - # First pass: identify variables with empty values to warn the user. - empty_value_vars = [ev for ev in env_vars if ev.get("required") and not ev.get("value")] - if empty_value_vars and skip_prompting: - var_names = [ev.get("name") for ev in empty_value_vars] - _rich_warning( - f"Warning: The following required environment variables have no default " - f"value and cannot be prompted in non-interactive mode: {var_names}" - ) - - for env_var in env_vars: - name = env_var.get("name", "") - if not name: - continue - - # Priority 1: caller-supplied override. - # An explicit empty (or whitespace-only) value is treated as - # "user cleared this". For names with a GitHub-style default the - # logic falls through so the literal default wins; for names - # without a default the entry is dropped from the resolved map. - if name in env_overrides: - override_value = env_overrides[name] - if isinstance(override_value, str) and not override_value.strip(): - if name not in default_github_env: - continue - else: - resolved[name] = override_value - continue - - # Priority 2: check GitHub-specific defaults (values are literal defaults, not env-var names) - if name in default_github_env: - resolved[name] = os.getenv(name) or default_github_env[name] - continue - - # Priority 3: environment variable with the same name - env_val = os.getenv(name, "") - if env_val: - resolved[name] = env_val - continue - - # Priority 4: interactive prompt - default_value = env_var.get("value", "") - required = env_var.get("required", False) - - if not skip_prompting: - from rich.prompt import Prompt - - description = env_var.get("description", "") - prompt_text = f"Enter value for {name}" - if description: - prompt_text += f" ({description})" - is_secret = "token" in name.lower() or "key" in name.lower() - user_input = Prompt.ask( - prompt_text, - default=default_value, - password=True # noqa: SIM210 - if is_secret - else False, - ) - resolved[name] = user_input - elif default_value: - resolved[name] = default_value - elif required: - _rich_warning( - f"Warning: Required environment variable '{name}' could not be resolved. " - f"The MCP server may not function correctly." - ) - resolved[name] = "" - else: - resolved[name] = default_value - - return resolved diff --git a/src/apm_cli/adapters/client/copilot.py b/src/apm_cli/adapters/client/copilot.py index dc689b425..689288ab7 100644 --- a/src/apm_cli/adapters/client/copilot.py +++ b/src/apm_cli/adapters/client/copilot.py @@ -18,21 +18,24 @@ from ...registry.integration import RegistryIntegration from ...utils.console import _rich_warning from ...utils.github_host import is_github_hostname -from ._mcp_runtime_args import process_v01_value_hint_arg +from ._copilot_env import _CopilotEnvMixin from .base import ( - _ENV_PLACEHOLDER_RE, _ENV_VAR_RE, MCPClientAdapter, - _extract_legacy_angle_vars, _has_env_placeholder, - _stringify_env_literal, +) +from .base import ( + _extract_legacy_angle_vars as _extract_legacy_angle_vars, +) +from .base import ( + _stringify_env_literal as _stringify_env_literal, ) from .base import ( _translate_env_placeholder as _translate_env_placeholder, ) -class CopilotClientAdapter(MCPClientAdapter): +class CopilotClientAdapter(_CopilotEnvMixin, MCPClientAdapter): """Copilot CLI implementation of MCP client adapter. This adapter handles Copilot CLI-specific configuration for MCP servers using @@ -644,362 +647,6 @@ def _select_and_dispatch_best_package( ) return package - def _resolve_environment_variables(self, env_vars, env_overrides=None): - """Resolve (or translate) declared environment variables. - - Behaviour depends on ``self._supports_runtime_env_substitution``: - - - True (Copilot CLI default): each declared env var ``NAME`` gets a - ``${NAME}`` placeholder that Copilot CLI resolves at server-start - from the host environment. Hardcoded literal defaults - (``GITHUB_TOOLSETS``, ``GITHUB_DYNAMIC_TOOLSETS``) stay literal - because they are not secrets and provide essential server - configuration. The host environment is NOT read; secrets never - touch disk. See issue #1152 for context. - - - False (legacy / sibling-adapter behaviour): resolve each variable - to its literal value via ``env_overrides`` -> ``os.environ`` -> - optional interactive prompt, baking the result into the config. - - Args: - env_vars (list): List of environment variable definitions from - server info (each item is ``{name, description, required}``). - env_overrides (dict, optional): Pre-collected environment - variable overrides. Ignored in translate mode. - - Returns: - dict: ``{name: value}`` -- placeholder string in translate mode, - literal value in legacy mode. - """ - # Hardcoded literal defaults that supply essential server behaviour - # rather than secrets. These stay literal in translate mode so that - # tool-selection still works without a user export step. - default_github_env = {"GITHUB_TOOLSETS": "context", "GITHUB_DYNAMIC_TOOLSETS": "1"} - - # Self-defined stdio deps pass ``env`` as a plain dict - # ({NAME: value-or-placeholder}); registry-sourced deps pass a list - # of {name, description, required} dicts. Translate-mode handling - # for the dict shape: each value is either already a placeholder - # (translate it to the adapter's runtime form) or a literal - # (record the key as a placeholder reference and emit a runtime - # placeholder so the value never lands on disk). See issue #1152. - if isinstance(env_vars, dict) and self._supports_runtime_env_substitution: - translated = {} - placeholder_keys = [] - for name, raw_value in env_vars.items(): - if not name: - continue - if raw_value is None: - continue - if not isinstance(raw_value, str): - translated[name] = _stringify_env_literal(raw_value) - continue - if _has_env_placeholder(raw_value): - self._last_legacy_angle_vars.update(_extract_legacy_angle_vars(raw_value)) - translated[name] = self._translate_env_placeholder_for_runtime(raw_value) - for match in _ENV_VAR_RE.finditer(translated[name]): - placeholder_keys.append(match.group(1)) - elif name in default_github_env and raw_value == default_github_env[name]: - translated[name] = raw_value - else: - # Literal value present in apm.yml -- replace with a - # runtime placeholder so the secret never touches disk. - translated[name] = self._format_runtime_env_placeholder(name) - placeholder_keys.append(name) - self._last_env_placeholder_keys = set(placeholder_keys) - return translated - - if self._supports_runtime_env_substitution: - resolved = {} - placeholder_keys = [] - for env_var in env_vars: - if not isinstance(env_var, dict): - continue - name = env_var.get("name", "") - if not name: - continue - if name in default_github_env: - # Non-secret literal default -- preserve as-is. - resolved[name] = default_github_env[name] - else: - # Emit a runtime-substitution placeholder; APM never reads - # or stores the value. - resolved[name] = self._format_runtime_env_placeholder(name) - placeholder_keys.append(name) - # Record for the post-install summary line and the - # security-improvement notice. - self._last_env_placeholder_keys = set(placeholder_keys) - return resolved - - if isinstance(env_vars, dict): - # Mirror the base-class dict-shape branch but coerce non-string - # scalars through Copilot's hardened ``_stringify_env_literal`` - # helper so booleans/ints land as the strings Copilot CLI expects. - return { - name: ( - self._resolve_env_variable(name, value, env_overrides=env_overrides) - if isinstance(value, str) - else _stringify_env_literal(value) - ) - for name, value in env_vars.items() - if name and value is not None - } - - return self._resolve_env_vars_with_prompting(env_vars, env_overrides, default_github_env) - - def _resolve_env_variable(self, name, value, env_overrides=None): - """Resolve (or translate) a single environment variable value. - - Behaviour depends on ``self._supports_runtime_env_substitution``: - - - True (Copilot CLI default): translate placeholders to Copilot CLI's - native runtime substitution syntax (``${VAR}``). The host - environment is NOT read; the secret never touches disk. See issue - #1152 for context. Legacy ```` offenders are tracked for the - aggregated deprecation warning emitted by - ``configure_mcp_server``. - - - False (legacy / sibling-adapter behaviour): resolve placeholders - to literal values via ``env_overrides`` -> ``os.environ`` -> - optional interactive prompt, baking the result into the config. - - Args: - name (str): Environment variable name. - value (str): Environment variable value or placeholder. - env_overrides (dict, optional): Pre-collected environment - variable overrides. Ignored in translate mode. - - Returns: - str: Translated placeholder (translate mode) or resolved - literal value (legacy mode). - """ - if self._supports_runtime_env_substitution: - # Track legacy offenders for the aggregated deprecation - # warning. Translation itself is a pure-textual rewrite. - self._last_legacy_angle_vars.update(_extract_legacy_angle_vars(value)) - # Track env-var names referenced via this header/value so the - # security-upgrade detector and per-server summary can see - # them (the env-block path tracks via _resolve_environment_variables). - for match in _ENV_VAR_RE.finditer(value): - self._last_env_placeholder_keys.add(match.group(1)) - return self._translate_env_placeholder_for_runtime(value) - - import sys - - from rich.prompt import Prompt - - env_overrides = env_overrides or {} - # If env_overrides is provided, it means we're in managed environment collection mode - skip_prompting = bool(env_overrides) - - # Check for CI/automated environment via APM_E2E_TESTS flag (more reliable than TTY detection) - if os.getenv("APM_E2E_TESTS") == "1": - skip_prompting = True - - # Also skip prompting if we're in a non-interactive environment (fallback) - is_interactive = sys.stdin.isatty() and sys.stdout.isatty() - if not is_interactive: - skip_prompting = True - - # Three accepted placeholder syntaxes (see _COPILOT_ENV_RE at module - # top), all resolved against env_overrides -> os.environ -> optional - # interactive prompt. Single-pass substitution preserves the legacy - # ```` semantics: resolved values are not re-scanned for further - # placeholder expansion. - def _replace(match): - # Group 1 = legacy ; group 2 = ${VAR} / ${env:VAR}. - env_name = match.group(1) or match.group(2) - env_value = env_overrides.get(env_name) or os.getenv(env_name) - if not env_value and not skip_prompting: - prompt_text = f"Enter value for {env_name}" - env_value = Prompt.ask( - prompt_text, - password=True # noqa: SIM210 - if "token" in env_name.lower() or "key" in env_name.lower() - else False, - ) - return env_value if env_value else match.group(0) - - return _ENV_PLACEHOLDER_RE.sub(_replace, value) - - def _inject_env_vars_into_docker_args(self, docker_args, env_vars): - """Inject environment variables into Docker arguments following registry template. - - The registry provides a complete Docker command template in runtime_arguments. - We need to inject actual environment variable values while respecting the template structure. - Also ensures required Docker flags (-i, --rm) are present. - - Args: - docker_args (list): Docker arguments from registry runtime_arguments. - env_vars (dict): Resolved environment variables. - - Returns: - list: Docker arguments with environment variables properly injected and required flags. - """ - if not env_vars: - env_vars = {} - - result = [] - i = 0 - has_interactive = False - has_rm = False - - # Check for existing -i and --rm flags - for arg in docker_args: - if arg == "-i" or arg == "--interactive": # noqa: PLR1714 - has_interactive = True - elif arg == "--rm": - has_rm = True - - while i < len(docker_args): - arg = docker_args[i] - result.append(arg) - - # When we encounter "run", inject required flags first - if arg == "run": - # Add -i flag if not present - if not has_interactive: - result.append("-i") - - # Add --rm flag if not present - if not has_rm: - result.append("--rm") - - # If this is an environment variable name placeholder, replace with actual env var - if arg in env_vars: - # This is an environment variable name that should be replaced with -e VAR=value - result.pop() # Remove the env var name - result.extend(["-e", f"{arg}={env_vars[arg]}"]) - elif arg == "-e" and i + 1 < len(docker_args): - # Handle -e flag followed by env var name - next_arg = docker_args[i + 1] - if next_arg in env_vars: - result.append(f"{next_arg}={env_vars[next_arg]}") - i += 1 # Skip the next argument as we've processed it - else: - # Keep the original argument structure - result.append(next_arg) - i += 1 - - i += 1 - - # Add any remaining environment variables that weren't in the template - template_env_vars = set() - for arg in docker_args: - if arg in env_vars: - template_env_vars.add(arg) - - for env_name, env_value in env_vars.items(): - if env_name not in template_env_vars: - # Find a good place to insert additional env vars (after "run" but before image name) - insert_pos = len(result) - for idx, arg in enumerate(result): - if arg == "run": - # Insert after run command but before image name (usually last arg) - insert_pos = min(len(result) - 1, idx + 1) - break - - result.insert(insert_pos, "-e") - result.insert(insert_pos + 1, f"{env_name}={env_value}") - - # Add default GitHub MCP server environment variables if not already present - # Only add defaults for variables that were NOT explicitly provided (even if empty) - default_github_env = {"GITHUB_TOOLSETS": "context", "GITHUB_DYNAMIC_TOOLSETS": "1"} # noqa: F841 - - existing_env_vars = set() - for i, arg in enumerate(result): - if arg == "-e" and i + 1 < len(result): - env_spec = result[i + 1] - if "=" in env_spec: - env_name = env_spec.split("=", 1)[0] - existing_env_vars.add(env_name) - - # For Copilot, defaults are already added during environment resolution - # This section is kept for compatibility but shouldn't add duplicates - - return result - - def _inject_docker_env_vars(self, args, env_vars): - """Inject environment variables into Docker arguments. - - Args: - args (list): Original Docker arguments. - env_vars (dict): Environment variables to inject. - - Returns: - list: Updated arguments with environment variables injected. - """ - result = [] - - for arg in args: - result.append(arg) - # If this is a docker run command, inject environment variables after "run" - if arg == "run" and env_vars: - for env_name, env_value in env_vars.items(): - result.extend(["-e", f"{env_name}={env_value}"]) - - return result - - def _process_arguments(self, arguments, resolved_env=None, runtime_vars=None): - """Process argument objects to extract simple string values with environment and runtime variable resolution. - - Args: - arguments (list): List of argument objects from registry. - resolved_env (dict): Resolved environment variables. - runtime_vars (dict): Resolved runtime variables. - - Returns: - list: List of processed argument strings. - """ - if resolved_env is None: - resolved_env = {} - if runtime_vars is None: - runtime_vars = {} - - processed = [] - - for arg in arguments: - if isinstance(arg, dict): - # Extract value from argument object - arg_type = arg.get("type", "") - if arg_type == "positional": - value = arg.get("value", arg.get("default", "")) - if value: - # Resolve both environment and runtime variable placeholders with actual values - processed_value = self._resolve_variable_placeholders( - str(value), resolved_env, runtime_vars - ) - processed.append(processed_value) - elif arg_type == "named": - name = arg.get("name", "") - value = arg.get("value", arg.get("default", "")) - if name: - processed.append(name) - # For named arguments, only add value if it's different from the flag name - # and not empty - if value and value != name and not value.startswith("-"): - processed_value = self._resolve_variable_placeholders( - str(value), resolved_env, runtime_vars - ) - processed.append(processed_value) - elif not arg_type and "value_hint" in arg: - # v0.1 registry format: shared helper handles is_required - # guard and {var_name} placeholder substitution. - value = process_v01_value_hint_arg(arg, runtime_vars) - if value: - processed_value = self._resolve_variable_placeholders( - value, resolved_env, runtime_vars - ) - processed.append(processed_value) - elif isinstance(arg, str): - # Already a string, use as-is but resolve variable placeholders - processed_value = self._resolve_variable_placeholders( - arg, resolved_env, runtime_vars - ) - processed.append(processed_value) - - return processed - def _is_github_server(self, server_name, url): """Securely determine if a server is a GitHub MCP server. diff --git a/src/apm_cli/cache/_git_cache_bare.py b/src/apm_cli/cache/_git_cache_bare.py new file mode 100644 index 000000000..e6f2479f8 --- /dev/null +++ b/src/apm_cli/cache/_git_cache_bare.py @@ -0,0 +1,279 @@ +"""Bare-repo lifecycle mixin for :class:`~apm_cli.cache.git_cache.GitCache`. + +Extracted to keep ``git_cache.py`` under the 800-line threshold while +preserving 100% behavioural equivalence. This module is private +(``_`` prefix) and must NOT be imported directly by callers outside the +``cache`` package; the public surface lives in ``git_cache.GitCache``. + +Rule B routing +-------------- +Unit tests patch these names at the ``git_cache`` *module* level: +``shard_lock`` (37x), ``os`` (26x), ``atomic_land`` (19x). +Every method in this mixin that references those names resolves them +via a late import of the origin module:: + + from apm_cli.cache import git_cache as _gc + _gc.shard_lock(...) # routes to the (possibly patched) module attr + _gc.os.chmod(...) + _gc.atomic_land(...) + +The private helpers ``_safe_git_args``, ``_sanitize_url``, and the +constant ``_PARTIAL_BARE_SUFFIX`` remain in ``git_cache.py`` and are +similarly accessed via ``_gc.*`` so that any test-side patches on those +names also take effect here. +""" + +from __future__ import annotations + +import subprocess +from pathlib import Path + + +class _GitCacheBareMixin: + """Mixin providing bare-repo clone/fetch lifecycle for GitCache. + + Requires the host class to expose: + ``self._db_root`` -- Path to the git bare-repo database root. + """ + + def _ensure_bare_repo( + self, + url: str, + shard_key: str, + sha: str, + *, + env: dict[str, str] | None = None, + partial: bool = False, + ) -> Path: + """Ensure a bare repo clone exists for the given shard, fetching if needed. + + Args: + partial: If True, clone with ``--filter=blob:none`` into a + separate ``__p`` directory so the bare downloads + commits + trees only (~5% of full repo size) and acts + as a promisor remote for consumer lazy-fetch. Falls + back to a full clone in the same directory if the + server rejects the filter (older Gerrit / pre-2.20 + GHE). Falling back leaves the partial-flavor dir with + full content; future sparse consumers will simply not + trigger any lazy fetch (all blobs already present), so + behavior degrades to today's baseline. + + Returns the path to the bare repo directory. + """ + # Late imports: Rule B for shard_lock / os / atomic_land + + # private helpers that stay in git_cache.py. + from apm_cli.cache import git_cache as _gc + from apm_cli.utils.path_security import ensure_path_within + + bare_shard = shard_key + (_gc._PARTIAL_BARE_SUFFIX if partial else "") + bare_dir = self._db_root / bare_shard + # Containment guard: defends against pathological shard_key + # values bypassing the cache root. + ensure_path_within(bare_dir, self._db_root) + lock = _gc.shard_lock(bare_dir) + + # Acquire the shard lock BEFORE the existence probe so that two + # concurrent processes hitting a cold shard cannot both perform + # a full network clone (one would lose the atomic_land race + # later, but only after wasting bandwidth + wall time). + with lock: + if bare_dir.is_dir(): + # Repo exists -- check if we have the required SHA + if self._bare_has_sha(bare_dir, sha, env=env): + return bare_dir + # Need to fetch the SHA (lock already held; call the + # inner helper that does NOT re-acquire). + self._fetch_into_bare_locked(bare_dir, url, sha, env=env) + return bare_dir + + # Cold miss: clone bare repo + from apm_cli.cache.locking import stage_path + from apm_cli.utils.git_env import get_git_executable, git_subprocess_env + + git_exe = get_git_executable() + staged = stage_path(bare_dir) + ensure_path_within(staged, self._db_root) + staged.mkdir(parents=True, exist_ok=True) + _gc.os.chmod(str(staged), 0o700) + + subprocess_env = env if env is not None else git_subprocess_env() + clone_args = [ + git_exe, + *_gc._safe_git_args(), + "clone", + "--bare", + "--no-tags", + "--no-recurse-submodules", + ] + if partial: + # Promisor partial clone: trees + commits only. Blobs + # arrive lazily via the remote when the consumer needs + # them. Github / modern GHES / ADO support this; older + # servers reject it and we retry without --filter. + # --no-tags above skips fetching tag objects (release + # tags can sum to MBs on monorepos); the cache is + # SHA-keyed and never resolves via tags. + clone_args += ["--filter=blob:none"] + clone_args += [url, str(staged)] + try: + # Full bare clone (or partial when requested above). The + # full path extracts file contents at checkout time, so + # all blobs must be present locally. The partial path + # relies on the consumer being configured as a promisor + # so missing blobs trigger an on-demand fetch. + subprocess.run( + clone_args, + capture_output=True, + text=True, + timeout=300, + env=subprocess_env, + check=True, + ) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, OSError) as exc: + # Partial clone fallback: some servers reject --filter + # (old Gerrit / pre-2.20 GHE). Retry once without it so + # we never block on this optimization. The resulting + # bare is full; future sparse consumers find all blobs + # locally and skip lazy fetch (degrades to baseline, + # no behavior change for the user). + fallback_done = False + if partial and isinstance(exc, subprocess.CalledProcessError): + from ..utils.console import _rich_warning + + _rich_warning( + f"Partial clone (--filter=blob:none) failed for " + f"{_gc._sanitize_url(url)}; retrying with full bare clone. " + f"Server may not support filter v2." + ) + from ..utils.file_ops import robust_rmtree + + robust_rmtree(staged, ignore_errors=True) + staged.mkdir(parents=True, exist_ok=True) + _gc.os.chmod(str(staged), 0o700) + try: + subprocess.run( + [ + git_exe, + *_gc._safe_git_args(), + "clone", + "--bare", + "--no-tags", + "--no-recurse-submodules", + url, + str(staged), + ], + capture_output=True, + text=True, + timeout=300, + env=subprocess_env, + check=True, + ) + fallback_done = True + except ( + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + OSError, + ) as exc2: + from ..utils.file_ops import robust_rmtree + + robust_rmtree(staged, ignore_errors=True) + raise RuntimeError( + f"Failed to clone {_gc._sanitize_url(url)} " + f"(partial fallback also failed): {exc2}" + ) from exc2 + if not fallback_done: + # Clean up staged on failure + from ..utils.file_ops import robust_rmtree + + robust_rmtree(staged, ignore_errors=True) + raise RuntimeError(f"Failed to clone {_gc._sanitize_url(url)}: {exc}") from exc + + # Atomic land (lock is already held; pass it through so the + # rename completes under the same critical section). + if not _gc.atomic_land(staged, bare_dir, lock): + # Another process won between our staging and rename + # (possible only on lock-acquisition timeout fallthrough); + # verify it has our SHA. + if not self._bare_has_sha(bare_dir, sha, env=env): + self._fetch_into_bare_locked(bare_dir, url, sha, env=env) + + return bare_dir + + def _bare_has_sha(self, bare_dir: Path, sha: str, *, env: dict[str, str] | None = None) -> bool: + """Check if the bare repo contains the specified commit.""" + from apm_cli.cache import git_cache as _gc + from apm_cli.utils.git_env import get_git_executable, git_subprocess_env + + git_exe = get_git_executable() + subprocess_env = env if env is not None else git_subprocess_env() + try: + result = subprocess.run( + [git_exe, *_gc._safe_git_args(), "-C", str(bare_dir), "cat-file", "-t", sha], + capture_output=True, + text=True, + timeout=10, + env=subprocess_env, + ) + return result.returncode == 0 and "commit" in result.stdout.strip() + except (subprocess.TimeoutExpired, OSError): + return False + + def _fetch_into_bare( + self, + bare_dir: Path, + url: str, + sha: str, + *, + env: dict[str, str] | None = None, + ) -> None: + """Fetch a specific SHA into an existing bare repo (acquires lock).""" + from apm_cli.cache import git_cache as _gc + + lock = _gc.shard_lock(bare_dir) + with lock: + if self._bare_has_sha(bare_dir, sha, env=env): + return + self._fetch_into_bare_locked(bare_dir, url, sha, env=env) + + def _fetch_into_bare_locked( + self, + bare_dir: Path, + url: str, + sha: str, + *, + env: dict[str, str] | None = None, + ) -> None: + """Fetch a specific SHA into a bare repo. Caller MUST hold the shard lock.""" + from apm_cli.cache import git_cache as _gc + from apm_cli.utils.git_env import get_git_executable, git_subprocess_env + + git_exe = get_git_executable() + subprocess_env = env if env is not None else git_subprocess_env() + # If this is a partial-flavor bare, preserve the filter on fetch + # so we don't pull all blobs reachable from the new SHA. Detected + # via shard-suffix naming convention (cheap, no git config probe). + is_partial = bare_dir.name.endswith(_gc._PARTIAL_BARE_SUFFIX) + fetch_args = [git_exe, *_gc._safe_git_args(), "-C", str(bare_dir), "fetch"] + if is_partial: + fetch_args += ["--filter=blob:none"] + fetch_args += [url, sha] + try: + subprocess.run( + fetch_args, + capture_output=True, + text=True, + timeout=120, + env=subprocess_env, + check=True, + ) + except subprocess.CalledProcessError: + # Some servers don't allow fetching by SHA -- fetch all refs + subprocess.run( + [git_exe, *_gc._safe_git_args(), "-C", str(bare_dir), "fetch", "--all"], + capture_output=True, + text=True, + timeout=120, + env=subprocess_env, + check=True, + ) diff --git a/src/apm_cli/cache/git_cache.py b/src/apm_cli/cache/git_cache.py index bc3fa8608..84b049f3c 100644 --- a/src/apm_cli/cache/git_cache.py +++ b/src/apm_cli/cache/git_cache.py @@ -35,6 +35,7 @@ from ..utils.git_sparse import apply_sparse_cone from ..utils.path_security import ensure_path_within +from ._git_cache_bare import _GitCacheBareMixin from .integrity import verify_checkout_sha from .locking import atomic_land, cleanup_incomplete, shard_lock, stage_path from .paths import get_git_checkouts_path, get_git_db_path @@ -104,7 +105,7 @@ def _variant_key(sparse_paths: list[str] | None) -> str: return f"sparse-{digest}" -class GitCache: +class GitCache(_GitCacheBareMixin): """Content-addressable git cache with integrity verification. Args: @@ -288,164 +289,6 @@ def _ls_remote_resolve( raise RuntimeError(f"Could not resolve ref '{ref}' for {_sanitize_url(url)}") - def _ensure_bare_repo( - self, - url: str, - shard_key: str, - sha: str, - *, - env: dict[str, str] | None = None, - partial: bool = False, - ) -> Path: - """Ensure a bare repo clone exists for the given shard, fetching if needed. - - Args: - partial: If True, clone with ``--filter=blob:none`` into a - separate ``__p`` directory so the bare downloads - commits + trees only (~5% of full repo size) and acts - as a promisor remote for consumer lazy-fetch. Falls - back to a full clone in the same directory if the - server rejects the filter (older Gerrit / pre-2.20 - GHE). Falling back leaves the partial-flavor dir with - full content; future sparse consumers will simply not - trigger any lazy fetch (all blobs already present), so - behavior degrades to today's baseline. - - Returns the path to the bare repo directory. - """ - from ..utils.git_env import get_git_executable, git_subprocess_env - - bare_shard = shard_key + (_PARTIAL_BARE_SUFFIX if partial else "") - bare_dir = self._db_root / bare_shard - # Containment guard: defends against pathological shard_key - # values bypassing the cache root. - ensure_path_within(bare_dir, self._db_root) - lock = shard_lock(bare_dir) - - # Acquire the shard lock BEFORE the existence probe so that two - # concurrent processes hitting a cold shard cannot both perform - # a full network clone (one would lose the atomic_land race - # later, but only after wasting bandwidth + wall time). - with lock: - if bare_dir.is_dir(): - # Repo exists -- check if we have the required SHA - if self._bare_has_sha(bare_dir, sha, env=env): - return bare_dir - # Need to fetch the SHA (lock already held; call the - # inner helper that does NOT re-acquire). - self._fetch_into_bare_locked(bare_dir, url, sha, env=env) - return bare_dir - - # Cold miss: clone bare repo - git_exe = get_git_executable() - staged = stage_path(bare_dir) - ensure_path_within(staged, self._db_root) - staged.mkdir(parents=True, exist_ok=True) - os.chmod(str(staged), 0o700) - - subprocess_env = env if env is not None else git_subprocess_env() - clone_args = [ - git_exe, - *_safe_git_args(), - "clone", - "--bare", - "--no-tags", - "--no-recurse-submodules", - ] - if partial: - # Promisor partial clone: trees + commits only. Blobs - # arrive lazily via the remote when the consumer needs - # them. Github / modern GHES / ADO support this; older - # servers reject it and we retry without --filter. - # --no-tags above skips fetching tag objects (release - # tags can sum to MBs on monorepos); the cache is - # SHA-keyed and never resolves via tags. - clone_args += ["--filter=blob:none"] - clone_args += [url, str(staged)] - try: - # Full bare clone (or partial when requested above). The - # full path extracts file contents at checkout time, so - # all blobs must be present locally. The partial path - # relies on the consumer being configured as a promisor - # so missing blobs trigger an on-demand fetch. - subprocess.run( - clone_args, - capture_output=True, - text=True, - timeout=300, - env=subprocess_env, - check=True, - ) - except (subprocess.CalledProcessError, subprocess.TimeoutExpired, OSError) as exc: - # Partial clone fallback: some servers reject --filter - # (old Gerrit / pre-2.20 GHE). Retry once without it so - # we never block on this optimization. The resulting - # bare is full; future sparse consumers find all blobs - # locally and skip lazy fetch (degrades to baseline, - # no behavior change for the user). - fallback_done = False - if partial and isinstance(exc, subprocess.CalledProcessError): - from ..utils.console import _rich_warning - - _rich_warning( - f"Partial clone (--filter=blob:none) failed for " - f"{_sanitize_url(url)}; retrying with full bare clone. " - f"Server may not support filter v2." - ) - from ..utils.file_ops import robust_rmtree - - robust_rmtree(staged, ignore_errors=True) - staged.mkdir(parents=True, exist_ok=True) - os.chmod(str(staged), 0o700) - try: - subprocess.run( - [ - git_exe, - *_safe_git_args(), - "clone", - "--bare", - "--no-tags", - "--no-recurse-submodules", - url, - str(staged), - ], - capture_output=True, - text=True, - timeout=300, - env=subprocess_env, - check=True, - ) - fallback_done = True - except ( - subprocess.CalledProcessError, - subprocess.TimeoutExpired, - OSError, - ) as exc2: - from ..utils.file_ops import robust_rmtree - - robust_rmtree(staged, ignore_errors=True) - raise RuntimeError( - f"Failed to clone {_sanitize_url(url)} " - f"(partial fallback also failed): {exc2}" - ) from exc2 - if not fallback_done: - # Clean up staged on failure - from ..utils.file_ops import robust_rmtree - - robust_rmtree(staged, ignore_errors=True) - raise RuntimeError(f"Failed to clone {_sanitize_url(url)}: {exc}") from exc - - # Atomic land (lock is already held; pass it through so the - # rename completes under the same critical section). - if not atomic_land(staged, bare_dir, lock): - # Another process won between our staging and rename - # (possible only on lock-acquisition timeout fallthrough); - # verify it has our SHA. - if not self._bare_has_sha(bare_dir, sha, env=env): - self._fetch_into_bare_locked(bare_dir, url, sha, env=env) - - return bare_dir - def _create_checkout( self, url: str, @@ -633,80 +476,6 @@ def _create_checkout( ) return final_dir - def _bare_has_sha(self, bare_dir: Path, sha: str, *, env: dict[str, str] | None = None) -> bool: - """Check if the bare repo contains the specified commit.""" - from ..utils.git_env import get_git_executable, git_subprocess_env - - git_exe = get_git_executable() - subprocess_env = env if env is not None else git_subprocess_env() - try: - result = subprocess.run( - [git_exe, *_safe_git_args(), "-C", str(bare_dir), "cat-file", "-t", sha], - capture_output=True, - text=True, - timeout=10, - env=subprocess_env, - ) - return result.returncode == 0 and "commit" in result.stdout.strip() - except (subprocess.TimeoutExpired, OSError): - return False - - def _fetch_into_bare( - self, - bare_dir: Path, - url: str, - sha: str, - *, - env: dict[str, str] | None = None, - ) -> None: - """Fetch a specific SHA into an existing bare repo (acquires lock).""" - lock = shard_lock(bare_dir) - with lock: - if self._bare_has_sha(bare_dir, sha, env=env): - return - self._fetch_into_bare_locked(bare_dir, url, sha, env=env) - - def _fetch_into_bare_locked( - self, - bare_dir: Path, - url: str, - sha: str, - *, - env: dict[str, str] | None = None, - ) -> None: - """Fetch a specific SHA into a bare repo. Caller MUST hold the shard lock.""" - from ..utils.git_env import get_git_executable, git_subprocess_env - - git_exe = get_git_executable() - subprocess_env = env if env is not None else git_subprocess_env() - # If this is a partial-flavor bare, preserve the filter on fetch - # so we don't pull all blobs reachable from the new SHA. Detected - # via shard-suffix naming convention (cheap, no git config probe). - is_partial = bare_dir.name.endswith(_PARTIAL_BARE_SUFFIX) - fetch_args = [git_exe, *_safe_git_args(), "-C", str(bare_dir), "fetch"] - if is_partial: - fetch_args += ["--filter=blob:none"] - fetch_args += [url, sha] - try: - subprocess.run( - fetch_args, - capture_output=True, - text=True, - timeout=120, - env=subprocess_env, - check=True, - ) - except subprocess.CalledProcessError: - # Some servers don't allow fetching by SHA -- fetch all refs - subprocess.run( - [git_exe, *_safe_git_args(), "-C", str(bare_dir), "fetch", "--all"], - capture_output=True, - text=True, - timeout=120, - env=subprocess_env, - check=True, - ) - def _evict_checkout(self, checkout_dir: Path) -> None: """Safely remove a corrupt checkout shard.""" from ..utils.file_ops import robust_rmtree diff --git a/src/apm_cli/cache/integrity.py b/src/apm_cli/cache/integrity.py index 9708839e4..ade57c736 100644 --- a/src/apm_cli/cache/integrity.py +++ b/src/apm_cli/cache/integrity.py @@ -20,6 +20,21 @@ _log = logging.getLogger(__name__) +def _resolve_packed_ref(git_dir: Path, ref_target: str) -> str | None: + """Return the SHA for *ref_target* from packed-refs, or ``None``.""" + packed = git_dir / "packed-refs" + if not packed.is_file(): + return None + for raw in packed.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line or line.startswith(("#", "^")): + continue + parts = line.split(maxsplit=1) + if len(parts) == 2 and parts[1] == ref_target: + return parts[0].lower() + return None + + def _read_head_sha(checkout_dir: Path) -> str | None: """Return the resolved 40-char SHA at HEAD, or None on any failure. @@ -53,16 +68,7 @@ def _read_head_sha(checkout_dir: Path) -> str | None: ref_path = git_dir / ref_target if ref_path.is_file(): return ref_path.read_text(encoding="utf-8").strip().lower() - packed = git_dir / "packed-refs" - if packed.is_file(): - for raw in packed.read_text(encoding="utf-8").splitlines(): - line = raw.strip() - if not line or line.startswith(("#", "^")): - continue - parts = line.split(maxsplit=1) - if len(parts) == 2 and parts[1] == ref_target: - return parts[0].lower() - return None + return _resolve_packed_ref(git_dir, ref_target) if len(head_content) == 40 and all(c in "0123456789abcdef" for c in head_content.lower()): return head_content.lower() return None diff --git a/src/apm_cli/core/_auth_support.py b/src/apm_cli/core/_auth_support.py new file mode 100644 index 000000000..ce1d1d70b --- /dev/null +++ b/src/apm_cli/core/_auth_support.py @@ -0,0 +1,532 @@ +"""Boundary-free support methods for :class:`apm_cli.core.auth.AuthResolver`. + +This module exists purely to keep ``core/auth.py`` under the 800-line ceiling +(issue #1078, Strangler Stage 2) WITHOUT changing behaviour or the public +import/monkeypatch surface. Everything here is composed into ``AuthResolver`` +via :class:`_AuthSupportMixin`, so callers and tests continue to patch +``apm_cli.core.auth.AuthResolver`` exactly as before. + +AUTH-BOUNDARY INVARIANT (see ``scripts/lint-auth-signals.sh``): nothing in this +module names the Azure bearer-provider lookup symbol or issues a ``git +ls-remote`` against ADO. The PAT->AAD bearer protocol stays wholly inside +``core/auth.py``. The only seam back into the boundary is +``self._ado_bearer_provider()`` (defined on ``AuthResolver``), through which +:meth:`_AuthSupportMixin.build_error_context` obtains the provider without +naming the boundary symbol here. + +To avoid an import cycle, this module never imports ``apm_cli.core.auth`` at +module scope -- ``HostInfo`` is imported lazily inside :meth:`classify_host`. +""" + +from __future__ import annotations + +import logging +import os +import sys + +from apm_cli.utils.github_host import ( + is_azure_devops_hostname, + is_gitlab_hostname, + is_valid_fqdn, +) + +logger = logging.getLogger(__name__) + + +_PORT_CREDENTIAL_DOCS_URL = ( + "https://microsoft.github.io/apm/getting-started/authentication/" + "#custom-port-hosts-and-per-port-credentials" +) + + +def _org_to_env_suffix(org: str) -> str: + """Convert an org name to an env-var suffix (upper-case, hyphens → underscores).""" + return org.upper().replace("-", "_") + + +class _AuthSupportMixin: + """Boundary-free helpers mixed into :class:`AuthResolver`. + + These methods carry no Azure bearer-provider or ADO ls-remote references, + so they live outside ``auth.py`` to keep that file under the size ceiling. + They are invoked through ``self`` on an ``AuthResolver`` instance and rely + on attributes/methods that ``AuthResolver`` itself defines (``self.resolve``, + ``self._token_manager``, ``self._lock``, ``self._ado_bearer_provider`` …). + """ + + # -- host classification ------------------------------------------------ + + @staticmethod + def classify_host(host: str, port: int | None = None) -> object: + """Return a ``HostInfo`` describing *host*. + + ``port`` is carried through onto the returned ``HostInfo`` so that + downstream code (cache keys, credential-helper input, error text) + can discriminate between the same hostname on different ports. + Host-kind classification itself is transport-agnostic -- the port + never influences whether a host is GitHub/GHES/ADO/generic. + """ + # Lazy import keeps this module free of an auth.py module-scope cycle. + from apm_cli.core.auth import HostInfo + + h = host.lower() + + if h == "github.com": + return HostInfo( + host=host, + kind="github", + has_public_repos=True, + api_base="https://api.github.com", + port=port, + ) + + if h.endswith(".ghe.com"): + return HostInfo( + host=host, + kind="ghe_cloud", + has_public_repos=False, + api_base=f"https://{host}/api/v3", + port=port, + ) + + if is_azure_devops_hostname(host): + return HostInfo( + host=host, + kind="ado", + has_public_repos=True, + api_base="https://dev.azure.com", + port=port, + ) + + # GHES: GITHUB_HOST is set to a non-github.com, non-ghe.com FQDN + ghes_host = os.environ.get("GITHUB_HOST", "").lower() + if ( + ghes_host + and ghes_host == h + and ghes_host not in {"github.com", "gitlab.com"} + and not ghes_host.endswith(".ghe.com") + ): + if is_valid_fqdn(ghes_host): + return HostInfo( + host=host, + kind="ghes", + has_public_repos=True, + api_base=f"https://{host}/api/v3", + port=port, + ) + + # GitLab (SaaS + env-configured self-managed) — after GHES per spec (no silent GHES → GitLab) + if is_gitlab_hostname(host): + if h == "gitlab.com": + api_base = "https://gitlab.com/api/v4" + else: + api_base = f"https://{host}/api/v4" + return HostInfo( + host=host, + kind="gitlab", + has_public_repos=True, + api_base=api_base, + port=port, + ) + + # Generic FQDN (Bitbucket, self-hosted non-GitLab, etc.) + return HostInfo( + host=host, + kind="generic", + has_public_repos=True, + api_base=f"https://{host}/api/v3", + port=port, + ) + + # -- token type detection ----------------------------------------------- + + @staticmethod + def detect_token_type(token: str) -> str: + """Classify a token string by its prefix. + + Note: EMU (Enterprise Managed Users) tokens use standard PAT + prefixes (``ghp_`` or ``github_pat_``). There is no prefix that + identifies a token as EMU-scoped — that's a property of the + account, not the token format. + + Prefix reference (docs.github.com): + - ``github_pat_`` → fine-grained PAT + - ``ghp_`` → classic PAT + - ``ghu_`` → OAuth user-to-server (e.g. ``gh auth login``) + - ``gho_`` → OAuth app token + - ``ghs_`` → GitHub App installation (server-to-server) + - ``ghr_`` → GitHub App refresh token + """ + if token.startswith("github_pat_"): + return "fine-grained" + if token.startswith("ghp_"): + return "classic" + if token.startswith("ghu_"): + return "oauth" + if token.startswith("gho_"): + return "oauth" + if token.startswith("ghs_"): + return "github-app" + if token.startswith("ghr_"): + return "github-app" + return "unknown" + + @staticmethod + def gitlab_rest_headers( + token: str | None, + *, + oauth_bearer: bool = False, + ) -> dict[str, str]: + """Build HTTP headers for GitLab REST API v4 calls. + + Personal access tokens use ``PRIVATE-TOKEN``. OAuth2 access tokens + typically use ``Authorization: Bearer ``; set *oauth_bearer* + to use that style. + + Does not log or print *token*. Callers must not log the returned dict. + """ + if not token: + return {} + if oauth_bearer: + return {"Authorization": f"Bearer {token}"} + return {"PRIVATE-TOKEN": token} + + # -- error context ------------------------------------------------------ + + def build_error_context( + self, + host: str, + operation: str, + org: str | None = None, + *, + port: int | None = None, + dep_url: str | None = None, + bearer_also_failed: bool = False, + ) -> str: + """Build an actionable error message for auth failures. + + ``bearer_also_failed=True`` prepends a single line to the Case 4 + block (PAT set, az available, both attempts failed) clarifying + that ADO_APM_PAT was tried first and rejected before the bearer + attempt -- so the user understands why both halves of the + protocol failed without having to read the full diagnostic + context. Callers MUST only set this when the bearer attempt + actually ran (see :class:`BearerFallbackOutcome.bearer_attempted`). + """ + auth_ctx = self.resolve(host, org, port=port) + host_info = auth_ctx.host_info + display = host_info.display_name + + # --- ADO-specific error cases --- + if host_info.kind == "ado": + # Provider access is routed through the auth-boundary accessor on + # AuthResolver so the bearer-provider symbol stays inside + # core/auth.py (see scripts/lint-auth-signals.sh, Rule A). + provider = self._ado_bearer_provider() + az_available = provider.is_available() + pat_set = bool(os.environ.get("ADO_APM_PAT")) + + org_part = org or "" + if not org_part: + source_url = dep_url or "" + if source_url: + parts = source_url.replace("https://", "").split("/") + if len(parts) >= 2 and ( + parts[0] in ("dev.azure.com",) or parts[0].endswith(".visualstudio.com") + ): + org_part = parts[1] if len(parts) > 1 else "" + + token_url = ( + f"https://dev.azure.com/{org_part}/_usersSettings/tokens" + if org_part + else "https://dev.azure.com//_usersSettings/tokens" + ) + + if pat_set: + if az_available: + # Case 4: PAT and bearer were both available; both attempts + # failed. We may not have observed an explicit 401 (could be + # a 404, a network error, etc.) so the wording stays + # tentative -- see #856 review C6. + prefix = ( + " ADO_APM_PAT was rejected; az cli bearer was also rejected.\n\n" + if bearer_also_failed + else "" + ) + return ( + f"\n{prefix}" + f" ADO_APM_PAT is set, and Azure CLI credentials may also be available,\n" + f" but the Azure DevOps request still failed.\n\n" + f" If this is an authentication failure, the PAT may be expired, revoked,\n" + f" or scoped to a different org, and Azure CLI credentials may need to\n" + f" be refreshed.\n\n" + f" To fix:\n" + f" 1. Unset the PAT to test Azure CLI auth only: unset ADO_APM_PAT\n" + f" 2. Re-authenticate Azure CLI if needed: az login\n" + f" 3. Retry: apm install\n\n" + f" Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" + ) + # PAT set but rejected, no az -> bare PAT failure + return ( + f"\n ADO_APM_PAT is set, but the Azure DevOps request failed.\n" + f" If this is an authentication failure, the token may be expired,\n" + f" revoked, or scoped to a different org.\n\n" + f" Generate a new PAT at {token_url}\n" + f" with Code (Read) scope.\n\n" + f" Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" + ) + + # No PAT set + if not az_available: + # Case 1: no az, no PAT + return ( + f"\n Azure DevOps requires authentication. You have two options:\n\n" + f" 1. Install Azure CLI and sign in (recommended for Entra ID users):\n" + f" brew install azure-cli # macOS\n" + f" winget install Microsoft.AzureCLI # Windows\n" + f" apt-get install azure-cli # Debian/Ubuntu\n" + f" dnf install azure-cli # Fedora/RHEL\n" + f" (full guide: https://aka.ms/InstallAzureCli)\n" + f" az login\n" + f" apm install # retry -- no env var needed\n\n" + f" 2. Use a Personal Access Token:\n" + f" export ADO_APM_PAT=your_token\n" + f" (Create one at {token_url} with Code (Read) scope.)\n\n" + f" Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" + ) + + # az is available; check if logged in by trying to get tenant + tenant = provider.get_current_tenant_id() + if tenant is None: + # Case 3: az present, not logged in + return ( + "\n Azure DevOps requires authentication. You have two options:\n\n" + " 1. Sign in with Azure CLI (recommended for Entra ID users):\n" + " az login\n" + " apm install # retry -- no env var needed\n\n" + " 2. Use a Personal Access Token:\n" + " export ADO_APM_PAT=your_token\n\n" + " Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" + ) + + # Case 2: az returned token (tenant known) but ADO rejected it. + # Note: bearer_also_failed=True is structurally unreachable here -- + # callers only set it when source == "ADO_APM_PAT" (i.e. pat_set + # is True), and Case 2 lives in the `not pat_set` branch. We do + # not render a "PAT was also rejected" prefix in this case + # because no PAT was tried. + return ( + f"\n Your az cli session (tenant: {tenant}) returned a bearer token,\n" + f" but Azure DevOps rejected it (HTTP 401).\n\n" + f" Check that you are signed into the correct tenant:\n" + f" az account show\n" + f" az login --tenant \n\n" + f" Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" + ) + + # --- Non-ADO error paths --- + lines: list[str] = [f"Authentication failed for {operation} on {display}."] + + if auth_ctx.token: + lines.append( + f"Token was provided (source: {auth_ctx.source}, type: {auth_ctx.token_type})." + ) + if host_info.kind == "ghe_cloud": + lines.append( + "GHE Cloud Data Residency hosts (*.ghe.com) require " + "enterprise-scoped tokens. Ensure your PAT is authorized " + "for this enterprise." + ) + elif host_info.kind == "gitlab": + lines.append( + "Ensure your GitLab personal or project access token meets the " + "API read requirements for your instance policy." + ) + elif host.lower() == "github.com": + lines.append( + "If your organization uses SAML SSO or is an EMU org, " + "ensure your PAT is authorized at " + "https://github.com/settings/tokens" + ) + elif host_info.kind == "generic": + lines.append("Verify credentials for this host in your git credential helper.") + else: + lines.append( + "If your organization uses SAML SSO, you may need to " + "authorize your token at https://github.com/settings/tokens" + ) + else: + lines.append("No token available.") + if host_info.kind == "gitlab": + lines.append( + "Set GITLAB_APM_PAT or GITLAB_TOKEN, or configure git credential fill " + f"for {display}." + ) + elif host_info.kind == "generic": + lines.append( + "APM does not apply GitHub PAT environment variables to generic git " + f"hosts; configure git credential fill for {display} or use a " + "public repository if available." + ) + else: + lines.append("Set GITHUB_APM_PAT or GITHUB_TOKEN, or run 'gh auth login'.") + + if org and host_info.kind not in ("ado", "gitlab", "generic"): + lines.append( + f"If packages span multiple organizations, set per-org tokens: " + f"GITHUB_APM_PAT_{_org_to_env_suffix(org)}" + ) + + # When a custom port is in play, helpers that key by hostname alone + # (some `gh` integrations, older keychain backends) can silently + # return the wrong credential. Point the user at the concrete fix. + if host_info.port is not None: + lines.append( + f"[i] Host '{display}' -- this helper may key by host only.\n" + f" Verify with: printf 'protocol=https\\nhost={display}\\n\\n'" + f" | git credential fill\n" + f" Docs: {_PORT_CREDENTIAL_DOCS_URL}" + ) + + lines.append("Run with --verbose for detailed auth diagnostics.") + return "\n".join(lines) + + # -- internals ---------------------------------------------------------- + + @staticmethod + def _purpose_for_host(host_info) -> str: + if host_info.kind == "ado": + return "ado_modules" + if host_info.kind == "gitlab": + return "gitlab_modules" + if host_info.kind == "generic": + return "generic_modules" + return "modules" + + def _identify_env_source(self, purpose: str) -> str: + """Return the name of the first env var that matched for *purpose*.""" + for var in self._token_manager.TOKEN_PRECEDENCE.get(purpose, []): + if os.environ.get(var): + return var + return "env" + + @staticmethod + def _build_git_env( + token: str | None = None, + *, + scheme: str = "basic", + host_kind: str = "github", + ) -> dict: + """Pre-built env dict for subprocess git calls. + + For ADO bearer tokens (scheme='bearer'), injects an Authorization header + via GIT_CONFIG_COUNT/KEY/VALUE env vars (see github_host.build_ado_bearer_git_env). + For all other cases, behavior is unchanged. + """ + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + env["GIT_ASKPASS"] = "echo" + if scheme == "bearer" and token and host_kind == "ado": + # B2 #852: skip GIT_TOKEN for bearer scheme -- the JWT is injected via + # GIT_CONFIG_VALUE_0 only; GIT_TOKEN here would leak it into every + # child-process env (visible in /proc//environ, ps eww). + # + # #1214 follow-up: a stale GIT_TOKEN already in the parent env + # (set by a prior shell, CI step, or another tool) would survive + # the os.environ.copy() above and defeat the isolation guarantee. + # Drop it explicitly so the bearer env is clean by construction. + env.pop("GIT_TOKEN", None) + from apm_cli.utils.github_host import build_ado_bearer_git_env + + env.update(build_ado_bearer_git_env(token)) + elif token: + env["GIT_TOKEN"] = token + return env + + def emit_stale_pat_diagnostic(self, host_display: str) -> None: + """Emit a [!] warning when PAT was rejected but bearer succeeded. + + F3 #852: when an InstallLogger is wired via :meth:`set_logger`, the + warning is collected by its DiagnosticCollector so it appears in the + install summary. Without a logger (e.g. unit tests) we fall back to + the inline ``_rich_warning`` emission for backwards compatibility. + + #1212 follow-up: dedup per host_display so the user sees ONE warning + per ADO host even when preflight, list_remote_refs, and the clone + path each trigger the bearer-fallback path against the same host. + + Naming: previously ``_emit_stale_pat_diagnostic`` (private). Public + now (#856 follow-up C9) so external modules (validation.py, + github_downloader.py) do not reach into the underscore API. + + #1214 follow-up: guard the check-then-add under self._lock so two + threads (parallel install) racing on the same ADO host cannot both + pass the membership check before either calls add(); without the + lock the dedup set defeats its own purpose. + """ + with self._lock: + if host_display in self._stale_pat_warned_hosts: + return + self._stale_pat_warned_hosts.add(host_display) + msg = f"ADO_APM_PAT was rejected for {host_display}; fell back to az cli bearer." + detail = "Consider unsetting the stale variable." + diagnostics = self._diagnostics_or_none() + if diagnostics is not None: + diagnostics.warn(msg, detail=detail) + return + try: + from apm_cli.utils.console import _rich_warning + + _rich_warning(msg, symbol="warning") + _rich_warning(f" {detail}", symbol="warning") + except ImportError as exc: + logger.debug("Console module unavailable for stale-PAT warning; skipping: %s", exc) + + # Backwards-compat alias for any in-tree caller still importing the + # private name. Safe to remove once all callers move to the public name. + _emit_stale_pat_diagnostic = emit_stale_pat_diagnostic + + def _diagnostics_or_none(self): + """Return the wired logger's DiagnosticCollector, or None.""" + if self._logger is None: + return None + try: + return self._logger.diagnostics + except AttributeError: + return None + + def notify_auth_source(self, host_display: str, ctx) -> None: + """Emit the verbose auth-source line for ``host_display`` exactly once. + + F2 #852: routes through CommandLogger when wired (so the line obeys + the same verbose channel as every other diagnostic), and falls back + to a direct stderr write when no logger is set so the existing + bearer e2e tests keep working. + """ + host_key = (host_display or "").lower() + if not host_key or host_key in self._verbose_auth_logged_hosts: + return + self._verbose_auth_logged_hosts.add(host_key) + if ctx is None or getattr(ctx, "source", "none") == "none": + return + if getattr(ctx, "auth_scheme", None) == "bearer": + line = f" [i] {host_key} -- using bearer from az cli (source: {ctx.source})" + else: + line = f" [i] {host_key} -- token from {ctx.source}" + if self._logger is not None and getattr(self._logger, "verbose", False): + try: + from apm_cli.utils.console import _rich_echo + + _rich_echo(line, color="dim") + return + except ImportError as exc: + logger.debug( + "Console module unavailable for auth-source logging; skipping: %s", exc + ) + # No logger wired -- the install path always wires one in the + # bearer branch, so this fallback only fires in unit-test contexts + # that opt-in via APM_VERBOSE=1. + sys.stderr.write(line + "\n") + + +__all__ = ["_PORT_CREDENTIAL_DOCS_URL", "_AuthSupportMixin", "_org_to_env_suffix"] diff --git a/src/apm_cli/core/_install_logger.py b/src/apm_cli/core/_install_logger.py new file mode 100644 index 000000000..9682b3fdf --- /dev/null +++ b/src/apm_cli/core/_install_logger.py @@ -0,0 +1,599 @@ +"""InstallLogger — install-specific phased logging for the APM CLI. + +Extracted from command_logger.py (Strangler Stage 2, #1078). +Re-exported from apm_cli.core.command_logger as ``InstallLogger``. + +Rule B: All output is routed through ``CommandLogger`` base-class methods +(``self.info``, ``self.warning``, ``self.error``, ``self.success``, +``self.verbose_detail``, ``self.tree_item``, ``self.dim_check_item``, +``self.dim``) so that test patches on ``apm_cli.core.command_logger._rich_*`` +are correctly intercepted. +""" + +from apm_cli.core.command_logger import CommandLogger, _strip_source_prefix, _ValidationOutcome + + +class InstallLogger(CommandLogger): + """Install-specific logger with validation, resolution, and download phases. + + Knows whether this is a partial install (specific packages requested) or + full install (all deps from apm.yml). Adjusts messages accordingly. + """ + + def __init__(self, verbose: bool = False, dry_run: bool = False, partial: bool = False): + super().__init__("install", verbose=verbose, dry_run=dry_run) + self.partial = partial # True when specific packages are passed to `apm install` + self._stale_cleaned_total = 0 # Accumulated by stale_cleanup / orphan_cleanup + + # --- Validation phase --- + + def validation_start(self, count: int): + """Log start of package validation.""" + noun = "package" if count == 1 else "packages" + self.info(f"Validating {count} {noun}...", symbol="gear") + + def validation_pass(self, canonical: str, already_present: bool, updated: bool = False): + """Log a package that passed validation.""" + if updated: + self.dim_check_item(f"{canonical} (updated ref in apm.yml)") + elif already_present: + self.dim_check_item(f"{canonical} (already in apm.yml)") + else: + self.success(canonical, symbol="check") + + def validation_fail(self, package: str, reason: str): + """Log a package that failed validation.""" + self.error(f"{package} -- {reason}") + + def validation_summary(self, outcome: _ValidationOutcome): + """Log validation summary and decide whether to continue. + + Returns True if install should continue, False if all packages failed. + """ + if outcome.all_failed: + self.error("All packages failed validation. Nothing to install.") + return False + + if outcome.has_failures: + failed_count = len(outcome.invalid) + noun = "package" if failed_count == 1 else "packages" + self.warning(f"{failed_count} {noun} failed validation and will be skipped.") + + return True + + # --- Resolution phase --- + + def resolution_start(self, to_install_count: int, lockfile_count: int): + """Log start of dependency resolution.""" + if self.partial: + noun = "package" if to_install_count == 1 else "packages" + self.start(f"Installing {to_install_count} new {noun}...") + if lockfile_count > 0 and self.verbose: + self.verbose_detail(f" ({lockfile_count} existing dependencies in lockfile)") + else: + self.start("Installing dependencies from apm.yml...") + if lockfile_count > 0: + self.info(f"Using apm.lock.yaml ({lockfile_count} locked dependencies)") + + def nothing_to_install( + self, + lockfile_present: bool = False, + update_mode: bool = False, + ): + """Log when there's nothing to install -- context-aware message. + + Args: + lockfile_present: True when apm.lock.yaml exists on disk at + the time of the no-op. When True (and we're not in + update mode) we append the standard hint pointing at + ``apm update`` -- this is the #1203 nudge that keeps + users from believing ``apm install`` checks for newer + versions. + update_mode: True when this run was invoked with + ``--update`` or via ``apm update``. Suppresses the + hint -- the user already asked to refresh. + """ + if self.partial: + self.info("Requested packages are already installed.", symbol="check") + else: + self.success("All dependencies are up to date.", symbol="check") + if lockfile_present and not update_mode: + self.info("Lockfile already satisfied -- run 'apm update' to resolve latest refs.") + + # --- Download phase --- + + def download_start(self, dep_name: str, cached: bool): + """Log start of a package download.""" + if cached: + self.verbose_detail(f" Using cached: {dep_name}") + elif self.verbose: + self.info(f" Downloading: {dep_name}", symbol="download") + + def resolving_heartbeat(self, dep_name: str): + """Emit a per-dependency progress heartbeat during BFS resolve. + + Surfaces an immediate ``[>] Resolving ...`` line so the + user sees the install moving forward instead of staring at + silence while transitive lookups happen behind the scenes + (F1, microsoft/apm#1116). The line is static (not a Rich + transient progress bar) so it survives in CI logs and behind + ``2>&1 | tee`` pipelines, which the duck critique flagged as + the must-survive surface. + + Called from the MAIN thread by the resolver/download callback + BEFORE network work begins; F7's parallel BFS keeps emission + on the main thread so output ordering is deterministic even + when downloads are dispatched to a worker pool. + """ + self.start(f"Resolving {dep_name}...") + + def download_complete( + self, + dep_name: str, + ref: str = "", + sha: str = "", + cached: bool = False, + # Legacy compat: if callers pass ref_suffix= we handle it + ref_suffix: str = "", + ): + """Log completion of a package download. + + Args: + dep_name: Package display name (repo_url or virtual path). + ref: Git reference (tag name, branch) if any. + sha: Short commit SHA (8 chars) if any. + cached: Whether this was a cache hit. + ref_suffix: DEPRECATED — legacy callers still pass this. + """ + msg = f" [+] {dep_name}" + if ref_suffix: + # Legacy path — pass-through until all callers are migrated + msg += f" ({ref_suffix})" + else: + if ref and sha: + msg += f" #{ref} @{sha}" + elif ref: + msg += f" #{ref}" + elif sha: + msg += f" @{sha}" + if cached: + msg += " (cached)" + self.tree_item(msg) + + def download_failed(self, dep_name: str, error: str): + """Log a download failure.""" + self.error(f" [x] {dep_name} -- {error}") + + # --- Verbose sub-item methods (install-specific) --- + + def lockfile_entry(self, key: str, ref: str = "", sha: str = ""): + """Log a lockfile entry in verbose mode. + + Omits the line entirely for unpinned deps (no ref, no sha). + """ + if not self.verbose: + return + if sha: + self.verbose_detail(f" {key}: locked at {sha}") + elif ref: + self.verbose_detail(f" {key}: pinned to {ref}") + # Unpinned → omit entirely (nothing useful to show) + + def package_auth(self, source: str, token_type: str = ""): + """Log auth source for a package (verbose only). 4-space indent.""" + if not self.verbose: + return + type_str = f" ({token_type})" if token_type else "" + self.verbose_detail(f" Auth: {source}{type_str}") + + def package_type_info(self, type_label: str): + """Log detected package type (verbose only). 4-space indent.""" + if not self.verbose: + return + self.verbose_detail(f" Package type: {type_label}") + + # --- Performance diagnostics (perf #1433) --- + + def subdir_download_start( + self, + dep_name: str, + cache_state: str, + sha_short: str = "", + sparse_paths: list[str] | None = None, + ): + """Log the start of a subdirectory dep download (verbose only). + + Names the dep, the bare-cache state (e.g. ``cold`` / ``warm`` / + ``persistent`` / ``shared-bare``), the resolved SHA (short), + and the sparse paths being requested. Surfaces enough state to + diagnose a perf regression from one log line. + """ + if not self.verbose: + return + sha_part = f" @{sha_short}" if sha_short else "" + paths_part = f" sparse={','.join(sparse_paths)}" if sparse_paths else " sparse=" + self.verbose_detail( + f" [i] perf: subdir {dep_name}{sha_part} cache={cache_state}{paths_part}" + ) + + def bare_clone_strategy(self, strategy: str, elapsed_ms: int): + """Log the bare-clone strategy and wall time (verbose only). + + ``strategy`` is the human-readable command shape, e.g. + ``--depth=1 --branch main`` or ``init+fetch --depth=1 ``. + ``elapsed_ms`` lets readers spot a network-bound regression + without re-running with a profiler. + """ + if not self.verbose: + return + self.verbose_detail(f" [i] perf: bare clone strategy={strategy} took={elapsed_ms}ms") + + def materialize_result(self, sparse_applied: bool, consumer_size_bytes: int): + """Log materialization outcome and consumer dir size (verbose only). + + ``sparse_applied`` tells the reader whether sparse-cone fired + on this consumer dir (sparse_paths were passed and accepted by + git). ``consumer_size_bytes`` is the on-disk size of the + working tree handed off to the integrator; a regression here + is the leading indicator that sparse-cone silently fell back. + """ + if not self.verbose: + return + size_mb = consumer_size_bytes / (1024 * 1024) + applied = "yes" if sparse_applied else "no" + self.verbose_detail(f" [i] perf: materialize sparse={applied} size={size_mb:.2f} MB") + + def tier_summary(self, stats: dict[str, int]): + """Log the tiered ref resolver hit counts (verbose only). + + Emitted at the end of the resolve phase so the reader can see + how many ref->SHA lookups hit each tier (L0 per-run cache, + L1 commits API, L2 bare rev-parse, L3 legacy clone) without + wiring a debugger. A run dominated by L3 is the canonical + signal that ref-resolution is paying full clone cost. + """ + if not self.verbose or not stats: + return + non_zero = {k: v for k, v in stats.items() if v} + if not non_zero: + return + parts = " ".join(f"{k}={v}" for k, v in non_zero.items()) + self.verbose_detail(f" [i] perf: ref-resolver tiers: {parts}") + + # --- Cleanup phase (stale and orphan file removal) --- + + def stale_cleanup(self, dep_key: str, count: int): + """Log per-package stale-file cleanup outcome at default verbosity. + + Stale-file deletion is a destructive operation in the user's + tracked workspace (unlike npm's ``node_modules``); it must be + visible without ``--verbose``. Rendered as an info line so it + groups visually with other phase messages, not as a tree item + (the originating package line was emitted earlier in the install + sequence and is no longer adjacent). + """ + if count <= 0: + return + self._stale_cleaned_total += count + noun = "file" if count == 1 else "files" + self.info(f"Cleaned {count} stale {noun} from {dep_key}") + + def orphan_cleanup(self, count: int): + """Log post-install orphan-file cleanup outcome at default verbosity. + + Same visibility rationale as :meth:`stale_cleanup`: file deletion + in the user's workspace must be visible by default. + """ + if count <= 0: + return + self._stale_cleaned_total += count + noun = "file" if count == 1 else "files" + self.info(f"Cleaned {count} {noun} from packages no longer in apm.yml") + + @property + def stale_cleaned_total(self) -> int: + """Total files removed by stale + orphan cleanup during this install.""" + return self._stale_cleaned_total + + def cleanup_skipped_user_edit(self, rel_path: str, dep_key: str): + """Log a stale-file deletion that was skipped because the user + edited the file after APM deployed it. + + Yellow inline at default verbosity -- the user needs to know APM + kept the file and a manual decision is pending. + """ + self.warning( + f" Kept user-edited file {rel_path} (from {dep_key}); " + "delete manually if no longer needed" + ) + + # --- Policy phase --- + + def policy_resolved( + self, + source: str, + cached: bool, + enforcement: str, + age_seconds: int | None = None, + ): + """Log policy discovery outcome. + + Verbose by default; always shown when ``enforcement == "block"`` + (users must know blocking is active). + + Format: ``[i] Policy: (cached, fetched 5m ago) -- enforcement=block`` + """ + parts = [f"Policy: {source}"] + + if cached: + cache_detail = "cached" + if age_seconds is not None: + if age_seconds < 60: + cache_detail += f", fetched {age_seconds}s ago" + else: + minutes = age_seconds // 60 + unit = "m" if minutes < 60 else "h" + value = minutes if minutes < 60 else minutes // 60 + cache_detail += f", fetched {value}{unit} ago" + parts.append(f"({cache_detail})") + parts.append(f"-- enforcement={enforcement}") + + message = " ".join(parts) + + if enforcement == "block": + # Always visible — blocking installs is a big deal + self.warning(message) + elif self.verbose: + self.info(message) + # Non-verbose + non-block: silent (no noise for warn/off) + + def policy_discovery_miss( + self, + outcome: str, + source: str = "", + error: str | None = None, + host_org: str | None = None, + ): + """Log a policy-discovery non-success outcome. + + Single canonical helper that routes all 7 non-found / non-disabled + outcomes through one wording table. Replaces the per-call-site + ``_rich_info`` / ``_rich_warning`` invocations in ``policy_gate`` + and ``install_preflight`` (Logging C1 / C2, UX F1 / F2 / F4 / F5). + + Args: + outcome: One of ``"absent"``, ``"no_git_remote"``, ``"empty"``, + ``"malformed"``, ``"cache_miss_fetch_fail"``, + ``"garbage_response"``, ``"cached_stale"``. + source: Policy source string (e.g. ``"org:acme/.github"``). + error: Optional error string (used for malformed, + cache_miss_fetch_fail, garbage_response, cached_stale). + host_org: Optional org slug for ``absent`` outcome (verbose + hint). Auto-derived from ``source`` when not provided. + """ + err_text = error or "unknown" + + # Merge the two verbose-only early-exit outcomes to stay within the + # PLR0911 return-statement budget (absent + no_git_remote share the + # same "silent when not verbose" guard). + if outcome in ("absent", "no_git_remote"): + if not self.verbose: + return + if outcome == "absent": + org = host_org or _strip_source_prefix(source) or "this project" + self.info(f"No org policy found for {org}") + else: + # UX F2: normal state for fresh `git init`, unpacked bundles, etc. + self.info("Could not determine org from git remote; policy auto-discovery skipped") + return + + if outcome == "empty": + src = source or "this project" + self.warning(f"Org policy at {src} is present but empty; no enforcement applied") + return + + if outcome == "malformed": + self.warning( + f"Policy at {source} is malformed: {err_text}. " + "Contact your org admin to fix the policy file." + ) + return + + if outcome == "cache_miss_fetch_fail": + # UX F5: explicit posture -- enforcement skipped. + self.warning( + f"Could not fetch org policy from {source} ({err_text}); " + "proceeding without policy enforcement. " + "Retry, check connectivity, or use --no-policy to bypass." + ) + return + + if outcome == "garbage_response": + # UX F4: server IS reachable; "check VPN/firewall" is wrong advice. + self.warning( + f"Policy response from {source} is not valid YAML " + f"({err_text}); proceeding without policy enforcement. " + "Contact your org admin or use --no-policy." + ) + return + + if outcome == "cached_stale": + # UX F5: explicit posture -- enforcement still applies. + self.warning( + f"Using stale cached policy (refresh failed: {err_text}); " + "enforcement still applies from cached policy." + ) + return + + if outcome == "hash_mismatch": + # #827: always-error posture -- pinned policy.hash does not match. + self.error( + f"Policy hash mismatch: pinned hash does not match fetched " + f"policy ({err_text}). Update apm.yml policy.hash or " + "contact your org admin." + ) + return + + # Defensive: unknown outcome -- emit a conservative warning + if error: + self.warning(f"Policy discovery issue: {err_text}") + + def policy_violation( + self, + dep_ref: str, + reason: str, + severity: str, + source: str | None = None, + ): + """Record a policy violation for a dependency. + + Pushes to ``DiagnosticCollector`` under ``CATEGORY_POLICY`` for + the end-of-install summary. When ``severity == "block"``, also + prints an inline error so the user sees the failure immediately + (before the summary), followed by a dim secondary line with the + actionable next-step (CLI logging C3). + + Args: + dep_ref: Dependency reference (e.g. ``"acme/evil-pkg"``). + reason: Actionable reason text per rubber-duck I9. + severity: ``"block"`` or ``"warn"``. + source: Optional policy source (used for block-mode next-step + hint). When provided, a dim secondary line with + remediation guidance is rendered under the inline error. + """ + + # F9 dedupe: some callers pass reason with a "{dep_ref}: " prefix + # (the detail strings produced by policy_checks.py do this). + # Strip it defensively so the inline error reads cleanly. + prefix = f"{dep_ref}: " + if reason.startswith(prefix): + reason = reason[len(prefix) :] + + self.diagnostics.policy( + message=reason, + package=dep_ref, + severity=severity, + ) + + if severity == "block": + self.error(f"Policy violation: {dep_ref} -- {reason}") + if source: + self.dim(f" {self._policy_reason_blocked(dep_ref, source)}") + + def policy_disabled(self, reason: str): + """Log a loud warning that policy enforcement is disabled. + + Emitted when ``--no-policy`` or ``APM_POLICY_DISABLE=1`` is + active. Always visible (never silenceable) -- matches the + ``--allow-insecure`` pattern. + """ + self.warning( + f"Policy enforcement disabled by {reason} for this invocation. " + "This does NOT bypass apm audit --ci. " + "CI will still fail the PR for the same policy violation." + ) + + # --- Policy violation reason helpers --- + + @staticmethod + def _policy_reason_auth(source: str) -> str: + """Actionable reason for auth failure during policy fetch.""" + return ( + f"Could not authenticate to fetch policy from {source} " + "-- check `gh auth status` and `GITHUB_APM_PAT`" + ) + + @staticmethod + def _policy_reason_unreachable(source: str) -> str: + """Actionable reason for unreachable policy source.""" + return ( + f"Policy source {source} is unreachable " + "-- retry, check VPN/firewall, or use `--no-policy` to bypass" + ) + + @staticmethod + def _policy_reason_malformed(source: str) -> str: + """Actionable reason for malformed policy file.""" + return f"Policy at {source} is malformed -- contact your org admin to fix the policy file" + + @staticmethod + def _policy_reason_blocked(dep_ref: str, source: str) -> str: + """Actionable reason for a blocked dependency.""" + return ( + f"Blocked by org policy at {source} " + f"-- remove `{dep_ref}` from apm.yml, contact admin to update policy, " + "or use `--no-policy` for one-off bypass" + ) + + # --- Install summary --- + + def install_summary( + self, + apm_count: int, + mcp_count: int, + lsp_count: int = 0, + errors: int = 0, + stale_cleaned: int = 0, + elapsed_seconds: float | None = None, + ): + """Log final install summary. + + Args: + apm_count: Number of APM dependencies installed. + mcp_count: Number of MCP servers installed. + lsp_count: Number of LSP servers installed. + errors: Number of errors collected during install. + stale_cleaned: Total stale + orphan files removed during + this install. Reported as a parenthetical so existing + callers and assertion patterns continue to work. + elapsed_seconds: Wall-clock duration of the install command. + When provided, appended as `` in {x:.1f}s`` before the + terminating period so the user can see how long the + whole command took (F5, microsoft/apm#1116). + """ + parts = [] + if apm_count > 0: + noun = "dependency" if apm_count == 1 else "dependencies" + parts.append(f"{apm_count} APM {noun}") + if mcp_count > 0: + noun = "server" if mcp_count == 1 else "servers" + parts.append(f"{mcp_count} MCP {noun}") + if lsp_count > 0: + noun = "server" if lsp_count == 1 else "servers" + parts.append(f"{lsp_count} LSP {noun}") + + cleanup_suffix = "" + if stale_cleaned > 0: + file_noun = "file" if stale_cleaned == 1 else "files" + cleanup_suffix = f" ({stale_cleaned} stale {file_noun} cleaned)" + + timing_suffix = "" + if elapsed_seconds is not None: + timing_suffix = f" in {elapsed_seconds:.1f}s" + + if parts: + summary = " and ".join(parts) + if errors > 0: + self.warning( + f"Installed {summary}{cleanup_suffix}{timing_suffix} with {errors} error(s)." + ) + else: + self.success( + f"Installed {summary}{cleanup_suffix}{timing_suffix}.", + symbol="sparkles", + ) + elif errors > 0: + self.error(f"Installation failed with {errors} error(s){timing_suffix}.") + else: + self.info(f"No changes -- install state already up to date{timing_suffix}.") + + def install_interrupted(self, elapsed_seconds: float): + """Log a minimal elapsed-time line when the normal summary did + not render (errors, KeyboardInterrupt, click.UsageError). + + Emitted from the outer ``finally`` in ``commands.install.install`` + so users always see how long the failed/interrupted command ran + (F5, microsoft/apm#1116). Best-effort: callers swallow any + exception so a render failure cannot mask the original error. + """ + self.warning(f"Install interrupted after {elapsed_seconds:.1f}s.") diff --git a/src/apm_cli/core/_prompt_compiler.py b/src/apm_cli/core/_prompt_compiler.py new file mode 100644 index 000000000..3975939ba --- /dev/null +++ b/src/apm_cli/core/_prompt_compiler.py @@ -0,0 +1,183 @@ +"""PromptCompiler — compile .prompt.md files with parameter substitution. + +Extracted from script_runner.py (Strangler Stage 2, #1078). +Re-exported from apm_cli.core.script_runner as ``PromptCompiler``. +""" + +from pathlib import Path + + +class PromptCompiler: + """Compiles .prompt.md files with parameter substitution.""" + + DEFAULT_COMPILED_DIR = Path(".apm/compiled") + + def __init__(self): + """Initialize compiler.""" + self.compiled_dir = self.DEFAULT_COMPILED_DIR + + def compile(self, prompt_file: str, params: dict[str, str]) -> str: + """Compile a .prompt.md file with parameter substitution. + + Args: + prompt_file: Path to the .prompt.md file + params: Parameters to substitute + + Returns: + Path to the compiled file + """ + # Resolve the prompt file path - check local first, then dependencies + prompt_path = self._resolve_prompt_file(prompt_file) + + # Now ensure compiled directory exists + self.compiled_dir.mkdir(parents=True, exist_ok=True) + + with open(prompt_path, encoding="utf-8") as f: + content = f.read() + + # Parse frontmatter and content + if content.startswith("---"): + # Split frontmatter and content + parts = content.split("---", 2) + if len(parts) >= 3: + frontmatter = parts[1].strip() # noqa: F841 + main_content = parts[2].strip() + else: + main_content = content + else: + main_content = content + + # Substitute parameters in content + compiled_content = self._substitute_parameters(main_content, params) + + # Generate output file path + output_name = prompt_path.stem.replace(".prompt", "") + ".txt" + output_path = self.compiled_dir / output_name + + # Write compiled content + with open(output_path, "w", encoding="utf-8") as f: + f.write(compiled_content) + + return str(output_path) + + def _resolve_prompt_file(self, prompt_file: str) -> Path: + """Resolve prompt file path, checking local directory first, then common directories, then dependencies. + + Symlinks are rejected outright to prevent traversal attacks. + + Args: + prompt_file: Relative path to the .prompt.md file + + Returns: + Path: Resolved path to the prompt file + + Raises: + FileNotFoundError: If prompt file is not found or is a symlink + """ + prompt_path = Path(prompt_file) + + # First check if it exists in current directory (local) + if prompt_path.exists(): + if prompt_path.is_symlink(): + raise FileNotFoundError( + f"Prompt file '{prompt_file}' is a symlink. " + f"Symlinks are not allowed for security reasons." + ) + return prompt_path + + # Check in common project directories + common_dirs = [".github/prompts", ".apm/prompts"] + for common_dir in common_dirs: + common_path = Path(common_dir) / prompt_file + if common_path.exists() and not common_path.is_symlink(): + return common_path + + # Search dependencies — scan directory tree once to avoid double walk + apm_modules_dir = Path("apm_modules") + dep_dirs = self._collect_dependency_dirs(apm_modules_dir) + + for _org_name, _repo_name, repo_dir in dep_dirs: + dep_prompt_path = repo_dir / prompt_file + if dep_prompt_path.exists() and not dep_prompt_path.is_symlink(): + return dep_prompt_path + + for subdir in ["prompts", ".", "workflows"]: + sub_prompt_path = repo_dir / subdir / prompt_file + if sub_prompt_path.exists() and not sub_prompt_path.is_symlink(): + return sub_prompt_path + + # Build error using already-collected directories (no second walk) + self._raise_prompt_not_found(prompt_file, prompt_path, dep_dirs) + + def _collect_dependency_dirs(self, apm_modules_dir: Path) -> list: + """Collect (org_name, repo_name, repo_dir) tuples from apm_modules. + + Walks the two-level directory tree once so callers can iterate + without repeated filesystem scans. + + Args: + apm_modules_dir: Path to the apm_modules directory + + Returns: + List of (org_name, repo_name, repo_dir) tuples + """ + if not apm_modules_dir.exists(): + return [] + result = [] + for org_dir in apm_modules_dir.iterdir(): + if org_dir.is_dir() and not org_dir.name.startswith("."): + for repo_dir in org_dir.iterdir(): + if repo_dir.is_dir() and not repo_dir.name.startswith("."): + result.append((org_dir.name, repo_dir.name, repo_dir)) + return result + + def _raise_prompt_not_found( + self, + prompt_file: str, + prompt_path: Path, + dep_dirs: list, + ) -> None: + """Build and raise a helpful FileNotFoundError for a missing prompt. + + Args: + prompt_file: Original prompt file reference + prompt_path: Local Path that was checked + dep_dirs: Pre-collected dependency directory tuples + + Raises: + FileNotFoundError: Always — with a message listing searched locations + """ + searched_locations = [ + f"Local: {prompt_path}", + f"GitHub prompts: .github/prompts/{prompt_file}", + f"APM prompts: .apm/prompts/{prompt_file}", + ] + + if dep_dirs: + searched_locations.append("Dependencies:") + for org_name, repo_name, _repo_dir in dep_dirs: + searched_locations.append(f" - {org_name}/{repo_name}/{prompt_file}") + + raise FileNotFoundError( + f"Prompt file '{prompt_file}' not found.\n" + f"Searched in:\n" + + "\n".join(searched_locations) + + f"\n\nTip: Run 'apm install' to ensure dependencies are installed." # noqa: F541 + ) + + def _substitute_parameters(self, content: str, params: dict[str, str]) -> str: + """Substitute parameters in content. + + Args: + content: Content to process + params: Parameters to substitute + + Returns: + Content with parameters substituted + """ + result = content + for key, value in params.items(): + # Replace ${input:key} placeholders + placeholder = f"${{input:{key}}}" + result = result.replace(placeholder, str(value)) + return result diff --git a/src/apm_cli/core/_runtime_commands.py b/src/apm_cli/core/_runtime_commands.py new file mode 100644 index 000000000..7e64d4097 --- /dev/null +++ b/src/apm_cli/core/_runtime_commands.py @@ -0,0 +1,302 @@ +"""_RuntimeCommandsMixin — runtime-command builder methods for ScriptRunner. + +Extracted from script_runner.py (Strangler Stage 2, #1078). +Composed into ScriptRunner via ``class ScriptRunner(_RuntimeCommandsMixin)``. + +Rule B: ``_detect_installed_runtime`` references ``find_runtime_binary`` which +tests patch at ``apm_cli.core.script_runner.find_runtime_binary``. The method +uses a function-level late import to route through the origin module so patches +are intercepted correctly. +""" + +import re +from pathlib import Path + + +class _RuntimeCommandsMixin: + """Mixin carrying the runtime-command builder cluster for ScriptRunner.""" + + # ------------------------------------------------------------------ + # Command transformation helpers + # ------------------------------------------------------------------ + + def _transform_runtime_command( + self, command: str, prompt_file: str, compiled_content: str, compiled_path: str + ) -> str: + """Transform runtime commands to their proper execution format. + + Dispatches to per-runtime builders after extracting arguments + around the prompt file reference. + + Args: + command: Original command + prompt_file: Original .prompt.md file path + compiled_content: Compiled prompt content as string + compiled_path: Path to compiled .txt file + + Returns: + Transformed command for proper runtime execution + """ + # Handle environment variables prefix (e.g., "ENV1=val1 ENV2=val2 codex [args] file.prompt.md") + # More robust approach: split by runtime commands to separate env vars from command + runtime_commands = ["codex", "copilot", "llm", "gemini"] + + # Try matching with env-var prefix (e.g. "ENV=val codex args file.prompt.md") + for runtime_cmd in runtime_commands: + runtime_pattern = f" {runtime_cmd} " + if runtime_pattern in command and re.search(re.escape(prompt_file), command): + parts = command.split(runtime_pattern, 1) + potential_env_part = parts[0] + runtime_part = runtime_cmd + " " + parts[1] + + if "=" in potential_env_part and not potential_env_part.startswith(runtime_cmd): + result = self._parse_and_build_runtime_command( + runtime_cmd, + runtime_part, + prompt_file, + env_prefix=potential_env_part, + ) + if result is not None: + return result + + # Try individual runtime patterns without environment variables + for runtime_cmd in runtime_commands: + if re.search(r"^" + runtime_cmd + r"\s+.*" + re.escape(prompt_file), command): + result = self._parse_and_build_runtime_command( + runtime_cmd, + command, + prompt_file, + ) + if result is not None: + return result + + # Handle bare "file.prompt.md" -> "codex exec" (default to codex) + if command.strip() == prompt_file: + return "codex exec" + + # Fallback: just replace file path with compiled path (for non-runtime commands) + return command.replace(prompt_file, compiled_path) + + def _parse_and_build_runtime_command( + self, + runtime_cmd: str, + command_part: str, + prompt_file: str, + env_prefix: str = None, # noqa: RUF013 + ) -> str | None: + """Parse arguments around the prompt file and delegate to a per-runtime builder. + + Args: + runtime_cmd: Runtime name (codex, copilot, llm, or gemini) + command_part: The command portion containing the runtime invocation + prompt_file: The .prompt.md filename to strip + env_prefix: Optional environment variable prefix (e.g. "DEBUG=1") + + Returns: + Transformed command string, or None if the pattern does not match + """ + match = re.search( + f"{runtime_cmd}\\s+(.*?)(" + re.escape(prompt_file) + r")(.*?)$", + command_part, + ) + if not match: + return None + + args_before = match.group(1).strip() + args_after = match.group(3).strip() + + # In the env-var path, non-codex runtimes strip -p flags (matches + # original behaviour where copilot and llm shared an else branch). + if env_prefix is not None and runtime_cmd != "codex": + args_before = args_before.replace("-p", "").strip() + + builders = { + "codex": self._build_codex_command, + "copilot": self._build_copilot_command, + "llm": self._build_llm_command, + "gemini": self._build_gemini_command, + } + builder = builders.get(runtime_cmd) + if builder: + return builder(args_before, args_after, env_prefix) + return None + + def _build_codex_command( + self, + args_before: str, + args_after: str, + env_prefix: str | None = None, + ) -> str: + """Build a codex command from parsed arguments. + + Args: + args_before: Arguments that appeared before the prompt file + args_after: Arguments that appeared after the prompt file + env_prefix: Optional environment variable prefix + + Returns: + Assembled codex command string + """ + prefix = f"{env_prefix} " if env_prefix else "" + result = f"{prefix}codex exec" + if args_before: + result += f" {args_before}" + if args_after: + result += f" {args_after}" + return result + + def _build_copilot_command( + self, + args_before: str, + args_after: str, + env_prefix: str | None = None, + ) -> str: + """Build a copilot command from parsed arguments. + + Removes any existing -p flag since content is passed separately + during execution. + + Args: + args_before: Arguments that appeared before the prompt file + args_after: Arguments that appeared after the prompt file + env_prefix: Optional environment variable prefix + + Returns: + Assembled copilot command string + """ + prefix = f"{env_prefix} " if env_prefix else "" + result = f"{prefix}copilot" + if args_before: + # Remove any existing -p flag since we handle it in execution + cleaned_args = args_before.replace("-p", "").strip() + if cleaned_args: + result += f" {cleaned_args}" + if args_after: + result += f" {args_after}" + return result + + def _build_llm_command( + self, + args_before: str, + args_after: str, + env_prefix: str | None = None, + ) -> str: + """Build an llm command from parsed arguments. + + Args: + args_before: Arguments that appeared before the prompt file + args_after: Arguments that appeared after the prompt file + env_prefix: Optional environment variable prefix + + Returns: + Assembled llm command string + """ + prefix = f"{env_prefix} " if env_prefix else "" + result = f"{prefix}llm" + if args_before: + result += f" {args_before}" + if args_after: + result += f" {args_after}" + return result + + def _build_gemini_command( + self, + args_before: str, + args_after: str, + env_prefix: str | None = None, + ) -> str: + """Build a gemini command from parsed arguments. + + Args: + args_before: Arguments that appeared before the prompt file + args_after: Arguments that appeared after the prompt file + env_prefix: Optional environment variable prefix + + Returns: + Assembled gemini command string + """ + prefix = f"{env_prefix} " if env_prefix else "" + result = f"{prefix}gemini" + if args_before: + cleaned_args = re.sub(r"(^|\s)-p(?=\s|$)", "", args_before).strip() + if cleaned_args: + result += f" {cleaned_args}" + if args_after: + result += f" {args_after}" + return result + + def _detect_runtime(self, command: str) -> str: + """Detect which runtime is being used in the command. + + Args: + command: The command to analyze + + Returns: + Name of the detected runtime (copilot, codex, llm, gemini, or unknown) + """ + command_lower = command.lower().strip() + if re.search(r"(?:^|\s)copilot(?:\s|$)", command_lower): + return "copilot" + elif re.search(r"(?:^|\s)codex(?:\s|$)", command_lower): + return "codex" + elif re.search(r"(?:^|\s)llm(?:\s|$)", command_lower): + return "llm" + elif re.search(r"(?:^|\s)gemini(?:\s|$)", command_lower): + return "gemini" + else: + return "unknown" + + def _generate_runtime_command(self, runtime: str, prompt_file: Path) -> str: + """Generate appropriate runtime command with proper defaults. + + Args: + runtime: Name of runtime (copilot or codex) + prompt_file: Path to the prompt file + + Returns: + Full command string with runtime-specific defaults + """ + if runtime == "copilot": + return ( + f"copilot --log-level all --log-dir copilot-logs --allow-all-tools -p {prompt_file}" + ) + elif runtime == "codex": + return f"codex -s workspace-write --skip-git-repo-check {prompt_file}" + elif runtime == "gemini": + return f"gemini -p {prompt_file}" + else: + raise ValueError(f"Unsupported runtime: {runtime}") + + def _detect_installed_runtime(self) -> str: + """Detect installed runtime with priority order. + + Priority: copilot > codex > gemini > error + + Rule B: ``find_runtime_binary`` is patched at + ``apm_cli.core.script_runner.find_runtime_binary`` in tests; + route through the origin module via a function-level import. + + Returns: + Name of detected runtime + + Raises: + RuntimeError: If no compatible runtime is found + """ + import apm_cli.core.script_runner as _sr + + if _sr.find_runtime_binary("copilot"): + return "copilot" + elif _sr.find_runtime_binary("codex"): + return "codex" + elif _sr.find_runtime_binary("gemini"): + return "gemini" + else: + raise RuntimeError( + "No compatible runtime found.\n" + "Install GitHub Copilot CLI with:\n" + " apm runtime setup copilot\n" + "Or install Codex CLI with:\n" + " apm runtime setup codex\n" + "Or install Gemini CLI with:\n" + " apm runtime setup gemini" + ) diff --git a/src/apm_cli/core/auth.py b/src/apm_cli/core/auth.py index f75e6df5e..16cbbeaa6 100644 --- a/src/apm_cli/core/auth.py +++ b/src/apm_cli/core/auth.py @@ -31,19 +31,14 @@ import logging import os import re -import sys import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, NamedTuple, TypeVar +from apm_cli.core._auth_support import _AuthSupportMixin, _org_to_env_suffix from apm_cli.core.token_manager import GitHubTokenManager -from apm_cli.utils.github_host import ( - default_host, - is_azure_devops_hostname, - is_gitlab_hostname, - is_valid_fqdn, -) +from apm_cli.utils.github_host import default_host if TYPE_CHECKING: from apm_cli.models.dependency.reference import DependencyReference @@ -106,12 +101,6 @@ def filter(self, record: logging.LogRecord) -> bool: return True -_PORT_CREDENTIAL_DOCS_URL = ( - "https://microsoft.github.io/apm/getting-started/authentication/" - "#custom-port-hosts-and-per-port-credentials" -) - - # --------------------------------------------------------------------------- # Data classes # --------------------------------------------------------------------------- @@ -181,7 +170,7 @@ class BearerFallbackOutcome(NamedTuple): bearer_attempted: bool -class AuthResolver: +class AuthResolver(_AuthSupportMixin): """Single source of truth for auth resolution. Every APM operation that touches a remote host MUST use this class. @@ -217,140 +206,6 @@ def set_logger(self, logger: object) -> None: the logger before it knows it needs an AuthResolver elsewhere.""" self._logger = logger - # -- host classification ------------------------------------------------ - - @staticmethod - def classify_host(host: str, port: int | None = None) -> HostInfo: - """Return a ``HostInfo`` describing *host*. - - ``port`` is carried through onto the returned ``HostInfo`` so that - downstream code (cache keys, credential-helper input, error text) - can discriminate between the same hostname on different ports. - Host-kind classification itself is transport-agnostic -- the port - never influences whether a host is GitHub/GHES/ADO/generic. - """ - h = host.lower() - - if h == "github.com": - return HostInfo( - host=host, - kind="github", - has_public_repos=True, - api_base="https://api.github.com", - port=port, - ) - - if h.endswith(".ghe.com"): - return HostInfo( - host=host, - kind="ghe_cloud", - has_public_repos=False, - api_base=f"https://{host}/api/v3", - port=port, - ) - - if is_azure_devops_hostname(host): - return HostInfo( - host=host, - kind="ado", - has_public_repos=True, - api_base="https://dev.azure.com", - port=port, - ) - - # GHES: GITHUB_HOST is set to a non-github.com, non-ghe.com FQDN - ghes_host = os.environ.get("GITHUB_HOST", "").lower() - if ( - ghes_host - and ghes_host == h - and ghes_host not in {"github.com", "gitlab.com"} - and not ghes_host.endswith(".ghe.com") - ): - if is_valid_fqdn(ghes_host): - return HostInfo( - host=host, - kind="ghes", - has_public_repos=True, - api_base=f"https://{host}/api/v3", - port=port, - ) - - # GitLab (SaaS + env-configured self-managed) — after GHES per spec (no silent GHES → GitLab) - if is_gitlab_hostname(host): - if h == "gitlab.com": - api_base = "https://gitlab.com/api/v4" - else: - api_base = f"https://{host}/api/v4" - return HostInfo( - host=host, - kind="gitlab", - has_public_repos=True, - api_base=api_base, - port=port, - ) - - # Generic FQDN (Bitbucket, self-hosted non-GitLab, etc.) - return HostInfo( - host=host, - kind="generic", - has_public_repos=True, - api_base=f"https://{host}/api/v3", - port=port, - ) - - # -- token type detection ----------------------------------------------- - - @staticmethod - def detect_token_type(token: str) -> str: - """Classify a token string by its prefix. - - Note: EMU (Enterprise Managed Users) tokens use standard PAT - prefixes (``ghp_`` or ``github_pat_``). There is no prefix that - identifies a token as EMU-scoped — that's a property of the - account, not the token format. - - Prefix reference (docs.github.com): - - ``github_pat_`` → fine-grained PAT - - ``ghp_`` → classic PAT - - ``ghu_`` → OAuth user-to-server (e.g. ``gh auth login``) - - ``gho_`` → OAuth app token - - ``ghs_`` → GitHub App installation (server-to-server) - - ``ghr_`` → GitHub App refresh token - """ - if token.startswith("github_pat_"): - return "fine-grained" - if token.startswith("ghp_"): - return "classic" - if token.startswith("ghu_"): - return "oauth" - if token.startswith("gho_"): - return "oauth" - if token.startswith("ghs_"): - return "github-app" - if token.startswith("ghr_"): - return "github-app" - return "unknown" - - @staticmethod - def gitlab_rest_headers( - token: str | None, - *, - oauth_bearer: bool = False, - ) -> dict[str, str]: - """Build HTTP headers for GitLab REST API v4 calls. - - Personal access tokens use ``PRIVATE-TOKEN``. OAuth2 access tokens - typically use ``Authorization: Bearer ``; set *oauth_bearer* - to use that style. - - Does not log or print *token*. Callers must not log the returned dict. - """ - if not token: - return {} - if oauth_bearer: - return {"Authorization": f"Bearer {token}"} - return {"PRIVATE-TOKEN": token} - # -- core resolution ---------------------------------------------------- def resolve( @@ -460,6 +315,28 @@ def _log(msg: str) -> None: if verbose_callback: verbose_callback(msg) + def _run_auth_only(fallback: Callable[[Exception], T]) -> T: + """Auth-only strategy shared by ghe_cloud and ado hosts. + + Attempt *operation* with the resolved token; on any failure + delegate to *fallback* (credential chain for ghe_cloud, AAD + bearer for ado). Collapses what used to be two near-identical + try/except branches in the outer dispatch. + """ + _log(f"Auth-only attempt for {host_info.kind} host {host_info.display_name}") + try: + return operation(auth_ctx.token, git_env) + except Exception as exc: + # operation is caller-provided; broad catch required -- cannot narrow + # without restricting the caller API. Use %r so the type is visible. + logger.debug( + "Auth-only operation failed for %s host %s: %r", + host_info.kind, + host_info.display_name, + exc, + ) + return fallback(exc) + def _try_credential_fallback(exc: Exception) -> T: """Retry the operation when the originally-resolved token fails. @@ -548,35 +425,13 @@ def _try_ado_bearer_fallback(exc: Exception) -> T: ) raise exc - # Hosts that never have public repos -> auth-only - if host_info.kind == "ghe_cloud": - _log(f"Auth-only attempt for {host_info.kind} host {host_info.display_name}") - try: - return operation(auth_ctx.token, git_env) - except Exception as exc: - # operation is caller-provided; broad catch required -- cannot narrow - # without restricting the caller API. Use %r so the type is visible. - logger.debug( - "Auth-only operation failed for ghe_cloud host %s: %r", - host_info.display_name, - exc, - ) - return _try_credential_fallback(exc) - - # ADO: auth-first with bearer fallback when PAT fails - if host_info.kind == "ado": - _log(f"Auth-only attempt for {host_info.kind} host {host_info.display_name}") - try: - return operation(auth_ctx.token, git_env) - except Exception as exc: - # operation is caller-provided; broad catch required -- cannot narrow - # without restricting the caller API. Use %r so the type is visible. - logger.debug( - "Auth-only operation failed for ado host %s; trying bearer fallback: %r", - host_info.display_name, - exc, - ) - return _try_ado_bearer_fallback(exc) + # ghe_cloud (never public) and ado (PAT then AAD bearer) share the + # auth-only shape; a single dispatch picks the host-kind fallback. + if host_info.kind in ("ghe_cloud", "ado"): + fallback = ( + _try_ado_bearer_fallback if host_info.kind == "ado" else _try_credential_fallback + ) + return _run_auth_only(fallback) if unauth_first: # Validation path: save rate limits, EMU-safe @@ -631,211 +486,13 @@ def _try_ado_bearer_fallback(exc: Exception) -> T: host_info.display_name, unauth_exc, ) - return _try_credential_fallback(exc) + # Both the unauth retry failure and the no-public-repos case + # converge on the secondary credential chain. return _try_credential_fallback(exc) else: _log(f"No token available, trying unauthenticated access to {host_info.display_name}") return operation(None, git_env) - # -- error context ------------------------------------------------------ - - def build_error_context( - self, - host: str, - operation: str, - org: str | None = None, - *, - port: int | None = None, - dep_url: str | None = None, - bearer_also_failed: bool = False, - ) -> str: - """Build an actionable error message for auth failures. - - ``bearer_also_failed=True`` prepends a single line to the Case 4 - block (PAT set, az available, both attempts failed) clarifying - that ADO_APM_PAT was tried first and rejected before the bearer - attempt -- so the user understands why both halves of the - protocol failed without having to read the full diagnostic - context. Callers MUST only set this when the bearer attempt - actually ran (see :class:`BearerFallbackOutcome.bearer_attempted`). - """ - auth_ctx = self.resolve(host, org, port=port) - host_info = auth_ctx.host_info - display = host_info.display_name - - # --- ADO-specific error cases --- - if host_info.kind == "ado": - from apm_cli.core.azure_cli import get_bearer_provider - - provider = get_bearer_provider() - az_available = provider.is_available() - pat_set = bool(os.environ.get("ADO_APM_PAT")) - - org_part = org or "" - if not org_part: - source_url = dep_url or "" - if source_url: - parts = source_url.replace("https://", "").split("/") - if len(parts) >= 2 and ( - parts[0] in ("dev.azure.com",) or parts[0].endswith(".visualstudio.com") - ): - org_part = parts[1] if len(parts) > 1 else "" - - token_url = ( - f"https://dev.azure.com/{org_part}/_usersSettings/tokens" - if org_part - else "https://dev.azure.com//_usersSettings/tokens" - ) - - if pat_set: - if az_available: - # Case 4: PAT and bearer were both available; both attempts - # failed. We may not have observed an explicit 401 (could be - # a 404, a network error, etc.) so the wording stays - # tentative -- see #856 review C6. - prefix = ( - " ADO_APM_PAT was rejected; az cli bearer was also rejected.\n\n" - if bearer_also_failed - else "" - ) - return ( - f"\n{prefix}" - f" ADO_APM_PAT is set, and Azure CLI credentials may also be available,\n" - f" but the Azure DevOps request still failed.\n\n" - f" If this is an authentication failure, the PAT may be expired, revoked,\n" - f" or scoped to a different org, and Azure CLI credentials may need to\n" - f" be refreshed.\n\n" - f" To fix:\n" - f" 1. Unset the PAT to test Azure CLI auth only: unset ADO_APM_PAT\n" - f" 2. Re-authenticate Azure CLI if needed: az login\n" - f" 3. Retry: apm install\n\n" - f" Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" - ) - # PAT set but rejected, no az -> bare PAT failure - return ( - f"\n ADO_APM_PAT is set, but the Azure DevOps request failed.\n" - f" If this is an authentication failure, the token may be expired,\n" - f" revoked, or scoped to a different org.\n\n" - f" Generate a new PAT at {token_url}\n" - f" with Code (Read) scope.\n\n" - f" Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" - ) - - # No PAT set - if not az_available: - # Case 1: no az, no PAT - return ( - f"\n Azure DevOps requires authentication. You have two options:\n\n" - f" 1. Install Azure CLI and sign in (recommended for Entra ID users):\n" - f" brew install azure-cli # macOS\n" - f" winget install Microsoft.AzureCLI # Windows\n" - f" apt-get install azure-cli # Debian/Ubuntu\n" - f" dnf install azure-cli # Fedora/RHEL\n" - f" (full guide: https://aka.ms/InstallAzureCli)\n" - f" az login\n" - f" apm install # retry -- no env var needed\n\n" - f" 2. Use a Personal Access Token:\n" - f" export ADO_APM_PAT=your_token\n" - f" (Create one at {token_url} with Code (Read) scope.)\n\n" - f" Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" - ) - - # az is available; check if logged in by trying to get tenant - tenant = provider.get_current_tenant_id() - if tenant is None: - # Case 3: az present, not logged in - return ( - "\n Azure DevOps requires authentication. You have two options:\n\n" - " 1. Sign in with Azure CLI (recommended for Entra ID users):\n" - " az login\n" - " apm install # retry -- no env var needed\n\n" - " 2. Use a Personal Access Token:\n" - " export ADO_APM_PAT=your_token\n\n" - " Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" - ) - - # Case 2: az returned token (tenant known) but ADO rejected it. - # Note: bearer_also_failed=True is structurally unreachable here -- - # callers only set it when source == "ADO_APM_PAT" (i.e. pat_set - # is True), and Case 2 lives in the `not pat_set` branch. We do - # not render a "PAT was also rejected" prefix in this case - # because no PAT was tried. - return ( - f"\n Your az cli session (tenant: {tenant}) returned a bearer token,\n" - f" but Azure DevOps rejected it (HTTP 401).\n\n" - f" Check that you are signed into the correct tenant:\n" - f" az account show\n" - f" az login --tenant \n\n" - f" Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops" - ) - - # --- Non-ADO error paths --- - lines: list[str] = [f"Authentication failed for {operation} on {display}."] - - if auth_ctx.token: - lines.append( - f"Token was provided (source: {auth_ctx.source}, type: {auth_ctx.token_type})." - ) - if host_info.kind == "ghe_cloud": - lines.append( - "GHE Cloud Data Residency hosts (*.ghe.com) require " - "enterprise-scoped tokens. Ensure your PAT is authorized " - "for this enterprise." - ) - elif host_info.kind == "gitlab": - lines.append( - "Ensure your GitLab personal or project access token meets the " - "API read requirements for your instance policy." - ) - elif host.lower() == "github.com": - lines.append( - "If your organization uses SAML SSO or is an EMU org, " - "ensure your PAT is authorized at " - "https://github.com/settings/tokens" - ) - elif host_info.kind == "generic": - lines.append("Verify credentials for this host in your git credential helper.") - else: - lines.append( - "If your organization uses SAML SSO, you may need to " - "authorize your token at https://github.com/settings/tokens" - ) - else: - lines.append("No token available.") - if host_info.kind == "gitlab": - lines.append( - "Set GITLAB_APM_PAT or GITLAB_TOKEN, or configure git credential fill " - f"for {display}." - ) - elif host_info.kind == "generic": - lines.append( - "APM does not apply GitHub PAT environment variables to generic git " - f"hosts; configure git credential fill for {display} or use a " - "public repository if available." - ) - else: - lines.append("Set GITHUB_APM_PAT or GITHUB_TOKEN, or run 'gh auth login'.") - - if org and host_info.kind not in ("ado", "gitlab", "generic"): - lines.append( - f"If packages span multiple organizations, set per-org tokens: " - f"GITHUB_APM_PAT_{_org_to_env_suffix(org)}" - ) - - # When a custom port is in play, helpers that key by hostname alone - # (some `gh` integrations, older keychain backends) can silently - # return the wrong credential. Point the user at the concrete fix. - if host_info.port is not None: - lines.append( - f"[i] Host '{display}' -- this helper may key by host only.\n" - f" Verify with: printf 'protocol=https\\nhost={display}\\n\\n'" - f" | git credential fill\n" - f" Docs: {_PORT_CREDENTIAL_DOCS_URL}" - ) - - lines.append("Run with --verbose for detailed auth diagnostics.") - return "\n".join(lines) - # -- internals ---------------------------------------------------------- def _resolve_token(self, host_info: HostInfo, org: str | None) -> tuple[str | None, str, str]: @@ -922,140 +579,21 @@ def _resolve_token(self, host_info: HostInfo, org: str | None) -> tuple[str | No return None, "none", "basic" - @staticmethod - def _purpose_for_host(host_info: HostInfo) -> str: - if host_info.kind == "ado": - return "ado_modules" - if host_info.kind == "gitlab": - return "gitlab_modules" - if host_info.kind == "generic": - return "generic_modules" - return "modules" - - def _identify_env_source(self, purpose: str) -> str: - """Return the name of the first env var that matched for *purpose*.""" - for var in self._token_manager.TOKEN_PRECEDENCE.get(purpose, []): - if os.environ.get(var): - return var - return "env" - - @staticmethod - def _build_git_env( - token: str | None = None, - *, - scheme: str = "basic", - host_kind: str = "github", - ) -> dict: - """Pre-built env dict for subprocess git calls. - - For ADO bearer tokens (scheme='bearer'), injects an Authorization header - via GIT_CONFIG_COUNT/KEY/VALUE env vars (see github_host.build_ado_bearer_git_env). - For all other cases, behavior is unchanged. - """ - env = os.environ.copy() - env["GIT_TERMINAL_PROMPT"] = "0" - env["GIT_ASKPASS"] = "echo" - if scheme == "bearer" and token and host_kind == "ado": - # B2 #852: skip GIT_TOKEN for bearer scheme -- the JWT is injected via - # GIT_CONFIG_VALUE_0 only; GIT_TOKEN here would leak it into every - # child-process env (visible in /proc//environ, ps eww). - # - # #1214 follow-up: a stale GIT_TOKEN already in the parent env - # (set by a prior shell, CI step, or another tool) would survive - # the os.environ.copy() above and defeat the isolation guarantee. - # Drop it explicitly so the bearer env is clean by construction. - env.pop("GIT_TOKEN", None) - from apm_cli.utils.github_host import build_ado_bearer_git_env - - env.update(build_ado_bearer_git_env(token)) - elif token: - env["GIT_TOKEN"] = token - return env - - def emit_stale_pat_diagnostic(self, host_display: str) -> None: - """Emit a [!] warning when PAT was rejected but bearer succeeded. - - F3 #852: when an InstallLogger is wired via :meth:`set_logger`, the - warning is collected by its DiagnosticCollector so it appears in the - install summary. Without a logger (e.g. unit tests) we fall back to - the inline ``_rich_warning`` emission for backwards compatibility. - - #1212 follow-up: dedup per host_display so the user sees ONE warning - per ADO host even when preflight, list_remote_refs, and the clone - path each trigger the bearer-fallback path against the same host. - - Naming: previously ``_emit_stale_pat_diagnostic`` (private). Public - now (#856 follow-up C9) so external modules (validation.py, - github_downloader.py) do not reach into the underscore API. - - #1214 follow-up: guard the check-then-add under self._lock so two - threads (parallel install) racing on the same ADO host cannot both - pass the membership check before either calls add(); without the - lock the dedup set defeats its own purpose. - """ - with self._lock: - if host_display in self._stale_pat_warned_hosts: - return - self._stale_pat_warned_hosts.add(host_display) - msg = f"ADO_APM_PAT was rejected for {host_display}; fell back to az cli bearer." - detail = "Consider unsetting the stale variable." - diagnostics = self._diagnostics_or_none() - if diagnostics is not None: - diagnostics.warn(msg, detail=detail) - return - try: - from apm_cli.utils.console import _rich_warning - - _rich_warning(msg, symbol="warning") - _rich_warning(f" {detail}", symbol="warning") - except ImportError as exc: - logger.debug("Console module unavailable for stale-PAT warning; skipping: %s", exc) - - # Backwards-compat alias for any in-tree caller still importing the - # private name. Safe to remove once all callers move to the public name. - _emit_stale_pat_diagnostic = emit_stale_pat_diagnostic - - def _diagnostics_or_none(self): - """Return the wired logger's DiagnosticCollector, or None.""" - if self._logger is None: - return None - try: - return self._logger.diagnostics - except AttributeError: - return None - - def notify_auth_source(self, host_display: str, ctx) -> None: - """Emit the verbose auth-source line for ``host_display`` exactly once. - - F2 #852: routes through CommandLogger when wired (so the line obeys - the same verbose channel as every other diagnostic), and falls back - to a direct stderr write when no logger is set so the existing - bearer e2e tests keep working. + def _ado_bearer_provider(self): + """Return the Azure CLI AAD bearer provider (auth-boundary accessor). + + The sole seam through which boundary-free helpers (notably + :meth:`build_error_context` in ``_auth_support``) reach the ADO + bearer provider. Keeping the ``get_bearer_provider`` symbol behind + this accessor preserves the invariant enforced by + ``scripts/lint-auth-signals.sh`` (Rule A): every reference to the + provider lives inside ``core/auth.py``. The import stays + function-local so non-ADO paths never pay the azure_cli load cost + and tests can patch ``apm_cli.core.azure_cli.get_bearer_provider``. """ - host_key = (host_display or "").lower() - if not host_key or host_key in self._verbose_auth_logged_hosts: - return - self._verbose_auth_logged_hosts.add(host_key) - if ctx is None or getattr(ctx, "source", "none") == "none": - return - if getattr(ctx, "auth_scheme", None) == "bearer": - line = f" [i] {host_key} -- using bearer from az cli (source: {ctx.source})" - else: - line = f" [i] {host_key} -- token from {ctx.source}" - if self._logger is not None and getattr(self._logger, "verbose", False): - try: - from apm_cli.utils.console import _rich_echo + from apm_cli.core.azure_cli import get_bearer_provider - _rich_echo(line, color="dim") - return - except ImportError as exc: - logger.debug( - "Console module unavailable for auth-source logging; skipping: %s", exc - ) - # No logger wired -- the install path always wires one in the - # bearer branch, so this fallback only fires in unit-test contexts - # that opt-in via APM_VERBOSE=1. - sys.stderr.write(line + "\n") + return get_bearer_provider() def execute_with_bearer_fallback( self, @@ -1130,13 +668,3 @@ def execute_with_bearer_fallback( host_display = getattr(dep_ref, "host", None) or "dev.azure.com" self.emit_stale_pat_diagnostic(host_display) return BearerFallbackOutcome(fallback, True) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _org_to_env_suffix(org: str) -> str: - """Convert an org name to an env-var suffix (upper-case, hyphens → underscores).""" - return org.upper().replace("-", "_") diff --git a/src/apm_cli/core/command_logger.py b/src/apm_cli/core/command_logger.py index f45a9c8de..36881a400 100644 --- a/src/apm_cli/core/command_logger.py +++ b/src/apm_cli/core/command_logger.py @@ -141,6 +141,24 @@ def tree_item(self, message: str): """ _rich_echo(message, color="green") + def dim(self, message: str): + """Log a dim (grey) line unconditionally — no verbose gate. + + Use for secondary/contextual lines that must always appear + (e.g. inline remediation hints under an error), where + :meth:`verbose_detail` would suppress them for non-verbose runs. + """ + _rich_echo(message, color="dim") + + def dim_check_item(self, message: str): + """Log a dim check-marked item (used by validation_pass). + + Renders the message in dim colour with a check symbol — these + are "already present / already updated" confirmations that sit + visually alongside :meth:`tree_item` lines. + """ + _rich_echo(message, color="dim", symbol="check") + def blank_line(self): """Log a blank line through the shared console output path.""" _rich_echo("") @@ -198,639 +216,18 @@ def render_summary(self): self._diagnostics.render_summary() -class InstallLogger(CommandLogger): - """Install-specific logger with validation, resolution, and download phases. +def __getattr__(name: str): + """Lazily re-export ``InstallLogger`` (PEP 562). - Knows whether this is a partial install (specific packages requested) or - full install (all deps from apm.yml). Adjusts messages accordingly. + ``_install_logger`` imports ``CommandLogger`` (its base class) from this + module at module scope. Re-exporting ``InstallLogger`` eagerly here would + create a circular import that fails whenever ``_install_logger`` is imported + first (partially-initialised module). Deferring the import until + ``command_logger.InstallLogger`` is first accessed breaks the cycle while + preserving the public ``apm_cli.core.command_logger.InstallLogger`` surface. """ + if name == "InstallLogger": + from ._install_logger import InstallLogger - def __init__(self, verbose: bool = False, dry_run: bool = False, partial: bool = False): - super().__init__("install", verbose=verbose, dry_run=dry_run) - self.partial = partial # True when specific packages are passed to `apm install` - self._stale_cleaned_total = 0 # Accumulated by stale_cleanup / orphan_cleanup - - # --- Validation phase --- - - def validation_start(self, count: int): - """Log start of package validation.""" - noun = "package" if count == 1 else "packages" - _rich_info(f"Validating {count} {noun}...", symbol="gear") - - def validation_pass(self, canonical: str, already_present: bool, updated: bool = False): - """Log a package that passed validation.""" - if updated: - _rich_echo(f"{canonical} (updated ref in apm.yml)", color="dim", symbol="check") - elif already_present: - _rich_echo(f"{canonical} (already in apm.yml)", color="dim", symbol="check") - else: - _rich_success(canonical, symbol="check") - - def validation_fail(self, package: str, reason: str): - """Log a package that failed validation.""" - _rich_error(f"{package} -- {reason}", symbol="error") - - def validation_summary(self, outcome: _ValidationOutcome): - """Log validation summary and decide whether to continue. - - Returns True if install should continue, False if all packages failed. - """ - if outcome.all_failed: - _rich_error("All packages failed validation. Nothing to install.") - return False - - if outcome.has_failures: - failed_count = len(outcome.invalid) - noun = "package" if failed_count == 1 else "packages" - _rich_warning(f"{failed_count} {noun} failed validation and will be skipped.") - - return True - - # --- Resolution phase --- - - def resolution_start(self, to_install_count: int, lockfile_count: int): - """Log start of dependency resolution.""" - if self.partial: - noun = "package" if to_install_count == 1 else "packages" - _rich_info(f"Installing {to_install_count} new {noun}...", symbol="running") - if lockfile_count > 0 and self.verbose: - _rich_echo( - f" ({lockfile_count} existing dependencies in lockfile)", - color="dim", - ) - else: - _rich_info("Installing dependencies from apm.yml...", symbol="running") - if lockfile_count > 0: - _rich_info(f"Using apm.lock.yaml ({lockfile_count} locked dependencies)") - - def nothing_to_install( - self, - lockfile_present: bool = False, - update_mode: bool = False, - ): - """Log when there's nothing to install -- context-aware message. - - Args: - lockfile_present: True when apm.lock.yaml exists on disk at - the time of the no-op. When True (and we're not in - update mode) we append the standard hint pointing at - ``apm update`` -- this is the #1203 nudge that keeps - users from believing ``apm install`` checks for newer - versions. - update_mode: True when this run was invoked with - ``--update`` or via ``apm update``. Suppresses the - hint -- the user already asked to refresh. - """ - if self.partial: - _rich_info("Requested packages are already installed.", symbol="check") - else: - _rich_success("All dependencies are up to date.", symbol="check") - if lockfile_present and not update_mode: - _rich_info("Lockfile already satisfied -- run 'apm update' to resolve latest refs.") - - # --- Download phase --- - - def download_start(self, dep_name: str, cached: bool): - """Log start of a package download.""" - if cached: - self.verbose_detail(f" Using cached: {dep_name}") - elif self.verbose: - _rich_info(f" Downloading: {dep_name}", symbol="download") - - def resolving_heartbeat(self, dep_name: str): - """Emit a per-dependency progress heartbeat during BFS resolve. - - Surfaces an immediate ``[>] Resolving ...`` line so the - user sees the install moving forward instead of staring at - silence while transitive lookups happen behind the scenes - (F1, microsoft/apm#1116). The line is static (not a Rich - transient progress bar) so it survives in CI logs and behind - ``2>&1 | tee`` pipelines, which the duck critique flagged as - the must-survive surface. - - Called from the MAIN thread by the resolver/download callback - BEFORE network work begins; F7's parallel BFS keeps emission - on the main thread so output ordering is deterministic even - when downloads are dispatched to a worker pool. - """ - _rich_info(f"Resolving {dep_name}...", symbol="running") - - def download_complete( - self, - dep_name: str, - ref: str = "", - sha: str = "", - cached: bool = False, - # Legacy compat: if callers pass ref_suffix= we handle it - ref_suffix: str = "", - ): - """Log completion of a package download. - - Args: - dep_name: Package display name (repo_url or virtual path). - ref: Git reference (tag name, branch) if any. - sha: Short commit SHA (8 chars) if any. - cached: Whether this was a cache hit. - ref_suffix: DEPRECATED — legacy callers still pass this. - """ - msg = f" [+] {dep_name}" - if ref_suffix: - # Legacy path — pass-through until all callers are migrated - msg += f" ({ref_suffix})" - else: - if ref and sha: - msg += f" #{ref} @{sha}" - elif ref: - msg += f" #{ref}" - elif sha: - msg += f" @{sha}" - if cached: - msg += " (cached)" - _rich_echo(msg, color="green") - - def download_failed(self, dep_name: str, error: str): - """Log a download failure.""" - _rich_error(f" [x] {dep_name} -- {error}") - - # --- Verbose sub-item methods (install-specific) --- - - def lockfile_entry(self, key: str, ref: str = "", sha: str = ""): - """Log a lockfile entry in verbose mode. - - Omits the line entirely for unpinned deps (no ref, no sha). - """ - if not self.verbose: - return - if sha: - _rich_echo(f" {key}: locked at {sha}", color="dim") - elif ref: - _rich_echo(f" {key}: pinned to {ref}", color="dim") - # Unpinned → omit entirely (nothing useful to show) - - def package_auth(self, source: str, token_type: str = ""): - """Log auth source for a package (verbose only). 4-space indent.""" - if not self.verbose: - return - type_str = f" ({token_type})" if token_type else "" - _rich_echo(f" Auth: {source}{type_str}", color="dim") - - def package_type_info(self, type_label: str): - """Log detected package type (verbose only). 4-space indent.""" - if not self.verbose: - return - _rich_echo(f" Package type: {type_label}", color="dim") - - # --- Performance diagnostics (perf #1433) --- - - def subdir_download_start( - self, - dep_name: str, - cache_state: str, - sha_short: str = "", - sparse_paths: list[str] | None = None, - ): - """Log the start of a subdirectory dep download (verbose only). - - Names the dep, the bare-cache state (e.g. ``cold`` / ``warm`` / - ``persistent`` / ``shared-bare``), the resolved SHA (short), - and the sparse paths being requested. Surfaces enough state to - diagnose a perf regression from one log line. - """ - if not self.verbose: - return - sha_part = f" @{sha_short}" if sha_short else "" - paths_part = f" sparse={','.join(sparse_paths)}" if sparse_paths else " sparse=" - _rich_echo( - f" [i] perf: subdir {dep_name}{sha_part} cache={cache_state}{paths_part}", - color="dim", - ) - - def bare_clone_strategy(self, strategy: str, elapsed_ms: int): - """Log the bare-clone strategy and wall time (verbose only). - - ``strategy`` is the human-readable command shape, e.g. - ``--depth=1 --branch main`` or ``init+fetch --depth=1 ``. - ``elapsed_ms`` lets readers spot a network-bound regression - without re-running with a profiler. - """ - if not self.verbose: - return - _rich_echo( - f" [i] perf: bare clone strategy={strategy} took={elapsed_ms}ms", - color="dim", - ) - - def materialize_result(self, sparse_applied: bool, consumer_size_bytes: int): - """Log materialization outcome and consumer dir size (verbose only). - - ``sparse_applied`` tells the reader whether sparse-cone fired - on this consumer dir (sparse_paths were passed and accepted by - git). ``consumer_size_bytes`` is the on-disk size of the - working tree handed off to the integrator; a regression here - is the leading indicator that sparse-cone silently fell back. - """ - if not self.verbose: - return - size_mb = consumer_size_bytes / (1024 * 1024) - applied = "yes" if sparse_applied else "no" - _rich_echo( - f" [i] perf: materialize sparse={applied} size={size_mb:.2f} MB", - color="dim", - ) - - def tier_summary(self, stats: dict[str, int]): - """Log the tiered ref resolver hit counts (verbose only). - - Emitted at the end of the resolve phase so the reader can see - how many ref->SHA lookups hit each tier (L0 per-run cache, - L1 commits API, L2 bare rev-parse, L3 legacy clone) without - wiring a debugger. A run dominated by L3 is the canonical - signal that ref-resolution is paying full clone cost. - """ - if not self.verbose or not stats: - return - non_zero = {k: v for k, v in stats.items() if v} - if not non_zero: - return - parts = " ".join(f"{k}={v}" for k, v in non_zero.items()) - _rich_echo(f" [i] perf: ref-resolver tiers: {parts}", color="dim") - - # --- Cleanup phase (stale and orphan file removal) --- - - def stale_cleanup(self, dep_key: str, count: int): - """Log per-package stale-file cleanup outcome at default verbosity. - - Stale-file deletion is a destructive operation in the user's - tracked workspace (unlike npm's ``node_modules``); it must be - visible without ``--verbose``. Rendered as an info line so it - groups visually with other phase messages, not as a tree item - (the originating package line was emitted earlier in the install - sequence and is no longer adjacent). - """ - if count <= 0: - return - self._stale_cleaned_total += count - noun = "file" if count == 1 else "files" - _rich_info(f"Cleaned {count} stale {noun} from {dep_key}", symbol="info") - - def orphan_cleanup(self, count: int): - """Log post-install orphan-file cleanup outcome at default verbosity. - - Same visibility rationale as :meth:`stale_cleanup`: file deletion - in the user's workspace must be visible by default. - """ - if count <= 0: - return - self._stale_cleaned_total += count - noun = "file" if count == 1 else "files" - _rich_info( - f"Cleaned {count} {noun} from packages no longer in apm.yml", - symbol="info", - ) - - @property - def stale_cleaned_total(self) -> int: - """Total files removed by stale + orphan cleanup during this install.""" - return self._stale_cleaned_total - - def cleanup_skipped_user_edit(self, rel_path: str, dep_key: str): - """Log a stale-file deletion that was skipped because the user - edited the file after APM deployed it. - - Yellow inline at default verbosity -- the user needs to know APM - kept the file and a manual decision is pending. - """ - _rich_warning( - f" Kept user-edited file {rel_path} (from {dep_key}); " - "delete manually if no longer needed", - symbol="warning", - ) - - # --- Policy phase --- - - def policy_resolved( - self, - source: str, - cached: bool, - enforcement: str, - age_seconds: int | None = None, - ): - """Log policy discovery outcome. - - Verbose by default; always shown when ``enforcement == "block"`` - (users must know blocking is active). - - Format: ``[i] Policy: (cached, fetched 5m ago) -- enforcement=block`` - """ - parts = [f"Policy: {source}"] - - if cached: - cache_detail = "cached" - if age_seconds is not None: - if age_seconds < 60: - cache_detail += f", fetched {age_seconds}s ago" - else: - minutes = age_seconds // 60 - unit = "m" if minutes < 60 else "h" - value = minutes if minutes < 60 else minutes // 60 - cache_detail += f", fetched {value}{unit} ago" - parts.append(f"({cache_detail})") - parts.append(f"-- enforcement={enforcement}") - - message = " ".join(parts) - - if enforcement == "block": - # Always visible — blocking installs is a big deal - _rich_warning(message, symbol="warning") - elif self.verbose: - _rich_info(message, symbol="info") - # Non-verbose + non-block: silent (no noise for warn/off) - - def policy_discovery_miss( - self, - outcome: str, - source: str = "", - error: str | None = None, - host_org: str | None = None, - ): - """Log a policy-discovery non-success outcome. - - Single canonical helper that routes all 7 non-found / non-disabled - outcomes through one wording table. Replaces the per-call-site - ``_rich_info`` / ``_rich_warning`` invocations in ``policy_gate`` - and ``install_preflight`` (Logging C1 / C2, UX F1 / F2 / F4 / F5). - - Args: - outcome: One of ``"absent"``, ``"no_git_remote"``, ``"empty"``, - ``"malformed"``, ``"cache_miss_fetch_fail"``, - ``"garbage_response"``, ``"cached_stale"``. - source: Policy source string (e.g. ``"org:acme/.github"``). - error: Optional error string (used for malformed, - cache_miss_fetch_fail, garbage_response, cached_stale). - host_org: Optional org slug for ``absent`` outcome (verbose - hint). Auto-derived from ``source`` when not provided. - """ - err_text = error or "unknown" - - if outcome == "absent": - # Verbose-only: the vast majority of users have no org policy - # and don't need to see a line for it on every install (UX F1). - if not self.verbose: - return - org = host_org or _strip_source_prefix(source) or "this project" - _rich_info(f"No org policy found for {org}", symbol="info") - return - - if outcome == "no_git_remote": - # UX F2: this is a normal state for fresh `git init`, unpacked - # bundles, or temp dirs -- info, not a warning. Verbose-gated - # for the same reason as ``absent`` (#832): the vast majority - # of users have no org policy configured and don't need to - # see a line for it on every install (fresh checkouts, CI - # environments, unpacked tarballs). - if not self.verbose: - return - _rich_info( - "Could not determine org from git remote; policy auto-discovery skipped", - symbol="info", - ) - return - - if outcome == "empty": - src = source or "this project" - _rich_warning( - f"Org policy at {src} is present but empty; no enforcement applied", - symbol="warning", - ) - return - - if outcome == "malformed": - _rich_warning( - f"Policy at {source} is malformed: {err_text}. " - "Contact your org admin to fix the policy file.", - symbol="warning", - ) - return - - if outcome == "cache_miss_fetch_fail": - # UX F5: explicit posture -- enforcement skipped. - _rich_warning( - f"Could not fetch org policy from {source} ({err_text}); " - "proceeding without policy enforcement. " - "Retry, check connectivity, or use --no-policy to bypass.", - symbol="warning", - ) - return - - if outcome == "garbage_response": - # UX F4: server IS reachable; "check VPN/firewall" is wrong - # advice. Point at the org admin instead. - _rich_warning( - f"Policy response from {source} is not valid YAML " - f"({err_text}); proceeding without policy enforcement. " - "Contact your org admin or use --no-policy.", - symbol="warning", - ) - return - - if outcome == "cached_stale": - # UX F5: explicit posture -- enforcement still applies. - _rich_warning( - f"Using stale cached policy (refresh failed: {err_text}); " - "enforcement still applies from cached policy.", - symbol="warning", - ) - return - - if outcome == "hash_mismatch": - # #827: always-error posture -- pinned policy.hash does not - # match fetched bytes. Show both expected and actual via the - # error message so the admin can compare without re-fetching. - _rich_error( - f"Policy hash mismatch: pinned hash does not match fetched " - f"policy ({err_text}). Update apm.yml policy.hash or " - "contact your org admin.", - symbol="error", - ) - return - - # Defensive: unknown outcome -- emit a conservative warning - if error: - _rich_warning( - f"Policy discovery issue: {err_text}", - symbol="warning", - ) - - def policy_violation( - self, - dep_ref: str, - reason: str, - severity: str, - source: str | None = None, - ): - """Record a policy violation for a dependency. - - Pushes to ``DiagnosticCollector`` under ``CATEGORY_POLICY`` for - the end-of-install summary. When ``severity == "block"``, also - prints an inline error so the user sees the failure immediately - (before the summary), followed by a dim secondary line with the - actionable next-step (CLI logging C3). - - Args: - dep_ref: Dependency reference (e.g. ``"acme/evil-pkg"``). - reason: Actionable reason text per rubber-duck I9. - severity: ``"block"`` or ``"warn"``. - source: Optional policy source (used for block-mode next-step - hint). When provided, a dim secondary line with - remediation guidance is rendered under the inline error. - """ - - # F9 dedupe: some callers pass reason with a "{dep_ref}: " prefix - # (the detail strings produced by policy_checks.py do this). - # Strip it defensively so the inline error reads cleanly. - prefix = f"{dep_ref}: " - if reason.startswith(prefix): - reason = reason[len(prefix) :] - - self.diagnostics.policy( - message=reason, - package=dep_ref, - severity=severity, - ) - - if severity == "block": - _rich_error(f"Policy violation: {dep_ref} -- {reason}", symbol="error") - if source: - _rich_echo( - f" {self._policy_reason_blocked(dep_ref, source)}", - color="dim", - ) - - def policy_disabled(self, reason: str): - """Log a loud warning that policy enforcement is disabled. - - Emitted when ``--no-policy`` or ``APM_POLICY_DISABLE=1`` is - active. Always visible (never silenceable) -- matches the - ``--allow-insecure`` pattern. - """ - _rich_warning( - f"Policy enforcement disabled by {reason} for this invocation. " - "This does NOT bypass apm audit --ci. " - "CI will still fail the PR for the same policy violation.", - symbol="warning", - ) - - # --- Policy violation reason helpers --- - - @staticmethod - def _policy_reason_auth(source: str) -> str: - """Actionable reason for auth failure during policy fetch.""" - return ( - f"Could not authenticate to fetch policy from {source} " - "-- check `gh auth status` and `GITHUB_APM_PAT`" - ) - - @staticmethod - def _policy_reason_unreachable(source: str) -> str: - """Actionable reason for unreachable policy source.""" - return ( - f"Policy source {source} is unreachable " - "-- retry, check VPN/firewall, or use `--no-policy` to bypass" - ) - - @staticmethod - def _policy_reason_malformed(source: str) -> str: - """Actionable reason for malformed policy file.""" - return f"Policy at {source} is malformed -- contact your org admin to fix the policy file" - - @staticmethod - def _policy_reason_blocked(dep_ref: str, source: str) -> str: - """Actionable reason for a blocked dependency.""" - return ( - f"Blocked by org policy at {source} " - f"-- remove `{dep_ref}` from apm.yml, contact admin to update policy, " - "or use `--no-policy` for one-off bypass" - ) - - # --- Install summary --- - - def install_summary( - self, - apm_count: int, - mcp_count: int, - lsp_count: int = 0, - errors: int = 0, - stale_cleaned: int = 0, - elapsed_seconds: float | None = None, - ): - """Log final install summary. - - Args: - apm_count: Number of APM dependencies installed. - mcp_count: Number of MCP servers installed. - lsp_count: Number of LSP servers installed. - errors: Number of errors collected during install. - stale_cleaned: Total stale + orphan files removed during - this install. Reported as a parenthetical so existing - callers and assertion patterns continue to work. - elapsed_seconds: Wall-clock duration of the install command. - When provided, appended as `` in {x:.1f}s`` before the - terminating period so the user can see how long the - whole command took (F5, microsoft/apm#1116). - """ - parts = [] - if apm_count > 0: - noun = "dependency" if apm_count == 1 else "dependencies" - parts.append(f"{apm_count} APM {noun}") - if mcp_count > 0: - noun = "server" if mcp_count == 1 else "servers" - parts.append(f"{mcp_count} MCP {noun}") - if lsp_count > 0: - noun = "server" if lsp_count == 1 else "servers" - parts.append(f"{lsp_count} LSP {noun}") - - cleanup_suffix = "" - if stale_cleaned > 0: - file_noun = "file" if stale_cleaned == 1 else "files" - cleanup_suffix = f" ({stale_cleaned} stale {file_noun} cleaned)" - - timing_suffix = "" - if elapsed_seconds is not None: - timing_suffix = f" in {elapsed_seconds:.1f}s" - - if parts: - summary = " and ".join(parts) - if errors > 0: - _rich_warning( - f"Installed {summary}{cleanup_suffix}{timing_suffix} with {errors} error(s).", - symbol="warning", - ) - else: - _rich_success( - f"Installed {summary}{cleanup_suffix}{timing_suffix}.", - symbol="sparkles", - ) - elif errors > 0: - _rich_error( - f"Installation failed with {errors} error(s){timing_suffix}.", - symbol="error", - ) - else: - _rich_info( - f"No changes -- install state already up to date{timing_suffix}.", - symbol="info", - ) - - def install_interrupted(self, elapsed_seconds: float): - """Log a minimal elapsed-time line when the normal summary did - not render (errors, KeyboardInterrupt, click.UsageError). - - Emitted from the outer ``finally`` in ``commands.install.install`` - so users always see how long the failed/interrupted command ran - (F5, microsoft/apm#1116). Best-effort: callers swallow any - exception so a render failure cannot mask the original error. - """ - _rich_warning( - f"Install interrupted after {elapsed_seconds:.1f}s.", - symbol="warning", - ) + return InstallLogger + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/apm_cli/core/script_runner.py b/src/apm_cli/core/script_runner.py index 3cf115396..f9a5c4606 100644 --- a/src/apm_cli/core/script_runner.py +++ b/src/apm_cli/core/script_runner.py @@ -9,10 +9,12 @@ from ..output.script_formatters import ScriptExecutionFormatter from ..runtime.utils import find_runtime_binary +from ._prompt_compiler import PromptCompiler as PromptCompiler +from ._runtime_commands import _RuntimeCommandsMixin from .token_manager import setup_runtime_environment -class ScriptRunner: +class ScriptRunner(_RuntimeCommandsMixin): """Executes APM scripts with auto-compilation of .prompt.md files.""" def __init__(self, compiler=None, use_color: bool = True): @@ -265,232 +267,6 @@ def _auto_compile_prompts( return compiled_command, compiled_prompt_files, runtime_content - def _transform_runtime_command( - self, command: str, prompt_file: str, compiled_content: str, compiled_path: str - ) -> str: - """Transform runtime commands to their proper execution format. - - Dispatches to per-runtime builders after extracting arguments - around the prompt file reference. - - Args: - command: Original command - prompt_file: Original .prompt.md file path - compiled_content: Compiled prompt content as string - compiled_path: Path to compiled .txt file - - Returns: - Transformed command for proper runtime execution - """ - # Handle environment variables prefix (e.g., "ENV1=val1 ENV2=val2 codex [args] file.prompt.md") - # More robust approach: split by runtime commands to separate env vars from command - runtime_commands = ["codex", "copilot", "llm", "gemini"] - - # Try matching with env-var prefix (e.g. "ENV=val codex args file.prompt.md") - for runtime_cmd in runtime_commands: - runtime_pattern = f" {runtime_cmd} " - if runtime_pattern in command and re.search(re.escape(prompt_file), command): - parts = command.split(runtime_pattern, 1) - potential_env_part = parts[0] - runtime_part = runtime_cmd + " " + parts[1] - - if "=" in potential_env_part and not potential_env_part.startswith(runtime_cmd): - result = self._parse_and_build_runtime_command( - runtime_cmd, - runtime_part, - prompt_file, - env_prefix=potential_env_part, - ) - if result is not None: - return result - - # Try individual runtime patterns without environment variables - for runtime_cmd in runtime_commands: - if re.search(r"^" + runtime_cmd + r"\s+.*" + re.escape(prompt_file), command): - result = self._parse_and_build_runtime_command( - runtime_cmd, - command, - prompt_file, - ) - if result is not None: - return result - - # Handle bare "file.prompt.md" -> "codex exec" (default to codex) - if command.strip() == prompt_file: - return "codex exec" - - # Fallback: just replace file path with compiled path (for non-runtime commands) - return command.replace(prompt_file, compiled_path) - - def _parse_and_build_runtime_command( - self, - runtime_cmd: str, - command_part: str, - prompt_file: str, - env_prefix: str = None, # noqa: RUF013 - ) -> str | None: - """Parse arguments around the prompt file and delegate to a per-runtime builder. - - Args: - runtime_cmd: Runtime name (codex, copilot, llm, or gemini) - command_part: The command portion containing the runtime invocation - prompt_file: The .prompt.md filename to strip - env_prefix: Optional environment variable prefix (e.g. "DEBUG=1") - - Returns: - Transformed command string, or None if the pattern does not match - """ - match = re.search( - f"{runtime_cmd}\\s+(.*?)(" + re.escape(prompt_file) + r")(.*?)$", - command_part, - ) - if not match: - return None - - args_before = match.group(1).strip() - args_after = match.group(3).strip() - - # In the env-var path, non-codex runtimes strip -p flags (matches - # original behaviour where copilot and llm shared an else branch). - if env_prefix is not None and runtime_cmd != "codex": - args_before = args_before.replace("-p", "").strip() - - builders = { - "codex": self._build_codex_command, - "copilot": self._build_copilot_command, - "llm": self._build_llm_command, - "gemini": self._build_gemini_command, - } - builder = builders.get(runtime_cmd) - if builder: - return builder(args_before, args_after, env_prefix) - return None - - def _build_codex_command( - self, - args_before: str, - args_after: str, - env_prefix: str | None = None, - ) -> str: - """Build a codex command from parsed arguments. - - Args: - args_before: Arguments that appeared before the prompt file - args_after: Arguments that appeared after the prompt file - env_prefix: Optional environment variable prefix - - Returns: - Assembled codex command string - """ - prefix = f"{env_prefix} " if env_prefix else "" - result = f"{prefix}codex exec" - if args_before: - result += f" {args_before}" - if args_after: - result += f" {args_after}" - return result - - def _build_copilot_command( - self, - args_before: str, - args_after: str, - env_prefix: str | None = None, - ) -> str: - """Build a copilot command from parsed arguments. - - Removes any existing -p flag since content is passed separately - during execution. - - Args: - args_before: Arguments that appeared before the prompt file - args_after: Arguments that appeared after the prompt file - env_prefix: Optional environment variable prefix - - Returns: - Assembled copilot command string - """ - prefix = f"{env_prefix} " if env_prefix else "" - result = f"{prefix}copilot" - if args_before: - # Remove any existing -p flag since we handle it in execution - cleaned_args = args_before.replace("-p", "").strip() - if cleaned_args: - result += f" {cleaned_args}" - if args_after: - result += f" {args_after}" - return result - - def _build_llm_command( - self, - args_before: str, - args_after: str, - env_prefix: str | None = None, - ) -> str: - """Build an llm command from parsed arguments. - - Args: - args_before: Arguments that appeared before the prompt file - args_after: Arguments that appeared after the prompt file - env_prefix: Optional environment variable prefix - - Returns: - Assembled llm command string - """ - prefix = f"{env_prefix} " if env_prefix else "" - result = f"{prefix}llm" - if args_before: - result += f" {args_before}" - if args_after: - result += f" {args_after}" - return result - - def _build_gemini_command( - self, - args_before: str, - args_after: str, - env_prefix: str | None = None, - ) -> str: - """Build a gemini command from parsed arguments. - - Args: - args_before: Arguments that appeared before the prompt file - args_after: Arguments that appeared after the prompt file - env_prefix: Optional environment variable prefix - - Returns: - Assembled gemini command string - """ - prefix = f"{env_prefix} " if env_prefix else "" - result = f"{prefix}gemini" - if args_before: - cleaned_args = re.sub(r"(^|\s)-p(?=\s|$)", "", args_before).strip() - if cleaned_args: - result += f" {cleaned_args}" - if args_after: - result += f" {args_after}" - return result - - def _detect_runtime(self, command: str) -> str: - """Detect which runtime is being used in the command. - - Args: - command: The command to analyze - - Returns: - Name of the detected runtime (copilot, codex, llm, gemini, or unknown) - """ - command_lower = command.lower().strip() - if re.search(r"(?:^|\s)copilot(?:\s|$)", command_lower): - return "copilot" - elif re.search(r"(?:^|\s)codex(?:\s|$)", command_lower): - return "codex" - elif re.search(r"(?:^|\s)llm(?:\s|$)", command_lower): - return "llm" - elif re.search(r"(?:^|\s)gemini(?:\s|$)", command_lower): - return "gemini" - else: - return "unknown" - def _execute_runtime_command( self, command: str, content: str, env: dict ) -> subprocess.CompletedProcess: @@ -908,228 +684,3 @@ def _create_minimal_config(self) -> None: dump_yaml(minimal_config, "apm.yml") print(f" [i] Created minimal apm.yml for zero-config execution") # noqa: F541 - - def _detect_installed_runtime(self) -> str: - """Detect installed runtime with priority order. - - Priority: copilot > codex > gemini > error - - Returns: - Name of detected runtime - - Raises: - RuntimeError: If no compatible runtime is found - """ - if find_runtime_binary("copilot"): - return "copilot" - elif find_runtime_binary("codex"): - return "codex" - elif find_runtime_binary("gemini"): - return "gemini" - else: - raise RuntimeError( - "No compatible runtime found.\n" - "Install GitHub Copilot CLI with:\n" - " apm runtime setup copilot\n" - "Or install Codex CLI with:\n" - " apm runtime setup codex\n" - "Or install Gemini CLI with:\n" - " apm runtime setup gemini" - ) - - def _generate_runtime_command(self, runtime: str, prompt_file: Path) -> str: - """Generate appropriate runtime command with proper defaults. - - Args: - runtime: Name of runtime (copilot or codex) - prompt_file: Path to the prompt file - - Returns: - Full command string with runtime-specific defaults - """ - if runtime == "copilot": - return ( - f"copilot --log-level all --log-dir copilot-logs --allow-all-tools -p {prompt_file}" - ) - elif runtime == "codex": - return f"codex -s workspace-write --skip-git-repo-check {prompt_file}" - elif runtime == "gemini": - return f"gemini -p {prompt_file}" - else: - raise ValueError(f"Unsupported runtime: {runtime}") - - -class PromptCompiler: - """Compiles .prompt.md files with parameter substitution.""" - - DEFAULT_COMPILED_DIR = Path(".apm/compiled") - - def __init__(self): - """Initialize compiler.""" - self.compiled_dir = self.DEFAULT_COMPILED_DIR - - def compile(self, prompt_file: str, params: dict[str, str]) -> str: - """Compile a .prompt.md file with parameter substitution. - - Args: - prompt_file: Path to the .prompt.md file - params: Parameters to substitute - - Returns: - Path to the compiled file - """ - # Resolve the prompt file path - check local first, then dependencies - prompt_path = self._resolve_prompt_file(prompt_file) - - # Now ensure compiled directory exists - self.compiled_dir.mkdir(parents=True, exist_ok=True) - - with open(prompt_path, encoding="utf-8") as f: - content = f.read() - - # Parse frontmatter and content - if content.startswith("---"): - # Split frontmatter and content - parts = content.split("---", 2) - if len(parts) >= 3: - frontmatter = parts[1].strip() # noqa: F841 - main_content = parts[2].strip() - else: - main_content = content - else: - main_content = content - - # Substitute parameters in content - compiled_content = self._substitute_parameters(main_content, params) - - # Generate output file path - output_name = prompt_path.stem.replace(".prompt", "") + ".txt" - output_path = self.compiled_dir / output_name - - # Write compiled content - with open(output_path, "w", encoding="utf-8") as f: - f.write(compiled_content) - - return str(output_path) - - def _resolve_prompt_file(self, prompt_file: str) -> Path: - """Resolve prompt file path, checking local directory first, then common directories, then dependencies. - - Symlinks are rejected outright to prevent traversal attacks. - - Args: - prompt_file: Relative path to the .prompt.md file - - Returns: - Path: Resolved path to the prompt file - - Raises: - FileNotFoundError: If prompt file is not found or is a symlink - """ - prompt_path = Path(prompt_file) - - # First check if it exists in current directory (local) - if prompt_path.exists(): - if prompt_path.is_symlink(): - raise FileNotFoundError( - f"Prompt file '{prompt_file}' is a symlink. " - f"Symlinks are not allowed for security reasons." - ) - return prompt_path - - # Check in common project directories - common_dirs = [".github/prompts", ".apm/prompts"] - for common_dir in common_dirs: - common_path = Path(common_dir) / prompt_file - if common_path.exists() and not common_path.is_symlink(): - return common_path - - # Search dependencies — scan directory tree once to avoid double walk - apm_modules_dir = Path("apm_modules") - dep_dirs = self._collect_dependency_dirs(apm_modules_dir) - - for _org_name, _repo_name, repo_dir in dep_dirs: - dep_prompt_path = repo_dir / prompt_file - if dep_prompt_path.exists() and not dep_prompt_path.is_symlink(): - return dep_prompt_path - - for subdir in ["prompts", ".", "workflows"]: - sub_prompt_path = repo_dir / subdir / prompt_file - if sub_prompt_path.exists() and not sub_prompt_path.is_symlink(): - return sub_prompt_path - - # Build error using already-collected directories (no second walk) - self._raise_prompt_not_found(prompt_file, prompt_path, dep_dirs) - - def _collect_dependency_dirs(self, apm_modules_dir: Path) -> list: - """Collect (org_name, repo_name, repo_dir) tuples from apm_modules. - - Walks the two-level directory tree once so callers can iterate - without repeated filesystem scans. - - Args: - apm_modules_dir: Path to the apm_modules directory - - Returns: - List of (org_name, repo_name, repo_dir) tuples - """ - if not apm_modules_dir.exists(): - return [] - result = [] - for org_dir in apm_modules_dir.iterdir(): - if org_dir.is_dir() and not org_dir.name.startswith("."): - for repo_dir in org_dir.iterdir(): - if repo_dir.is_dir() and not repo_dir.name.startswith("."): - result.append((org_dir.name, repo_dir.name, repo_dir)) - return result - - def _raise_prompt_not_found( - self, - prompt_file: str, - prompt_path: Path, - dep_dirs: list, - ) -> None: - """Build and raise a helpful FileNotFoundError for a missing prompt. - - Args: - prompt_file: Original prompt file reference - prompt_path: Local Path that was checked - dep_dirs: Pre-collected dependency directory tuples - - Raises: - FileNotFoundError: Always — with a message listing searched locations - """ - searched_locations = [ - f"Local: {prompt_path}", - f"GitHub prompts: .github/prompts/{prompt_file}", - f"APM prompts: .apm/prompts/{prompt_file}", - ] - - if dep_dirs: - searched_locations.append("Dependencies:") - for org_name, repo_name, _repo_dir in dep_dirs: - searched_locations.append(f" - {org_name}/{repo_name}/{prompt_file}") - - raise FileNotFoundError( - f"Prompt file '{prompt_file}' not found.\n" - f"Searched in:\n" - + "\n".join(searched_locations) - + f"\n\nTip: Run 'apm install' to ensure dependencies are installed." # noqa: F541 - ) - - def _substitute_parameters(self, content: str, params: dict[str, str]) -> str: - """Substitute parameters in content. - - Args: - content: Content to process - params: Parameters to substitute - - Returns: - Content with parameters substituted - """ - result = content - for key, value in params.items(): - # Replace ${input:key} placeholders - placeholder = f"${{input:{key}}}" - result = result.replace(placeholder, str(value)) - return result diff --git a/src/apm_cli/core/target_detection.py b/src/apm_cli/core/target_detection.py index 617addfc7..3b5c345d6 100644 --- a/src/apm_cli/core/target_detection.py +++ b/src/apm_cli/core/target_detection.py @@ -103,7 +103,63 @@ def agents_alias_was_detected() -> bool: ] -def detect_target( # noqa: PLR0911 +# Maps an explicit/config target name to its internal canonical TargetType. +# copilot/vscode/agents are aliases that all collapse to "vscode". "minimal" +# is intentionally absent: an explicit/config "minimal" falls through to +# folder auto-detection, matching the historical behaviour of the priority +# chains this table replaced. +_NAME_TO_CANONICAL: dict[str, TargetType] = { + "copilot": "vscode", + "vscode": "vscode", + "agents": "vscode", + "claude": "claude", + "cursor": "cursor", + "opencode": "opencode", + "codex": "codex", + "gemini": "gemini", + "windsurf": "windsurf", + "agent-skills": "agent-skills", + "all": "all", +} + +# Ordered folder probes for Priority-3 auto-detection. Each entry pairs an +# existence test with the label used in the reason string and the canonical +# target a lone match resolves to. ".github"/".claude" use exists() (file or +# dir) for backwards compatibility; the newer integrations require a real dir. +_FOLDER_PROBES: tuple[tuple[str, bool, str, TargetType], ...] = ( + (".github", False, ".github/", "vscode"), + (".claude", False, ".claude/", "claude"), + (".cursor", True, ".cursor/", "cursor"), + (".opencode", True, ".opencode/", "opencode"), + (".codex", True, ".codex/", "codex"), + (".gemini", True, ".gemini/", "gemini"), + (".windsurf", True, ".windsurf/", "windsurf"), +) + + +def _detect_from_folders(project_root: Path) -> tuple[TargetType, str]: + """Resolve a target from the integration folders present under *project_root*. + + Two or more folders -> ``"all"``; exactly one -> that folder's target; + none -> ``"minimal"``. Folder order in :data:`_FOLDER_PROBES` is the + tie-break that never triggers for the single-match case but fixes the + label ordering in the multi-match reason string. + """ + detected = [ + (label, canonical) + for name, require_dir, label, canonical in _FOLDER_PROBES + if ((project_root / name).is_dir() if require_dir else (project_root / name).exists()) + ] + if len(detected) >= 2: + labels = " and ".join(label for label, _ in detected) + return "all", f"detected {labels} folders" + if detected: + label, canonical = detected[0] + return canonical, f"detected {label} folder" + return "minimal", REASON_NO_TARGET_FOLDER + + +def detect_target( project_root: Path, explicit_target: str | None = None, config_target: str | None = None, @@ -120,90 +176,27 @@ def detect_target( # noqa: PLR0911 - target: The detected target type - reason: Human-readable explanation for the choice """ - # Priority 1: Explicit --target flag + # Normalise: callers may pass a list when apm.yml has ``targets: [...]``. + # _NAME_TO_CANONICAL requires a hashable key; extract the first element. + if isinstance(explicit_target, list): + explicit_target = explicit_target[0] if explicit_target else None + if isinstance(config_target, list): + config_target = config_target[0] if config_target else None + + # Priority 1: explicit --target flag (always wins when recognised). if explicit_target: - if explicit_target in ("copilot", "vscode", "agents"): - return "vscode", "explicit --target flag" - elif explicit_target == "claude": - return "claude", "explicit --target flag" - elif explicit_target == "cursor": - return "cursor", "explicit --target flag" - elif explicit_target == "opencode": - return "opencode", "explicit --target flag" - elif explicit_target == "codex": - return "codex", "explicit --target flag" - elif explicit_target == "gemini": - return "gemini", "explicit --target flag" - elif explicit_target == "windsurf": - return "windsurf", "explicit --target flag" - elif explicit_target == "agent-skills": - return "agent-skills", "explicit --target flag" - elif explicit_target == "all": - return "all", "explicit --target flag" - - # Priority 2: apm.yml target setting + canonical = _NAME_TO_CANONICAL.get(explicit_target) + if canonical is not None: + return canonical, "explicit --target flag" + + # Priority 2: apm.yml target setting. if config_target: - if config_target in ("copilot", "vscode", "agents"): - return "vscode", "apm.yml target" - elif config_target == "claude": - return "claude", "apm.yml target" - elif config_target == "cursor": - return "cursor", "apm.yml target" - elif config_target == "opencode": - return "opencode", "apm.yml target" - elif config_target == "codex": - return "codex", "apm.yml target" - elif config_target == "gemini": - return "gemini", "apm.yml target" - elif config_target == "windsurf": - return "windsurf", "apm.yml target" - elif config_target == "agent-skills": - return "agent-skills", "apm.yml target" - elif config_target == "all": - return "all", "apm.yml target" - - # Priority 3: Auto-detect from existing folders - github_exists = (project_root / ".github").exists() - claude_exists = (project_root / ".claude").exists() - cursor_exists = (project_root / ".cursor").is_dir() - opencode_exists = (project_root / ".opencode").is_dir() - codex_exists = (project_root / ".codex").is_dir() - gemini_exists = (project_root / ".gemini").is_dir() - windsurf_exists = (project_root / ".windsurf").is_dir() - detected = [] - if github_exists: - detected.append(".github/") - if claude_exists: - detected.append(".claude/") - if cursor_exists: - detected.append(".cursor/") - if opencode_exists: - detected.append(".opencode/") - if codex_exists: - detected.append(".codex/") - if gemini_exists: - detected.append(".gemini/") - if windsurf_exists: - detected.append(".windsurf/") + canonical = _NAME_TO_CANONICAL.get(config_target) + if canonical is not None: + return canonical, "apm.yml target" - if len(detected) >= 2: - return "all", f"detected {' and '.join(detected)} folders" - elif github_exists: - return "vscode", "detected .github/ folder" - elif claude_exists: - return "claude", "detected .claude/ folder" - elif cursor_exists: - return "cursor", "detected .cursor/ folder" - elif opencode_exists: - return "opencode", "detected .opencode/ folder" - elif codex_exists: - return "codex", "detected .codex/ folder" - elif gemini_exists: - return "gemini", "detected .gemini/ folder" - elif windsurf_exists: - return "windsurf", "detected .windsurf/ folder" - else: - return "minimal", REASON_NO_TARGET_FOLDER + # Priority 3: auto-detect from existing integration folders. + return _detect_from_folders(project_root) def should_compile_agents_md(target: CompileTargetType) -> bool: diff --git a/src/apm_cli/output/_formatters_detail.py b/src/apm_cli/output/_formatters_detail.py new file mode 100644 index 000000000..f4de4df0e --- /dev/null +++ b/src/apm_cli/output/_formatters_detail.py @@ -0,0 +1,465 @@ +"""Heavy detail-rendering mixin for CompilationFormatter. + +Extracted from formatters.py to keep that module under 800 lines. +``CompilationFormatter`` composes this mixin in so all method names +remain importable/patchable at their original paths. + +Rule B: moved methods that check ``RICH_AVAILABLE`` fetch it via a +function-level late import from the parent module so tests patching +``apm_cli.output.formatters.RICH_AVAILABLE`` are correctly intercepted. +``Path`` is not used in any of the moved methods, so no Rule B routing +is needed for that name. +""" + +try: + from rich import box + from rich.panel import Panel + from rich.table import Table + from rich.text import Text +except ImportError: + box = None # type: ignore[assignment] + Panel = None # type: ignore[assignment] + Table = None # type: ignore[assignment] + Text = None # type: ignore[assignment] + + +class _FormattersDetailMixin: + """Heavy detail renderers composed into CompilationFormatter. + + Accesses ``self.use_color``, ``self.console``, and ``self._styled`` + which are defined on ``CompilationFormatter``. + """ + + def _format_mathematical_analysis(self, decisions) -> list: + """Format mathematical analysis for verbose mode with coverage-first principles.""" + # Rule B: fetch RICH_AVAILABLE from the parent module at call time + # so tests patching apm_cli.output.formatters.RICH_AVAILABLE work. + from apm_cli.output import formatters as _f + + RICH_AVAILABLE = _f.RICH_AVAILABLE + lines = [] + + if self.use_color: + lines.append(self._styled("Mathematical Optimization Analysis", "cyan bold")) + else: + lines.append("Mathematical Optimization Analysis") + + lines.append("") + + if self.use_color and RICH_AVAILABLE: + # Coverage-First Strategy Table + strategy_table = Table( + title="Three-Tier Coverage-First Strategy", + show_header=True, + header_style="bold cyan", + box=box.SIMPLE_HEAD, + ) + strategy_table.add_column("Pattern", style="white", width=25) + strategy_table.add_column("Source", style="yellow", width=15) + strategy_table.add_column("Distribution", style="yellow", width=12) + strategy_table.add_column("Strategy", style="green", width=15) + strategy_table.add_column("Coverage Guarantee", style="blue", width=20) + + for decision in decisions: + pattern = decision.pattern if decision.pattern else "(global)" + + # Extract source information + source_display = "unknown" + if decision.instruction and hasattr(decision.instruction, "file_path"): + try: + source_display = decision.instruction.file_path.name + except Exception: + source_display = "unknown" + + # Distribution score with threshold classification + score = decision.distribution_score + if score < 0.3: + dist_display = f"{score:.3f} (Low)" + strategy_name = "Single Point" + coverage_status = "[+] Perfect" + elif score > 0.7: + dist_display = f"{score:.3f} (High)" + strategy_name = "Distributed" + coverage_status = "[+] Universal" + else: + dist_display = f"{score:.3f} (Medium)" + strategy_name = "Selective Multi" + # Check if root placement was used (indicates coverage fallback) + if any(str(p) == "." or p.name == "" for p in decision.placement_directories): + coverage_status = "[!] Root Fallback" + else: + coverage_status = "[+] Verified" + + strategy_table.add_row( + pattern, source_display, dist_display, strategy_name, coverage_status + ) + + # Render strategy table + if self.console: + with self.console.capture() as capture: + self.console.print(strategy_table) + table_output = capture.get() + if table_output.strip(): + lines.extend(table_output.split("\n")) + + lines.append("") + + # Hierarchical Coverage Analysis Table + coverage_table = Table( + title="Hierarchical Coverage Analysis", + show_header=True, + header_style="bold cyan", + box=box.SIMPLE_HEAD, + ) + coverage_table.add_column("Pattern", style="white", width=25) + coverage_table.add_column("Matching Files", style="yellow", width=15) + coverage_table.add_column("Placement", style="green", width=20) + coverage_table.add_column("Coverage Result", style="blue", width=25) + + for decision in decisions: + pattern = decision.pattern if decision.pattern else "(global)" + matching_files = f"{decision.matching_directories} dirs" + + if len(decision.placement_directories) == 1: + placement = self._get_relative_display_path(decision.placement_directories[0]) + + # Analyze coverage outcome + if str(decision.placement_directories[0]).endswith("."): + coverage_result = "Root -> All files inherit" + elif decision.distribution_score < 0.3: + coverage_result = "Local -> Perfect efficiency" + else: + coverage_result = "Selective -> Coverage verified" + else: + placement = f"{len(decision.placement_directories)} locations" + coverage_result = "Multi-point -> Full coverage" + + coverage_table.add_row(pattern, matching_files, placement, coverage_result) + + # Render coverage table + if self.console: + with self.console.capture() as capture: + self.console.print(coverage_table) + table_output = capture.get() + if table_output.strip(): + lines.extend(table_output.split("\n")) + + lines.append("") + + # Updated Mathematical Foundation Panel + foundation_text = """Objective: minimize sum(context_pollution x directory_weight) +Constraints: for_allfile_matching_pattern -> can_inherit_instruction +Variables: placement_matrix in {0,1} +Algorithm: Three-tier strategy with hierarchical coverage verification + +Coverage Guarantee: Every file can access applicable instructions through +hierarchical inheritance. Coverage takes priority over efficiency.""" + + if self.console: + from rich.panel import Panel as _Panel + + try: + panel = _Panel( + foundation_text, + title="Coverage-Constrained Optimization", + border_style="cyan", + ) + with self.console.capture() as capture: + self.console.print(panel) + panel_output = capture.get() + if panel_output.strip(): + lines.extend(panel_output.split("\n")) + except Exception: + # Fallback to simple text + lines.append("Coverage-Constrained Optimization:") + for line in foundation_text.split("\n"): + lines.append(f" {line}") + + else: + # Fallback for non-Rich environments + lines.append("Coverage-First Strategy Analysis:") + for decision in decisions: + pattern = decision.pattern if decision.pattern else "(global)" + score = f"{decision.distribution_score:.3f}" + strategy = decision.strategy.value + coverage = ( + "[+] Verified" if decision.distribution_score < 0.7 else "[!] Root Fallback" + ) + lines.append(f" {pattern:<30} {score:<8} {strategy:<15} {coverage}") + + lines.append("") + lines.append("Mathematical Foundation:") + lines.append(" Objective: minimize sum(context_pollution x directory_weight)") + lines.append(" Constraints: for_allfile_matching_pattern -> can_inherit_instruction") + lines.append(" Algorithm: Three-tier strategy with coverage verification") + lines.append(" Principle: Coverage guarantee takes priority over efficiency") + + return lines + + def _format_detailed_metrics(self, stats) -> list: + """Format detailed performance metrics table with interpretations.""" + # Rule B: fetch RICH_AVAILABLE from the parent module at call time. + from apm_cli.output import formatters as _f + + RICH_AVAILABLE = _f.RICH_AVAILABLE + lines = [] + + if self.use_color: + lines.append(self._styled("Performance Metrics", "cyan bold")) + else: + lines.append("Performance Metrics") + + # Create metrics table + if self.use_color and RICH_AVAILABLE: + table = Table(box=box.SIMPLE) + table.add_column("Metric", style="white", width=20) + table.add_column("Value", style="white", width=12) + table.add_column("Assessment", style="blue", width=35) + + # Context Efficiency with coverage-first interpretation + efficiency = stats.efficiency_percentage + if efficiency >= 80: + assessment = "Excellent - perfect pattern locality" + assessment_color = "bright_green" + value_color = "bright_green" + elif efficiency >= 60: + assessment = "Good - well-optimized with minimal coverage conflicts" + assessment_color = "green" + value_color = "green" + elif efficiency >= 40: + assessment = "Fair - moderate coverage-driven pollution" + assessment_color = "yellow" + value_color = "yellow" + elif efficiency >= 20: + assessment = "Poor - significant coverage constraints" + assessment_color = "orange1" + value_color = "orange1" + else: + assessment = "Very Poor - may be mathematically optimal given coverage" + assessment_color = "red" + value_color = "red" + + table.add_row( + "Context Efficiency", + Text(f"{efficiency:.1f}%", style=value_color), + Text(assessment, style=assessment_color), + ) + + # Calculate pollution level with coverage-aware interpretation + pollution_level = 100 - efficiency + if pollution_level <= 20: + pollution_assessment = "Excellent - perfect pattern locality" + pollution_color = "bright_green" + elif pollution_level <= 40: + pollution_assessment = "Good - minimal coverage conflicts" + pollution_color = "green" + elif pollution_level <= 60: + pollution_assessment = "Fair - acceptable coverage-driven pollution" + pollution_color = "yellow" + elif pollution_level <= 80: + pollution_assessment = "Poor - high coverage constraints" + pollution_color = "orange1" + else: + pollution_assessment = "Very Poor - but may guarantee coverage" + pollution_color = "red" + + table.add_row( + "Pollution Level", + Text(f"{pollution_level:.1f}%", style=pollution_color), + Text(pollution_assessment, style=pollution_color), + ) + + if stats.placement_accuracy: + accuracy = stats.placement_accuracy * 100 + if accuracy >= 95: + accuracy_assessment = "Excellent - mathematically optimal" + accuracy_color = "bright_green" + elif accuracy >= 85: + accuracy_assessment = "Good - near optimal" + accuracy_color = "green" + elif accuracy >= 70: + accuracy_assessment = "Fair - reasonably placed" + accuracy_color = "yellow" + else: + accuracy_assessment = "Poor - suboptimal placement" + accuracy_color = "orange1" + + table.add_row( + "Placement Accuracy", + Text(f"{accuracy:.1f}%", style=accuracy_color), + Text(accuracy_assessment, style=accuracy_color), + ) + + # Render table + if self.console: + with self.console.capture() as capture: + self.console.print(table) + table_output = capture.get() + if table_output.strip(): + lines.extend(table_output.split("\n")) + + lines.append("") + + # Add interpretation guide + if self.console: + try: + interpretation_text = """How These Metrics Are Calculated + +Context Efficiency = Average across all directories of (Relevant Instructions / Total Instructions) +* For each directory, APM analyzes what instructions agents would inherit from AGENTS.md files +* Calculates ratio of instructions that apply to files in that directory vs total instructions loaded +* Takes weighted average across all project directories with files + +Pollution Level = 100% - Context Efficiency (inverse relationship) +* High pollution = agents load many irrelevant instructions when working in specific directories +* Low pollution = agents see mostly relevant instructions for their current context + +Interpretation Benchmarks + +Context Efficiency: +* 80-100%: Excellent - Instructions perfectly targeted to usage context +* 60-80%: Good - Well-optimized with minimal wasted context +* 40-60%: Fair - Some optimization opportunities exist +* 20-40%: Poor - Significant context pollution, consider restructuring +* 0-20%: Very Poor - High pollution, instructions poorly distributed + +Pollution Level: +* 0-10%: Excellent - Agents see highly relevant instructions only +* 10-25%: Good - Low noise, mostly relevant context +* 25-50%: Fair - Moderate noise, some irrelevant instructions +* 50%+: Poor - High noise, agents see many irrelevant instructions + +Example: 36.7% efficiency means agents working in specific directories see only 36.7% relevant instructions and 63.3% irrelevant context pollution.""" + + panel = Panel( + interpretation_text, + title="Metrics Guide", + border_style="dim", + title_align="left", + ) + with self.console.capture() as capture: + self.console.print(panel) + panel_output = capture.get() + if panel_output.strip(): + lines.extend(panel_output.split("\n")) + except Exception: + # Fallback to simple text + lines.extend( + [ + "Metrics Guide:", + "* Context Efficiency 80-100%: Excellent | 60-80%: Good | 40-60%: Fair | <40%: Poor", + "* Pollution 0-10%: Excellent | 10-25%: Good | 25-50%: Fair | >50%: Poor", + ] + ) + else: + # Fallback for non-Rich environments + efficiency = stats.efficiency_percentage + pollution = 100 - efficiency + + if efficiency >= 80: + efficiency_assessment = "Excellent" + elif efficiency >= 60: + efficiency_assessment = "Good" + elif efficiency >= 40: + efficiency_assessment = "Fair" + elif efficiency >= 20: + efficiency_assessment = "Poor" + else: + efficiency_assessment = "Very Poor" + + if pollution <= 10: + pollution_assessment = "Excellent" + elif pollution <= 25: + pollution_assessment = "Good" + elif pollution <= 50: + pollution_assessment = "Fair" + else: + pollution_assessment = "Poor" + + lines.extend( + [ + f"Context Efficiency: {efficiency:.1f}% ({efficiency_assessment})", + f"Pollution Level: {pollution:.1f}% ({pollution_assessment})", + "Guide: 80-100% Excellent | 60-80% Good | 40-60% Fair | 20-40% Poor | <20% Very Poor", + ] + ) + + return lines + + def _format_coverage_explanation(self, stats) -> list: + """Explain the coverage vs. efficiency trade-off.""" + lines = [] + + if self.use_color: + lines.append(self._styled("Coverage vs. Efficiency Analysis", "cyan bold")) + else: + lines.append("Coverage vs. Efficiency Analysis") + + lines.append("") + + efficiency = stats.efficiency_percentage + + if efficiency < 30: + lines.append("[!] Low Efficiency Detected:") + lines.append(" * Coverage guarantee requires some instructions at root level") + lines.append(" * This creates pollution for specialized directories") + lines.append(" * Trade-off: Guaranteed coverage vs. optimal efficiency") + lines.append(" * Alternative: Higher efficiency with coverage violations (data loss)") + lines.append("") + lines.append("This may be mathematically optimal given coverage constraints") + elif efficiency < 60: + lines.append("[+] Moderate Efficiency:") + lines.append(" * Good balance between coverage and efficiency") + lines.append(" * Some coverage-driven pollution is acceptable") + lines.append(" * Most patterns are well-localized") + else: + lines.append("High Efficiency:") + lines.append(" * Excellent pattern locality achieved") + lines.append(" * Minimal coverage conflicts") + lines.append(" * Instructions are optimally placed") + + lines.append("") + lines.append("Why Coverage Takes Priority:") + lines.append(" * Every file must access applicable instructions") + lines.append(" * Hierarchical inheritance prevents data loss") + lines.append(" * Better low efficiency than missing instructions") + + return lines + + def _format_issues(self, warnings: list, errors: list) -> list: + """Format warnings and errors as professional blocks.""" + lines = [] + + # Errors first + for error in errors: + if self.use_color: + lines.append(self._styled(f"x Error: {error}", "red")) + else: + lines.append(f"x Error: {error}") + + # Then warnings - handle multi-line warnings as cohesive blocks + for warning in warnings: + if "\n" in warning: + # Multi-line warning - format as a professional block + warning_lines = warning.split("\n") + # First line gets the warning symbol and styling + if self.use_color: + lines.append(self._styled(f"[!] Warning: {warning_lines[0]}", "yellow")) + else: + lines.append(f"[!] Warning: {warning_lines[0]}") + + # Subsequent lines are indented and styled consistently + for line in warning_lines[1:]: + if line.strip(): # Skip empty lines + if self.use_color: + lines.append(self._styled(f" {line}", "yellow")) + else: + lines.append(f" {line}") + else: # noqa: PLR5501 + # Single-line warning - standard format + if self.use_color: + lines.append(self._styled(f"[!] Warning: {warning}", "yellow")) + else: + lines.append(f"[!] Warning: {warning}") + + return lines diff --git a/src/apm_cli/output/formatters.py b/src/apm_cli/output/formatters.py index 4f6f2428d..e56809ca4 100644 --- a/src/apm_cli/output/formatters.py +++ b/src/apm_cli/output/formatters.py @@ -5,7 +5,6 @@ try: from rich import box from rich.console import Console - from rich.panel import Panel from rich.table import Table from rich.text import Text @@ -13,10 +12,11 @@ except ImportError: RICH_AVAILABLE = False +from ._formatters_detail import _FormattersDetailMixin from .models import CompilationResults, OptimizationDecision, PlacementStrategy -class CompilationFormatter: +class CompilationFormatter(_FormattersDetailMixin): """Professional formatter for compilation output with fallback for no-rich environments.""" def __init__(self, use_color: bool = True): @@ -100,7 +100,7 @@ def format_verbose(self, results: CompilationResults) -> str: lines.append("") # Phase 6: Final Summary (Generated X files + placement distribution) - lines.extend(self._format_final_summary(results)) + lines.extend(self._format_results_summary(results)) # Issues (warnings/errors) if results.has_issues: @@ -109,91 +109,6 @@ def format_verbose(self, results: CompilationResults) -> str: return "\n".join(lines) - def _format_final_summary(self, results: CompilationResults) -> list[str]: - """Format final summary for verbose mode: Generated files + placement distribution.""" - lines = [] - - # Main result - file_count = len(results.placement_summaries) - target = results.target_name - summary_line = f"Generated {file_count} {target} file{'s' if file_count != 1 else ''}" - - if results.is_dry_run: - summary_line = f"[DRY RUN] Would generate {file_count} {target} file{'s' if file_count != 1 else ''}" - - if self.use_color: - color = "yellow" if results.is_dry_run else "green" - lines.append(self._styled(summary_line, f"{color} bold")) - else: - lines.append(summary_line) - - # Efficiency metrics with improved formatting - stats = results.optimization_stats - efficiency_pct = f"{stats.efficiency_percentage:.1f}%" - - # Build metrics with baselines and improvements when available - metrics_lines = [f"+- Context efficiency: {efficiency_pct}"] - - if stats.efficiency_improvement is not None: - improvement = ( - f"(baseline: {stats.baseline_efficiency * 100:.1f}%, improvement: +{stats.efficiency_improvement:.0f}%)" - if stats.efficiency_improvement > 0 - else f"(baseline: {stats.baseline_efficiency * 100:.1f}%, change: {stats.efficiency_improvement:.0f}%)" - ) - metrics_lines[0] += f" {improvement}" - - if stats.pollution_improvement is not None: - pollution_pct = f"{(1.0 - stats.pollution_improvement) * 100:.1f}%" - improvement_pct = ( - f"-{stats.pollution_improvement * 100:.0f}%" - if stats.pollution_improvement > 0 - else f"+{abs(stats.pollution_improvement) * 100:.0f}%" - ) - metrics_lines.append( - f"|- Average pollution: {pollution_pct} (improvement: {improvement_pct})" - ) - - if stats.placement_accuracy is not None: - accuracy_pct = f"{stats.placement_accuracy * 100:.1f}%" - metrics_lines.append(f"|- Placement accuracy: {accuracy_pct} (mathematical optimum)") - - if stats.generation_time_ms is not None: - metrics_lines.append(f"+- Generation time: {stats.generation_time_ms}ms") - else: # noqa: PLR5501 - # Change last |- to +- - if len(metrics_lines) > 1: - metrics_lines[-1] = metrics_lines[-1].replace("|-", "+-") - - for line in metrics_lines: - if self.use_color: - lines.append(self._styled(line, "dim")) - else: - lines.append(line) - - # Add placement distribution summary - lines.append("") - if self.use_color: - lines.append(self._styled("Placement Distribution", "cyan bold")) - else: - lines.append("Placement Distribution") - - # Show distribution of files - for summary in results.placement_summaries: - rel_path = str(summary.get_relative_path(Path.cwd())) - content_text = self._get_placement_description(summary) - source_text = f"{summary.source_count} source{'s' if summary.source_count != 1 else ''}" - - # Use proper tree formatting - prefix = "|-" if summary != results.placement_summaries[-1] else "+-" - line = f"{prefix} {rel_path:<30} {content_text} from {source_text}" - - if self.use_color: - lines.append(self._styled(line, "dim")) - else: - lines.append(line) - - return lines - def format_dry_run(self, results: CompilationResults) -> str: """Format dry run output. @@ -503,391 +418,6 @@ def _format_dry_run_summary(self, results: CompilationResults) -> list[str]: return lines - def _format_mathematical_analysis(self, decisions: list[OptimizationDecision]) -> list[str]: - """Format mathematical analysis for verbose mode with coverage-first principles.""" - lines = [] - - if self.use_color: - lines.append(self._styled("Mathematical Optimization Analysis", "cyan bold")) - else: - lines.append("Mathematical Optimization Analysis") - - lines.append("") - - if self.use_color and RICH_AVAILABLE: - # Coverage-First Strategy Table - strategy_table = Table( - title="Three-Tier Coverage-First Strategy", - show_header=True, - header_style="bold cyan", - box=box.SIMPLE_HEAD, - ) - strategy_table.add_column("Pattern", style="white", width=25) - strategy_table.add_column("Source", style="yellow", width=15) - strategy_table.add_column("Distribution", style="yellow", width=12) - strategy_table.add_column("Strategy", style="green", width=15) - strategy_table.add_column("Coverage Guarantee", style="blue", width=20) - - for decision in decisions: - pattern = decision.pattern if decision.pattern else "(global)" - - # Extract source information - source_display = "unknown" - if decision.instruction and hasattr(decision.instruction, "file_path"): - try: - source_display = decision.instruction.file_path.name - except Exception: - source_display = "unknown" - - # Distribution score with threshold classification - score = decision.distribution_score - if score < 0.3: - dist_display = f"{score:.3f} (Low)" - strategy_name = "Single Point" - coverage_status = "[+] Perfect" - elif score > 0.7: - dist_display = f"{score:.3f} (High)" - strategy_name = "Distributed" - coverage_status = "[+] Universal" - else: - dist_display = f"{score:.3f} (Medium)" - strategy_name = "Selective Multi" - # Check if root placement was used (indicates coverage fallback) - if any(str(p) == "." or p.name == "" for p in decision.placement_directories): - coverage_status = "[!] Root Fallback" - else: - coverage_status = "[+] Verified" - - strategy_table.add_row( - pattern, source_display, dist_display, strategy_name, coverage_status - ) - - # Render strategy table - if self.console: - with self.console.capture() as capture: - self.console.print(strategy_table) - table_output = capture.get() - if table_output.strip(): - lines.extend(table_output.split("\n")) - - lines.append("") - - # Hierarchical Coverage Analysis Table - coverage_table = Table( - title="Hierarchical Coverage Analysis", - show_header=True, - header_style="bold cyan", - box=box.SIMPLE_HEAD, - ) - coverage_table.add_column("Pattern", style="white", width=25) - coverage_table.add_column("Matching Files", style="yellow", width=15) - coverage_table.add_column("Placement", style="green", width=20) - coverage_table.add_column("Coverage Result", style="blue", width=25) - - for decision in decisions: - pattern = decision.pattern if decision.pattern else "(global)" - matching_files = f"{decision.matching_directories} dirs" - - if len(decision.placement_directories) == 1: - placement = self._get_relative_display_path(decision.placement_directories[0]) - - # Analyze coverage outcome - if str(decision.placement_directories[0]).endswith("."): - coverage_result = "Root -> All files inherit" - elif decision.distribution_score < 0.3: - coverage_result = "Local -> Perfect efficiency" - else: - coverage_result = "Selective -> Coverage verified" - else: - placement = f"{len(decision.placement_directories)} locations" - coverage_result = "Multi-point -> Full coverage" - - coverage_table.add_row(pattern, matching_files, placement, coverage_result) - - # Render coverage table - if self.console: - with self.console.capture() as capture: - self.console.print(coverage_table) - table_output = capture.get() - if table_output.strip(): - lines.extend(table_output.split("\n")) - - lines.append("") - - # Updated Mathematical Foundation Panel - foundation_text = """Objective: minimize sum(context_pollution x directory_weight) -Constraints: for_allfile_matching_pattern -> can_inherit_instruction -Variables: placement_matrix in {0,1} -Algorithm: Three-tier strategy with hierarchical coverage verification - -Coverage Guarantee: Every file can access applicable instructions through -hierarchical inheritance. Coverage takes priority over efficiency.""" - - if self.console: - from rich.panel import Panel - - try: - panel = Panel( - foundation_text, - title="Coverage-Constrained Optimization", - border_style="cyan", - ) - with self.console.capture() as capture: - self.console.print(panel) - panel_output = capture.get() - if panel_output.strip(): - lines.extend(panel_output.split("\n")) - except Exception: - # Fallback to simple text - lines.append("Coverage-Constrained Optimization:") - for line in foundation_text.split("\n"): - lines.append(f" {line}") - - else: - # Fallback for non-Rich environments - lines.append("Coverage-First Strategy Analysis:") - for decision in decisions: - pattern = decision.pattern if decision.pattern else "(global)" - score = f"{decision.distribution_score:.3f}" - strategy = decision.strategy.value - coverage = ( - "[+] Verified" if decision.distribution_score < 0.7 else "[!] Root Fallback" - ) - lines.append(f" {pattern:<30} {score:<8} {strategy:<15} {coverage}") - - lines.append("") - lines.append("Mathematical Foundation:") - lines.append(" Objective: minimize sum(context_pollution x directory_weight)") - lines.append(" Constraints: for_allfile_matching_pattern -> can_inherit_instruction") - lines.append(" Algorithm: Three-tier strategy with coverage verification") - lines.append(" Principle: Coverage guarantee takes priority over efficiency") - - return lines - - def _format_detailed_metrics(self, stats) -> list[str]: - """Format detailed performance metrics table with interpretations.""" - lines = [] - - if self.use_color: - lines.append(self._styled("Performance Metrics", "cyan bold")) - else: - lines.append("Performance Metrics") - - # Create metrics table - if self.use_color and RICH_AVAILABLE: - table = Table(box=box.SIMPLE) - table.add_column("Metric", style="white", width=20) - table.add_column("Value", style="white", width=12) - table.add_column("Assessment", style="blue", width=35) - - # Context Efficiency with coverage-first interpretation - efficiency = stats.efficiency_percentage - if efficiency >= 80: - assessment = "Excellent - perfect pattern locality" - assessment_color = "bright_green" - value_color = "bright_green" - elif efficiency >= 60: - assessment = "Good - well-optimized with minimal coverage conflicts" - assessment_color = "green" - value_color = "green" - elif efficiency >= 40: - assessment = "Fair - moderate coverage-driven pollution" - assessment_color = "yellow" - value_color = "yellow" - elif efficiency >= 20: - assessment = "Poor - significant coverage constraints" - assessment_color = "orange1" - value_color = "orange1" - else: - assessment = "Very Poor - may be mathematically optimal given coverage" - assessment_color = "red" - value_color = "red" - - table.add_row( - "Context Efficiency", - Text(f"{efficiency:.1f}%", style=value_color), - Text(assessment, style=assessment_color), - ) - - # Calculate pollution level with coverage-aware interpretation - pollution_level = 100 - efficiency - if pollution_level <= 20: - pollution_assessment = "Excellent - perfect pattern locality" - pollution_color = "bright_green" - elif pollution_level <= 40: - pollution_assessment = "Good - minimal coverage conflicts" - pollution_color = "green" - elif pollution_level <= 60: - pollution_assessment = "Fair - acceptable coverage-driven pollution" - pollution_color = "yellow" - elif pollution_level <= 80: - pollution_assessment = "Poor - high coverage constraints" - pollution_color = "orange1" - else: - pollution_assessment = "Very Poor - but may guarantee coverage" - pollution_color = "red" - - table.add_row( - "Pollution Level", - Text(f"{pollution_level:.1f}%", style=pollution_color), - Text(pollution_assessment, style=pollution_color), - ) - - if stats.placement_accuracy: - accuracy = stats.placement_accuracy * 100 - if accuracy >= 95: - accuracy_assessment = "Excellent - mathematically optimal" - accuracy_color = "bright_green" - elif accuracy >= 85: - accuracy_assessment = "Good - near optimal" - accuracy_color = "green" - elif accuracy >= 70: - accuracy_assessment = "Fair - reasonably placed" - accuracy_color = "yellow" - else: - accuracy_assessment = "Poor - suboptimal placement" - accuracy_color = "orange1" - - table.add_row( - "Placement Accuracy", - Text(f"{accuracy:.1f}%", style=accuracy_color), - Text(accuracy_assessment, style=accuracy_color), - ) - - # Render table - if self.console: - with self.console.capture() as capture: - self.console.print(table) - table_output = capture.get() - if table_output.strip(): - lines.extend(table_output.split("\n")) - - lines.append("") - - # Add interpretation guide - if self.console: - try: - interpretation_text = """How These Metrics Are Calculated - -Context Efficiency = Average across all directories of (Relevant Instructions / Total Instructions) -* For each directory, APM analyzes what instructions agents would inherit from AGENTS.md files -* Calculates ratio of instructions that apply to files in that directory vs total instructions loaded -* Takes weighted average across all project directories with files - -Pollution Level = 100% - Context Efficiency (inverse relationship) -* High pollution = agents load many irrelevant instructions when working in specific directories -* Low pollution = agents see mostly relevant instructions for their current context - -Interpretation Benchmarks - -Context Efficiency: -* 80-100%: Excellent - Instructions perfectly targeted to usage context -* 60-80%: Good - Well-optimized with minimal wasted context -* 40-60%: Fair - Some optimization opportunities exist -* 20-40%: Poor - Significant context pollution, consider restructuring -* 0-20%: Very Poor - High pollution, instructions poorly distributed - -Pollution Level: -* 0-10%: Excellent - Agents see highly relevant instructions only -* 10-25%: Good - Low noise, mostly relevant context -* 25-50%: Fair - Moderate noise, some irrelevant instructions -* 50%+: Poor - High noise, agents see many irrelevant instructions - -Example: 36.7% efficiency means agents working in specific directories see only 36.7% relevant instructions and 63.3% irrelevant context pollution.""" - - panel = Panel( - interpretation_text, - title="Metrics Guide", - border_style="dim", - title_align="left", - ) - with self.console.capture() as capture: - self.console.print(panel) - panel_output = capture.get() - if panel_output.strip(): - lines.extend(panel_output.split("\n")) - except Exception: - # Fallback to simple text - lines.extend( - [ - "Metrics Guide:", - "* Context Efficiency 80-100%: Excellent | 60-80%: Good | 40-60%: Fair | <40%: Poor", - "* Pollution 0-10%: Excellent | 10-25%: Good | 25-50%: Fair | >50%: Poor", - ] - ) - else: - # Fallback for non-Rich environments - efficiency = stats.efficiency_percentage - pollution = 100 - efficiency - - if efficiency >= 80: - efficiency_assessment = "Excellent" - elif efficiency >= 60: - efficiency_assessment = "Good" - elif efficiency >= 40: - efficiency_assessment = "Fair" - elif efficiency >= 20: - efficiency_assessment = "Poor" - else: - efficiency_assessment = "Very Poor" - - if pollution <= 10: - pollution_assessment = "Excellent" - elif pollution <= 25: - pollution_assessment = "Good" - elif pollution <= 50: - pollution_assessment = "Fair" - else: - pollution_assessment = "Poor" - - lines.extend( - [ - f"Context Efficiency: {efficiency:.1f}% ({efficiency_assessment})", - f"Pollution Level: {pollution:.1f}% ({pollution_assessment})", - "Guide: 80-100% Excellent | 60-80% Good | 40-60% Fair | 20-40% Poor | <20% Very Poor", - ] - ) - - return lines - - def _format_issues(self, warnings: list[str], errors: list[str]) -> list[str]: - """Format warnings and errors as professional blocks.""" - lines = [] - - # Errors first - for error in errors: - if self.use_color: - lines.append(self._styled(f"x Error: {error}", "red")) - else: - lines.append(f"x Error: {error}") - - # Then warnings - handle multi-line warnings as cohesive blocks - for warning in warnings: - if "\n" in warning: - # Multi-line warning - format as a professional block - warning_lines = warning.split("\n") - # First line gets the warning symbol and styling - if self.use_color: - lines.append(self._styled(f"[!] Warning: {warning_lines[0]}", "yellow")) - else: - lines.append(f"[!] Warning: {warning_lines[0]}") - - # Subsequent lines are indented and styled consistently - for line in warning_lines[1:]: - if line.strip(): # Skip empty lines - if self.use_color: - lines.append(self._styled(f" {line}", "yellow")) - else: - lines.append(f" {line}") - else: # noqa: PLR5501 - # Single-line warning - standard format - if self.use_color: - lines.append(self._styled(f"[!] Warning: {warning}", "yellow")) - else: - lines.append(f"[!] Warning: {warning}") - - return lines - def _get_strategy_symbol(self, strategy: PlacementStrategy) -> str: """Get symbol for placement strategy.""" symbols = { @@ -916,46 +446,6 @@ def _get_relative_display_path(self, path: Path) -> str: except ValueError: return str(path / self._target_name) - def _format_coverage_explanation(self, stats) -> list[str]: - """Explain the coverage vs. efficiency trade-off.""" - lines = [] - - if self.use_color: - lines.append(self._styled("Coverage vs. Efficiency Analysis", "cyan bold")) - else: - lines.append("Coverage vs. Efficiency Analysis") - - lines.append("") - - efficiency = stats.efficiency_percentage - - if efficiency < 30: - lines.append("[!] Low Efficiency Detected:") - lines.append(" * Coverage guarantee requires some instructions at root level") - lines.append(" * This creates pollution for specialized directories") - lines.append(" * Trade-off: Guaranteed coverage vs. optimal efficiency") - lines.append(" * Alternative: Higher efficiency with coverage violations (data loss)") - lines.append("") - lines.append("This may be mathematically optimal given coverage constraints") - elif efficiency < 60: - lines.append("[+] Moderate Efficiency:") - lines.append(" * Good balance between coverage and efficiency") - lines.append(" * Some coverage-driven pollution is acceptable") - lines.append(" * Most patterns are well-localized") - else: - lines.append("High Efficiency:") - lines.append(" * Excellent pattern locality achieved") - lines.append(" * Minimal coverage conflicts") - lines.append(" * Instructions are optimally placed") - - lines.append("") - lines.append("Why Coverage Takes Priority:") - lines.append(" * Every file must access applicable instructions") - lines.append(" * Hierarchical inheritance prevents data loss") - lines.append(" * Better low efficiency than missing instructions") - - return lines - def _get_placement_description(self, summary) -> str: """Get description of what's included in a placement summary. diff --git a/src/apm_cli/primitives/parser.py b/src/apm_cli/primitives/parser.py index d0ff8f2e1..88a9093ae 100644 --- a/src/apm_cli/primitives/parser.py +++ b/src/apm_cli/primitives/parser.py @@ -221,6 +221,30 @@ def _parse_context( ) +_PRIMITIVE_SUFFIXES = ( + ".chatmode.md", + ".instructions.md", + ".context.md", + ".memory.md", + ".agent.md", + ".md", +) + +_STRUCTURED_SUBDIRS = frozenset({"chatmodes", "instructions", "context", "memory", "agents"}) + + +def _strip_file_ext(basename: str) -> str: + """Strip the primitive double-extension from a basename, returning the stem. + + Tries each known suffix in priority order; returns *basename* unchanged when + no suffix matches (so callers can detect "no strip happened"). + """ + for suffix in _PRIMITIVE_SUFFIXES: + if basename.endswith(suffix): + return basename[: -len(suffix)] + return basename + + def _extract_primitive_name(file_path: Path) -> str: """Extract primitive name from file path based on naming conventions. @@ -230,57 +254,23 @@ def _extract_primitive_name(file_path: Path) -> str: Returns: str: Extracted primitive name. """ - # Normalize path path_parts = file_path.parts - # Check if it's in a structured directory (.apm/ or .github/) + # Structured directory (.apm/ or .github/): strip double extension directly. if ".apm" in path_parts or ".github" in path_parts: try: - # Find the base directory index - if ".apm" in path_parts: - base_idx = path_parts.index(".apm") - else: - base_idx = path_parts.index(".github") - - # For structured directories like .apm/chatmodes/name.chatmode.md - if base_idx + 2 < len(path_parts) and path_parts[base_idx + 1] in [ - "chatmodes", - "instructions", - "context", - "memory", - "agents", - ]: - basename = file_path.name - # Remove the double extension (.chatmode.md, .instructions.md, .agent.md, etc.) - if basename.endswith(".chatmode.md"): - return basename.replace(".chatmode.md", "") - elif basename.endswith(".instructions.md"): - return basename.replace(".instructions.md", "") - elif basename.endswith(".context.md"): - return basename.replace(".context.md", "") - elif basename.endswith(".memory.md"): - return basename.replace(".memory.md", "") - elif basename.endswith(".agent.md"): - return basename.replace(".agent.md", "") - elif basename.endswith(".md"): - return basename.replace(".md", "") + base_idx = ( + path_parts.index(".apm") if ".apm" in path_parts else path_parts.index(".github") + ) + if base_idx + 2 < len(path_parts) and path_parts[base_idx + 1] in _STRUCTURED_SUBDIRS: + return _strip_file_ext(file_path.name) except (ValueError, IndexError): pass - # Fallback: extract from filename - basename = file_path.name - if basename.endswith(".chatmode.md"): - return basename.replace(".chatmode.md", "") - elif basename.endswith(".instructions.md"): - return basename.replace(".instructions.md", "") - elif basename.endswith(".context.md"): - return basename.replace(".context.md", "") - elif basename.endswith(".memory.md"): - return basename.replace(".memory.md", "") - elif basename.endswith(".md"): - return basename.replace(".md", "") - - # Final fallback: use filename without extension + # Fallback: strip extension if recognised; otherwise use pathlib stem. + stripped = _strip_file_ext(file_path.name) + if stripped != file_path.name: + return stripped return file_path.stem diff --git a/src/apm_cli/utils/_github_host_artifactory.py b/src/apm_cli/utils/_github_host_artifactory.py new file mode 100644 index 000000000..4545bbba5 --- /dev/null +++ b/src/apm_cli/utils/_github_host_artifactory.py @@ -0,0 +1,173 @@ +"""JFrog Artifactory URL helpers extracted from :mod:`apm_cli.utils.github_host`. + +Extracted to keep ``github_host.py`` under the 800-line threshold while +preserving 100% behavioural equivalence. This module is private +(``_`` prefix). All public names are re-exported from ``github_host.py`` +so ``apm_cli.utils.github_host.NAME`` continues to resolve correctly. + +Rule B note: none of the functions here reference the patched module-level +names from ``github_host`` (``is_github_hostname``, ``is_azure_devops_hostname``, +etc.), so no late-import routing is needed. +""" + +from __future__ import annotations + + +def is_artifactory_path(path_segments: list) -> bool: + """Return True if path segments indicate a JFrog Artifactory VCS repository. + + Artifactory VCS paths follow the pattern: artifactory/{repo-key}/{owner}/{repo} + Detection: first segment is 'artifactory' and there are at least 4 segments. + """ + return len(path_segments) >= 4 and path_segments[0].lower() == "artifactory" + + +def iter_artifactory_boundary_candidates(path_segments: list, shape_filter=None): + """Yield ``(prefix, owner, repo, virtual_path)`` candidates shallow-first. + + Mirrors :meth:`DependencyReference.iter_gitlab_direct_shorthand_boundary_candidates`: + enumerate every plausible (owner, repo) split and let the caller probe each + one against the Artifactory proxy. The probe (HEAD on the archive URL) + decides the real boundary; this iterator only proposes candidates. + + If *shape_filter* is provided, candidates whose ``virtual_path`` fails the + filter are skipped. The candidate with no virtual path (``k == n``) is + always yielded as the all-as-repo fallback so callers that need a + deterministic answer (no probing) can pick it. + + The ``//`` empty-segment notation explicitly marks the repo / virtual + boundary and short-circuits the iterator to a single candidate. + + Returns nothing for non-Artifactory paths. + """ + if not is_artifactory_path(path_segments): + return + repo_key = path_segments[1] + prefix = f"artifactory/{repo_key}" + remaining = path_segments[2:] + if not remaining: + return + owner = remaining[0] + after_owner = remaining[1:] + n = len(after_owner) + if n == 0: + return + + if "" in after_owner: + empty_idx = after_owner.index("") + repo_parts = after_owner[:empty_idx] + suffix_parts = [s for s in after_owner[empty_idx + 1 :] if s] + if repo_parts: + yield ( + prefix, + owner, + "/".join(repo_parts), + "/".join(suffix_parts) if suffix_parts else None, + ) + return + + for k in range(1, n + 1): + repo = "/".join(after_owner[:k]) + suffix_parts = after_owner[k:] + suffix = "/".join(suffix_parts) if suffix_parts else None + if suffix is not None and shape_filter is not None and not shape_filter(suffix): + continue + yield (prefix, owner, repo, suffix) + + +def parse_artifactory_path(path_segments: list) -> tuple: + """Parse Artifactory path into ``(prefix, owner, repo, virtual_path)``. + + Parse-time output is intentionally simple and unambiguous: ``owner`` is the + first segment after ``artifactory/{key}``, ``repo`` is the next segment, + and any further segments become ``virtual_path``. The authoritative + boundary -- needed for nested GitLab subgroup paths behind the Artifactory + proxy -- is determined by :func:`apm_cli.install.artifactory_resolver.\ +_resolve_artifactory_boundary`, which probes archive URLs and rebuilds the + dependency reference at the verified boundary. + + The ``//`` notation (empty segment) is honored as an explicit, deterministic + boundary marker so users can opt out of probing. + + Returns None if not a valid Artifactory path. + """ + if not is_artifactory_path(path_segments): + return None + repo_key = path_segments[1] + prefix = f"artifactory/{repo_key}" + remaining = path_segments[2:] + if not remaining: + return None + owner = remaining[0] + after_owner = remaining[1:] + if not after_owner: + return None + + if "" in after_owner: + empty_idx = after_owner.index("") + repo_parts = after_owner[:empty_idx] + suffix_parts = [s for s in after_owner[empty_idx + 1 :] if s] + if not repo_parts: + # ``owner//virtual`` has no segments before the explicit boundary, + # so there is no repo to install -- reject as invalid rather than + # falling through and returning ``repo=''``. + return None + return ( + prefix, + owner, + "/".join(repo_parts), + "/".join(suffix_parts) if suffix_parts else None, + ) + + repo = after_owner[0] + virtual_path = "/".join(after_owner[1:]) if len(after_owner) > 1 else None + return (prefix, owner, repo, virtual_path) + + +def build_artifactory_archive_url( + host: str, prefix: str, owner: str, repo: str, ref: str = "main", scheme: str = "https" +) -> tuple: + """Build Artifactory VCS archive download URLs. + + Returns a tuple of URLs to try in order. Because Artifactory proxies + the upstream server's native URL scheme, we attempt GitHub-style, + GitLab-style, and codeload.github.com-style archive paths so the caller + does not need to know what sits behind the Artifactory remote repository. + + Organizations using private GitHub repositories must configure their + Artifactory upstream as ``codeload.github.com`` (instead of ``github.com``) + because Artifactory cannot follow GitHub's cross-host redirect (which + carries short-lived tokens) to codeload. When the upstream is + ``codeload.github.com``, the required archive path is + ``/{owner}/{repo}/zip/refs/heads/{ref}`` (no ``.zip`` extension). + + Args: + host: Artifactory hostname (e.g., 'artifactory.example.com') + prefix: Artifactory path prefix (e.g., 'artifactory/github') + owner: Repository owner + repo: Repository name + ref: Git reference (branch or tag name) + scheme: URL scheme (default 'https'; 'http' for local dev proxies) + + Returns: + Tuple of URLs to try in order + """ + base = f"{scheme}://{host}/{prefix}/{owner}/{repo}" + # GitLab archive filenames use only the project basename, even when the + # project sits inside a subgroup (e.g. ``group/sub/pkg`` becomes + # ``pkg-{ref}.zip``). ``rsplit`` keeps the flat case unchanged. + repo_basename = repo.rsplit("/", 1)[-1] + return ( + # GitHub-style: /archive/refs/heads/{ref}.zip + f"{base}/archive/refs/heads/{ref}.zip", + # GitLab-style: /-/archive/{ref}/{basename}-{ref}.zip + f"{base}/-/archive/{ref}/{repo_basename}-{ref}.zip", + # GitHub-style tags fallback + f"{base}/archive/refs/tags/{ref}.zip", + # codeload.github.com-style: /zip/refs/heads/{ref} + # Required when Artifactory upstream is configured as codeload.github.com + # (workaround for private repos where github.com redirects to codeload with tokens + # that Artifactory cannot follow across hosts) + f"{base}/zip/refs/heads/{ref}", + f"{base}/zip/refs/tags/{ref}", + ) diff --git a/src/apm_cli/utils/github_host.py b/src/apm_cli/utils/github_host.py index c41fed8c2..5e27d8043 100644 --- a/src/apm_cli/utils/github_host.py +++ b/src/apm_cli/utils/github_host.py @@ -4,6 +4,19 @@ import re import urllib.parse +from ._github_host_artifactory import ( + build_artifactory_archive_url as build_artifactory_archive_url, +) +from ._github_host_artifactory import ( + is_artifactory_path as is_artifactory_path, +) +from ._github_host_artifactory import ( + iter_artifactory_boundary_candidates as iter_artifactory_boundary_candidates, +) +from ._github_host_artifactory import ( + parse_artifactory_path as parse_artifactory_path, +) + def default_host() -> str: """Return the default Git host (can be overridden via GITHUB_HOST env var).""" @@ -601,166 +614,6 @@ def build_ado_api_url( ) -def is_artifactory_path(path_segments: list) -> bool: - """Return True if path segments indicate a JFrog Artifactory VCS repository. - - Artifactory VCS paths follow the pattern: artifactory/{repo-key}/{owner}/{repo} - Detection: first segment is 'artifactory' and there are at least 4 segments. - """ - return len(path_segments) >= 4 and path_segments[0].lower() == "artifactory" - - -def iter_artifactory_boundary_candidates(path_segments: list, shape_filter=None): - """Yield ``(prefix, owner, repo, virtual_path)`` candidates shallow-first. - - Mirrors :meth:`DependencyReference.iter_gitlab_direct_shorthand_boundary_candidates`: - enumerate every plausible (owner, repo) split and let the caller probe each - one against the Artifactory proxy. The probe (HEAD on the archive URL) - decides the real boundary; this iterator only proposes candidates. - - If *shape_filter* is provided, candidates whose ``virtual_path`` fails the - filter are skipped. The candidate with no virtual path (``k == n``) is - always yielded as the all-as-repo fallback so callers that need a - deterministic answer (no probing) can pick it. - - The ``//`` empty-segment notation explicitly marks the repo / virtual - boundary and short-circuits the iterator to a single candidate. - - Returns nothing for non-Artifactory paths. - """ - if not is_artifactory_path(path_segments): - return - repo_key = path_segments[1] - prefix = f"artifactory/{repo_key}" - remaining = path_segments[2:] - if not remaining: - return - owner = remaining[0] - after_owner = remaining[1:] - n = len(after_owner) - if n == 0: - return - - if "" in after_owner: - empty_idx = after_owner.index("") - repo_parts = after_owner[:empty_idx] - suffix_parts = [s for s in after_owner[empty_idx + 1 :] if s] - if repo_parts: - yield ( - prefix, - owner, - "/".join(repo_parts), - "/".join(suffix_parts) if suffix_parts else None, - ) - return - - for k in range(1, n + 1): - repo = "/".join(after_owner[:k]) - suffix_parts = after_owner[k:] - suffix = "/".join(suffix_parts) if suffix_parts else None - if suffix is not None and shape_filter is not None and not shape_filter(suffix): - continue - yield (prefix, owner, repo, suffix) - - -def parse_artifactory_path(path_segments: list) -> tuple: - """Parse Artifactory path into ``(prefix, owner, repo, virtual_path)``. - - Parse-time output is intentionally simple and unambiguous: ``owner`` is the - first segment after ``artifactory/{key}``, ``repo`` is the next segment, - and any further segments become ``virtual_path``. The authoritative - boundary -- needed for nested GitLab subgroup paths behind the Artifactory - proxy -- is determined by :func:`apm_cli.install.artifactory_resolver.\ -_resolve_artifactory_boundary`, which probes archive URLs and rebuilds the - dependency reference at the verified boundary. - - The ``//`` notation (empty segment) is honored as an explicit, deterministic - boundary marker so users can opt out of probing. - - Returns None if not a valid Artifactory path. - """ - if not is_artifactory_path(path_segments): - return None - repo_key = path_segments[1] - prefix = f"artifactory/{repo_key}" - remaining = path_segments[2:] - if not remaining: - return None - owner = remaining[0] - after_owner = remaining[1:] - if not after_owner: - return None - - if "" in after_owner: - empty_idx = after_owner.index("") - repo_parts = after_owner[:empty_idx] - suffix_parts = [s for s in after_owner[empty_idx + 1 :] if s] - if not repo_parts: - # ``owner//virtual`` has no segments before the explicit boundary, - # so there is no repo to install -- reject as invalid rather than - # falling through and returning ``repo=''``. - return None - return ( - prefix, - owner, - "/".join(repo_parts), - "/".join(suffix_parts) if suffix_parts else None, - ) - - repo = after_owner[0] - virtual_path = "/".join(after_owner[1:]) if len(after_owner) > 1 else None - return (prefix, owner, repo, virtual_path) - - -def build_artifactory_archive_url( - host: str, prefix: str, owner: str, repo: str, ref: str = "main", scheme: str = "https" -) -> tuple: - """Build Artifactory VCS archive download URLs. - - Returns a tuple of URLs to try in order. Because Artifactory proxies - the upstream server's native URL scheme, we attempt GitHub-style, - GitLab-style, and codeload.github.com-style archive paths so the caller - does not need to know what sits behind the Artifactory remote repository. - - Organizations using private GitHub repositories must configure their - Artifactory upstream as ``codeload.github.com`` (instead of ``github.com``) - because Artifactory cannot follow GitHub's cross-host redirect (which - carries short-lived tokens) to codeload. When the upstream is - ``codeload.github.com``, the required archive path is - ``/{owner}/{repo}/zip/refs/heads/{ref}`` (no ``.zip`` extension). - - Args: - host: Artifactory hostname (e.g., 'artifactory.example.com') - prefix: Artifactory path prefix (e.g., 'artifactory/github') - owner: Repository owner - repo: Repository name - ref: Git reference (branch or tag name) - scheme: URL scheme (default 'https'; 'http' for local dev proxies) - - Returns: - Tuple of URLs to try in order - """ - base = f"{scheme}://{host}/{prefix}/{owner}/{repo}" - # GitLab archive filenames use only the project basename, even when the - # project sits inside a subgroup (e.g. ``group/sub/pkg`` becomes - # ``pkg-{ref}.zip``). ``rsplit`` keeps the flat case unchanged. - repo_basename = repo.rsplit("/", 1)[-1] - return ( - # GitHub-style: /archive/refs/heads/{ref}.zip - f"{base}/archive/refs/heads/{ref}.zip", - # GitLab-style: /-/archive/{ref}/{basename}-{ref}.zip - f"{base}/-/archive/{ref}/{repo_basename}-{ref}.zip", - # GitHub-style tags fallback - f"{base}/archive/refs/tags/{ref}.zip", - # codeload.github.com-style: /zip/refs/heads/{ref} - # Required when Artifactory upstream is configured as codeload.github.com - # (workaround for private repos where github.com redirects to codeload with tokens - # that Artifactory cannot follow across hosts) - f"{base}/zip/refs/heads/{ref}", - f"{base}/zip/refs/tags/{ref}", - ) - - def is_valid_fqdn(hostname: str) -> bool: """Validate if a string is a valid Fully Qualified Domain Name (FQDN). diff --git a/tests/unit/test_output_formatters_phase3.py b/tests/unit/test_output_formatters_phase3.py index 78477a25b..ffa78cd31 100644 --- a/tests/unit/test_output_formatters_phase3.py +++ b/tests/unit/test_output_formatters_phase3.py @@ -1143,102 +1143,6 @@ def test_formatter_color_off_when_no_rich(self) -> None: self.assertFalse(formatter.use_color) -# =========================================================================== -# Tests: _format_final_summary (verbose path) -# =========================================================================== - - -class TestFormatFinalSummary(unittest.TestCase): - """Tests for _format_final_summary.""" - - def setUp(self) -> None: - self.formatter = CompilationFormatter(use_color=False) - - def test_returns_list(self) -> None: - results = _make_results() - out = self.formatter._format_final_summary(results) - self.assertIsInstance(out, list) - - def test_dry_run_label_present(self) -> None: - results = _make_results(is_dry_run=True) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("[DRY RUN]", text) - - def test_generated_label_when_not_dry_run(self) -> None: - results = _make_results(is_dry_run=False) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("Generated", text) - - def test_efficiency_percentage_in_output(self) -> None: - stats = _make_stats(efficiency=0.55) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("55.0%", text) - - def test_placement_distribution_header(self) -> None: - results = _make_results() - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("Placement Distribution", text) - - def test_efficiency_improvement_positive_in_final(self) -> None: - stats = _make_stats(efficiency=0.80, baseline_efficiency=0.50) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("improvement:", text) - - def test_pollution_improvement_in_final(self) -> None: - stats = _make_stats(efficiency=0.70, pollution_improvement=0.20) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("pollution", text.lower()) - - def test_placement_accuracy_in_final(self) -> None: - stats = _make_stats(efficiency=0.70, placement_accuracy=0.92) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("Placement accuracy", text) - - def test_generation_time_in_final(self) -> None: - stats = _make_stats(efficiency=0.70, generation_time_ms=456) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("456ms", text) - - def test_generation_time_none_pipe_change_final(self) -> None: - stats = _make_stats( - efficiency=0.70, - pollution_improvement=0.10, - placement_accuracy=0.90, - generation_time_ms=None, - ) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("+-", text) - - def test_pollution_improvement_negative_in_final(self) -> None: - stats = _make_stats(efficiency=0.70, pollution_improvement=-0.05) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("pollution", text.lower()) - - def test_efficiency_improvement_negative_in_final(self) -> None: - stats = _make_stats(efficiency=0.45, baseline_efficiency=0.60) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("change:", text) - - # =========================================================================== # Integration smoke-tests: round-trip all three public format methods # =========================================================================== @@ -1863,80 +1767,6 @@ def test_styled_empty_string(self) -> None: self.assertIsInstance(result, str) -class TestFormatFinalSummaryColorBranches(unittest.TestCase): - """Tests for _format_final_summary with use_color=True.""" - - def setUp(self) -> None: - try: - self.formatter = _color_formatter() - except unittest.SkipTest as exc: - self.skipTest(str(exc)) - - def test_dry_run_label_colored(self) -> None: - results = _make_results(is_dry_run=True) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("DRY RUN", text) - - def test_generated_label_colored(self) -> None: - results = _make_results(is_dry_run=False) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("Generated", text) - - def test_efficiency_colored(self) -> None: - stats = _make_stats(efficiency=0.65) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("65.0%", text) - - def test_placement_distribution_colored(self) -> None: - results = _make_results() - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("Placement Distribution", text) - - def test_with_all_metrics_colored(self) -> None: - stats = _make_stats( - efficiency=0.80, - baseline_efficiency=0.55, - pollution_improvement=0.20, - placement_accuracy=0.90, - generation_time_ms=75, - ) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("75ms", text) - - def test_efficiency_improvement_negative_colored(self) -> None: - stats = _make_stats(efficiency=0.40, baseline_efficiency=0.60) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("change:", text) - - def test_pollution_improvement_negative_colored(self) -> None: - stats = _make_stats(efficiency=0.70, pollution_improvement=-0.05) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertTrue(len(text) > 0) - - def test_generation_time_none_pipe_change_colored(self) -> None: - stats = _make_stats( - efficiency=0.70, - pollution_improvement=0.10, - placement_accuracy=0.90, - generation_time_ms=None, - ) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("+-", text) - - class TestRichColorIntegrationRoundTrips(unittest.TestCase): """Round-trip integration tests with use_color=True.""" diff --git a/tests/unit/test_output_formatters_rendering.py b/tests/unit/test_output_formatters_rendering.py index 78477a25b..ffa78cd31 100644 --- a/tests/unit/test_output_formatters_rendering.py +++ b/tests/unit/test_output_formatters_rendering.py @@ -1143,102 +1143,6 @@ def test_formatter_color_off_when_no_rich(self) -> None: self.assertFalse(formatter.use_color) -# =========================================================================== -# Tests: _format_final_summary (verbose path) -# =========================================================================== - - -class TestFormatFinalSummary(unittest.TestCase): - """Tests for _format_final_summary.""" - - def setUp(self) -> None: - self.formatter = CompilationFormatter(use_color=False) - - def test_returns_list(self) -> None: - results = _make_results() - out = self.formatter._format_final_summary(results) - self.assertIsInstance(out, list) - - def test_dry_run_label_present(self) -> None: - results = _make_results(is_dry_run=True) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("[DRY RUN]", text) - - def test_generated_label_when_not_dry_run(self) -> None: - results = _make_results(is_dry_run=False) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("Generated", text) - - def test_efficiency_percentage_in_output(self) -> None: - stats = _make_stats(efficiency=0.55) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("55.0%", text) - - def test_placement_distribution_header(self) -> None: - results = _make_results() - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("Placement Distribution", text) - - def test_efficiency_improvement_positive_in_final(self) -> None: - stats = _make_stats(efficiency=0.80, baseline_efficiency=0.50) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("improvement:", text) - - def test_pollution_improvement_in_final(self) -> None: - stats = _make_stats(efficiency=0.70, pollution_improvement=0.20) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("pollution", text.lower()) - - def test_placement_accuracy_in_final(self) -> None: - stats = _make_stats(efficiency=0.70, placement_accuracy=0.92) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("Placement accuracy", text) - - def test_generation_time_in_final(self) -> None: - stats = _make_stats(efficiency=0.70, generation_time_ms=456) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("456ms", text) - - def test_generation_time_none_pipe_change_final(self) -> None: - stats = _make_stats( - efficiency=0.70, - pollution_improvement=0.10, - placement_accuracy=0.90, - generation_time_ms=None, - ) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("+-", text) - - def test_pollution_improvement_negative_in_final(self) -> None: - stats = _make_stats(efficiency=0.70, pollution_improvement=-0.05) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("pollution", text.lower()) - - def test_efficiency_improvement_negative_in_final(self) -> None: - stats = _make_stats(efficiency=0.45, baseline_efficiency=0.60) - results = _make_results(stats=stats) - out = self.formatter._format_final_summary(results) - text = "\n".join(out) - self.assertIn("change:", text) - - # =========================================================================== # Integration smoke-tests: round-trip all three public format methods # =========================================================================== @@ -1863,80 +1767,6 @@ def test_styled_empty_string(self) -> None: self.assertIsInstance(result, str) -class TestFormatFinalSummaryColorBranches(unittest.TestCase): - """Tests for _format_final_summary with use_color=True.""" - - def setUp(self) -> None: - try: - self.formatter = _color_formatter() - except unittest.SkipTest as exc: - self.skipTest(str(exc)) - - def test_dry_run_label_colored(self) -> None: - results = _make_results(is_dry_run=True) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("DRY RUN", text) - - def test_generated_label_colored(self) -> None: - results = _make_results(is_dry_run=False) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("Generated", text) - - def test_efficiency_colored(self) -> None: - stats = _make_stats(efficiency=0.65) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("65.0%", text) - - def test_placement_distribution_colored(self) -> None: - results = _make_results() - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("Placement Distribution", text) - - def test_with_all_metrics_colored(self) -> None: - stats = _make_stats( - efficiency=0.80, - baseline_efficiency=0.55, - pollution_improvement=0.20, - placement_accuracy=0.90, - generation_time_ms=75, - ) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("75ms", text) - - def test_efficiency_improvement_negative_colored(self) -> None: - stats = _make_stats(efficiency=0.40, baseline_efficiency=0.60) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("change:", text) - - def test_pollution_improvement_negative_colored(self) -> None: - stats = _make_stats(efficiency=0.70, pollution_improvement=-0.05) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertTrue(len(text) > 0) - - def test_generation_time_none_pipe_change_colored(self) -> None: - stats = _make_stats( - efficiency=0.70, - pollution_improvement=0.10, - placement_accuracy=0.90, - generation_time_ms=None, - ) - results = _make_results(stats=stats) - lines = self.formatter._format_final_summary(results) - text = "\n".join(lines) - self.assertIn("+-", text) - - class TestRichColorIntegrationRoundTrips(unittest.TestCase): """Round-trip integration tests with use_color=True.""" From 7c9eec49759f821251c53e7c8f80a14ad0d427bb Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 08:32:31 +0200 Subject: [PATCH 19/21] refactor(policy,marketplace): split second-tier length offenders under 800 (#1078) Strangler Stage 2, Commit 7: drive the remaining seven >800-line files in the policy/ and marketplace/ subsystems under the 800-line guardrail via leaf-module and mixin splits that preserve every patched/public import surface, plus genuine PLR0911 decompositions. policy/ - discovery.py 1364->662: extract cache cluster to _discovery_cache.py and chain/host-pin helpers to _discovery_chain.py (Rule B routes the patched requests/subprocess/discover_policy/_write_cache globals). _fetch_github_contents PLR0911 11->2 via _parse_github_repo_ref / _decode_github_content / _call_github_api. - policy_checks.py 1134->693: extract MCP/compilation/manifest checks to _policy_checks_mcp.py; run_policy_checks PLR0911 9->5. _check_unmanaged_files + _MAX_UNMANAGED_SCAN_FILES kept in place to preserve a monkeypatch. - ci_checks.py run_baseline_checks PLR0911 11->6 (lazy lambda loop). - _constraint_pinning.py classify_unbounded_reason PLR0911 10->8. marketplace/ - yml_schema.py 1220->462: dataclasses to _yml_models.py (leaf), parsers to _yml_parsers.py (imports models, no cycle); dedup _parse_outputs. - builder.py 1130->506: report dataclasses to _builder_reports.py, resolve methods to _BuilderResolveMixin in _builder_resolve.py (Rule B for the 20x-patched urllib); dedup the _compute_diff plugin-SHA loop. - resolver.py 953->696: cross-repo-misconfig/matching cluster to _resolver_match.py; two PLR0911 9->5 via guard extraction. Public parse_marketplace_ref/get_marketplace_by_name/resolve_marketplace_plugin stay. - client.py 817->717: cache I/O cluster to _client_cache.py. Public fetch_marketplace/fetch_or_cache/clear_marketplace_cache/search_marketplace stay. - publisher.py 922->759: PublishState + dataclasses to _publish_state.py; _process_single_target PLR0911 9->8. - semver.py PLR0911 11->7 via _satisfies_caret; public API unchanged. All resulting files (parents + 10 new siblings) <800 (<=770). Whole-src >800 backlog 7->0. Full unit+acceptance 16605 passed; targeted integration 2747 passed; ruff/format/complexity(final thresholds)/R0801 10.00/auth-signals all clean. No threshold flip yet (enforcement is the final commit). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/marketplace/_builder_reports.py | 200 ++++ src/apm_cli/marketplace/_builder_resolve.py | 374 ++++++ src/apm_cli/marketplace/_client_cache.py | 158 +++ src/apm_cli/marketplace/_publish_state.py | 208 ++++ src/apm_cli/marketplace/_resolver_match.py | 342 ++++++ src/apm_cli/marketplace/_yml_models.py | 177 +++ src/apm_cli/marketplace/_yml_parsers.py | 733 ++++++++++++ src/apm_cli/marketplace/builder.py | 742 +----------- src/apm_cli/marketplace/client.py | 150 +-- src/apm_cli/marketplace/publisher.py | 309 ++--- src/apm_cli/marketplace/resolver.py | 373 +----- src/apm_cli/marketplace/semver.py | 52 +- src/apm_cli/marketplace/yml_schema.py | 1130 +++---------------- src/apm_cli/policy/_constraint_pinning.py | 27 +- src/apm_cli/policy/_discovery_cache.py | 553 +++++++++ src/apm_cli/policy/_discovery_chain.py | 272 +++++ src/apm_cli/policy/_policy_checks_mcp.py | 465 ++++++++ src/apm_cli/policy/ci_checks.py | 36 +- src/apm_cli/policy/discovery.py | 972 +++------------- src/apm_cli/policy/policy_checks.py | 749 +++--------- 20 files changed, 4233 insertions(+), 3789 deletions(-) create mode 100644 src/apm_cli/marketplace/_builder_reports.py create mode 100644 src/apm_cli/marketplace/_builder_resolve.py create mode 100644 src/apm_cli/marketplace/_client_cache.py create mode 100644 src/apm_cli/marketplace/_publish_state.py create mode 100644 src/apm_cli/marketplace/_resolver_match.py create mode 100644 src/apm_cli/marketplace/_yml_models.py create mode 100644 src/apm_cli/marketplace/_yml_parsers.py create mode 100644 src/apm_cli/policy/_discovery_cache.py create mode 100644 src/apm_cli/policy/_discovery_chain.py create mode 100644 src/apm_cli/policy/_policy_checks_mcp.py diff --git a/src/apm_cli/marketplace/_builder_reports.py b/src/apm_cli/marketplace/_builder_reports.py new file mode 100644 index 000000000..5ff9adc94 --- /dev/null +++ b/src/apm_cli/marketplace/_builder_reports.py @@ -0,0 +1,200 @@ +"""Report and options dataclasses for the marketplace builder. + +Leaf module -- no imports from ``builder`` (cycle-safe). All public +symbols are re-exported by ``builder`` so existing import paths such +as ``from apm_cli.marketplace.builder import BuildReport`` continue +to work without changes. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from .diagnostics import BuildDiagnostic + +__all__ = [ + "BuildDiagnostic", + "BuildOptions", + "BuildReport", + "MarketplaceOutputReport", + "ResolveResult", + "ResolvedPackage", +] + + +@dataclass(frozen=True) +class ResolvedPackage: + """A package entry after ref resolution.""" + + name: str + source_repo: str # "owner/repo" only + subdir: str | None # APM-only (used to compose the output ``source`` object) + ref: str # resolved tag name, e.g. "v1.2.0" + sha: str # 40-char git SHA + requested_version: str | None # original APM-only range (for diagnostics) + tags: tuple[str, ...] + is_prerelease: bool # True if the resolved ref was a prerelease semver + host: str | None = None # non-default git host parsed from apm.yml source + + +@dataclass(frozen=True) +class ResolveResult: + """Result of resolving package refs in a marketplace build.""" + + entries: tuple[ResolvedPackage, ...] + errors: tuple[tuple[str, str], ...] # (package name, error message) pairs + + @property + def ok(self) -> bool: + """True when every package resolved without error.""" + return len(self.errors) == 0 + + +@dataclass(frozen=True) +class MarketplaceOutputReport: + """Summary for one generated marketplace output profile.""" + + profile: str + resolved: tuple[ResolvedPackage, ...] + errors: tuple[tuple[str, str], ...] # (package name, error message) pairs + warnings: tuple[str, ...] # non-fatal diagnostic messages + diagnostics: tuple[BuildDiagnostic, ...] = () # structured diagnostics + unchanged_count: int = 0 + added_count: int = 0 + updated_count: int = 0 + removed_count: int = 0 + output_path: Path = field(default_factory=lambda: Path(".")) + dry_run: bool = False + + +@dataclass(frozen=True) +class BuildReport: + """Summary of a marketplace build run across one or more output profiles.""" + + outputs: tuple[MarketplaceOutputReport, ...] + + @property + def primary_output(self) -> MarketplaceOutputReport: + """Return the first output report for legacy single-output callers.""" + if not self.outputs: + return MarketplaceOutputReport( + profile="", + resolved=(), + errors=(), + warnings=(), + ) + return self.outputs[0] + + @property + def resolved(self) -> tuple[ResolvedPackage, ...]: + return self.primary_output.resolved + + @property + def errors(self) -> tuple[tuple[str, str], ...]: + return self.primary_output.errors + + @property + def warnings(self) -> tuple[str, ...]: + return tuple(warn for output in self.outputs for warn in output.warnings) + + @property + def diagnostics(self) -> tuple[BuildDiagnostic, ...]: + return tuple(diag for output in self.outputs for diag in output.diagnostics) + + @property + def unchanged_count(self) -> int: + return self.primary_output.unchanged_count + + @property + def added_count(self) -> int: + return self.primary_output.added_count + + @property + def updated_count(self) -> int: + return self.primary_output.updated_count + + @property + def removed_count(self) -> int: + return self.primary_output.removed_count + + @property + def output_path(self) -> Path: + return self.primary_output.output_path + + @property + def dry_run(self) -> bool: + return any(output.dry_run for output in self.outputs) + + def to_json_dict(self) -> dict[str, Any]: + """Serialize build report as the S4 JSON contract. + + Shape: {ok, dry_run, warnings[], errors[], + marketplace: {outputs: [{format, path, added, updated, + unchanged, skipped}]}, bundle: null} + """ + all_warnings = list(self.warnings) + all_errors: list[dict[str, str]] = [] + output_entries: list[dict[str, Any]] = [] + + for out in self.outputs: + output_entries.append( + { + "format": out.profile, + "path": str(out.output_path), + "added": out.added_count, + "updated": out.updated_count, + "unchanged": out.unchanged_count, + "skipped": out.removed_count, + } + ) + for pkg_name, err_msg in out.errors: + all_errors.append({"code": "build_error", "message": f"{pkg_name}: {err_msg}"}) + + ok = len(all_errors) == 0 + return { + "ok": ok, + "dry_run": self.dry_run, + "warnings": all_warnings, + "errors": all_errors, + "marketplace": { + "outputs": output_entries, + }, + "bundle": None, + } + + @classmethod + def failure_to_json_dict( + cls, + *, + errors: list[dict[str, str]], + warnings: list[str] | None = None, + dry_run: bool = False, + ) -> dict[str, Any]: + """Produce the S4 JSON shape for a pre-build failure.""" + return { + "ok": False, + "dry_run": dry_run, + "warnings": warnings or [], + "errors": errors, + "marketplace": { + "outputs": [], + }, + "bundle": None, + } + + +@dataclass +class BuildOptions: + """Configuration knobs for MarketplaceBuilder.""" + + concurrency: int = 8 + timeout_seconds: float = 10.0 + include_prerelease: bool = False + allow_head: bool = False + continue_on_error: bool = False + offline: bool = False + # Backwards-compatible spelling for callers that predate ``apm pack``. + output_override: Path | None = None + dry_run: bool = False diff --git a/src/apm_cli/marketplace/_builder_resolve.py b/src/apm_cli/marketplace/_builder_resolve.py new file mode 100644 index 000000000..239711d92 --- /dev/null +++ b/src/apm_cli/marketplace/_builder_resolve.py @@ -0,0 +1,374 @@ +"""Resolve mixin for MarketplaceBuilder. + +Provides ``_BuilderResolveMixin`` which is mixed into ``MarketplaceBuilder`` +in ``builder.py``. Keeping these methods separate reduces the line count +of ``builder.py`` without splitting the public class interface. + +urllib Rule B +------------- +``_fetch_remote_metadata`` uses ``urllib.request`` but does NOT import it +at module scope here. Instead it performs a late import:: + + from apm_cli.marketplace import builder as _b + ... _b.urllib.request.urlopen(req, timeout=5) ... + +This keeps the patch target ``apm_cli.marketplace.builder.urllib`` valid +for the 20+ test-suite ``patch()`` calls that mock ``urlopen``. +""" + +from __future__ import annotations + +import logging +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Any + +import yaml + +from ._builder_reports import BuildOptions, ResolvedPackage, ResolveResult +from ._shared import iter_semver_tags +from .errors import ( + BuildError, + HeadNotAllowedError, + NoMatchingVersionError, + RefNotFoundError, +) +from .ref_resolver import RefResolver +from .semver import SemVer, parse_semver, satisfies_range +from .tag_pattern import build_tag_regex +from .yml_schema import MarketplaceYml + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +# 40-char hex SHA pattern (also used in builder.py -- defined here because +# _resolve_explicit_ref lives here). +_SHA40_RE = re.compile(r"^[0-9a-f]{40}$") + + +def _strip_ref_prefix(refname: str) -> str: + """Strip ``refs/tags/`` or ``refs/heads/`` prefix.""" + if refname.startswith("refs/tags/"): + return refname[len("refs/tags/") :] + if refname.startswith("refs/heads/"): + return refname[len("refs/heads/") :] + return refname + + +class _BuilderResolveMixin: + """Resolution methods factored out of MarketplaceBuilder. + + All methods access ``self`` attributes set by ``MarketplaceBuilder.__init__``. + This class should never be instantiated directly. + """ + + # -- single-entry resolution -------------------------------------------- + + def _resolve_entry(self, entry: Any) -> ResolvedPackage: + """Resolve a single package entry to a concrete tag + SHA.""" + if entry.is_local: + return ResolvedPackage( + name=entry.name, + source_repo="", + subdir=entry.source, + ref="", + sha="", + requested_version=entry.version, + tags=tuple(entry.tags), + is_prerelease=False, + ) + yml = self._load_yml() # type: ignore[attr-defined] + resolver = self._get_resolver_for_host(entry.host) # type: ignore[attr-defined] + owner_repo = entry.source + + if entry.ref is not None: + return self._resolve_explicit_ref(entry, resolver, owner_repo) + return self._resolve_version_range(entry, resolver, owner_repo, yml) + + def _resolve_explicit_ref( + self, + entry: Any, + resolver: RefResolver, + owner_repo: str, + ) -> ResolvedPackage: + """Resolve an entry with an explicit ``ref:`` field.""" + ref_text = entry.ref + assert ref_text is not None # noqa: S101 + + if _SHA40_RE.match(ref_text): + sv = parse_semver(ref_text.lstrip("vV")) + return ResolvedPackage( + name=entry.name, + source_repo=owner_repo, + subdir=entry.subdir, + ref=ref_text, + sha=ref_text, + requested_version=entry.version, + tags=entry.tags, + is_prerelease=sv.is_prerelease if sv else False, + host=self._effective_host(entry.host), # type: ignore[attr-defined] + ) + + refs = resolver.list_remote_refs(owner_repo) + + # Try as tag first + for remote_ref in refs: + if not remote_ref.name.startswith("refs/tags/"): + continue + tag_name = _strip_ref_prefix(remote_ref.name) + if tag_name == ref_text: + sv = parse_semver(tag_name.lstrip("vV")) + return ResolvedPackage( + name=entry.name, + source_repo=owner_repo, + subdir=entry.subdir, + ref=tag_name, + sha=remote_ref.sha, + requested_version=entry.version, + tags=entry.tags, + is_prerelease=sv.is_prerelease if sv else False, + host=self._effective_host(entry.host), # type: ignore[attr-defined] + ) + + # Try as full refname + for remote_ref in refs: + if remote_ref.name == ref_text: + short = _strip_ref_prefix(remote_ref.name) + is_branch = remote_ref.name.startswith("refs/heads/") + if is_branch and not self._options.allow_head: # type: ignore[attr-defined] + raise HeadNotAllowedError(entry.name, short) + sv = parse_semver(short.lstrip("vV")) + return ResolvedPackage( + name=entry.name, + source_repo=owner_repo, + subdir=entry.subdir, + ref=short, + sha=remote_ref.sha, + requested_version=entry.version, + tags=entry.tags, + is_prerelease=sv.is_prerelease if sv else False, + host=self._effective_host(entry.host), # type: ignore[attr-defined] + ) + + # Try as branch name + for remote_ref in refs: + if remote_ref.name == f"refs/heads/{ref_text}": + if not self._options.allow_head: # type: ignore[attr-defined] + raise HeadNotAllowedError(entry.name, ref_text) + return ResolvedPackage( + name=entry.name, + source_repo=owner_repo, + subdir=entry.subdir, + ref=ref_text, + sha=remote_ref.sha, + requested_version=entry.version, + tags=entry.tags, + is_prerelease=False, + host=self._effective_host(entry.host), # type: ignore[attr-defined] + ) + + if ref_text.upper() == "HEAD": + if not self._options.allow_head: # type: ignore[attr-defined] + raise HeadNotAllowedError(entry.name, "HEAD") + + raise RefNotFoundError(entry.name, ref_text, owner_repo) + + def _resolve_version_range( + self, + entry: Any, + resolver: RefResolver, + owner_repo: str, + yml: MarketplaceYml, + ) -> ResolvedPackage: + """Resolve an entry using its ``version:`` semver range.""" + version_range = entry.version + assert version_range is not None # noqa: S101 + + pattern = entry.tag_pattern or yml.build.tag_pattern + tag_rx = build_tag_regex(pattern) + refs = resolver.list_remote_refs(owner_repo) + + candidates: list[tuple[SemVer, str, str]] = [] + for sv, tag_name, sha in iter_semver_tags(refs, tag_rx): + include_pre = ( + entry.include_prerelease or self._options.include_prerelease # type: ignore[attr-defined] + ) + if sv.is_prerelease and not include_pre: + continue + if satisfies_range(sv, version_range): + candidates.append((sv, tag_name, sha)) + + if not candidates: + raise NoMatchingVersionError( + entry.name, + version_range, + detail=f"pattern='{pattern}', remote='{owner_repo}'", + ) + + candidates.sort(key=lambda c: c[0], reverse=True) + best_sv, best_tag, best_sha = candidates[0] + + return ResolvedPackage( + name=entry.name, + source_repo=owner_repo, + subdir=entry.subdir, + ref=best_tag, + sha=best_sha, + requested_version=version_range, + tags=entry.tags, + is_prerelease=best_sv.is_prerelease, + host=self._effective_host(entry.host), # type: ignore[attr-defined] + ) + + # -- concurrent resolution ---------------------------------------------- + + def resolve(self) -> ResolveResult: + """Resolve every entry concurrently. + + Returns + ------- + ResolveResult + Contains resolved entries and any errors encountered. + + Raises + ------ + BuildError + On any resolution failure (unless ``continue_on_error``). + """ + yml = self._load_yml() # type: ignore[attr-defined] + entries = yml.packages + if not entries: + return ResolveResult(entries=(), errors=()) + + results: dict[int, ResolvedPackage] = {} + errors: list[tuple[str, str]] = [] + + self._get_resolver() # type: ignore[attr-defined] + for entry in entries: + if entry.host: + self._get_resolver_for_host(entry.host) # type: ignore[attr-defined] + + options: BuildOptions = self._options # type: ignore[attr-defined] + with ThreadPoolExecutor(max_workers=min(options.concurrency, len(entries))) as pool: + future_to_index = { + pool.submit(self._resolve_entry, entry): idx for idx, entry in enumerate(entries) + } + for future in as_completed(future_to_index): + idx = future_to_index[future] + entry = entries[idx] + try: + resolved = future.result(timeout=options.timeout_seconds) + results[idx] = resolved + except BuildError as exc: + if options.continue_on_error: + errors.append((entry.name, str(exc))) + else: + raise + except Exception as exc: + logger.debug("Unexpected error resolving '%s'", entry.name, exc_info=True) + if options.continue_on_error: + errors.append((entry.name, str(exc))) + else: + raise BuildError( + f"Unexpected error resolving '{entry.name}': {exc}", + package=entry.name, + ) from exc + + ordered: list[ResolvedPackage] = [] + for idx in range(len(entries)): + if idx in results: + ordered.append(results[idx]) + return ResolveResult(entries=tuple(ordered), errors=tuple(errors)) + + # -- remote description fetcher ----------------------------------------- + + def _fetch_remote_metadata(self, pkg: ResolvedPackage) -> dict[str, str] | None: + """Best-effort: fetch ``description`` and ``version`` from the + package's remote ``apm.yml``. + + urllib Rule B: ``urllib`` is accessed via ``_b.urllib`` (late import of + ``builder`` module) so that test patches on + ``apm_cli.marketplace.builder.urllib.request.urlopen`` remain effective. + """ + try: + path_prefix = f"{pkg.subdir}/" if pkg.subdir else "" + file_path = f"{path_prefix}apm.yml" + + effective_host = pkg.host or self._host # type: ignore[attr-defined] + if pkg.host is None or pkg.host == self._host: # type: ignore[attr-defined] + host_info = self._host_info # type: ignore[attr-defined] + token = self._github_token # type: ignore[attr-defined] + else: + from ..core.auth import AuthResolver # lazy import + + try: + host_info = AuthResolver.classify_host(effective_host) + except Exception: + host_info = None + token = self._resolve_token_for_host(effective_host) # type: ignore[attr-defined] + + host_kind = host_info.kind if host_info else "github" + + if host_kind not in ("github", "ghe_cloud", "ghes"): + logger.debug( + "Skipping metadata fetch for %s (non-GitHub host: %s)", + pkg.name, + effective_host, + ) + return None + + if host_kind == "ghe_cloud" and not token: + logger.debug( + "Skipping metadata fetch for %s (GHE Cloud requires auth)", + pkg.name, + ) + return None + + # Rule B: access urllib via builder module so patch target is preserved + from apm_cli.marketplace import builder as _b + + if effective_host == "github.com": + url = f"https://raw.githubusercontent.com/{pkg.source_repo}/{pkg.sha}/{file_path}" + req = _b.urllib.request.Request(url) + if token: + req.add_header("Authorization", f"token {token}") + else: + api_base = ( + host_info.api_base if host_info else None + ) or f"https://{effective_host}/api/v3" + url = f"{api_base}/repos/{pkg.source_repo}/contents/{file_path}?ref={pkg.sha}" + req = _b.urllib.request.Request(url) + req.add_header("Accept", "application/vnd.github.raw") + if token: + req.add_header("Authorization", f"token {token}") + + with _b.urllib.request.urlopen(req, timeout=5) as resp: + raw = resp.read().decode("utf-8") + data = yaml.safe_load(raw) + if not isinstance(data, dict): + return None + result: dict[str, str] = {} + desc = data.get("description") + if isinstance(desc, str) and desc: + result["description"] = desc + ver = data.get("version") + if ver is not None: + ver_str = str(ver).strip() + if ver_str: + result["version"] = ver_str + if result: + logger.debug( + "Fetched metadata for %s from remote apm.yml: %s", + pkg.name, + ", ".join(result.keys()), + ) + return result + except Exception: + logger.debug( + "Could not fetch remote metadata for %s", + pkg.name, + exc_info=True, + ) + return None diff --git a/src/apm_cli/marketplace/_client_cache.py b/src/apm_cli/marketplace/_client_cache.py new file mode 100644 index 000000000..fb6d751e7 --- /dev/null +++ b/src/apm_cli/marketplace/_client_cache.py @@ -0,0 +1,158 @@ +"""Cache I/O helpers for the marketplace JSON sidecar cache. + +Extracted from client.py to keep module complexity bounded. +All functions in this module are private to the marketplace package; +``client.py`` re-imports them so callers see no change. + +The only external dependency is the shared config dir (lazy-imported at +call time to avoid circular imports with client.py). +""" + +import contextlib +import json +import logging +import os +import re +import time +from urllib.parse import urlsplit + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_CACHE_TTL_SECONDS = 3600 # 1 hour +_CACHE_DIR_NAME = os.path.join("cache", "marketplace") + +# --------------------------------------------------------------------------- +# URL utility (used by cache key; kept here so client can re-import it) +# --------------------------------------------------------------------------- + + +def _host_from_url(url: str) -> str: + """Extract host from a URL (handles SCP-like SSH URLs too).""" + if not url: + return "" + # SCP-like: git@host:path + if "@" in url and not url.startswith(("http", "git://", "ssh://", "file://")): + try: + return url.split("@", 1)[1].split(":", 1)[0] + except (IndexError, ValueError): + return "" + try: + return urlsplit(url).hostname or "" + except ValueError: + return "" + + +# --------------------------------------------------------------------------- +# Cache directory helpers +# --------------------------------------------------------------------------- + + +def _cache_dir() -> str: + """Return the cache directory, creating it if needed.""" + from ..config import CONFIG_DIR + + d = os.path.join(CONFIG_DIR, _CACHE_DIR_NAME) + os.makedirs(d, exist_ok=True) + return d + + +def _sanitize_cache_name(name: str) -> str: + """Sanitize marketplace name for safe use in file paths.""" + from ..utils.path_security import PathTraversalError, validate_path_segments + + safe = re.sub(r"[^a-zA-Z0-9._-]", "_", name) + # Prevent path traversal even after sanitization + safe = safe.strip(".").strip("_") or "unnamed" + # Defense-in-depth: validate with centralized path security + try: + validate_path_segments(safe, context="cache name") + except PathTraversalError: + safe = "unnamed" + return safe + + +def _cache_key(source) -> str: + """Cache key that includes kind+host to avoid collisions across hosts.""" + kind = source.kind + if kind == "local": + return f"local__{_sanitize_cache_name(source.name)}" + if kind == "git": + # Generic git: include host so a.com/o/r vs b.com/o/r never collapse. + host = _host_from_url(source.url) or source.host or "unknown" + return f"git__{_sanitize_cache_name(host)}__{_sanitize_cache_name(source.name)}" + normalized_host = (source.host or "github.com").lower() + if normalized_host == "github.com": + return source.name + return f"{_sanitize_cache_name(normalized_host)}__{source.name}" + + +def _cache_data_path(name: str) -> str: + return os.path.join(_cache_dir(), f"{_sanitize_cache_name(name)}.json") + + +def _cache_meta_path(name: str) -> str: + return os.path.join(_cache_dir(), f"{_sanitize_cache_name(name)}.meta.json") + + +# --------------------------------------------------------------------------- +# Cache read / write / clear +# --------------------------------------------------------------------------- + + +def _read_cache(name: str) -> dict | None: + """Read cached marketplace data if valid (not expired).""" + data_path = _cache_data_path(name) + meta_path = _cache_meta_path(name) + if not os.path.exists(data_path) or not os.path.exists(meta_path): + return None + try: + with open(meta_path, encoding="utf-8") as f: + meta = json.load(f) + fetched_at = meta.get("fetched_at", 0) + ttl = meta.get("ttl_seconds", _CACHE_TTL_SECONDS) + if time.time() - fetched_at > ttl: + return None # Expired + with open(data_path, encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, OSError, KeyError) as exc: + logger.debug("Cache read failed for '%s': %s", name, exc) + return None + + +def _read_stale_cache(name: str) -> dict | None: + """Read cached data even if expired (stale-while-revalidate).""" + data_path = _cache_data_path(name) + if not os.path.exists(data_path): + return None + try: + with open(data_path, encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + return None + + +def _write_cache(name: str, data: dict) -> None: + """Write marketplace data and metadata to cache.""" + data_path = _cache_data_path(name) + meta_path = _cache_meta_path(name) + try: + with open(data_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + with open(meta_path, "w", encoding="utf-8") as f: + json.dump( + {"fetched_at": time.time(), "ttl_seconds": _CACHE_TTL_SECONDS}, + f, + ) + except OSError as exc: + logger.debug("Cache write failed for '%s': %s", name, exc) + + +def _clear_cache(name: str) -> None: + """Remove cached data for a marketplace.""" + for path in (_cache_data_path(name), _cache_meta_path(name)): + with contextlib.suppress(OSError): + os.remove(path) diff --git a/src/apm_cli/marketplace/_publish_state.py b/src/apm_cli/marketplace/_publish_state.py new file mode 100644 index 000000000..abb9c384c --- /dev/null +++ b/src/apm_cli/marketplace/_publish_state.py @@ -0,0 +1,208 @@ +"""Publish data model and transactional state file. + +Extracted from publisher.py to keep module complexity bounded. +All public symbols are re-exported from publisher.py so existing +import paths (tests, patches) keep working unchanged. + +No module-level import of publisher.py (cycle-safe). +""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any + +from ..utils.path_security import ensure_path_within +from ._io import atomic_write + +# --------------------------------------------------------------------------- +# Validation regexes (used by ConsumerTarget) +# --------------------------------------------------------------------------- + +_REPO_RE = re.compile(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$") +_BRANCH_SAFE_RE = re.compile(r"^[a-zA-Z0-9._/-]+$") + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ConsumerTarget: + """A consumer repository whose ``apm.yml`` should be updated.""" + + repo: str # e.g. "acme-org/service-a" + branch: str = "main" # base branch on the consumer to PR into + path_in_repo: str = "apm.yml" # location of the consumer's apm.yml + + def __post_init__(self) -> None: + if not _REPO_RE.match(self.repo): + raise ValueError( + f"ConsumerTarget.repo must be in 'owner/name' format " + f"using only alphanumerics, dots, hyphens, and underscores. " + f"Got: {self.repo!r}" + ) + if not _BRANCH_SAFE_RE.match(self.branch) or ".." in self.branch: + raise ValueError( + f"ConsumerTarget.branch contains disallowed characters. " + f"Only alphanumerics, dots, hyphens, underscores, and " + f"forward slashes are permitted (no '..' sequences). " + f"Got: {self.branch!r}" + ) + from ..utils.path_security import validate_path_segments + + validate_path_segments(self.path_in_repo, context="consumer-targets path_in_repo") + + +@dataclass(frozen=True) +class PublishPlan: + """Computed plan for a publish run -- frozen and deterministic.""" + + marketplace_name: str # name from the local marketplace.yml + marketplace_version: str # version from the local marketplace.yml + targets: tuple[ConsumerTarget, ...] + commit_message: str # pre-computed, contains the APM trailer + branch_name: str # pre-computed, deterministic + new_ref: str # rendered tag, e.g. "v2.0.0" + tag_pattern_used: str # tag pattern, e.g. "v{version}" + short_hash: str = "" # deterministic hash suffix for the branch name + allow_downgrade: bool = False + allow_ref_change: bool = False + target_package: str | None = None + + +class PublishOutcome(str, Enum): + """Outcome of processing a single consumer target.""" + + UPDATED = "updated" + NO_CHANGE = "no-change" + SKIPPED_DOWNGRADE = "skipped-downgrade" + SKIPPED_REF_CHANGE = "skipped-ref-change" + FAILED = "failed" + + +@dataclass(frozen=True) +class TargetResult: + """Result of processing a single consumer target.""" + + target: ConsumerTarget + outcome: PublishOutcome + message: str # human-readable detail + old_version: str | None = None + new_version: str | None = None + + +# --------------------------------------------------------------------------- +# Transactional state file +# --------------------------------------------------------------------------- + +_STATE_FILENAME = "publish-state.json" +_STATE_DIR = ".apm" +_MAX_HISTORY = 10 +_SCHEMA_VERSION = 1 + + +class PublishState: + """Transactional state file for publish runs. + + State is persisted at ``.apm/publish-state.json`` relative to the + marketplace repo root. All writes are atomic (write-tmp + fsync + + ``os.replace``). + """ + + def __init__(self, root: Path) -> None: + self._root = root.resolve() + self._state_dir = self._root / _STATE_DIR + self._state_path = self._state_dir / _STATE_FILENAME + self._data: dict[str, Any] = { + "schemaVersion": _SCHEMA_VERSION, + "lastRun": None, + "history": [], + } + + @classmethod + def load(cls, root: Path) -> PublishState: + """Load state from disk or return a fresh instance. + + A missing file or corrupt JSON both result in a fresh state -- + no exception is raised. + """ + instance = cls(root) + if instance._state_path.exists(): + try: + text = instance._state_path.read_text(encoding="utf-8") + data = json.loads(text) + if isinstance(data, dict): + instance._data = data + except (json.JSONDecodeError, OSError): + pass # start fresh on corrupt state + return instance + + def _atomic_write(self) -> None: + """Write state atomically via temp file + fsync + os.replace. + + Path validation and directory creation happen here; the actual + write is delegated to the shared ``atomic_write()`` helper from + ``_io.py``. + """ + ensure_path_within(self._state_dir, self._root) + self._state_dir.mkdir(parents=True, exist_ok=True) + + content = json.dumps(self._data, indent=2) + "\n" + atomic_write(self._state_path, content) + + def begin_run(self, plan: PublishPlan) -> None: + """Start a new publish run -- writes ``startedAt``.""" + self._data["lastRun"] = { + "startedAt": datetime.now(timezone.utc).isoformat(), + "finishedAt": None, + "marketplaceName": plan.marketplace_name, + "marketplaceVersion": plan.marketplace_version, + "branchName": plan.branch_name, + "results": [], + } + self._atomic_write() + + def record_result(self, result: TargetResult) -> None: + """Append a target result to the current run.""" + if self._data.get("lastRun") is None: + return + self._data["lastRun"]["results"].append( + { + "repo": result.target.repo, + "outcome": result.outcome.value, + "message": result.message, + "oldVersion": result.old_version, + "newVersion": result.new_version, + } + ) + self._atomic_write() + + def finalise(self, finished_at: datetime) -> None: + """Finalise the current run and rotate history.""" + if self._data.get("lastRun") is None: + return + self._data["lastRun"]["finishedAt"] = finished_at.isoformat() + + # Rotate history -- keep at most _MAX_HISTORY entries + history = self._data.get("history", []) + history.insert(0, dict(self._data["lastRun"])) + self._data["history"] = history[:_MAX_HISTORY] + self._atomic_write() + + def abort(self, reason: str) -> None: + """Mark the current run as aborted.""" + if self._data.get("lastRun") is None: + return + self._data["lastRun"]["finishedAt"] = f"ABORTED: {reason}" + self._atomic_write() + + @property + def data(self) -> dict[str, Any]: + """Return the raw state data (read-only snapshot for inspection).""" + return dict(self._data) diff --git a/src/apm_cli/marketplace/_resolver_match.py b/src/apm_cli/marketplace/_resolver_match.py new file mode 100644 index 000000000..3c8989e22 --- /dev/null +++ b/src/apm_cli/marketplace/_resolver_match.py @@ -0,0 +1,342 @@ +"""Cross-repo misconfig detection and in-marketplace source-matching helpers. + +Extracted from resolver.py to keep module complexity bounded. +All symbols are re-exported from resolver.py so existing import paths +(tests, patches) keep working unchanged. + +No module-level import of resolver.py (cycle-safe). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from ..utils.github_host import ( + is_azure_devops_hostname, + is_github_hostname, + is_supported_git_host, +) + +if TYPE_CHECKING: + from ..models.dependency.reference import DependencyReference + from .models import MarketplacePlugin, MarketplaceSource + + +# --------------------------------------------------------------------------- +# CrossRepoMisconfigRisk sentinel +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class CrossRepoMisconfigRisk: + """Signal that a cross-repo dict ``type: github`` source on an enterprise + GitHub-family marketplace declares a bare ``owner/repo`` whose canonical + falls back to ``github.com`` -- the same syntactic ambiguity that powers + a dependency-confusion attack (#1326, formerly diagnosed only as #1305). + + Attached to :class:`MarketplacePluginResolution` when the marketplace is on + ``*.ghe.com`` and the plugin's dict source declares a bare ``owner/repo`` + that does not match the marketplace project. The resolver deliberately + leaves these canonicals bare (PR #1292 scoped its host backfill to + in-marketplace sources), so ``DependencyReference.parse`` defaults the host + to ``github.com``. Two intents share this syntax -- a legitimate cross-host + ``github.com`` open-source dep, or a misconfigured same-host entry that + should have been ``corp.ghe.com/owner/repo`` -- and the resolver cannot + distinguish them. + + Consumer contract (#1326): the install command consults this sentinel + BEFORE any outbound validation HTTP call and refuses the package + fail-closed when it is non-``None``. The earlier #1305 design surfaced + only an advisory hint on validation failure, which left the success + path (attacker pre-stages the bare namespace on public github.com) + silently exploitable. Cross-host explicit qualification by the + marketplace author -- ``repo: github.com/owner/repo`` -- prevents + the sentinel from attaching at the resolver layer (see + :func:`_compute_cross_repo_misconfig_risk`), which is the supported + escape hatch for declared cross-host intent. + """ + + marketplace_host: str + bare_repo_field: str + suggested_qualified_repo: str + + +# --------------------------------------------------------------------------- +# Owner/repo slug normalisation +# --------------------------------------------------------------------------- + + +def _normalize_owner_repo_slug(repo: str) -> str: + """Lowercase ``owner/repo`` slug with optional ``.git`` suffix stripped.""" + r = repo.strip().rstrip("/").lower() + if r.endswith(".git"): + r = r[:-4] + return r + + +def _marketplace_project_slug(owner: str, repo: str) -> str: + return _normalize_owner_repo_slug(f"{owner}/{repo}") + + +def _normalize_repo_field_for_match(repo_field: str, marketplace_host: str) -> str: + """Normalize a repo field to a logical project path for matching. + + Accept bare ``owner/repo`` paths, host-qualified shorthand like + ``git.epam.com/owner/repo``, and URL / SSH forms. If the field explicitly names + a different host than the marketplace host, return an empty string so it does + not match by suffix alone. + """ + raw = repo_field.strip().rstrip("/") + if raw.endswith(".git"): + raw = raw[:-4] + + host_l = marketplace_host.strip().lower() + + if raw.startswith(("http://", "https://", "ssh://")): + parsed = urlparse(raw) + parsed_host = (parsed.hostname or "").strip().lower() + if parsed_host and parsed_host != host_l: + return "" + return parsed.path.lstrip("/").lower() + + if raw.startswith("git@") and ":" in raw: + host_part, path_part = raw[4:].split(":", 1) + if host_part.strip().lower() != host_l: + return "" + return path_part.lstrip("/").lower() + + parts = [p for p in raw.split("/") if p] + if len(parts) >= 3 and parts[0].strip().lower() == host_l: + parts = parts[1:] + return "/".join(parts).lower() + + +def _repo_field_matches_marketplace( + repo_field: str, owner: str, repo: str, marketplace_host: str +) -> bool: + """True if dict ``repo`` identifies the same project as the marketplace source.""" + if not repo_field or "/" not in repo_field: + return False + normalized_repo = _normalize_repo_field_for_match(repo_field, marketplace_host) + if not normalized_repo: + return False + return normalized_repo == _marketplace_project_slug(owner, repo) + + +def _coerce_dict_plugin_type(s: dict) -> str: + """Return normalized source ``type`` for a plugin entry dict (``type`` / ``source`` / ``kind``). + + ``type`` is case-insensitive. When it is missing, infers ``github`` or + ``git-subdir`` from ``repo`` plus path fields so in-marketplace matching and + ``path``/``subdir`` extraction match manifests that only set ``kind`` or omit + ``type`` (still require a valid ``repo`` for dict sources). + """ + for key in ("type", "source", "kind"): + v = s.get(key, "") + if isinstance(v, str) and v.strip(): + return v.strip().lower() + repo = s.get("repo", "") + if not isinstance(repo, str) or "/" not in repo.strip(): + return "" + subdir = s.get("subdir", "") + if isinstance(subdir, str) and subdir.strip(): + return "git-subdir" + path = s.get("path", "") + if isinstance(path, str) and path.strip(): + return "github" + return "github" + + +def _is_in_marketplace_source(plugin: MarketplacePlugin, source: MarketplaceSource) -> bool: + """Per spec §Interface Contract -- in-marketplace detection.""" + s = plugin.source + if s is None: + return False + if isinstance(s, str): + return True + if not isinstance(s, dict): + return False + source_type = _coerce_dict_plugin_type(s) + if source_type in ("github", "git-subdir", "gitlab"): + return _repo_field_matches_marketplace( + s.get("repo", ""), source.owner, source.repo, source.host + ) + return False + + +# --------------------------------------------------------------------------- +# Host-routing helpers +# --------------------------------------------------------------------------- + + +def _marketplace_host_needs_explicit_git_path(host: str) -> bool: + """True when in-repo marketplace plugins must use ``git`` + ``path`` (clone root + subdir). + + ``github.com`` and ``*.ghe.com`` virtual shorthand is reliable. Azure DevOps uses + a different URL shape and is excluded. Self-managed GitLab FQDNs are often + classified as ``generic`` by :meth:`AuthResolver.classify_host` when not listed in + ``GITLAB_HOST`` / ``APM_GITLAB_HOSTS`` -- they still need explicit clone URLs so + paths like ``registry/pkg`` are not treated as extra project namespace segments. + """ + if not host or not str(host).strip(): + return False + h = str(host).strip().split("/", 1)[0] + if is_azure_devops_hostname(h): + return False + return not is_github_hostname(h) + + +def _source_needs_explicit_git_path(source: MarketplaceSource) -> bool: + """Kind-aware variant of :func:`_marketplace_host_needs_explicit_git_path`. + + For URL-first sources, the ``kind`` derivation already encodes the routing + decision: any host APM doesn't classify as github-family needs the explicit + git+path canonical (mirrors the existing GitLab self-managed pattern), and + that now includes Azure DevOps and generic git hosts since their + ``marketplace.json`` is fetched via subprocess git instead of an API. + + Local marketplaces handle relative sources via :func:`_resolve_local_relative_source` + on the fast path and never reach this helper. + """ + kind = source.kind + if kind == "github": + return False + if kind in ("gitlab", "git"): + return True + # Fall back to legacy host-based behaviour for any kind we don't recognise + return _marketplace_host_needs_explicit_git_path(source.host) + + +def _needs_canonical_host_prefix(canonical: str, host: str) -> bool: + """True when a GitHub-family enterprise host must be prefixed to ``canonical``. + + GitHub-family hosts (``github.com`` + ``*.ghe.com``) keep virtual shorthand -- + ``resolve_plugin_source`` emits a bare ``owner/repo[/path]`` canonical because + there is no nested-group ambiguity to disambiguate. ``DependencyReference.parse`` + defaults missing hosts to ``github.com``, which is correct for ``github.com`` but + silently mis-routes auth for every ``*.ghe.com`` marketplace. + + Returns True only for enterprise GitHub hosts (``*.ghe.com``) so the caller can + backfill the host while preserving shorthand semantics. Idempotent: when the + canonical already starts with ``host`` (case-insensitive) -- as happens when the + manifest's dict source carries a host-qualified ``repo`` -- this returns False + so the prefix is not duplicated. + + GHES (GitHub Enterprise Server, configured via ``GITHUB_HOST``) is not handled + here. Those hosts return True from ``_marketplace_host_needs_explicit_git_path`` + (neither GitHub-family nor ADO) so ``resolve_marketplace_plugin`` builds a + structured ``dep_ref`` upstream and this helper is never reached. The + ``is_github_hostname`` check below is defense-in-depth that would also reject + them if a future change ever bypassed the upstream guard. + + Also returns False when ``canonical`` is in URL form (``https://...``) or SSH + SCP shorthand (``git@host:owner/repo``). Manifests that put a full URL in the + ``repo`` field reach this point via ``_resolve_github_source`` (which only + requires a ``/``); detecting those by ``":"`` in the first slash-split segment + avoids producing malformed ``host/https://...`` canonicals. Those forms already + carry a host and ``DependencyReference.parse`` resolves them natively. + """ + h = (host or "").strip() + if not h or not is_github_hostname(h) or h.lower() == "github.com": + return False + first_segment = canonical.split("/", 1)[0] + if ":" in first_segment: + return False + return first_segment.lower() != h.lower() + + +# --------------------------------------------------------------------------- +# Cross-repo misconfig risk computation +# --------------------------------------------------------------------------- + + +def _cross_repo_early_exit( + plugin: MarketplacePlugin, + source: MarketplaceSource, + canonical: str, + dep_ref: DependencyReference | None, +) -> bool: + """Return True when ``_compute_cross_repo_misconfig_risk`` should short-circuit. + + Consolidates the five guard conditions that all lead to ``return None`` in + the parent so it stays within the PLR0911 return-statement budget. + """ + if dep_ref is not None: + return True + if not isinstance(plugin.source, dict): + return True + if _coerce_dict_plugin_type(plugin.source) != "github": + return True + if _is_in_marketplace_source(plugin, source): + return True + return not _needs_canonical_host_prefix(canonical, source.host) + + +def _compute_cross_repo_misconfig_risk( + plugin: MarketplacePlugin, + source: MarketplaceSource, + canonical: str, + dep_ref: DependencyReference | None, +) -> CrossRepoMisconfigRisk | None: + """Identify the #1305 misconfiguration: cross-repo dict ``type: github`` + source with bare ``repo`` on an enterprise GitHub-family marketplace. + + Returns a :class:`CrossRepoMisconfigRisk` when **all** of: + + - ``dep_ref`` is ``None`` (GitHub-family virtual-shorthand path; GitLab and + self-managed FQDNs build a structured ref upstream and sidestep the bug) + - ``plugin.source`` is a dict whose normalized type is ``github`` (other + dict types -- ``gitlab``, ``git-subdir`` -- hit the same auth-routing + bug but the "host-qualify with marketplace host" remediation only + matches operator intent for the GitHub family) + - the source is **not** an in-marketplace reference (PR #1292 already + backfills the host for those) + - ``_needs_canonical_host_prefix`` agrees the canonical is bare and the + host is GitHub-family enterprise (``*.ghe.com``; idempotent against + already host-qualified, URL, and SSH forms) + - the ``repo`` field is a non-empty ``owner/repo`` shorthand + + Otherwise returns ``None``. Pure -- no logging, no side effects. + """ + if _cross_repo_early_exit(plugin, source, canonical, dep_ref): + return None + repo_field = plugin.source.get("repo", "") # type: ignore[union-attr] + if not isinstance(repo_field, str): + return None + bare = repo_field.strip().lstrip("/") + if "/" not in bare: + return None + # #1326: an already-host-qualified `repo:` field declares explicit intent + # (e.g. ``repo: github.com/owner/repo`` on a ``*.ghe.com`` marketplace is + # an unambiguous declared cross-host dependency). Only the truly-bare + # ``owner/repo`` form is the dependency-confusion vector this sentinel + # flags. ``_needs_canonical_host_prefix`` above already returns False + # for SAME-host qualification (its idempotency clause) and for URL / + # SSH SCP shorthand canonicals; this is the symmetric guard for the + # remaining case -- CROSS-host shorthand qualification (``github.com/...`` + # on a ``*.ghe.com`` marketplace), which the idempotency check cannot + # detect because the canonical starts with a different host than + # ``source.host``. + # + # Defense in depth: extract the host from URL and SCP shorthand forms + # too, so the guard is robust even if a future upstream refactor lets + # those forms reach this point. + explicit_host = "" + bare_lower = bare.lower() + if bare_lower.startswith(("https://", "http://", "ssh://")): + explicit_host = (urlparse(bare).hostname or "").strip() + elif bare.startswith("git@") and ":" in bare: + # SCP shorthand: ``git@host:owner/repo`` + explicit_host = bare[4:].split(":", 1)[0].strip() + else: + explicit_host = bare.split("/", 1)[0] + # ``is_supported_git_host`` accepts any valid FQDN, not an allowlist. + if is_supported_git_host(explicit_host): + return None + return CrossRepoMisconfigRisk( + marketplace_host=source.host, + bare_repo_field=bare, + suggested_qualified_repo=f"{source.host}/{bare}", + ) diff --git a/src/apm_cli/marketplace/_yml_models.py b/src/apm_cli/marketplace/_yml_models.py new file mode 100644 index 000000000..23bed1429 --- /dev/null +++ b/src/apm_cli/marketplace/_yml_models.py @@ -0,0 +1,177 @@ +"""Dataclasses for marketplace authoring configuration. + +Leaf module -- contains only frozen dataclasses. No imports from +``yml_schema`` or ``_yml_parsers`` (cycle-safe). All public symbols +are re-exported by ``yml_schema`` so existing import paths continue to +work. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Mapping # noqa: UP035 + +from .output_profiles import MARKETPLACE_OUTPUTS + +__all__ = [ + "MarketplaceBuild", + "MarketplaceClaudeConfig", + "MarketplaceCodexConfig", + "MarketplaceConfig", + "MarketplaceOutputSpec", + "MarketplaceOwner", + "MarketplaceVersioning", + "PackageEntry", +] + + +@dataclass(frozen=True) +class MarketplaceOwner: + """Owner block of ``marketplace.yml``.""" + + name: str + email: str | None = None + url: str | None = None + + +@dataclass(frozen=True) +class MarketplaceBuild: + """APM-only build configuration block.""" + + tag_pattern: str = "v{version}" + + +@dataclass(frozen=True) +class MarketplaceVersioning: + """Release-time versioning strategy for the marketplace. + + Controls how ``apm pack --check-versions`` verifies per-package + version alignment across local-path packages: + + * ``lockstep`` (default) -- every local package's top-level + ``version`` must equal the marketplace's top-level ``version``. + * ``tag_pattern`` -- each rendered tag must be unique across all + local packages; missing ``version`` still fails. + * ``per_package`` -- only requires that each local package declare + a ``version``; equality is not enforced. + """ + + strategy: str = "lockstep" + + +@dataclass(frozen=True) +class MarketplaceClaudeConfig: + """Claude-specific marketplace output configuration.""" + + output: str = ".claude-plugin/marketplace.json" + + +@dataclass(frozen=True) +class MarketplaceCodexConfig: + """Codex-specific marketplace output configuration.""" + + output: str = MARKETPLACE_OUTPUTS["codex"].default_output + + +@dataclass(frozen=True) +class PackageEntry: + """A single entry in the ``packages`` list. + + Attributes that are Anthropic pass-through (``description``, + ``homepage``, ``tags``) are stored alongside APM-only attributes + (``subdir``, ``version``, ``ref``, ``tag_pattern``, + ``include_prerelease``) so the builder can partition them at + compile time. + + ``is_local`` is derived by the loader from the ``source`` field -- + a leading ``./`` marks a local-path package that skips git + resolution. + """ + + name: str + source: str + # APM-only fields + subdir: str | None = None + version: str | None = None + ref: str | None = None + tag_pattern: str | None = None + include_prerelease: bool = False + # Anthropic pass-through fields + description: str | None = None + homepage: str | None = None + tags: tuple[str, ...] = () + # ``author`` is normalized to a Claude-Code-compliant object: + # ``{"name": str, "email"?: str, "url"?: str}``. Accepts either a + # bare string (treated as ``name``) or a mapping at parse time. + author: Mapping[str, str] | None = None + license: str | None = None + repository: str | None = None + # Marketplace category metadata. Emitted only by output formats that + # consume categories, currently Codex repo marketplace output. + category: str | None = None + # Derived (set by loader, not by user) + is_local: bool = False + # Optional non-default git host parsed from ``source`` of the form + # ``host.tld/owner/repo``. ``None`` means use the default host + # (``GITHUB_HOST`` env or ``github.com``). + host: str | None = None + + +@dataclass(frozen=True) +class MarketplaceOutputSpec: + """Resolved specification for one marketplace output format. + + Produced by the map-form ``outputs:`` parser. When ``path_explicit`` + is True, the manifest set an explicit ``path:`` value (vs. the + profile default). + """ + + name: str + """Format name (matches a key in ``MARKETPLACE_OUTPUTS``).""" + + path: str + """Resolved output path (explicit or profile default).""" + + path_explicit: bool = False + """True if the user set an explicit ``path:`` in the outputs map.""" + + +@dataclass(frozen=True) +class MarketplaceConfig: + """Parsed marketplace configuration. + + May originate from apm.yml's ``marketplace:`` block (current) or + from a standalone ``marketplace.yml`` (legacy, deprecated). + + ``metadata`` is stored as a plain ``dict`` preserving the original + key casing so the builder can forward it verbatim to + ``marketplace.json``. + + Override flags (``*_overridden``) record whether the marketplace + block explicitly set each inheritable field. The builder uses + these flags to decide whether to emit ``description``/``version`` + at the top level of ``marketplace.json`` -- per the Anthropic + azure-skills convention, inherited values are omitted from output. + """ + + name: str + description: str + version: str + owner: MarketplaceOwner + output: str = ".claude-plugin/marketplace.json" + outputs: tuple[str, ...] = ("claude",) + claude: MarketplaceClaudeConfig = field(default_factory=MarketplaceClaudeConfig) + codex: MarketplaceCodexConfig = field(default_factory=MarketplaceCodexConfig) + metadata: dict[str, Any] = field(default_factory=dict) + build: MarketplaceBuild = field(default_factory=MarketplaceBuild) + versioning: MarketplaceVersioning = field(default_factory=MarketplaceVersioning) + packages: tuple[PackageEntry, ...] = () + output_specs: tuple[MarketplaceOutputSpec, ...] = () + warnings: tuple[str, ...] = () + # Origin tracking + override-detection metadata + source_path: Path | None = None + is_legacy: bool = False + name_overridden: bool = False + description_overridden: bool = False + version_overridden: bool = False diff --git a/src/apm_cli/marketplace/_yml_parsers.py b/src/apm_cli/marketplace/_yml_parsers.py new file mode 100644 index 000000000..674ddb095 --- /dev/null +++ b/src/apm_cli/marketplace/_yml_parsers.py @@ -0,0 +1,733 @@ +"""Parse helpers and validation functions for marketplace YAML configs. + +Leaf module -- imports from ``._yml_models`` but never from +``yml_schema`` (cycle-safe). All public symbols used by ``yml_schema`` +are imported from here. +""" + +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any + +import yaml + +from ..utils.path_security import PathTraversalError, validate_path_segments +from ._yml_models import ( + MarketplaceBuild, + MarketplaceClaudeConfig, + MarketplaceCodexConfig, + MarketplaceOutputSpec, + MarketplaceOwner, + MarketplaceVersioning, + PackageEntry, +) +from .errors import MarketplaceYmlError +from .output_profiles import MARKETPLACE_OUTPUTS, known_output_names + +__all__ = [ + "LOCAL_SOURCE_RE", + "SOURCE_RE", + "split_host_from_source", +] + +# --------------------------------------------------------------------------- +# Semver validation (regex, no external lib) +# --------------------------------------------------------------------------- + +_SEMVER_RE = re.compile( + r"^\d+\.\d+\.\d+" + r"(?:-[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?" + r"(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?$" +) + +# --------------------------------------------------------------------------- +# Source field patterns +# --------------------------------------------------------------------------- + +# Source field accepts: +# - ``owner/repo`` (remote, default host) +# - ``host.tld/owner/repo`` (remote on a non-default host, shorthand) +# - ``https://host.tld/owner/repo`` (remote on a non-default host, full URL) +# - ``https://host.tld/owner/repo.git`` (same, with optional ``.git`` suffix) +# - ``./...`` (local path within the same repo) +_HOST_PAT = r"(?:[A-Za-z0-9](?:[A-Za-z0-9-]*[A-Za-z0-9])?\.)+[A-Za-z][A-Za-z0-9-]*" +_OWNER_REPO_PAT = r"[A-Za-z0-9._-]+/[A-Za-z0-9._-]+" + +SOURCE_RE = re.compile( + r"^(?:" + rf"https://{_HOST_PAT}/{_OWNER_REPO_PAT}(?:\.git)?" + rf"|{_HOST_PAT}/{_OWNER_REPO_PAT}" + rf"|{_OWNER_REPO_PAT}" + r"|\./.*" + r")$" +) +LOCAL_SOURCE_RE = re.compile(r"^\./") +_HOST_PREFIXED_SOURCE_RE = re.compile(rf"^({_HOST_PAT})/({_OWNER_REPO_PAT})$") +_HTTPS_URL_SOURCE_RE = re.compile(rf"^https://({_HOST_PAT})/({_OWNER_REPO_PAT})(?:\.git)?$") + +# --------------------------------------------------------------------------- +# Tag-pattern placeholders +# --------------------------------------------------------------------------- + +_TAG_PLACEHOLDERS = ("{version}", "{name}") + +# --------------------------------------------------------------------------- +# Permitted key sets (strict mode) +# --------------------------------------------------------------------------- + +_BUILD_KEYS = frozenset({"tagPattern"}) + +_PACKAGE_ENTRY_KEYS = frozenset( + { + "name", + "source", + "subdir", + "version", + "ref", + "tag_pattern", + "include_prerelease", + "description", + "homepage", + "tags", + "author", + "license", + "repository", + "keywords", + "category", + } +) + +# Limits for keywords/tags array to prevent DoS via oversized manifests (S4). +_MAX_TAGS_COUNT = 50 +_MAX_TAG_LENGTH = 100 + +# Keys permitted inside an ``author`` object. +_AUTHOR_OBJECT_KEYS = frozenset({"name", "email", "url"}) + +_APM_MARKETPLACE_KEYS = frozenset( + { + "name", + "description", + "version", + "owner", + "output", + "outputs", + "claude", + "metadata", + "build", + "codex", + "packages", + "versioning", + } +) + +_VERSIONING_KEYS = frozenset({"strategy"}) +_VERSIONING_STRATEGIES = frozenset({"lockstep", "tag_pattern", "per_package"}) +_CLAUDE_KEYS = frozenset({"output"}) +_CODEX_KEYS = frozenset({"output"}) + +# --------------------------------------------------------------------------- +# Public: source field splitter +# --------------------------------------------------------------------------- + + +def split_host_from_source(source: str) -> tuple[str | None, str]: + """Split a host-qualified source into ``(host, owner/repo)``. + + Accepts both shorthand (``host.tld/owner/repo``) and full HTTPS URL + (``https://host.tld/owner/repo[.git]``) forms. Returns ``(None, source)`` + for the plain ``owner/repo`` shorthand or local ``./...`` paths. + """ + m = _HTTPS_URL_SOURCE_RE.match(source) + if m: + host, owner_repo = m.group(1), m.group(2) + if owner_repo.endswith(".git"): + owner_repo = owner_repo[: -len(".git")] + return host, owner_repo + m = _HOST_PREFIXED_SOURCE_RE.match(source) + if m: + return m.group(1), m.group(2) + return None, source + + +# --------------------------------------------------------------------------- +# Validation helpers +# --------------------------------------------------------------------------- + + +def _require_str(data: dict[str, Any], key: str, *, context: str = "") -> str: + """Return a non-empty string value or raise ``MarketplaceYmlError``.""" + path = f"{context}.{key}" if context else key + value = data.get(key) + if value is None: + raise MarketplaceYmlError(f"'{path}' is required") + if not isinstance(value, str) or not value.strip(): + raise MarketplaceYmlError(f"'{path}' must be a non-empty string") + return value.strip() + + +def _validate_semver(version: str, *, context: str = "version") -> None: + """Raise if *version* is not a valid semver string.""" + if not _SEMVER_RE.match(version): + raise MarketplaceYmlError( + f"'{context}' value '{version}' is not valid semver (expected x.y.z)" + ) + + +def _validate_source(source: str, *, index: int) -> None: + """Validate ``source`` field shape and path safety.""" + ctx = f"packages[{index}].source" + if not SOURCE_RE.match(source): + raise MarketplaceYmlError( + f"'{ctx}' must be one of " + f"'/', '//', " + f"'https:////[.git]', or './', " + f"got '{source}'" + ) + is_local = bool(LOCAL_SOURCE_RE.match(source)) + try: + validate_path_segments(source, context=ctx, allow_current_dir=is_local) + except PathTraversalError as exc: + raise MarketplaceYmlError(str(exc)) from exc + + +def _validate_tag_pattern(pattern: str, *, context: str) -> None: + """Ensure *pattern* contains at least one recognised placeholder.""" + if not any(ph in pattern for ph in _TAG_PLACEHOLDERS): + raise MarketplaceYmlError( + f"'{context}' must contain at least one of " + f"{', '.join(_TAG_PLACEHOLDERS)}, got '{pattern}'" + ) + + +def _check_unknown_keys( + data: dict[str, Any], + permitted: frozenset, + *, + context: str, +) -> None: + """Raise on any key not in *permitted*.""" + unknown = set(data.keys()) - permitted + if unknown: + raise MarketplaceYmlError( + f"Unknown key(s) in {context}: {', '.join(sorted(unknown))}. " + f"Permitted keys: {', '.join(sorted(permitted))}" + ) + + +# --------------------------------------------------------------------------- +# Internal parse helpers +# --------------------------------------------------------------------------- + + +def _parse_author(raw: Any, index: int) -> dict[str, str] | None: + """Normalize a curator-supplied ``author`` value. + + Accepts either a non-empty string (treated as ``name``) or a mapping + with at least ``name`` and only the permitted keys. + Returns ``None`` when ``raw`` is ``None``. + """ + if raw is None: + return None + ctx = f"packages[{index}].author" + if isinstance(raw, str): + name = raw.strip() + if not name: + raise MarketplaceYmlError(f"'{ctx}' must be a non-empty string or object with 'name'") + return {"name": name} + if isinstance(raw, dict): + unknown = set(raw.keys()) - _AUTHOR_OBJECT_KEYS + if unknown: + raise MarketplaceYmlError( + f"'{ctx}' has unknown key(s): " + f"{', '.join(sorted(unknown))}; allowed: " + f"{', '.join(sorted(_AUTHOR_OBJECT_KEYS))}" + ) + name = raw.get("name") + if not isinstance(name, str) or not name.strip(): + raise MarketplaceYmlError(f"'{ctx}.name' is required and must be a non-empty string") + out: dict[str, str] = {"name": name.strip()} + for key in ("email", "url"): + val = raw.get(key) + if val is None: + continue + if not isinstance(val, str) or not val.strip(): + raise MarketplaceYmlError(f"'{ctx}.{key}' must be a non-empty string") + out[key] = val.strip() + return out + raise MarketplaceYmlError(f"'{ctx}' must be a string or object, got {type(raw).__name__}") + + +def _parse_owner(raw: Any) -> MarketplaceOwner: + """Parse and validate the ``owner`` block.""" + if not isinstance(raw, dict): + raise MarketplaceYmlError("'owner' must be a mapping with at least a 'name' key") + name = _require_str(raw, "name", context="owner") + email = raw.get("email") + if email is not None: + email = str(email).strip() or None + url = raw.get("url") + if url is not None: + url = str(url).strip() or None + return MarketplaceOwner(name=name, email=email, url=url) + + +def _parse_build(raw: Any) -> MarketplaceBuild: + """Parse and validate the ``build`` block.""" + if raw is None: + return MarketplaceBuild() + if not isinstance(raw, dict): + raise MarketplaceYmlError("'build' must be a mapping") + _check_unknown_keys(raw, _BUILD_KEYS, context="build") + tag_pattern = raw.get("tagPattern", "v{version}") + if not isinstance(tag_pattern, str) or not tag_pattern.strip(): + raise MarketplaceYmlError("'build.tagPattern' must be a non-empty string") + tag_pattern = tag_pattern.strip() + _validate_tag_pattern(tag_pattern, context="build.tagPattern") + return MarketplaceBuild(tag_pattern=tag_pattern) + + +def _parse_versioning(raw: Any) -> MarketplaceVersioning: + """Parse and validate the optional ``marketplace.versioning`` block.""" + if raw is None: + return MarketplaceVersioning() + if not isinstance(raw, dict): + raise MarketplaceYmlError(f"'versioning' must be a mapping, got {type(raw).__name__}") + _check_unknown_keys(raw, _VERSIONING_KEYS, context="versioning") + strategy = raw.get("strategy", "lockstep") + if not isinstance(strategy, str) or not strategy.strip(): + raise MarketplaceYmlError("'versioning.strategy' must be a non-empty string") + strategy = strategy.strip() + if strategy not in _VERSIONING_STRATEGIES: + valid = ", ".join(sorted(_VERSIONING_STRATEGIES)) + raise MarketplaceYmlError( + f"'versioning.strategy' must be one of: {valid}; got {strategy!r}" + ) + return MarketplaceVersioning(strategy=strategy) + + +def _parse_claude(raw: Any, *, default_output: str) -> MarketplaceClaudeConfig: + """Parse and validate the optional ``marketplace.claude`` block.""" + if raw is None: + return MarketplaceClaudeConfig(output=default_output) + if not isinstance(raw, dict): + raise MarketplaceYmlError("'claude' must be a mapping") + _check_unknown_keys(raw, _CLAUDE_KEYS, context="claude") + output = raw.get("output", default_output) + if not isinstance(output, str) or not output.strip(): + raise MarketplaceYmlError("'claude.output' must be a non-empty string") + output = output.strip() + try: + validate_path_segments(output, context="claude.output") + except PathTraversalError as exc: + raise MarketplaceYmlError(str(exc)) from exc + return MarketplaceClaudeConfig(output=output) + + +def _parse_codex(raw: Any) -> MarketplaceCodexConfig: + """Parse and validate the optional ``marketplace.codex`` block.""" + if raw is None: + return MarketplaceCodexConfig() + if not isinstance(raw, dict): + raise MarketplaceYmlError("'codex' must be a mapping") + _check_unknown_keys(raw, _CODEX_KEYS, context="codex") + output = raw.get("output", MARKETPLACE_OUTPUTS["codex"].default_output) + if not isinstance(output, str) or not output.strip(): + raise MarketplaceYmlError("'codex.output' must be a non-empty string") + output = output.strip() + try: + validate_path_segments(output, context="codex.output") + except PathTraversalError as exc: + raise MarketplaceYmlError(str(exc)) from exc + return MarketplaceCodexConfig(output=output) + + +def _parse_outputs( + raw: Any, + warnings_sink: list[str] | None = None, +) -> tuple[tuple[str, ...], tuple[MarketplaceOutputSpec, ...]]: + """Parse the marketplace output selector. + + Accepts: + - ``None`` -> default (claude only). + - A list of strings -> back-compat list form (emits deprecation warning). + - A string -> single-element back-compat list form. + - A dict -> new map form with optional per-format ``path:``. + + Returns ``(outputs_tuple, output_specs_tuple)``. + """ + if raw is None: + default_spec = MarketplaceOutputSpec( + name="claude", + path=MARKETPLACE_OUTPUTS["claude"].default_output, + path_explicit=False, + ) + return ("claude",), (default_spec,) + + # --- Map form (new) --- + if isinstance(raw, dict): + return _parse_outputs_map(raw) + + # --- List / string form (deprecated back-compat) --- + if isinstance(raw, str): + raw_items: list[Any] = [raw] + elif isinstance(raw, list): + raw_items = raw + else: + raise MarketplaceYmlError("'outputs' must be a string, list, or mapping") + + outputs_list: list[str] = [] + specs_list: list[MarketplaceOutputSpec] = [] + seen_set: set[str] = set() + for index, item in enumerate(raw_items): + if not isinstance(item, str) or not item.strip(): + raise MarketplaceYmlError(f"'outputs[{index}]' must be a non-empty string") + output = item.strip() + known_outputs = known_output_names() + if output not in known_outputs: + raise MarketplaceYmlError( + f"Unknown marketplace output '{output}'. " + f"Permitted outputs: {', '.join(sorted(known_outputs))}" + ) + if output in seen_set: + raise MarketplaceYmlError(f"Duplicate marketplace output '{output}'") + seen_set.add(output) + outputs_list.append(output) + specs_list.append( + MarketplaceOutputSpec( + name=output, + path=MARKETPLACE_OUTPUTS[output].default_output, + path_explicit=False, + ) + ) + + if not outputs_list: + raise MarketplaceYmlError("'outputs' must contain at least one marketplace output") + + names_str = ", ".join(outputs_list) + map_lines = "\n".join(f" {n}: {{}}" for n in outputs_list) + deprecation_msg = ( + f"outputs: [{names_str}] is deprecated; use the map form:\n\n" + f" outputs:\n{map_lines}\n\n" + f" The list form will be removed in v0.15." + ) + if warnings_sink is not None: + warnings_sink.append(deprecation_msg) + + return tuple(outputs_list), tuple(specs_list) + + +def _parse_outputs_map( + raw: dict[Any, Any], +) -> tuple[tuple[str, ...], tuple[MarketplaceOutputSpec, ...]]: + """Parse the map form of the ``outputs:`` block.""" + outputs: list[str] = [] + specs: list[MarketplaceOutputSpec] = [] + seen: set[str] = set() + known = known_output_names() + + for key, value in raw.items(): + if not isinstance(key, str) or not key.strip(): + raise MarketplaceYmlError("'outputs' map keys must be non-empty strings") + name = key.strip() + if name not in known: + raise MarketplaceYmlError( + f"Unknown marketplace output '{name}'. " + f"Permitted outputs: {', '.join(sorted(known))}" + ) + if name in seen: + raise MarketplaceYmlError(f"Duplicate marketplace output '{name}'") + seen.add(name) + + path_explicit = False + path = MARKETPLACE_OUTPUTS[name].default_output + if value is not None: + if not isinstance(value, dict): + raise MarketplaceYmlError(f"'outputs.{name}' must be a mapping or null") + raw_path = value.get("path") + if raw_path is not None: + if not isinstance(raw_path, str) or not raw_path.strip(): + raise MarketplaceYmlError(f"'outputs.{name}.path' must be a non-empty string") + path = raw_path.strip() + path_explicit = True + try: + validate_path_segments(path, context=f"outputs.{name}.path") + except PathTraversalError as exc: + raise MarketplaceYmlError(str(exc)) from exc + unknown = set(value.keys()) - {"path"} + if unknown: + raise MarketplaceYmlError( + f"Unknown key(s) in 'outputs.{name}': {', '.join(sorted(unknown))}" + ) + + outputs.append(name) + specs.append(MarketplaceOutputSpec(name=name, path=path, path_explicit=path_explicit)) + + if not outputs: + raise MarketplaceYmlError("'outputs' must contain at least one marketplace output") + return tuple(outputs), tuple(specs) + + +def _parse_package_entry(raw: Any, index: int) -> PackageEntry: + """Parse and validate a single ``packages`` entry.""" + if not isinstance(raw, dict): + raise MarketplaceYmlError(f"packages[{index}] must be a mapping") + + _check_unknown_keys(raw, _PACKAGE_ENTRY_KEYS, context=f"packages[{index}]") + + name = _require_str(raw, "name", context=f"packages[{index}]") + source = _require_str(raw, "source", context=f"packages[{index}]") + _validate_source(source, index=index) + is_local = bool(LOCAL_SOURCE_RE.match(source)) + host: str | None = None + if not is_local: + host, source = split_host_from_source(source) + + # APM-only: subdir + subdir: str | None = raw.get("subdir") + if subdir is not None: + if not isinstance(subdir, str) or not subdir.strip(): + raise MarketplaceYmlError(f"'packages[{index}].subdir' must be a non-empty string") + subdir = subdir.strip() + try: + validate_path_segments(subdir, context=f"packages[{index}].subdir") + except PathTraversalError as exc: + raise MarketplaceYmlError(str(exc)) from exc + + # APM-only: version + version: str | None = raw.get("version") + if version is not None: + version = str(version).strip() + if not version: + raise MarketplaceYmlError(f"'packages[{index}].version' must be a non-empty string") + + # APM-only: ref + ref: str | None = raw.get("ref") + if ref is not None: + ref = str(ref).strip() + if not ref: + raise MarketplaceYmlError(f"'packages[{index}].ref' must be a non-empty string") + + if not is_local and version is None and ref is None: + raise MarketplaceYmlError( + f"packages[{index}] ('{name}'): remote packages require at " + f"least one of 'version' or 'ref'" + ) + + # APM-only: tag_pattern + tag_pattern: str | None = raw.get("tag_pattern") + if tag_pattern is not None: + if not isinstance(tag_pattern, str) or not tag_pattern.strip(): + raise MarketplaceYmlError(f"'packages[{index}].tag_pattern' must be a non-empty string") + tag_pattern = tag_pattern.strip() + _validate_tag_pattern(tag_pattern, context=f"packages[{index}].tag_pattern") + + # APM-only: include_prerelease + include_prerelease = raw.get("include_prerelease", False) + if not isinstance(include_prerelease, bool): + raise MarketplaceYmlError(f"'packages[{index}].include_prerelease' must be a boolean") + + # Anthropic pass-through: description + description: str | None = raw.get("description") + if description is not None: + if not isinstance(description, str) or not description.strip(): + raise MarketplaceYmlError(f"'packages[{index}].description' must be a non-empty string") + description = description.strip() + + # Anthropic pass-through: homepage + homepage: str | None = raw.get("homepage") + if homepage is not None: + if not isinstance(homepage, str) or not homepage.strip(): + raise MarketplaceYmlError(f"'packages[{index}].homepage' must be a non-empty string") + homepage = homepage.strip() + + # Anthropic pass-through: tags + keywords (merged, deduplicated) + tags = _parse_tags_and_keywords(raw, index) + + # Anthropic pass-through: author + author = _parse_author(raw.get("author"), index) + + # Anthropic pass-through: license + license_val: str | None = raw.get("license") + if license_val is not None: + if not isinstance(license_val, str) or not license_val.strip(): + raise MarketplaceYmlError(f"'packages[{index}].license' must be a non-empty string") + license_val = license_val.strip() + + # Anthropic pass-through: repository + repository: str | None = raw.get("repository") + if repository is not None: + if not isinstance(repository, str) or not repository.strip(): + raise MarketplaceYmlError(f"'packages[{index}].repository' must be a non-empty string") + repository = repository.strip() + + # Marketplace category + category: str | None = None + raw_category = raw.get("category") + if raw_category is not None: + if not isinstance(raw_category, str) or not raw_category.strip(): + raise MarketplaceYmlError(f"'packages[{index}].category' must be a non-empty string") + category = raw_category.strip() + + return PackageEntry( + name=name, + source=source, + subdir=subdir, + version=version, + ref=ref, + tag_pattern=tag_pattern, + include_prerelease=include_prerelease, + description=description, + homepage=homepage, + tags=tags, + author=author, + license=license_val, + repository=repository, + category=category, + is_local=is_local, + host=host, + ) + + +def _parse_tags_and_keywords(raw: dict[str, Any], index: int) -> tuple[str, ...]: + """Parse and merge ``tags`` and ``keywords`` fields, capped per S4.""" + raw_tags = raw.get("tags") + tags: tuple[str, ...] = () + if raw_tags is not None: + if not isinstance(raw_tags, list): + raise MarketplaceYmlError(f"'packages[{index}].tags' must be a list of strings") + for i, item in enumerate(raw_tags): + if not isinstance(item, str): + raise MarketplaceYmlError( + f"'packages[{index}].tags[{i}]' must be a string, got {type(item).__name__}" + ) + tags = tuple(str(t) for t in raw_tags) + + raw_keywords = raw.get("keywords") + if raw_keywords is not None: + if not isinstance(raw_keywords, list): + raise MarketplaceYmlError(f"'packages[{index}].keywords' must be a list of strings") + for i, item in enumerate(raw_keywords): + if not isinstance(item, str): + raise MarketplaceYmlError( + f"'packages[{index}].keywords[{i}]' must be a string, got {type(item).__name__}" + ) + seen = set(tags) + merged = list(tags) + for kw in raw_keywords: + if kw not in seen: + seen.add(kw) + merged.append(kw) + tags = tuple(merged) + + # S4: cap array length and item length + if len(tags) > _MAX_TAGS_COUNT: + import logging as _lg + + _lg.getLogger(__name__).warning( + "packages[%d]: tags truncated from %d to %d items", + index, + len(tags), + _MAX_TAGS_COUNT, + ) + tags = tags[:_MAX_TAGS_COUNT] + return tuple(t[:_MAX_TAG_LENGTH] for t in tags) + + +# --------------------------------------------------------------------------- +# Config field assembler (shared by both loaders via yml_schema._build_config) +# --------------------------------------------------------------------------- + + +def _build_config_fields( + marketplace_dict: dict[str, Any], + default_output: str, + warnings_sink: list[str], +) -> tuple[ + MarketplaceOwner, + tuple[str, ...], + tuple[MarketplaceOutputSpec, ...], + str, + MarketplaceClaudeConfig, + dict[str, Any], + MarketplaceBuild, + MarketplaceVersioning, + MarketplaceCodexConfig, +]: + """Parse all sub-blocks from *marketplace_dict*. + + Returns ``(owner, outputs, output_specs, output, claude, + metadata, build, versioning, codex)``. The ``output`` string + is already path-traversal-checked. + """ + # owner + raw_owner = marketplace_dict.get("owner") + if raw_owner is None: + raise MarketplaceYmlError("'owner' is required") + owner = _parse_owner(raw_owner) + + # output selection + outputs, output_specs = _parse_outputs( + marketplace_dict.get("outputs"), warnings_sink=warnings_sink + ) + + # Claude output -- legacy shorthand ``output:`` is the default_output + legacy_output = marketplace_dict.get("output") + output = default_output if legacy_output is None else legacy_output + if not isinstance(output, str) or not output.strip(): + raise MarketplaceYmlError("'output' must be a non-empty string") + output = output.strip() + + # Path traversal guard for raw ``output`` value + try: + validate_path_segments(output, context="marketplace output") + except PathTraversalError as exc: + raise MarketplaceYmlError(str(exc)) from exc + + claude = _parse_claude(marketplace_dict.get("claude"), default_output=output) + # After parse_claude the canonical output is claude.output + output = claude.output + + # metadata (Anthropic pass-through, preserve verbatim) + metadata: dict[str, Any] = {} + raw_metadata = marketplace_dict.get("metadata") + if raw_metadata is not None: + if not isinstance(raw_metadata, dict): + raise MarketplaceYmlError("'metadata' must be a mapping") + metadata = dict(raw_metadata) + + build = _parse_build(marketplace_dict.get("build")) + versioning = _parse_versioning(marketplace_dict.get("versioning")) + codex = _parse_codex(marketplace_dict.get("codex")) + + return owner, outputs, output_specs, output, claude, metadata, build, versioning, codex + + +# --------------------------------------------------------------------------- +# YAML file reader (shared by both loaders) +# --------------------------------------------------------------------------- + + +def _read_yaml_mapping(path: Path) -> dict[str, Any]: + """Read *path* and return its top-level mapping or raise.""" + try: + text = path.read_text(encoding="utf-8") + except OSError as exc: + raise MarketplaceYmlError(f"Cannot read '{path}': {exc}") from exc + + try: + data = yaml.safe_load(text) + except yaml.YAMLError as exc: + detail = "" + if hasattr(exc, "problem_mark") and exc.problem_mark is not None: + mark = exc.problem_mark + detail = f" (line {mark.line + 1}, column {mark.column + 1})" + raise MarketplaceYmlError(f"YAML parse error in '{path}'{detail}: {exc}") from exc + + if data is None: + return {} + if not isinstance(data, dict): + raise MarketplaceYmlError(f"'{path}' must contain a YAML mapping at the top level") + return data diff --git a/src/apm_cli/marketplace/builder.py b/src/apm_cli/marketplace/builder.py index 67a61567c..9fe0477f8 100644 --- a/src/apm_cli/marketplace/builder.py +++ b/src/apm_cli/marketplace/builder.py @@ -13,37 +13,49 @@ Hard rule: the output ``marketplace.json`` conforms byte-for-byte to Anthropic's schema. No APM-specific keys, no extensions, no renamed fields. ``packages`` in yml becomes ``plugins`` in json. + +Internal implementation is split across sibling leaf modules: + +* ``._builder_reports`` -- frozen result/report dataclasses + BuildOptions +* ``._builder_resolve`` -- ``_BuilderResolveMixin`` (resolve + metadata fetch) """ from __future__ import annotations import json import logging -import re import threading import urllib.error -import urllib.request +import urllib.request # noqa: F401 -- patchable at apm_cli.marketplace.builder.urllib.request.urlopen from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any -import yaml - if TYPE_CHECKING: from ..core.auth import HostInfo from ..utils.github_host import default_host from ..utils.path_security import ensure_path_within -from ._io import atomic_write -from ._shared import iter_semver_tags -from .diagnostics import BuildDiagnostic -from .errors import ( - BuildError, - HeadNotAllowedError, - NoMatchingVersionError, - RefNotFoundError, +from ._builder_reports import ( + BuildOptions as BuildOptions, +) +from ._builder_reports import ( + BuildReport as BuildReport, +) +from ._builder_reports import ( + MarketplaceOutputReport as MarketplaceOutputReport, ) +from ._builder_reports import ( + ResolvedPackage as ResolvedPackage, +) +from ._builder_reports import ( + ResolveResult as ResolveResult, +) +from ._builder_resolve import _BuilderResolveMixin +from ._builder_resolve import _strip_ref_prefix as _strip_ref_prefix +from ._io import atomic_write +from .diagnostics import BuildDiagnostic as BuildDiagnostic +from .errors import BuildError from .output_mappers import ( MARKETPLACE_OUTPUT_MAPPERS, MapperResult, @@ -60,9 +72,7 @@ MarketplaceOutputProfile, ) from .ref_resolver import RefResolver -from .semver import SemVer, parse_semver, satisfies_range -from .tag_pattern import build_tag_regex -from .yml_schema import MarketplaceYml, PackageEntry, load_marketplace_yml +from .yml_schema import MarketplaceYml, load_marketplace_yml logger = logging.getLogger(__name__) @@ -75,198 +85,6 @@ "ResolvedPackage", ] -# --------------------------------------------------------------------------- -# Public dataclasses -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True) -class ResolvedPackage: - """A package entry after ref resolution.""" - - name: str - source_repo: str # "owner/repo" only - subdir: str | None # APM-only (used to compose the output ``source`` object) - ref: str # resolved tag name, e.g. "v1.2.0" - sha: str # 40-char git SHA - requested_version: str | None # original APM-only range (for diagnostics) - tags: tuple[str, ...] - is_prerelease: bool # True if the resolved ref was a prerelease semver - host: str | None = None # non-default git host parsed from apm.yml source - - -@dataclass(frozen=True) -class ResolveResult: - """Result of resolving package refs in a marketplace build.""" - - entries: tuple[ResolvedPackage, ...] - errors: tuple[tuple[str, str], ...] # (package name, error message) pairs - - @property - def ok(self) -> bool: - """True when every package resolved without error.""" - return len(self.errors) == 0 - - -@dataclass(frozen=True) -class MarketplaceOutputReport: - """Summary for one generated marketplace output profile.""" - - profile: str - resolved: tuple[ResolvedPackage, ...] - errors: tuple[tuple[str, str], ...] # (package name, error message) pairs - warnings: tuple[str, ...] # non-fatal diagnostic messages - diagnostics: tuple[BuildDiagnostic, ...] = () # structured diagnostics - unchanged_count: int = 0 - added_count: int = 0 - updated_count: int = 0 - removed_count: int = 0 - output_path: Path = field(default_factory=lambda: Path(".")) - dry_run: bool = False - - -@dataclass(frozen=True) -class BuildReport: - """Summary of a marketplace build run across one or more output profiles.""" - - outputs: tuple[MarketplaceOutputReport, ...] - - @property - def primary_output(self) -> MarketplaceOutputReport: - """Return the first output report for legacy single-output callers.""" - if not self.outputs: - return MarketplaceOutputReport( - profile="", - resolved=(), - errors=(), - warnings=(), - ) - return self.outputs[0] - - @property - def resolved(self) -> tuple[ResolvedPackage, ...]: - return self.primary_output.resolved - - @property - def errors(self) -> tuple[tuple[str, str], ...]: - return self.primary_output.errors - - @property - def warnings(self) -> tuple[str, ...]: - return tuple(warn for output in self.outputs for warn in output.warnings) - - @property - def diagnostics(self) -> tuple[BuildDiagnostic, ...]: - return tuple(diag for output in self.outputs for diag in output.diagnostics) - - @property - def unchanged_count(self) -> int: - return self.primary_output.unchanged_count - - @property - def added_count(self) -> int: - return self.primary_output.added_count - - @property - def updated_count(self) -> int: - return self.primary_output.updated_count - - @property - def removed_count(self) -> int: - return self.primary_output.removed_count - - @property - def output_path(self) -> Path: - return self.primary_output.output_path - - @property - def dry_run(self) -> bool: - return any(output.dry_run for output in self.outputs) - - def to_json_dict(self) -> dict[str, Any]: - """Serialize build report as the §4 JSON contract. - - Shape: {ok, dry_run, warnings[], errors[], - marketplace: {outputs: [{format, path, added, updated, - unchanged, skipped}]}, bundle: null} - """ - all_warnings = list(self.warnings) - all_errors: list[dict[str, str]] = [] - output_entries: list[dict[str, Any]] = [] - - for out in self.outputs: - output_entries.append( - { - "format": out.profile, - "path": str(out.output_path), - "added": out.added_count, - "updated": out.updated_count, - "unchanged": out.unchanged_count, - "skipped": out.removed_count, - } - ) - for pkg_name, err_msg in out.errors: - all_errors.append({"code": "build_error", "message": f"{pkg_name}: {err_msg}"}) - - ok = len(all_errors) == 0 - return { - "ok": ok, - "dry_run": self.dry_run, - "warnings": all_warnings, - "errors": all_errors, - "marketplace": { - "outputs": output_entries, - }, - "bundle": None, - } - - @classmethod - def failure_to_json_dict( - cls, - *, - errors: list[dict[str, str]], - warnings: list[str] | None = None, - dry_run: bool = False, - ) -> dict[str, Any]: - """Produce the §4 JSON shape for a pre-build failure. - - Used when the build cannot even start (e.g., config parse error, - unknown format filter). - """ - return { - "ok": False, - "dry_run": dry_run, - "warnings": warnings or [], - "errors": errors, - "marketplace": { - "outputs": [], - }, - "bundle": None, - } - - -@dataclass -class BuildOptions: - """Configuration knobs for MarketplaceBuilder.""" - - concurrency: int = 8 - timeout_seconds: float = 10.0 - include_prerelease: bool = False - allow_head: bool = False - continue_on_error: bool = False - offline: bool = False - # Backwards-compatible spelling for callers that predate ``apm pack``. - output_override: Path | None = None - dry_run: bool = False - - -# --------------------------------------------------------------------------- -# Builder -# --------------------------------------------------------------------------- - -# 40-char hex SHA pattern -_SHA40_RE = re.compile(r"^[0-9a-f]{40}$") - def _is_display_version(version: str | None) -> bool: """Return True if *version* looks like a fixed display version, not a range.""" @@ -278,7 +96,7 @@ def _subtract_plugin_root(source: str, plugin_root: str) -> str: return _mapper_subtract_plugin_root(source, plugin_root) -class MarketplaceBuilder: +class MarketplaceBuilder(_BuilderResolveMixin): """Load marketplace.yml, resolve refs, compose and write marketplace.json. Parameters @@ -305,14 +123,10 @@ def __init__( self._yml: MarketplaceYml | None = None self._resolver: RefResolver | None = None self._auth_resolver = auth_resolver - # Resolved once per build, used by worker threads (read-only). self._github_token: str | None = None self._host: str = default_host() or "github.com" self._host_info: HostInfo | None = None self._auth_resolved: bool = False - # Per-host RefResolver cache, keyed by host override on PackageEntry. - # Pre-warmed on the main thread before workers spawn; lock guards - # against future refactors that allow worker-side cache misses. self._host_resolvers: dict[str, RefResolver] = {} self._host_resolvers_lock = threading.Lock() @@ -324,15 +138,7 @@ def from_config( options: BuildOptions | None = None, auth_resolver: object | None = None, ) -> MarketplaceBuilder: - """Construct a builder from an already-loaded MarketplaceConfig. - - Use this when the caller has already chosen between apm.yml and - the legacy ``marketplace.yml`` (typically via - ``migration.load_marketplace_config``). ``project_root`` is the - directory output paths are resolved against. - """ - # Use a synthetic path so legacy code paths that consult - # ``self._yml_path.parent`` still resolve to the project root. + """Construct a builder from an already-loaded MarketplaceConfig.""" synthetic_path = project_root / ( config.source_path.name if config.source_path is not None else "apm.yml" ) @@ -345,10 +151,6 @@ def from_config( def _load_yml(self) -> MarketplaceYml: if self._yml is None: - # Shape-aware load: when the configured path is an apm.yml - # file, use the apm.yml loader; otherwise default to the - # legacy marketplace.yml loader. Callers that have already - # loaded a config should use ``from_config`` to bypass this. from .yml_schema import load_marketplace_from_apm_yml if self._yml_path.name == "apm.yml": @@ -371,27 +173,14 @@ def _get_resolver(self) -> RefResolver: def _effective_host(self, host: str | None) -> str | None: """Normalize ``host`` for marketplace.json emission. - Returns ``None`` when ``host`` matches the active default host so - an explicit ``github.com/owner/repo`` source in apm.yml emits the - same shorthand (``source: github``, ``repo: owner/repo``) shape as - the bare ``owner/repo`` form. Non-default hosts pass through - unchanged and downstream mappers emit ``source: url`` / - ``source: git-subdir`` with the full HTTPS URL. + Returns ``None`` when ``host`` matches the active default host. """ if host is None or host == self._host: return None return host def _get_resolver_for_host(self, host: str | None) -> RefResolver: - """Return a RefResolver bound to *host* (default when ``None``). - - Non-default hosts go through ``AuthResolver.resolve(host)`` so that - ``GITHUB_APM_PAT``, ``GITHUB_APM_PAT_{ORG}``, ``GITHUB_TOKEN`` and - ``GH_TOKEN`` are consulted before falling back to ambient git - credentials (SSH key / credential helper). Per-host resolvers are - cached for the lifetime of the build so each unique host pays the - auth-resolution cost only once. - """ + """Return a RefResolver bound to *host* (default when ``None``).""" if host is None or host == self._host: return self._get_resolver() with self._host_resolvers_lock: @@ -414,12 +203,7 @@ def _get_resolver_for_host(self, host: str | None) -> RefResolver: return resolver def _resolve_token_for_host(self, host: str) -> str | None: - """Resolve an auth token for a non-default *host* via ``AuthResolver``. - - Returns ``None`` -- letting ``git`` fall back to ambient credentials - -- when offline, when no token is configured for the host, or when - ``AuthResolver`` raises. Never raises. - """ + """Resolve an auth token for a non-default *host* via ``AuthResolver``.""" if self._options.offline: return None try: @@ -438,14 +222,7 @@ def _resolve_token_for_host(self, host: str) -> str | None: return None def _ensure_auth(self) -> None: - """Lazily resolve host classification and GitHub token. - - Short-circuits when already resolved (even if no token was found) - or when running in offline mode. Offline mode is still marked as - resolved so repeated calls remain idempotent. Called by - ``_get_resolver()`` so both ``resolve()`` and ``build()`` benefit - from authenticated ``git ls-remote`` when available. - """ + """Lazily resolve host classification and GitHub token.""" if self._auth_resolved: return if self._options.offline: @@ -461,7 +238,6 @@ def _output_path(self) -> Path: return self._options.output_override yml = self._load_yml() output_path = self._project_root / yml.claude.output - # Containment guard -- reject output paths that escape the project root. ensure_path_within(output_path, self._project_root) return output_path @@ -496,351 +272,10 @@ def _map_output( remote_metadata=remote_metadata, ) - # -- single-entry resolution -------------------------------------------- - - def _resolve_entry(self, entry: PackageEntry) -> ResolvedPackage: - """Resolve a single package entry to a concrete tag + SHA.""" - # Local-path packages skip git resolution entirely. - if entry.is_local: - return ResolvedPackage( - name=entry.name, - source_repo="", - subdir=entry.source, - ref="", - sha="", - requested_version=entry.version, - tags=tuple(entry.tags), - is_prerelease=False, - ) - yml = self._load_yml() - resolver = self._get_resolver_for_host(entry.host) - owner_repo = entry.source - - if entry.ref is not None: - return self._resolve_explicit_ref(entry, resolver, owner_repo) - # version range resolution - return self._resolve_version_range(entry, resolver, owner_repo, yml) - - def _resolve_explicit_ref( - self, - entry: PackageEntry, - resolver: RefResolver, - owner_repo: str, - ) -> ResolvedPackage: - """Resolve an entry with an explicit ``ref:`` field.""" - ref_text = entry.ref - assert ref_text is not None # noqa: S101 - - # If it looks like a 40-char SHA, accept it directly - if _SHA40_RE.match(ref_text): - sv = parse_semver(ref_text.lstrip("vV")) - return ResolvedPackage( - name=entry.name, - source_repo=owner_repo, - subdir=entry.subdir, - ref=ref_text, - sha=ref_text, - requested_version=entry.version, - tags=entry.tags, - is_prerelease=sv.is_prerelease if sv else False, - host=self._effective_host(entry.host), - ) - - refs = resolver.list_remote_refs(owner_repo) - - # Try as tag first (only check tag refs) - for remote_ref in refs: - if not remote_ref.name.startswith("refs/tags/"): - continue - tag_name = _strip_ref_prefix(remote_ref.name) - if tag_name == ref_text: - sv = parse_semver(tag_name.lstrip("vV")) - return ResolvedPackage( - name=entry.name, - source_repo=owner_repo, - subdir=entry.subdir, - ref=tag_name, - sha=remote_ref.sha, - requested_version=entry.version, - tags=entry.tags, - is_prerelease=sv.is_prerelease if sv else False, - host=self._effective_host(entry.host), - ) - - # Try as full refname - for remote_ref in refs: - if remote_ref.name == ref_text: - short = _strip_ref_prefix(remote_ref.name) - is_branch = remote_ref.name.startswith("refs/heads/") - if is_branch and not self._options.allow_head: - raise HeadNotAllowedError(entry.name, short) - sv = parse_semver(short.lstrip("vV")) - return ResolvedPackage( - name=entry.name, - source_repo=owner_repo, - subdir=entry.subdir, - ref=short, - sha=remote_ref.sha, - requested_version=entry.version, - tags=entry.tags, - is_prerelease=sv.is_prerelease if sv else False, - host=self._effective_host(entry.host), - ) - - # Try as branch name - for remote_ref in refs: - if remote_ref.name == f"refs/heads/{ref_text}": - if not self._options.allow_head: - raise HeadNotAllowedError(entry.name, ref_text) - return ResolvedPackage( - name=entry.name, - source_repo=owner_repo, - subdir=entry.subdir, - ref=ref_text, - sha=remote_ref.sha, - requested_version=entry.version, - tags=entry.tags, - is_prerelease=False, - host=self._effective_host(entry.host), - ) - - # HEAD special case - if ref_text.upper() == "HEAD": - if not self._options.allow_head: - raise HeadNotAllowedError(entry.name, "HEAD") - - raise RefNotFoundError(entry.name, ref_text, owner_repo) - - def _resolve_version_range( - self, - entry: PackageEntry, - resolver: RefResolver, - owner_repo: str, - yml: MarketplaceYml, - ) -> ResolvedPackage: - """Resolve an entry using its ``version:`` semver range.""" - version_range = entry.version - assert version_range is not None # noqa: S101 - - # Determine tag pattern: entry > build > default - pattern = entry.tag_pattern or yml.build.tag_pattern - - tag_rx = build_tag_regex(pattern) - refs = resolver.list_remote_refs(owner_repo) - - # Filter tags matching the pattern and extract versions - candidates: list[tuple[SemVer, str, str]] = [] # (semver, tag_name, sha) - for sv, tag_name, sha in iter_semver_tags(refs, tag_rx): - # Prerelease filter - include_pre = entry.include_prerelease or self._options.include_prerelease - if sv.is_prerelease and not include_pre: - continue - - # Range filter - if satisfies_range(sv, version_range): - candidates.append((sv, tag_name, sha)) - - if not candidates: - raise NoMatchingVersionError( - entry.name, - version_range, - detail=f"pattern='{pattern}', remote='{owner_repo}'", - ) - - # Pick highest - candidates.sort(key=lambda c: c[0], reverse=True) - best_sv, best_tag, best_sha = candidates[0] - - return ResolvedPackage( - name=entry.name, - source_repo=owner_repo, - subdir=entry.subdir, - ref=best_tag, - sha=best_sha, - requested_version=version_range, - tags=entry.tags, - is_prerelease=best_sv.is_prerelease, - host=self._effective_host(entry.host), - ) - - # -- concurrent resolution ---------------------------------------------- - - def resolve(self) -> ResolveResult: - """Resolve every entry concurrently. - - Returns - ------- - ResolveResult - Contains resolved entries and any errors encountered. - - Raises - ------ - BuildError - On any resolution failure (unless ``continue_on_error``). - """ - yml = self._load_yml() - entries = yml.packages - if not entries: - return ResolveResult(entries=(), errors=()) - - results: dict[int, ResolvedPackage] = {} - errors: list[tuple[str, str]] = [] - - # Eagerly resolve auth + create the shared RefResolver before - # spawning workers -- avoids a race on _ensure_auth() and - # matches the pattern used in _prefetch_metadata(). - self._get_resolver() - # Pre-warm any per-host resolvers needed by entries that override the - # default host via the ``host.tld/owner/repo`` source form. Done on - # the main thread so workers never race to create the same resolver. - for entry in entries: - if entry.host: - self._get_resolver_for_host(entry.host) - - with ThreadPoolExecutor(max_workers=min(self._options.concurrency, len(entries))) as pool: - future_to_index = { - pool.submit(self._resolve_entry, entry): idx for idx, entry in enumerate(entries) - } - for future in as_completed(future_to_index): - idx = future_to_index[future] - entry = entries[idx] - try: - resolved = future.result(timeout=self._options.timeout_seconds) - results[idx] = resolved - except BuildError as exc: - if self._options.continue_on_error: - errors.append((entry.name, str(exc))) - else: - raise - except Exception as exc: - logger.debug("Unexpected error resolving '%s'", entry.name, exc_info=True) - if self._options.continue_on_error: - errors.append((entry.name, str(exc))) - else: - raise BuildError( - f"Unexpected error resolving '{entry.name}': {exc}", - package=entry.name, - ) from exc - - # Return in yml order - ordered: list[ResolvedPackage] = [] - for idx in range(len(entries)): - if idx in results: - ordered.append(results[idx]) - return ResolveResult(entries=tuple(ordered), errors=tuple(errors)) - - # -- remote description fetcher ----------------------------------------- - - def _fetch_remote_metadata(self, pkg: ResolvedPackage) -> dict[str, str] | None: - """Best-effort: fetch ``description`` and ``version`` from the - package's remote ``apm.yml``. - - Returns a dict with ``description`` and/or ``version`` keys, or - ``None`` on any error. This is purely cosmetic enrichment -- - failures are silently logged at debug level and never propagate. - - When a token is available for the package's host, it is included - as an ``Authorization`` header so private repos can be accessed. - A token resolved for the builder's default host is never sent to - another host. - - Each package is fetched from its own host: ``github.com`` - packages use the fast ``raw.githubusercontent.com`` CDN; GHES - and GHE Cloud packages use the GitHub REST API on the package's - host. For non-GitHub-class hosts, metadata enrichment is - skipped. - """ - try: - path_prefix = f"{pkg.subdir}/" if pkg.subdir else "" - file_path = f"{path_prefix}apm.yml" - - # Resolve the effective host for this package and its - # classification. Falls back to the builder default when the - # package did not carry an explicit host override. - effective_host = pkg.host or self._host - if pkg.host is None or pkg.host == self._host: - host_info = self._host_info - token = self._github_token - else: - from ..core.auth import AuthResolver # lazy import - - try: - host_info = AuthResolver.classify_host(effective_host) - except Exception: - host_info = None - token = self._resolve_token_for_host(effective_host) - - host_kind = host_info.kind if host_info else "github" - - if host_kind not in ("github", "ghe_cloud", "ghes"): - # Non-GitHub hosts -- skip metadata enrichment - logger.debug( - "Skipping metadata fetch for %s (non-GitHub host: %s)", - pkg.name, - effective_host, - ) - return None - - if host_kind == "ghe_cloud" and not token: - logger.debug( - "Skipping metadata fetch for %s (GHE Cloud requires auth)", - pkg.name, - ) - return None - - if effective_host == "github.com": - # github.com -- use fast raw.githubusercontent.com CDN - url = f"https://raw.githubusercontent.com/{pkg.source_repo}/{pkg.sha}/{file_path}" - req = urllib.request.Request(url) # noqa: S310 - if token: - req.add_header("Authorization", f"token {token}") - else: - # GHES / GHE Cloud -- use REST API on the package's host - api_base = ( - host_info.api_base if host_info else None - ) or f"https://{effective_host}/api/v3" - url = f"{api_base}/repos/{pkg.source_repo}/contents/{file_path}?ref={pkg.sha}" - req = urllib.request.Request(url) # noqa: S310 - req.add_header("Accept", "application/vnd.github.raw") - if token: - req.add_header("Authorization", f"token {token}") - - with urllib.request.urlopen(req, timeout=5) as resp: # noqa: S310 - raw = resp.read().decode("utf-8") - data = yaml.safe_load(raw) - if not isinstance(data, dict): - return None - result: dict[str, str] = {} - desc = data.get("description") - if isinstance(desc, str) and desc: - result["description"] = desc - ver = data.get("version") - if ver is not None: - ver_str = str(ver).strip() - if ver_str: - result["version"] = ver_str - if result: - logger.debug( - "Fetched metadata for %s from remote apm.yml: %s", - pkg.name, - ", ".join(result.keys()), - ) - return result - except Exception: - logger.debug( - "Could not fetch remote metadata for %s", - pkg.name, - exc_info=True, - ) - return None + # -- auth + metadata prefetch ------------------------------------------- def _resolve_github_token(self) -> str | None: - """Resolve a GitHub token using ``AuthResolver``. - - Called once before concurrent fetches. Returns the token string - or ``None`` if no credentials are available. Never raises -- - auth failures are logged at debug and silently ignored. - """ + """Resolve a GitHub token using ``AuthResolver``.""" try: from ..core.auth import AuthResolver # lazy import @@ -848,8 +283,6 @@ def _resolve_github_token(self) -> str | None: if resolver is None: resolver = AuthResolver() self._auth_resolver = resolver - # Always classify the host, regardless of token availability, - # so _fetch_remote_metadata() can branch on host kind. if self._host_info is None: self._host_info = AuthResolver.classify_host(self._host) ctx = resolver.resolve(self._host) # type: ignore[union-attr] @@ -860,27 +293,16 @@ def _resolve_github_token(self) -> str | None: logger.debug("Could not resolve GitHub token for metadata fetch", exc_info=True) return None - def _prefetch_metadata(self, resolved: list[ResolvedPackage]) -> dict[str, dict[str, str]]: - """Concurrently fetch remote metadata for all packages. - - Returns a mapping of ``{package_name: {"description": ..., "version": ...}}`` - for successful fetches. Skipped entirely when ``--offline`` is set. - Local-path packages are skipped (they carry their own metadata). - - A GitHub token is resolved once before spawning worker threads and - stored on ``self._github_token`` for the workers to read. - """ + def _prefetch_metadata( + self, resolved: tuple[ResolvedPackage, ...] | list[ResolvedPackage] + ) -> dict[str, dict[str, str]]: + """Concurrently fetch remote metadata for all packages.""" if self._options.offline: return {} - - # Filter out local-path entries -- they don't have a remote to fetch from. remote = [pkg for pkg in resolved if pkg.source_repo] if not remote: return {} - - # Resolve token once -- threads read self._github_token (immutable). self._ensure_auth() - results: dict[str, dict[str, str]] = {} workers = min(self._options.concurrency, len(remote)) with ThreadPoolExecutor(max_workers=workers) as pool: @@ -900,21 +322,7 @@ def _prefetch_metadata(self, resolved: list[ResolvedPackage]) -> dict[str, dict[ # -- composition -------------------------------------------------------- def compose_marketplace_json(self, resolved: list[ResolvedPackage]) -> dict[str, Any]: - """Produce an Anthropic-compliant marketplace.json dict. - - All APM-only fields are stripped. Key order follows the Anthropic - schema exactly. - - Parameters - ---------- - resolved: - List of resolved packages (from ``resolve()``). - - Returns - ------- - dict - An ``OrderedDict``-style dict ready to be serialised as JSON. - """ + """Produce an Anthropic-compliant marketplace.json dict.""" resolved_tuple = tuple(resolved) mapper_result = self._map_output( DEFAULT_MARKETPLACE_OUTPUT, @@ -1010,36 +418,10 @@ def _compute_diff( if old_json is None: return (0, len(new_json.get("plugins", [])), 0, 0) - old_plugins: dict[str, str] = {} - for p in old_json.get("plugins", []): - name = p.get("name", "") - sha = "" - src = p.get("source", {}) - if isinstance(src, dict): - # Accept both the new ``sha`` field (Claude-spec compliant) - # and the legacy ``commit`` field for backward-compatibility - # with marketplace.json files written before this PR. - sha = src.get("sha") or src.get("commit", "") - elif isinstance(src, str): - sha = src # local-path packages: use the path string itself - old_plugins[name] = sha - - new_plugins: dict[str, str] = {} - for p in new_json.get("plugins", []): - name = p.get("name", "") - sha = "" - src = p.get("source", {}) - if isinstance(src, dict): - sha = src.get("sha") or src.get("commit", "") - elif isinstance(src, str): - sha = src - new_plugins[name] = sha - - unchanged = 0 - updated = 0 - added = 0 - removed = 0 + old_plugins = _extract_plugin_shas(old_json) + new_plugins = _extract_plugin_shas(new_json) + unchanged = updated = added = removed = 0 for name, sha in new_plugins.items(): if name not in old_plugins: added += 1 @@ -1047,11 +429,9 @@ def _compute_diff( unchanged += 1 else: updated += 1 - for name in old_plugins: if name not in new_plugins: removed += 1 - return (unchanged, added, updated, removed) # -- atomic write ------------------------------------------------------- @@ -1079,13 +459,7 @@ def _load_existing_json(self, path: Path) -> dict[str, Any] | None: # -- full pipeline ------------------------------------------------------ def build(self) -> BuildReport: - """Full pipeline: load -> resolve -> compose -> write. - - Returns - ------- - BuildReport - Summary including diff statistics. - """ + """Full pipeline: load -> resolve -> compose -> write.""" result = self.resolve() report = self.write_output( DEFAULT_MARKETPLACE_OUTPUT, @@ -1099,32 +473,34 @@ def build(self) -> BuildReport: ), ) - # Cleanup default + per-host resolvers so long-lived builder - # instances do not leak caches or thread locks across builds. if self._resolver is not None: self._resolver.close() with self._host_resolvers_lock: for host_resolver in self._host_resolvers.values(): try: host_resolver.close() - except Exception: # pragma: no cover - close is best-effort + except Exception: # pragma: no cover logger.debug("Failed to close per-host RefResolver", exc_info=True) self._host_resolvers.clear() - return BuildReport( - outputs=report.outputs, - ) + return BuildReport(outputs=report.outputs) # --------------------------------------------------------------------------- -# Helpers +# Module-level helpers # --------------------------------------------------------------------------- -def _strip_ref_prefix(refname: str) -> str: - """Strip ``refs/tags/`` or ``refs/heads/`` prefix.""" - if refname.startswith("refs/tags/"): - return refname[len("refs/tags/") :] - if refname.startswith("refs/heads/"): - return refname[len("refs/heads/") :] - return refname +def _extract_plugin_shas(doc: dict[str, Any]) -> dict[str, str]: + """Return {plugin_name: sha_or_path} from a marketplace.json document.""" + plugins: dict[str, str] = {} + for p in doc.get("plugins", []): + name = p.get("name", "") + sha = "" + src = p.get("source", {}) + if isinstance(src, dict): + sha = src.get("sha") or src.get("commit", "") + elif isinstance(src, str): + sha = src + plugins[name] = sha + return plugins diff --git a/src/apm_cli/marketplace/client.py b/src/apm_cli/marketplace/client.py index 63752a502..de9f1fb38 100644 --- a/src/apm_cli/marketplace/client.py +++ b/src/apm_cli/marketplace/client.py @@ -15,19 +15,40 @@ ``~/.apm/cache/marketplace/`` with a 1-hour TTL. """ -import contextlib import json import logging -import os import re import subprocess -import time from collections.abc import Callable from pathlib import Path -from urllib.parse import quote, urlsplit +from urllib.parse import quote import requests +from ._client_cache import ( + _cache_data_path as _cache_data_path, +) +from ._client_cache import ( + _cache_key as _cache_key, +) +from ._client_cache import ( + _cache_meta_path as _cache_meta_path, +) +from ._client_cache import ( + _clear_cache as _clear_cache, +) +from ._client_cache import ( + _host_from_url as _host_from_url, +) +from ._client_cache import ( + _read_cache as _read_cache, +) +from ._client_cache import ( + _read_stale_cache as _read_stale_cache, +) +from ._client_cache import ( + _write_cache as _write_cache, +) from .errors import MarketplaceError, MarketplaceFetchError from .models import ( MarketplaceManifest, @@ -39,9 +60,6 @@ logger = logging.getLogger(__name__) -_CACHE_TTL_SECONDS = 3600 # 1 hour -_CACHE_DIR_NAME = os.path.join("cache", "marketplace") - # Candidate locations for marketplace.json in a repository (priority order) _MARKETPLACE_PATHS = [ "marketplace.json", @@ -69,108 +87,6 @@ def _validate_ref(ref: str, source_name: str) -> str: return ref -def _cache_dir() -> str: - """Return the cache directory, creating it if needed.""" - from ..config import CONFIG_DIR - - d = os.path.join(CONFIG_DIR, _CACHE_DIR_NAME) - os.makedirs(d, exist_ok=True) - return d - - -def _sanitize_cache_name(name: str) -> str: - """Sanitize marketplace name for safe use in file paths.""" - from ..utils.path_security import PathTraversalError, validate_path_segments - - safe = re.sub(r"[^a-zA-Z0-9._-]", "_", name) - # Prevent path traversal even after sanitization - safe = safe.strip(".").strip("_") or "unnamed" - # Defense-in-depth: validate with centralized path security - try: - validate_path_segments(safe, context="cache name") - except PathTraversalError: - safe = "unnamed" - return safe - - -def _cache_key(source: MarketplaceSource) -> str: - """Cache key that includes kind+host to avoid collisions across hosts.""" - kind = source.kind - if kind == "local": - return f"local__{_sanitize_cache_name(source.name)}" - if kind == "git": - # Generic git: include host so a.com/o/r vs b.com/o/r never collapse. - host = _host_from_url(source.url) or source.host or "unknown" - return f"git__{_sanitize_cache_name(host)}__{_sanitize_cache_name(source.name)}" - normalized_host = (source.host or "github.com").lower() - if normalized_host == "github.com": - return source.name - return f"{_sanitize_cache_name(normalized_host)}__{source.name}" - - -def _cache_data_path(name: str) -> str: - return os.path.join(_cache_dir(), f"{_sanitize_cache_name(name)}.json") - - -def _cache_meta_path(name: str) -> str: - return os.path.join(_cache_dir(), f"{_sanitize_cache_name(name)}.meta.json") - - -def _read_cache(name: str) -> dict | None: - """Read cached marketplace data if valid (not expired).""" - data_path = _cache_data_path(name) - meta_path = _cache_meta_path(name) - if not os.path.exists(data_path) or not os.path.exists(meta_path): - return None - try: - with open(meta_path, encoding="utf-8") as f: - meta = json.load(f) - fetched_at = meta.get("fetched_at", 0) - ttl = meta.get("ttl_seconds", _CACHE_TTL_SECONDS) - if time.time() - fetched_at > ttl: - return None # Expired - with open(data_path, encoding="utf-8") as f: - return json.load(f) - except (json.JSONDecodeError, OSError, KeyError) as exc: - logger.debug("Cache read failed for '%s': %s", name, exc) - return None - - -def _read_stale_cache(name: str) -> dict | None: - """Read cached data even if expired (stale-while-revalidate).""" - data_path = _cache_data_path(name) - if not os.path.exists(data_path): - return None - try: - with open(data_path, encoding="utf-8") as f: - return json.load(f) - except (json.JSONDecodeError, OSError): - return None - - -def _write_cache(name: str, data: dict) -> None: - """Write marketplace data and metadata to cache.""" - data_path = _cache_data_path(name) - meta_path = _cache_meta_path(name) - try: - with open(data_path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2) - with open(meta_path, "w", encoding="utf-8") as f: - json.dump( - {"fetched_at": time.time(), "ttl_seconds": _CACHE_TTL_SECONDS}, - f, - ) - except OSError as exc: - logger.debug("Cache write failed for '%s': %s", name, exc) - - -def _clear_cache(name: str) -> None: - """Remove cached data for a marketplace.""" - for path in (_cache_data_path(name), _cache_meta_path(name)): - with contextlib.suppress(OSError): - os.remove(path) - - # --------------------------------------------------------------------------- # Network fetch -- API path (GitHub / GitLab) # --------------------------------------------------------------------------- @@ -667,22 +583,6 @@ def _fetch_file( return fetcher(source, file_path, host_info=host_info, auth_resolver=auth_resolver) -def _host_from_url(url: str) -> str: - """Extract host from a URL (handles SCP-like SSH URLs too).""" - if not url: - return "" - # SCP-like: git@host:path - if "@" in url and not url.startswith(("http", "git://", "ssh://", "file://")): - try: - return url.split("@", 1)[1].split(":", 1)[0] - except (IndexError, ValueError): - return "" - try: - return urlsplit(url).hostname or "" - except ValueError: - return "" - - def _auto_detect_path( source: MarketplaceSource, auth_resolver: object | None = None, diff --git a/src/apm_cli/marketplace/publisher.py b/src/apm_cli/marketplace/publisher.py index 4ce48cf24..eed375f53 100644 --- a/src/apm_cli/marketplace/publisher.py +++ b/src/apm_cli/marketplace/publisher.py @@ -24,7 +24,6 @@ from __future__ import annotations import hashlib -import json import logging import os import re @@ -32,11 +31,9 @@ import tempfile from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass from datetime import datetime, timezone -from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: from .semver import SemVer @@ -49,7 +46,21 @@ validate_path_segments, ) from ._git_utils import redact_token as _redact_token -from ._io import atomic_write +from ._publish_state import ( + ConsumerTarget as ConsumerTarget, +) +from ._publish_state import ( + PublishOutcome as PublishOutcome, +) +from ._publish_state import ( + PublishPlan as PublishPlan, +) +from ._publish_state import ( + PublishState as PublishState, +) +from ._publish_state import ( + TargetResult as TargetResult, +) from .errors import MarketplaceError from .git_stderr import translate_git_stderr from .migration import load_marketplace_config @@ -69,11 +80,6 @@ "TargetResult", ] -# --------------------------------------------------------------------------- -# Token redaction -- delegated to _git_utils; alias kept for call-site compat. -# --------------------------------------------------------------------------- - - # --------------------------------------------------------------------------- # Branch name sanitisation # --------------------------------------------------------------------------- @@ -92,191 +98,6 @@ def _sanitise_branch_segment(text: str) -> str: return _BRANCH_UNSAFE_RE.sub("-", text) -# --------------------------------------------------------------------------- -# Data model -# --------------------------------------------------------------------------- - - -_REPO_RE = re.compile(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$") -_BRANCH_SAFE_RE = re.compile(r"^[a-zA-Z0-9._/-]+$") - - -@dataclass(frozen=True) -class ConsumerTarget: - """A consumer repository whose ``apm.yml`` should be updated.""" - - repo: str # e.g. "acme-org/service-a" - branch: str = "main" # base branch on the consumer to PR into - path_in_repo: str = "apm.yml" # location of the consumer's apm.yml - - def __post_init__(self) -> None: - if not _REPO_RE.match(self.repo): - raise ValueError( - f"ConsumerTarget.repo must be in 'owner/name' format " - f"using only alphanumerics, dots, hyphens, and underscores. " - f"Got: {self.repo!r}" - ) - if not _BRANCH_SAFE_RE.match(self.branch) or ".." in self.branch: - raise ValueError( - f"ConsumerTarget.branch contains disallowed characters. " - f"Only alphanumerics, dots, hyphens, underscores, and " - f"forward slashes are permitted (no '..' sequences). " - f"Got: {self.branch!r}" - ) - from ..utils.path_security import validate_path_segments - - validate_path_segments(self.path_in_repo, context="consumer-targets path_in_repo") - - -@dataclass(frozen=True) -class PublishPlan: - """Computed plan for a publish run -- frozen and deterministic.""" - - marketplace_name: str # name from the local marketplace.yml - marketplace_version: str # version from the local marketplace.yml - targets: tuple[ConsumerTarget, ...] - commit_message: str # pre-computed, contains the APM trailer - branch_name: str # pre-computed, deterministic - new_ref: str # rendered tag, e.g. "v2.0.0" - tag_pattern_used: str # tag pattern, e.g. "v{version}" - short_hash: str = "" # deterministic hash suffix for the branch name - allow_downgrade: bool = False - allow_ref_change: bool = False - target_package: str | None = None - - -class PublishOutcome(str, Enum): - """Outcome of processing a single consumer target.""" - - UPDATED = "updated" - NO_CHANGE = "no-change" - SKIPPED_DOWNGRADE = "skipped-downgrade" - SKIPPED_REF_CHANGE = "skipped-ref-change" - FAILED = "failed" - - -@dataclass(frozen=True) -class TargetResult: - """Result of processing a single consumer target.""" - - target: ConsumerTarget - outcome: PublishOutcome - message: str # human-readable detail - old_version: str | None = None - new_version: str | None = None - - -# --------------------------------------------------------------------------- -# Transactional state file -# --------------------------------------------------------------------------- - -_STATE_FILENAME = "publish-state.json" -_STATE_DIR = ".apm" -_MAX_HISTORY = 10 -_SCHEMA_VERSION = 1 - - -class PublishState: - """Transactional state file for publish runs. - - State is persisted at ``.apm/publish-state.json`` relative to the - marketplace repo root. All writes are atomic (write-tmp + fsync + - ``os.replace``). - """ - - def __init__(self, root: Path) -> None: - self._root = root.resolve() - self._state_dir = self._root / _STATE_DIR - self._state_path = self._state_dir / _STATE_FILENAME - self._data: dict[str, Any] = { - "schemaVersion": _SCHEMA_VERSION, - "lastRun": None, - "history": [], - } - - @classmethod - def load(cls, root: Path) -> PublishState: - """Load state from disk or return a fresh instance. - - A missing file or corrupt JSON both result in a fresh state -- - no exception is raised. - """ - instance = cls(root) - if instance._state_path.exists(): - try: - text = instance._state_path.read_text(encoding="utf-8") - data = json.loads(text) - if isinstance(data, dict): - instance._data = data - except (json.JSONDecodeError, OSError): - pass # start fresh on corrupt state - return instance - - def _atomic_write(self) -> None: - """Write state atomically via temp file + fsync + os.replace. - - Path validation and directory creation happen here; the actual - write is delegated to the shared ``atomic_write()`` helper from - ``_io.py``. - """ - ensure_path_within(self._state_dir, self._root) - self._state_dir.mkdir(parents=True, exist_ok=True) - - content = json.dumps(self._data, indent=2) + "\n" - atomic_write(self._state_path, content) - - def begin_run(self, plan: PublishPlan) -> None: - """Start a new publish run -- writes ``startedAt``.""" - self._data["lastRun"] = { - "startedAt": datetime.now(timezone.utc).isoformat(), - "finishedAt": None, - "marketplaceName": plan.marketplace_name, - "marketplaceVersion": plan.marketplace_version, - "branchName": plan.branch_name, - "results": [], - } - self._atomic_write() - - def record_result(self, result: TargetResult) -> None: - """Append a target result to the current run.""" - if self._data.get("lastRun") is None: - return - self._data["lastRun"]["results"].append( - { - "repo": result.target.repo, - "outcome": result.outcome.value, - "message": result.message, - "oldVersion": result.old_version, - "newVersion": result.new_version, - } - ) - self._atomic_write() - - def finalise(self, finished_at: datetime) -> None: - """Finalise the current run and rotate history.""" - if self._data.get("lastRun") is None: - return - self._data["lastRun"]["finishedAt"] = finished_at.isoformat() - - # Rotate history -- keep at most _MAX_HISTORY entries - history = self._data.get("history", []) - history.insert(0, dict(self._data["lastRun"])) - self._data["history"] = history[:_MAX_HISTORY] - self._atomic_write() - - def abort(self, reason: str) -> None: - """Mark the current run as aborted.""" - if self._data.get("lastRun") is None: - return - self._data["lastRun"]["finishedAt"] = f"ABORTED: {reason}" - self._atomic_write() - - @property - def data(self) -> dict[str, Any]: - """Return the raw state data (read-only snapshot for inspection).""" - return dict(self._data) - - # --------------------------------------------------------------------------- # Publisher service # --------------------------------------------------------------------------- @@ -654,6 +475,59 @@ def _check_ref_guards( # -- per-target processing ---------------------------------------------- + def _clone_and_checkout( + self, + target: ConsumerTarget, + plan: PublishPlan, + tmpdir: str, + clone_dir: Path, + ) -> TargetResult | None: + """Shallow-clone target repo and create the publish branch. + + Returns ``None`` on success, or a :class:`TargetResult` with + ``FAILED`` outcome on any subprocess error. + """ + url = f"https://github.com/{target.repo}.git" + try: + self._run_git( + [ + "git", + "clone", + "--depth=1", + "--branch", + target.branch, + url, + str(clone_dir), + ], + cwd=tmpdir, + ) + except subprocess.CalledProcessError as exc: + stderr = _redact_token(exc.stderr or "") + translated = translate_git_stderr( + stderr, + exit_code=exc.returncode, + operation="clone", + remote=target.repo, + ) + return TargetResult( + target=target, + outcome=PublishOutcome.FAILED, + message=f"Clone failed: {translated.summary}", + ) + + try: + self._run_git( + ["git", "checkout", "-B", plan.branch_name], + cwd=str(clone_dir), + ) + except subprocess.CalledProcessError as exc: + return TargetResult( + target=target, + outcome=PublishOutcome.FAILED, + message=("Branch creation failed: " + _redact_token(str(exc))), + ) + return None + def _process_single_target( self, target: ConsumerTarget, @@ -665,47 +539,10 @@ def _process_single_target( with tempfile.TemporaryDirectory(prefix="apm-publish-") as tmpdir: clone_dir = Path(tmpdir) / "repo" - # 1. Shallow clone - url = f"https://github.com/{target.repo}.git" - try: - self._run_git( - [ - "git", - "clone", - "--depth=1", - "--branch", - target.branch, - url, - str(clone_dir), - ], - cwd=tmpdir, - ) - except subprocess.CalledProcessError as exc: - stderr = _redact_token(exc.stderr or "") - translated = translate_git_stderr( - stderr, - exit_code=exc.returncode, - operation="clone", - remote=target.repo, - ) - return TargetResult( - target=target, - outcome=PublishOutcome.FAILED, - message=f"Clone failed: {translated.summary}", - ) - - # 2. Create publish branch - try: - self._run_git( - ["git", "checkout", "-B", plan.branch_name], - cwd=str(clone_dir), - ) - except subprocess.CalledProcessError as exc: - return TargetResult( - target=target, - outcome=PublishOutcome.FAILED, - message=("Branch creation failed: " + _redact_token(str(exc))), - ) + # 1+2. Shallow clone + create publish branch + clone_err = self._clone_and_checkout(target, plan, tmpdir, clone_dir) + if clone_err is not None: + return clone_err # 3. Load consumer apm.yml data, apm_yml_path, manifest_err = self._load_consumer_manifest(clone_dir, target, plan) diff --git a/src/apm_cli/marketplace/resolver.py b/src/apm_cli/marketplace/resolver.py index ba97e3407..2ca84dc3d 100644 --- a/src/apm_cli/marketplace/resolver.py +++ b/src/apm_cli/marketplace/resolver.py @@ -29,15 +29,43 @@ import re from collections.abc import Callable, Iterator from dataclasses import dataclass -from urllib.parse import quote, urlparse +from urllib.parse import quote from ..models.dependency.reference import DependencyReference -from ..utils.github_host import ( - is_azure_devops_hostname, - is_github_hostname, - is_supported_git_host, -) from ..utils.path_security import PathTraversalError, validate_path_segments +from ._resolver_match import ( + CrossRepoMisconfigRisk as CrossRepoMisconfigRisk, +) +from ._resolver_match import ( + _coerce_dict_plugin_type as _coerce_dict_plugin_type, +) +from ._resolver_match import ( + _compute_cross_repo_misconfig_risk as _compute_cross_repo_misconfig_risk, +) +from ._resolver_match import ( + _is_in_marketplace_source as _is_in_marketplace_source, +) +from ._resolver_match import ( + _marketplace_host_needs_explicit_git_path as _marketplace_host_needs_explicit_git_path, +) +from ._resolver_match import ( + _marketplace_project_slug as _marketplace_project_slug, +) +from ._resolver_match import ( + _needs_canonical_host_prefix as _needs_canonical_host_prefix, +) +from ._resolver_match import ( + _normalize_owner_repo_slug as _normalize_owner_repo_slug, +) +from ._resolver_match import ( + _normalize_repo_field_for_match as _normalize_repo_field_for_match, +) +from ._resolver_match import ( + _repo_field_matches_marketplace as _repo_field_matches_marketplace, +) +from ._resolver_match import ( + _source_needs_explicit_git_path as _source_needs_explicit_git_path, +) from .client import fetch_or_cache from .errors import PluginNotFoundError from .models import MarketplacePlugin, MarketplaceSource @@ -51,40 +79,6 @@ _SEMVER_RANGE_CHARS = re.compile(r"[~^<>=!]") -@dataclass(frozen=True) -class CrossRepoMisconfigRisk: - """Signal that a cross-repo dict ``type: github`` source on an enterprise - GitHub-family marketplace declares a bare ``owner/repo`` whose canonical - falls back to ``github.com`` -- the same syntactic ambiguity that powers - a dependency-confusion attack (#1326, formerly diagnosed only as #1305). - - Attached to :class:`MarketplacePluginResolution` when the marketplace is on - ``*.ghe.com`` and the plugin's dict source declares a bare ``owner/repo`` - that does not match the marketplace project. The resolver deliberately - leaves these canonicals bare (PR #1292 scoped its host backfill to - in-marketplace sources), so ``DependencyReference.parse`` defaults the host - to ``github.com``. Two intents share this syntax -- a legitimate cross-host - ``github.com`` open-source dep, or a misconfigured same-host entry that - should have been ``corp.ghe.com/owner/repo`` -- and the resolver cannot - distinguish them. - - Consumer contract (#1326): the install command consults this sentinel - BEFORE any outbound validation HTTP call and refuses the package - fail-closed when it is non-``None``. The earlier #1305 design surfaced - only an advisory hint on validation failure, which left the success - path (attacker pre-stages the bare namespace on public github.com) - silently exploitable. Cross-host explicit qualification by the - marketplace author -- ``repo: github.com/owner/repo`` -- prevents - the sentinel from attaching at the resolver layer (see - :func:`_compute_cross_repo_misconfig_risk`), which is the supported - escape hatch for declared cross-host intent. - """ - - marketplace_host: str - bare_repo_field: str - suggested_qualified_repo: str - - @dataclass class MarketplacePluginResolution: """Outcome of :func:`resolve_marketplace_plugin`. @@ -111,263 +105,6 @@ def __iter__(self) -> Iterator[str | MarketplacePlugin]: yield self.plugin -def _normalize_owner_repo_slug(repo: str) -> str: - """Lowercase ``owner/repo`` slug with optional ``.git`` suffix stripped.""" - r = repo.strip().rstrip("/").lower() - if r.endswith(".git"): - r = r[:-4] - return r - - -def _marketplace_project_slug(owner: str, repo: str) -> str: - return _normalize_owner_repo_slug(f"{owner}/{repo}") - - -def _normalize_repo_field_for_match(repo_field: str, marketplace_host: str) -> str: - """Normalize a repo field to a logical project path for matching. - - Accept bare ``owner/repo`` paths, host-qualified shorthand like - ``git.epam.com/owner/repo``, and URL / SSH forms. If the field explicitly names - a different host than the marketplace host, return an empty string so it does - not match by suffix alone. - """ - raw = repo_field.strip().rstrip("/") - if raw.endswith(".git"): - raw = raw[:-4] - - host_l = marketplace_host.strip().lower() - - if raw.startswith(("http://", "https://", "ssh://")): - parsed = urlparse(raw) - parsed_host = (parsed.hostname or "").strip().lower() - if parsed_host and parsed_host != host_l: - return "" - return parsed.path.lstrip("/").lower() - - if raw.startswith("git@") and ":" in raw: - host_part, path_part = raw[4:].split(":", 1) - if host_part.strip().lower() != host_l: - return "" - return path_part.lstrip("/").lower() - - parts = [p for p in raw.split("/") if p] - if len(parts) >= 3 and parts[0].strip().lower() == host_l: - parts = parts[1:] - return "/".join(parts).lower() - - -def _repo_field_matches_marketplace( - repo_field: str, owner: str, repo: str, marketplace_host: str -) -> bool: - """True if dict ``repo`` identifies the same project as the marketplace source.""" - if not repo_field or "/" not in repo_field: - return False - normalized_repo = _normalize_repo_field_for_match(repo_field, marketplace_host) - if not normalized_repo: - return False - return normalized_repo == _marketplace_project_slug(owner, repo) - - -def _coerce_dict_plugin_type(s: dict) -> str: - """Return normalized source ``type`` for a plugin entry dict (``type`` / ``source`` / ``kind``). - - ``type`` is case-insensitive. When it is missing, infers ``github`` or - ``git-subdir`` from ``repo`` plus path fields so in-marketplace matching and - ``path``/``subdir`` extraction match manifests that only set ``kind`` or omit - ``type`` (still require a valid ``repo`` for dict sources). - """ - for key in ("type", "source", "kind"): - v = s.get(key, "") - if isinstance(v, str) and v.strip(): - return v.strip().lower() - repo = s.get("repo", "") - if not isinstance(repo, str) or "/" not in repo.strip(): - return "" - subdir = s.get("subdir", "") - if isinstance(subdir, str) and subdir.strip(): - return "git-subdir" - path = s.get("path", "") - if isinstance(path, str) and path.strip(): - return "github" - return "github" - - -def _is_in_marketplace_source(plugin: MarketplacePlugin, source: MarketplaceSource) -> bool: - """Per spec §Interface Contract — in-marketplace detection.""" - s = plugin.source - if s is None: - return False - if isinstance(s, str): - return True - if not isinstance(s, dict): - return False - source_type = _coerce_dict_plugin_type(s) - if source_type in ("github", "git-subdir", "gitlab"): - return _repo_field_matches_marketplace( - s.get("repo", ""), source.owner, source.repo, source.host - ) - return False - - -def _marketplace_host_needs_explicit_git_path(host: str) -> bool: - """True when in-repo marketplace plugins must use ``git`` + ``path`` (clone root + subdir). - - ``github.com`` and ``*.ghe.com`` virtual shorthand is reliable. Azure DevOps uses - a different URL shape and is excluded. Self-managed GitLab FQDNs are often - classified as ``generic`` by :meth:`AuthResolver.classify_host` when not listed in - ``GITLAB_HOST`` / ``APM_GITLAB_HOSTS`` -- they still need explicit clone URLs so - paths like ``registry/pkg`` are not treated as extra project namespace segments. - """ - if not host or not str(host).strip(): - return False - h = str(host).strip().split("/", 1)[0] - if is_azure_devops_hostname(h): - return False - return not is_github_hostname(h) - - -def _source_needs_explicit_git_path(source: MarketplaceSource) -> bool: - """Kind-aware variant of :func:`_marketplace_host_needs_explicit_git_path`. - - For URL-first sources, the ``kind`` derivation already encodes the routing - decision: any host APM doesn't classify as github-family needs the explicit - git+path canonical (mirrors the existing GitLab self-managed pattern), and - that now includes Azure DevOps and generic git hosts since their - ``marketplace.json`` is fetched via subprocess git instead of an API. - - Local marketplaces handle relative sources via :func:`_resolve_local_relative_source` - on the fast path and never reach this helper. - """ - kind = source.kind - if kind == "github": - return False - if kind in ("gitlab", "git"): - return True - # Fall back to legacy host-based behaviour for any kind we don't recognise - return _marketplace_host_needs_explicit_git_path(source.host) - - -def _needs_canonical_host_prefix(canonical: str, host: str) -> bool: - """True when a GitHub-family enterprise host must be prefixed to ``canonical``. - - GitHub-family hosts (``github.com`` + ``*.ghe.com``) keep virtual shorthand -- - ``resolve_plugin_source`` emits a bare ``owner/repo[/path]`` canonical because - there is no nested-group ambiguity to disambiguate. ``DependencyReference.parse`` - defaults missing hosts to ``github.com``, which is correct for ``github.com`` but - silently mis-routes auth for every ``*.ghe.com`` marketplace. - - Returns True only for enterprise GitHub hosts (``*.ghe.com``) so the caller can - backfill the host while preserving shorthand semantics. Idempotent: when the - canonical already starts with ``host`` (case-insensitive) -- as happens when the - manifest's dict source carries a host-qualified ``repo`` -- this returns False - so the prefix is not duplicated. - - GHES (GitHub Enterprise Server, configured via ``GITHUB_HOST``) is not handled - here. Those hosts return True from ``_marketplace_host_needs_explicit_git_path`` - (neither GitHub-family nor ADO) so ``resolve_marketplace_plugin`` builds a - structured ``dep_ref`` upstream and this helper is never reached. The - ``is_github_hostname`` check below is defense-in-depth that would also reject - them if a future change ever bypassed the upstream guard. - - Also returns False when ``canonical`` is in URL form (``https://...``) or SSH - SCP shorthand (``git@host:owner/repo``). Manifests that put a full URL in the - ``repo`` field reach this point via ``_resolve_github_source`` (which only - requires a ``/``); detecting those by ``":"`` in the first slash-split segment - avoids producing malformed ``host/https://...`` canonicals. Those forms already - carry a host and ``DependencyReference.parse`` resolves them natively. - """ - h = (host or "").strip() - if not h or not is_github_hostname(h) or h.lower() == "github.com": - return False - first_segment = canonical.split("/", 1)[0] - if ":" in first_segment: - return False - return first_segment.lower() != h.lower() - - -def _compute_cross_repo_misconfig_risk( - plugin: MarketplacePlugin, - source: MarketplaceSource, - canonical: str, - dep_ref: DependencyReference | None, -) -> CrossRepoMisconfigRisk | None: - """Identify the #1305 misconfiguration: cross-repo dict ``type: github`` - source with bare ``repo`` on an enterprise GitHub-family marketplace. - - Returns a :class:`CrossRepoMisconfigRisk` when **all** of: - - - ``dep_ref`` is ``None`` (GitHub-family virtual-shorthand path; GitLab and - self-managed FQDNs build a structured ref upstream and sidestep the bug) - - ``plugin.source`` is a dict whose normalized type is ``github`` (other - dict types -- ``gitlab``, ``git-subdir`` -- hit the same auth-routing - bug but the "host-qualify with marketplace host" remediation only - matches operator intent for the GitHub family) - - the source is **not** an in-marketplace reference (PR #1292 already - backfills the host for those) - - ``_needs_canonical_host_prefix`` agrees the canonical is bare and the - host is GitHub-family enterprise (``*.ghe.com``; idempotent against - already host-qualified, URL, and SSH forms) - - the ``repo`` field is a non-empty ``owner/repo`` shorthand - - Otherwise returns ``None``. Pure -- no logging, no side effects. - """ - if dep_ref is not None: - return None - if not isinstance(plugin.source, dict): - return None - if _coerce_dict_plugin_type(plugin.source) != "github": - return None - if _is_in_marketplace_source(plugin, source): - return None - if not _needs_canonical_host_prefix(canonical, source.host): - return None - repo_field = plugin.source.get("repo", "") - if not isinstance(repo_field, str): - return None - bare = repo_field.strip().lstrip("/") - if "/" not in bare: - return None - # #1326: an already-host-qualified `repo:` field declares explicit intent - # (e.g. ``repo: github.com/owner/repo`` on a ``*.ghe.com`` marketplace is - # an unambiguous declared cross-host dependency). Only the truly-bare - # ``owner/repo`` form is the dependency-confusion vector this sentinel - # flags. ``_needs_canonical_host_prefix`` above already returns False - # for SAME-host qualification (its idempotency clause) and for URL / - # SSH SCP shorthand canonicals; this is the symmetric guard for the - # remaining case -- CROSS-host shorthand qualification (``github.com/...`` - # on a ``*.ghe.com`` marketplace), which the idempotency check cannot - # detect because the canonical starts with a different host than - # ``source.host``. - # - # Defense in depth: extract the host from URL and SCP shorthand forms - # too, so the guard is robust even if a future upstream refactor lets - # those forms reach this point. A bare ``split("/", 1)[0]`` would - # otherwise classify ``https://...`` as having a ``https:`` first - # segment (not a host) and incorrectly attach the sentinel. - explicit_host = "" - bare_lower = bare.lower() - if bare_lower.startswith(("https://", "http://", "ssh://")): - explicit_host = (urlparse(bare).hostname or "").strip() - elif bare.startswith("git@") and ":" in bare: - # SCP shorthand: ``git@host:owner/repo`` - explicit_host = bare[4:].split(":", 1)[0].strip() - else: - explicit_host = bare.split("/", 1)[0] - # ``is_supported_git_host`` accepts any valid FQDN, not an allowlist. - # This is intentional: the goal is to distinguish "looks like a - # hostname" (explicit intent) from "bare owner/repo" (ambiguous). - # Restricting to known hosts would silently refuse legitimate - # self-hosted Git servers and create a false sense of security -- - # the real protection is the fail-closed refusal of the bare form. - if is_supported_git_host(explicit_host): - return None - return CrossRepoMisconfigRisk( - marketplace_host=source.host, - bare_repo_field=bare, - suggested_qualified_repo=f"{source.host}/{bare}", - ) - - def _marketplace_https_git_url(source: MarketplaceSource) -> str: """HTTPS clone URL for the registered marketplace project. @@ -389,6 +126,29 @@ def _marketplace_https_git_url(source: MarketplaceSource) -> str: return f"https://{source.host}/{encoded}.git" +def _extract_dict_path_ref( + src: dict, source_type: str, ref: str | None +) -> tuple[str | None, str | None]: + """Extract (path, ref) from a dict plugin source; used by _extract_in_repo_path_and_ref.""" + if source_type == "github": + path = src.get("path", "") + path = path.strip("/") if isinstance(path, str) else "" + if not path: + return None, ref + validate_path_segments(path, context="github source path") + return path, ref + + if source_type in ("git-subdir", "gitlab"): + sub = (src.get("subdir", "") or src.get("path", "")) or "" + sub = sub.strip("/") if isinstance(sub, str) else "" + if not sub: + return None, ref + validate_path_segments(sub, context="git-subdir source path") + return sub, ref + + return None, None + + def _extract_in_repo_path_and_ref( plugin: MarketplacePlugin, plugin_root: str = "" ) -> tuple[str | None, str | None]: @@ -426,24 +186,7 @@ def _extract_in_repo_path_and_ref( source_type = _coerce_dict_plugin_type(src) ref_val = src.get("ref", "") ref: str | None = ref_val.strip() if isinstance(ref_val, str) and ref_val.strip() else None - - if source_type == "github": - path = src.get("path", "") - path = path.strip("/") if isinstance(path, str) else "" - if not path: - return None, ref - validate_path_segments(path, context="github source path") - return path, ref - - if source_type in ("git-subdir", "gitlab"): - sub = (src.get("subdir", "") or src.get("path", "")) or "" - sub = sub.strip("/") if isinstance(sub, str) else "" - if not sub: - return None, ref - validate_path_segments(sub, context="git-subdir source path") - return sub, ref - - return None, None + return _extract_dict_path_ref(src, source_type, ref) def _gitlab_in_marketplace_dependency_reference( diff --git a/src/apm_cli/marketplace/semver.py b/src/apm_cli/marketplace/semver.py index eb78072e8..d32651b84 100644 --- a/src/apm_cli/marketplace/semver.py +++ b/src/apm_cli/marketplace/semver.py @@ -176,38 +176,46 @@ def satisfies_range(version: SemVer, range_spec: str) -> bool: ] +def _satisfies_caret(version: SemVer, spec: str) -> bool: + """Check a caret (``^``) range constraint.""" + base = parse_semver(spec) + if base is None: + return False + if base.major != 0: + # ^1.2.3 := >=1.2.3 <2.0.0 + return version >= base and version.major == base.major + if base.minor != 0: + # ^0.2.3 := >=0.2.3 <0.3.0 + return version >= base and version.major == 0 and version.minor == base.minor + # ^0.0.3 := >=0.0.3 <0.0.4 + return ( + version >= base + and version.major == 0 + and version.minor == 0 + and version.patch == base.patch + ) + + def _satisfies_single(version: SemVer, spec: str) -> bool: """Check a single constraint.""" spec = spec.strip() if not spec: return True - # Caret range: ^major.minor.patch + # Caret range: ^major.minor.patch (delegated to keep return count low) if spec.startswith("^"): - base = parse_semver(spec[1:]) - if base is None: - return False - if base.major != 0: - # ^1.2.3 := >=1.2.3 <2.0.0 - return version >= base and version.major == base.major - if base.minor != 0: - # ^0.2.3 := >=0.2.3 <0.3.0 - return version >= base and version.major == 0 and version.minor == base.minor - # ^0.0.3 := >=0.0.3 <0.0.4 - return ( - version >= base - and version.major == 0 - and version.minor == 0 - and version.patch == base.patch - ) + return _satisfies_caret(version, spec[1:]) # Tilde range: ~major.minor.patch if spec.startswith("~"): base = parse_semver(spec[1:]) - if base is None: - return False # ~1.2.3 := >=1.2.3 <1.3.0 - return version >= base and version.major == base.major and version.minor == base.minor + return ( + base is not None + and version >= base + and version.major == base.major + and version.minor == base.minor + ) # Comparison operators (table-driven dispatch) for prefix, cmp in _CMP_OPS: @@ -232,9 +240,7 @@ def _satisfies_single(version: SemVer, spec: str) -> bool: # Exact match (also handles explicit-equality after prefix strip) base = parse_semver(spec) - if base is None: - return False - return ( + return base is not None and ( version.major == base.major and version.minor == base.minor and version.patch == base.patch diff --git a/src/apm_cli/marketplace/yml_schema.py b/src/apm_cli/marketplace/yml_schema.py index bfcf4d44d..4650760f5 100644 --- a/src/apm_cli/marketplace/yml_schema.py +++ b/src/apm_cli/marketplace/yml_schema.py @@ -27,20 +27,120 @@ ``marketplace:`` subtree is validated by this module. * **Local-path packages.** ``source`` accepts ``./...`` paths in addition to ``owner/repo`` shape. Local packages skip ref resolution. + +Internal implementation is split across sibling leaf modules to keep +complexity manageable: + +* ``._yml_models`` -- frozen dataclasses (no parse logic) +* ``._yml_parsers`` -- constants, validators, parse helpers, YAML reader """ from __future__ import annotations -import re -from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Mapping # noqa: UP035 - -import yaml +from typing import Any from ..utils.path_security import PathTraversalError, validate_path_segments + +# --------------------------------------------------------------------------- +# Re-export dataclasses from the leaf model module so that existing callers +# such as ``from apm_cli.marketplace.yml_schema import PackageEntry`` keep +# working without any changes. +# --------------------------------------------------------------------------- +from ._yml_models import ( + MarketplaceBuild as MarketplaceBuild, +) +from ._yml_models import ( + MarketplaceClaudeConfig as MarketplaceClaudeConfig, +) +from ._yml_models import ( + MarketplaceCodexConfig as MarketplaceCodexConfig, +) +from ._yml_models import ( + MarketplaceConfig as MarketplaceConfig, +) +from ._yml_models import ( + MarketplaceOutputSpec as MarketplaceOutputSpec, +) +from ._yml_models import ( + MarketplaceOwner as MarketplaceOwner, +) +from ._yml_models import ( + MarketplaceVersioning as MarketplaceVersioning, +) +from ._yml_models import ( + PackageEntry as PackageEntry, +) + +# --------------------------------------------------------------------------- +# Re-export parse helpers so test files that do +# ``from apm_cli.marketplace.yml_schema import _parse_author`` keep working. +# The ``X as X`` form signals intentional re-export to ruff (suppresses F401). +# --------------------------------------------------------------------------- +from ._yml_parsers import ( + _APM_MARKETPLACE_KEYS as _APM_MARKETPLACE_KEYS, +) + +# --------------------------------------------------------------------------- +# Re-export public parser symbols so callers that import SOURCE_RE / +# LOCAL_SOURCE_RE / split_host_from_source from this module keep working. +# --------------------------------------------------------------------------- +from ._yml_parsers import ( + LOCAL_SOURCE_RE as LOCAL_SOURCE_RE, +) +from ._yml_parsers import ( + SOURCE_RE as SOURCE_RE, +) +from ._yml_parsers import ( + _build_config_fields as _build_config_fields, +) +from ._yml_parsers import ( + _check_unknown_keys as _check_unknown_keys, +) +from ._yml_parsers import ( + _parse_author as _parse_author, +) +from ._yml_parsers import ( + _parse_build as _parse_build, +) +from ._yml_parsers import ( + _parse_claude as _parse_claude, +) +from ._yml_parsers import ( + _parse_codex as _parse_codex, +) +from ._yml_parsers import ( + _parse_outputs as _parse_outputs, +) +from ._yml_parsers import ( + _parse_owner as _parse_owner, +) +from ._yml_parsers import ( + _parse_package_entry as _parse_package_entry, +) +from ._yml_parsers import ( + _parse_versioning as _parse_versioning, +) +from ._yml_parsers import ( + _read_yaml_mapping as _read_yaml_mapping, +) +from ._yml_parsers import ( + _require_str as _require_str, +) +from ._yml_parsers import ( + _validate_semver as _validate_semver, +) +from ._yml_parsers import ( + _validate_source as _validate_source, +) +from ._yml_parsers import ( + _validate_tag_pattern as _validate_tag_pattern, +) +from ._yml_parsers import ( + split_host_from_source as split_host_from_source, +) from .errors import MarketplaceYmlError -from .output_profiles import MARKETPLACE_OUTPUTS, known_output_names +from .output_profiles import MARKETPLACE_OUTPUTS __all__ = [ "LOCAL_SOURCE_RE", @@ -57,822 +157,13 @@ "load_marketplace_from_apm_yml", "load_marketplace_from_legacy_yml", "load_marketplace_yml", + "split_host_from_source", ] -# --------------------------------------------------------------------------- -# Semver validation (matches codebase convention -- regex, no external lib) -# --------------------------------------------------------------------------- - -_SEMVER_RE = re.compile( - r"^\d+\.\d+\.\d+" - r"(?:-[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?" - r"(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?$" -) - -# Source field accepts: -# - ``owner/repo`` (remote, default host) -# - ``host.tld/owner/repo`` (remote on a non-default host, shorthand) -# - ``https://host.tld/owner/repo`` (remote on a non-default host, full URL) -# - ``https://host.tld/owner/repo.git`` (same, with optional ``.git`` suffix) -# - ``./...`` (local path within the same repo) -# -# Used by both yml_schema and yml_editor for source field validation. -# -# The host segment is restricted to RFC-1123 hostname characters -# (letters, digits, hyphens, dots) and must contain at least one dot -# (i.e. look like a FQDN, to disambiguate from ``owner/repo``). Userinfo -# (``user@host``), port (``host:port``), query strings, fragments, SSH SCP -# (``git@host:path``) and non-``https`` URL schemes are explicitly rejected -# to avoid RFC 3986 confused-deputy attacks. -_HOST_PAT = r"(?:[A-Za-z0-9](?:[A-Za-z0-9-]*[A-Za-z0-9])?\.)+[A-Za-z][A-Za-z0-9-]*" -_OWNER_REPO_PAT = r"[A-Za-z0-9._-]+/[A-Za-z0-9._-]+" - -SOURCE_RE = re.compile( - r"^(?:" - rf"https://{_HOST_PAT}/{_OWNER_REPO_PAT}(?:\.git)?" - rf"|{_HOST_PAT}/{_OWNER_REPO_PAT}" - rf"|{_OWNER_REPO_PAT}" - r"|\./.*" - r")$" -) -LOCAL_SOURCE_RE = re.compile(r"^\./") -# Matches ``host.tld/owner/repo`` (3 segments, first is FQDN-ish). -_HOST_PREFIXED_SOURCE_RE = re.compile(rf"^({_HOST_PAT})/({_OWNER_REPO_PAT})$") -# Matches ``https://host.tld/owner/repo[.git]`` and captures host + owner/repo. -_HTTPS_URL_SOURCE_RE = re.compile(rf"^https://({_HOST_PAT})/({_OWNER_REPO_PAT})(?:\.git)?$") - - -def split_host_from_source(source: str) -> tuple[str | None, str]: - """Split a host-qualified source into ``(host, owner/repo)``. - - Accepts both shorthand (``host.tld/owner/repo``) and full HTTPS URL - (``https://host.tld/owner/repo[.git]``) forms. Returns ``(None, source)`` - for the plain ``owner/repo`` shorthand or local ``./...`` paths. - - A trailing ``.git`` suffix on the repo segment is stripped so the - returned ``owner/repo`` is normalized regardless of input form. - """ - m = _HTTPS_URL_SOURCE_RE.match(source) - if m: - host, owner_repo = m.group(1), m.group(2) - if owner_repo.endswith(".git"): - owner_repo = owner_repo[: -len(".git")] - return host, owner_repo - m = _HOST_PREFIXED_SOURCE_RE.match(source) - if m: - return m.group(1), m.group(2) - return None, source - - -# Placeholder tokens accepted in ``tag_pattern`` / ``build.tagPattern``. -_TAG_PLACEHOLDERS = ("{version}", "{name}") - -# --------------------------------------------------------------------------- -# Permitted key sets (strict mode) -# --------------------------------------------------------------------------- - -_BUILD_KEYS = frozenset( - { - "tagPattern", - } -) - -_PACKAGE_ENTRY_KEYS = frozenset( - { - "name", - "source", - "subdir", - "version", - "ref", - "tag_pattern", - "include_prerelease", - "description", - "homepage", - "tags", - "author", - "license", - "repository", - "keywords", - "category", - } -) - -# Limits for keywords/tags array to prevent DoS via oversized manifests (S4). -_MAX_TAGS_COUNT = 50 -_MAX_TAG_LENGTH = 100 - -# Keys permitted inside an ``author`` object (rejected if anything else -# present). Mirrors the Claude Code plugin manifest schema. -_AUTHOR_OBJECT_KEYS = frozenset({"name", "email", "url"}) - - -def _parse_author(raw: Any, index: int) -> dict[str, str] | None: - """Normalize a curator-supplied ``author`` value to a Claude-Code- - compliant object ``{name, email?, url?}``. - - Accepts either a non-empty string (treated as ``name``) or a mapping - with at least ``name`` and only the permitted keys. Returns ``None`` - when ``raw`` is ``None``. Raises :class:`MarketplaceYmlError` on any - other shape. - """ - if raw is None: - return None - ctx = f"packages[{index}].author" - if isinstance(raw, str): - name = raw.strip() - if not name: - raise MarketplaceYmlError(f"'{ctx}' must be a non-empty string or object with 'name'") - return {"name": name} - if isinstance(raw, dict): - unknown = set(raw.keys()) - _AUTHOR_OBJECT_KEYS - if unknown: - raise MarketplaceYmlError( - f"'{ctx}' has unknown key(s): " - f"{', '.join(sorted(unknown))}; allowed: " - f"{', '.join(sorted(_AUTHOR_OBJECT_KEYS))}" - ) - name = raw.get("name") - if not isinstance(name, str) or not name.strip(): - raise MarketplaceYmlError(f"'{ctx}.name' is required and must be a non-empty string") - out: dict[str, str] = {"name": name.strip()} - for key in ("email", "url"): - val = raw.get(key) - if val is None: - continue - if not isinstance(val, str) or not val.strip(): - raise MarketplaceYmlError(f"'{ctx}.{key}' must be a non-empty string") - out[key] = val.strip() - return out - raise MarketplaceYmlError(f"'{ctx}' must be a string or object, got {type(raw).__name__}") - - -# Keys permitted inside the ``marketplace:`` block of apm.yml. This is -# distinct from the legacy top-level keys (which include ``name``, -# ``description``, ``version`` -- those are inherited from apm.yml's -# top-level scalars in the new world). -_APM_MARKETPLACE_KEYS = frozenset( - { - "name", # optional override of top-level apm.yml name - "description", # optional override of top-level apm.yml description - "version", # optional override of top-level apm.yml version - "owner", - "output", - "outputs", - "claude", - "metadata", - "build", - "codex", - "packages", - "versioning", - } -) - -_VERSIONING_KEYS = frozenset({"strategy"}) - -_VERSIONING_STRATEGIES = frozenset({"lockstep", "tag_pattern", "per_package"}) - -_CLAUDE_KEYS = frozenset( - { - "output", - } -) - -_CODEX_KEYS = frozenset( - { - "output", - } -) - -# --------------------------------------------------------------------------- -# Dataclasses -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True) -class MarketplaceOwner: - """Owner block of ``marketplace.yml``.""" - - name: str - email: str | None = None - url: str | None = None - - -@dataclass(frozen=True) -class MarketplaceBuild: - """APM-only build configuration block.""" - - tag_pattern: str = "v{version}" - - -@dataclass(frozen=True) -class MarketplaceVersioning: - """Release-time versioning strategy for the marketplace. - - Controls how ``apm pack --check-versions`` verifies per-package - version alignment across local-path packages: - - * ``lockstep`` (default) -- every local package's top-level - ``version`` must equal the marketplace's top-level ``version``. - * ``tag_pattern`` -- each rendered tag must be unique across all - local packages; missing ``version`` still fails. - * ``per_package`` -- only requires that each local package declare - a ``version``; equality is not enforced. - """ - - strategy: str = "lockstep" - - -@dataclass(frozen=True) -class MarketplaceClaudeConfig: - """Claude-specific marketplace output configuration.""" - - output: str = ".claude-plugin/marketplace.json" - - -@dataclass(frozen=True) -class MarketplaceCodexConfig: - """Codex-specific marketplace output configuration.""" - - output: str = MARKETPLACE_OUTPUTS["codex"].default_output - - -@dataclass(frozen=True) -class PackageEntry: - """A single entry in the ``packages`` list. - - Attributes that are Anthropic pass-through (``description``, - ``homepage``, ``tags``) are stored alongside APM-only attributes - (``subdir``, ``version``, ``ref``, ``tag_pattern``, - ``include_prerelease``) so the builder can partition them at - compile time. - - ``is_local`` is derived by the loader from the ``source`` field -- - a leading ``./`` marks a local-path package that skips git - resolution. - """ - - name: str - source: str - # APM-only fields - subdir: str | None = None - version: str | None = None - ref: str | None = None - tag_pattern: str | None = None - include_prerelease: bool = False - # Anthropic pass-through fields - description: str | None = None - homepage: str | None = None - tags: tuple[str, ...] = () - # ``author`` is normalized to a Claude-Code-compliant object: - # ``{"name": str, "email"?: str, "url"?: str}``. Accepts either a - # bare string (treated as ``name``) or a mapping at parse time. - author: Mapping[str, str] | None = None - license: str | None = None - repository: str | None = None - # Marketplace category metadata. Emitted only by output formats that - # consume categories, currently Codex repo marketplace output. - category: str | None = None - # Derived (set by loader, not by user) - is_local: bool = False - # Optional non-default git host parsed from ``source`` of the form - # ``host.tld/owner/repo``. ``None`` means use the default host - # (``GITHUB_HOST`` env or ``github.com``). - host: str | None = None - - -@dataclass(frozen=True) -class MarketplaceOutputSpec: - """Resolved specification for one marketplace output format. - - Produced by the map-form ``outputs:`` parser. When ``path_explicit`` - is True, the manifest set an explicit ``path:`` value (vs. the - profile default). - """ - - name: str - """Format name (matches a key in ``MARKETPLACE_OUTPUTS``).""" - - path: str - """Resolved output path (explicit or profile default).""" - - path_explicit: bool = False - """True if the user set an explicit ``path:`` in the outputs map.""" - - -@dataclass(frozen=True) -class MarketplaceConfig: - """Parsed marketplace configuration. - - May originate from apm.yml's ``marketplace:`` block (current) or - from a standalone ``marketplace.yml`` (legacy, deprecated). - - ``metadata`` is stored as a plain ``dict`` preserving the original - key casing so the builder can forward it verbatim to - ``marketplace.json``. - - Override flags (``*_overridden``) record whether the marketplace - block explicitly set each inheritable field. The builder uses - these flags to decide whether to emit ``description``/``version`` - at the top level of ``marketplace.json`` -- per the Anthropic - azure-skills convention, inherited values are omitted from output. - """ - - name: str - description: str - version: str - owner: MarketplaceOwner - output: str = ".claude-plugin/marketplace.json" - outputs: tuple[str, ...] = ("claude",) - claude: MarketplaceClaudeConfig = field(default_factory=MarketplaceClaudeConfig) - codex: MarketplaceCodexConfig = field(default_factory=MarketplaceCodexConfig) - metadata: dict[str, Any] = field(default_factory=dict) - build: MarketplaceBuild = field(default_factory=MarketplaceBuild) - versioning: MarketplaceVersioning = field(default_factory=MarketplaceVersioning) - packages: tuple[PackageEntry, ...] = () - output_specs: tuple[MarketplaceOutputSpec, ...] = () - warnings: tuple[str, ...] = () - # Origin tracking + override-detection metadata - source_path: Path | None = None - is_legacy: bool = False - name_overridden: bool = False - description_overridden: bool = False - version_overridden: bool = False - - -# Backwards-compatibility alias for callers that still import -# ``MarketplaceYml``. Will be removed in a future minor release. +# Backwards-compatibility alias for callers that still import ``MarketplaceYml``. MarketplaceYml = MarketplaceConfig -# --------------------------------------------------------------------------- -# Validation helpers -# --------------------------------------------------------------------------- - - -def _require_str( - data: dict[str, Any], - key: str, - *, - context: str = "", -) -> str: - """Return a non-empty string value or raise ``MarketplaceYmlError``.""" - path = f"{context}.{key}" if context else key - value = data.get(key) - if value is None: - raise MarketplaceYmlError(f"'{path}' is required") - if not isinstance(value, str) or not value.strip(): - raise MarketplaceYmlError(f"'{path}' must be a non-empty string") - return value.strip() - - -def _validate_semver(version: str, *, context: str = "version") -> None: - """Raise if *version* is not a valid semver string.""" - if not _SEMVER_RE.match(version): - raise MarketplaceYmlError( - f"'{context}' value '{version}' is not valid semver (expected x.y.z)" - ) - - -def _validate_source(source: str, *, index: int) -> None: - """Validate ``source`` field shape and path safety. - - Accepts ``owner/repo``, ``host.tld/owner/repo``, ``https://host.tld/ - owner/repo[.git]``, or ``./``. - """ - ctx = f"packages[{index}].source" - if not SOURCE_RE.match(source): - raise MarketplaceYmlError( - f"'{ctx}' must be one of " - f"'/', '//', " - f"'https:////[.git]', or './', " - f"got '{source}'" - ) - is_local = bool(LOCAL_SOURCE_RE.match(source)) - try: - # Local paths legitimately start with ``.`` (current dir) and - # may have trailing-slash forms like ``./``. Allow ``.`` here. - validate_path_segments(source, context=ctx, allow_current_dir=is_local) - except PathTraversalError as exc: - raise MarketplaceYmlError(str(exc)) from exc - - -def _validate_tag_pattern(pattern: str, *, context: str) -> None: - """Ensure *pattern* contains at least one recognised placeholder.""" - if not any(ph in pattern for ph in _TAG_PLACEHOLDERS): - raise MarketplaceYmlError( - f"'{context}' must contain at least one of " - f"{', '.join(_TAG_PLACEHOLDERS)}, got '{pattern}'" - ) - - -def _check_unknown_keys( - data: dict[str, Any], - permitted: frozenset, - *, - context: str, -) -> None: - """Raise on any key not in *permitted*.""" - unknown = set(data.keys()) - permitted - if unknown: - sorted_unknown = sorted(unknown) - sorted_permitted = sorted(permitted) - raise MarketplaceYmlError( - f"Unknown key(s) in {context}: {', '.join(sorted_unknown)}. " - f"Permitted keys: {', '.join(sorted_permitted)}" - ) - - -# --------------------------------------------------------------------------- -# Internal parse helpers -# --------------------------------------------------------------------------- - - -def _parse_owner(raw: Any) -> MarketplaceOwner: - """Parse and validate the ``owner`` block.""" - if not isinstance(raw, dict): - raise MarketplaceYmlError("'owner' must be a mapping with at least a 'name' key") - name = _require_str(raw, "name", context="owner") - email = raw.get("email") - if email is not None: - email = str(email).strip() or None - url = raw.get("url") - if url is not None: - url = str(url).strip() or None - return MarketplaceOwner(name=name, email=email, url=url) - - -def _parse_build(raw: Any) -> MarketplaceBuild: - """Parse and validate the ``build`` block.""" - if raw is None: - return MarketplaceBuild() - if not isinstance(raw, dict): - raise MarketplaceYmlError("'build' must be a mapping") - _check_unknown_keys(raw, _BUILD_KEYS, context="build") - tag_pattern = raw.get("tagPattern", "v{version}") - if not isinstance(tag_pattern, str) or not tag_pattern.strip(): - raise MarketplaceYmlError("'build.tagPattern' must be a non-empty string") - tag_pattern = tag_pattern.strip() - _validate_tag_pattern(tag_pattern, context="build.tagPattern") - return MarketplaceBuild(tag_pattern=tag_pattern) - - -def _parse_versioning(raw: Any) -> MarketplaceVersioning: - """Parse and validate the optional ``marketplace.versioning`` block.""" - if raw is None: - return MarketplaceVersioning() - if not isinstance(raw, dict): - raise MarketplaceYmlError(f"'versioning' must be a mapping, got {type(raw).__name__}") - _check_unknown_keys(raw, _VERSIONING_KEYS, context="versioning") - strategy = raw.get("strategy", "lockstep") - if not isinstance(strategy, str) or not strategy.strip(): - raise MarketplaceYmlError("'versioning.strategy' must be a non-empty string") - strategy = strategy.strip() - if strategy not in _VERSIONING_STRATEGIES: - valid = ", ".join(sorted(_VERSIONING_STRATEGIES)) - raise MarketplaceYmlError( - f"'versioning.strategy' must be one of: {valid}; got {strategy!r}" - ) - return MarketplaceVersioning(strategy=strategy) - - -def _parse_claude(raw: Any, *, default_output: str) -> MarketplaceClaudeConfig: - """Parse and validate the optional ``marketplace.claude`` block.""" - if raw is None: - return MarketplaceClaudeConfig(output=default_output) - if not isinstance(raw, dict): - raise MarketplaceYmlError("'claude' must be a mapping") - _check_unknown_keys(raw, _CLAUDE_KEYS, context="claude") - - output = raw.get("output", default_output) - if not isinstance(output, str) or not output.strip(): - raise MarketplaceYmlError("'claude.output' must be a non-empty string") - output = output.strip() - try: - validate_path_segments(output, context="claude.output") - except PathTraversalError as exc: - raise MarketplaceYmlError(str(exc)) from exc - - return MarketplaceClaudeConfig(output=output) - - -def _parse_codex(raw: Any) -> MarketplaceCodexConfig: - """Parse and validate the optional ``marketplace.codex`` block.""" - if raw is None: - return MarketplaceCodexConfig() - if not isinstance(raw, dict): - raise MarketplaceYmlError("'codex' must be a mapping") - _check_unknown_keys(raw, _CODEX_KEYS, context="codex") - - output = raw.get("output", MARKETPLACE_OUTPUTS["codex"].default_output) - if not isinstance(output, str) or not output.strip(): - raise MarketplaceYmlError("'codex.output' must be a non-empty string") - output = output.strip() - try: - validate_path_segments(output, context="codex.output") - except PathTraversalError as exc: - raise MarketplaceYmlError(str(exc)) from exc - - return MarketplaceCodexConfig(output=output) - - -def _parse_outputs( - raw: Any, - warnings_sink: list[str] | None = None, -) -> tuple[tuple[str, ...], tuple[MarketplaceOutputSpec, ...]]: - """Parse the marketplace output selector. - - Accepts: - - ``None`` → default (claude only). - - A list of strings → back-compat list form (emits deprecation warning). - - A string → single-element back-compat list form. - - A dict → new map form with optional per-format ``path:``. - - Returns ``(outputs_tuple, output_specs_tuple)``. - """ - if raw is None: - default_spec = MarketplaceOutputSpec( - name="claude", - path=MARKETPLACE_OUTPUTS["claude"].default_output, - path_explicit=False, - ) - return ("claude",), (default_spec,) - - # --- Map form (new) --- - if isinstance(raw, dict): - outputs: list[str] = [] - specs: list[MarketplaceOutputSpec] = [] - seen: set[str] = set() - known = known_output_names() - - for key, value in raw.items(): - if not isinstance(key, str) or not key.strip(): - raise MarketplaceYmlError("'outputs' map keys must be non-empty strings") - name = key.strip() - if name not in known: - raise MarketplaceYmlError( - f"Unknown marketplace output '{name}'. " - f"Permitted outputs: {', '.join(sorted(known))}" - ) - if name in seen: - raise MarketplaceYmlError(f"Duplicate marketplace output '{name}'") - seen.add(name) - - # Value can be null/{}/mapping with optional path - path_explicit = False - path = MARKETPLACE_OUTPUTS[name].default_output - if value is not None: - if not isinstance(value, dict): - raise MarketplaceYmlError(f"'outputs.{name}' must be a mapping or null") - raw_path = value.get("path") - if raw_path is not None: - if not isinstance(raw_path, str) or not raw_path.strip(): - raise MarketplaceYmlError( - f"'outputs.{name}.path' must be a non-empty string" - ) - path = raw_path.strip() - path_explicit = True - try: - validate_path_segments(path, context=f"outputs.{name}.path") - except PathTraversalError as exc: - raise MarketplaceYmlError(str(exc)) from exc - # Check for unknown keys inside the format entry - _valid_output_entry_keys = {"path"} - unknown = set(value.keys()) - _valid_output_entry_keys - if unknown: - raise MarketplaceYmlError( - f"Unknown key(s) in 'outputs.{name}': {', '.join(sorted(unknown))}" - ) - - outputs.append(name) - specs.append(MarketplaceOutputSpec(name=name, path=path, path_explicit=path_explicit)) - - if not outputs: - raise MarketplaceYmlError("'outputs' must contain at least one marketplace output") - return tuple(outputs), tuple(specs) - - # --- List / string form (deprecated back-compat) --- - if isinstance(raw, str): - raw_items = [raw] - elif isinstance(raw, list): - raw_items = raw - else: - raise MarketplaceYmlError("'outputs' must be a string, list, or mapping") - - outputs_list: list[str] = [] - specs_list: list[MarketplaceOutputSpec] = [] - seen_set: set[str] = set() - for index, item in enumerate(raw_items): - if not isinstance(item, str) or not item.strip(): - raise MarketplaceYmlError(f"'outputs[{index}]' must be a non-empty string") - output = item.strip() - known_outputs = known_output_names() - if output not in known_outputs: - raise MarketplaceYmlError( - f"Unknown marketplace output '{output}'. " - f"Permitted outputs: {', '.join(sorted(known_outputs))}" - ) - if output in seen_set: - raise MarketplaceYmlError(f"Duplicate marketplace output '{output}'") - seen_set.add(output) - outputs_list.append(output) - specs_list.append( - MarketplaceOutputSpec( - name=output, - path=MARKETPLACE_OUTPUTS[output].default_output, - path_explicit=False, - ) - ) - - if not outputs_list: - raise MarketplaceYmlError("'outputs' must contain at least one marketplace output") - - # Emit deprecation warning for list/string form - names_str = ", ".join(outputs_list) - map_lines = "\n".join(f" {n}: {{}}" for n in outputs_list) - deprecation_msg = ( - f"outputs: [{names_str}] is deprecated; use the map form:\n\n" - f" outputs:\n{map_lines}\n\n" - f" The list form will be removed in v0.15." - ) - if warnings_sink is not None: - warnings_sink.append(deprecation_msg) - - return tuple(outputs_list), tuple(specs_list) - - -def _parse_package_entry(raw: Any, index: int) -> PackageEntry: - """Parse and validate a single ``packages`` entry.""" - if not isinstance(raw, dict): - raise MarketplaceYmlError(f"packages[{index}] must be a mapping") - - # -- strict key check -- - _check_unknown_keys(raw, _PACKAGE_ENTRY_KEYS, context=f"packages[{index}]") - - name = _require_str(raw, "name", context=f"packages[{index}]") - source = _require_str(raw, "source", context=f"packages[{index}]") - _validate_source(source, index=index) - is_local = bool(LOCAL_SOURCE_RE.match(source)) - # Detect host-prefixed source (e.g. ``host.tld/owner/repo``) and split - # the host off so downstream consumers continue to see ``owner/repo``. - host: str | None = None - if not is_local: - host, source = split_host_from_source(source) - - # APM-only: subdir (irrelevant for local packages but harmless) - subdir: str | None = raw.get("subdir") - if subdir is not None: - if not isinstance(subdir, str) or not subdir.strip(): - raise MarketplaceYmlError(f"'packages[{index}].subdir' must be a non-empty string") - subdir = subdir.strip() - try: - validate_path_segments(subdir, context=f"packages[{index}].subdir") - except PathTraversalError as exc: - raise MarketplaceYmlError(str(exc)) from exc - - # APM-only: version (semver range -- stored as string, not parsed here) - version: str | None = raw.get("version") - if version is not None: - version = str(version).strip() - if not version: - raise MarketplaceYmlError(f"'packages[{index}].version' must be a non-empty string") - - # APM-only: ref - ref: str | None = raw.get("ref") - if ref is not None: - ref = str(ref).strip() - if not ref: - raise MarketplaceYmlError(f"'packages[{index}].ref' must be a non-empty string") - - # At least one of version or ref must be present for REMOTE packages. - # Local-path packages skip git resolution so the requirement does not - # apply to them. - if not is_local and version is None and ref is None: - raise MarketplaceYmlError( - f"packages[{index}] ('{name}'): remote packages require at " - f"least one of 'version' or 'ref'" - ) - - # APM-only: tag_pattern - tag_pattern: str | None = raw.get("tag_pattern") - if tag_pattern is not None: - if not isinstance(tag_pattern, str) or not tag_pattern.strip(): - raise MarketplaceYmlError(f"'packages[{index}].tag_pattern' must be a non-empty string") - tag_pattern = tag_pattern.strip() - _validate_tag_pattern(tag_pattern, context=f"packages[{index}].tag_pattern") - - # APM-only: include_prerelease - include_prerelease = raw.get("include_prerelease", False) - if not isinstance(include_prerelease, bool): - raise MarketplaceYmlError(f"'packages[{index}].include_prerelease' must be a boolean") - - # Anthropic pass-through: description - description: str | None = raw.get("description") - if description is not None: - if not isinstance(description, str) or not description.strip(): - raise MarketplaceYmlError(f"'packages[{index}].description' must be a non-empty string") - description = description.strip() - - # Anthropic pass-through: homepage - homepage: str | None = raw.get("homepage") - if homepage is not None: - if not isinstance(homepage, str) or not homepage.strip(): - raise MarketplaceYmlError(f"'packages[{index}].homepage' must be a non-empty string") - homepage = homepage.strip() - - # Anthropic pass-through: tags - raw_tags = raw.get("tags") - tags: tuple[str, ...] = () - if raw_tags is not None: - if not isinstance(raw_tags, list): - raise MarketplaceYmlError(f"'packages[{index}].tags' must be a list of strings") - for i, item in enumerate(raw_tags): - if not isinstance(item, str): - raise MarketplaceYmlError( - f"'packages[{index}].tags[{i}]' must be a string, got {type(item).__name__}" - ) - tags = tuple(str(t) for t in raw_tags) - - # Anthropic pass-through: keywords (alias for tags -- merged, deduplicated) - raw_keywords = raw.get("keywords") - if raw_keywords is not None: - if not isinstance(raw_keywords, list): - raise MarketplaceYmlError(f"'packages[{index}].keywords' must be a list of strings") - for i, item in enumerate(raw_keywords): - if not isinstance(item, str): - raise MarketplaceYmlError( - f"'packages[{index}].keywords[{i}]' must be a string, got {type(item).__name__}" - ) - # Merge: tags first, then keywords entries (deduplicated) - seen = set(tags) - merged = list(tags) - for kw in raw_keywords: - if kw not in seen: - seen.add(kw) - merged.append(kw) - tags = tuple(merged) - - # S4: cap tags array length and item length - if len(tags) > _MAX_TAGS_COUNT: - import logging as _logging - - _logging.getLogger(__name__).warning( - "packages[%d] ('%s'): tags truncated from %d to %d items", - index, - name, - len(tags), - _MAX_TAGS_COUNT, - ) - tags = tags[:_MAX_TAGS_COUNT] - tags = tuple(t[:_MAX_TAG_LENGTH] for t in tags) - - # Anthropic pass-through: author -- accept string OR object input, - # normalize to ``{name, email?, url?}`` per the Claude Code plugin - # manifest schema (json.schemastore.org/claude-code-plugin-manifest.json). - author = _parse_author(raw.get("author"), index) - - # Anthropic pass-through: license (S3 -- must be str) - license_val: str | None = raw.get("license") - if license_val is not None: - if not isinstance(license_val, str) or not license_val.strip(): - raise MarketplaceYmlError(f"'packages[{index}].license' must be a non-empty string") - license_val = license_val.strip() - - # Anthropic pass-through: repository (S3 -- must be str) - repository: str | None = raw.get("repository") - if repository is not None: - if not isinstance(repository, str) or not repository.strip(): - raise MarketplaceYmlError(f"'packages[{index}].repository' must be a non-empty string") - repository = repository.strip() - - # Optional marketplace category. Claude output strips this; Codex output - # requires and emits it. - category: str | None = None - raw_category = raw.get("category") - if raw_category is not None: - if not isinstance(raw_category, str) or not raw_category.strip(): - raise MarketplaceYmlError(f"'packages[{index}].category' must be a non-empty string") - category = raw_category.strip() - - return PackageEntry( - name=name, - source=source, - subdir=subdir, - version=version, - ref=ref, - tag_pattern=tag_pattern, - include_prerelease=include_prerelease, - description=description, - homepage=homepage, - tags=tags, - author=author, - license=license_val, - repository=repository, - category=category, - is_local=is_local, - host=host, - ) - - # --------------------------------------------------------------------------- # Public loader # --------------------------------------------------------------------------- @@ -913,10 +204,8 @@ def load_marketplace_from_legacy_yml(path: Path) -> MarketplaceConfig: """ data = _read_yaml_mapping(path) - # -- strict top-level key check -- _check_unknown_keys(data, _APM_MARKETPLACE_KEYS, context="top level") - # -- required scalars -- name = _require_str(data, "name") description = _require_str(data, "description") version_str = _require_str(data, "version") @@ -971,10 +260,8 @@ def load_marketplace_from_apm_yml(apm_yml_path: Path) -> MarketplaceConfig: if not isinstance(raw_block, dict): raise MarketplaceYmlError("'marketplace' in apm.yml must be a mapping") - # -- strict marketplace-block key check -- _check_unknown_keys(raw_block, _APM_MARKETPLACE_KEYS, context="marketplace") - # -- inheritance with optional overrides -- top_name = data.get("name") top_desc = data.get("description") top_ver = data.get("version") @@ -994,19 +281,13 @@ def load_marketplace_from_apm_yml(apm_yml_path: Path) -> MarketplaceConfig: if desc_overridden: description = _require_str(raw_block, "description", context="marketplace") - else: # noqa: PLR5501 - if not isinstance(top_desc, str) or not top_desc.strip(): - description = "" - else: - description = top_desc.strip() + else: + description = top_desc.strip() if isinstance(top_desc, str) and top_desc.strip() else "" if ver_overridden: version_str = _require_str(raw_block, "version", context="marketplace") - else: # noqa: PLR5501 - if top_ver is None: # noqa: SIM108 - version_str = "" - else: - version_str = str(top_ver).strip() + else: + version_str = str(top_ver).strip() if top_ver is not None else "" if version_str: _validate_semver(version_str, context="version") @@ -1025,33 +306,10 @@ def load_marketplace_from_apm_yml(apm_yml_path: Path) -> MarketplaceConfig: # --------------------------------------------------------------------------- -# Shared internal helpers +# Shared internal config assembler # --------------------------------------------------------------------------- -def _read_yaml_mapping(path: Path) -> dict[str, Any]: - """Read *path* and return its top-level mapping or raise.""" - try: - text = path.read_text(encoding="utf-8") - except OSError as exc: - raise MarketplaceYmlError(f"Cannot read '{path}': {exc}") from exc - - try: - data = yaml.safe_load(text) - except yaml.YAMLError as exc: - detail = "" - if hasattr(exc, "problem_mark") and exc.problem_mark is not None: - mark = exc.problem_mark - detail = f" (line {mark.line + 1}, column {mark.column + 1})" - raise MarketplaceYmlError(f"YAML parse error in '{path}'{detail}: {exc}") from exc - - if data is None: - return {} - if not isinstance(data, dict): - raise MarketplaceYmlError(f"'{path}' must contain a YAML mapping at the top level") - return data - - def _build_config( *, marketplace_dict: dict[str, Any], @@ -1065,47 +323,26 @@ def _build_config( version_overridden: bool, default_output: str = ".claude-plugin/marketplace.json", ) -> MarketplaceConfig: - """Shared parser for the marketplace fields once name/desc/version - have been resolved (either inherited or read directly). + """Assemble a MarketplaceConfig from an already-parsed dict. + + Delegates field-level parsing to ``_yml_parsers._build_config_fields`` + which owns the sub-block parsers. This function owns only the + top-level wiring: path-traversal guard on ``output``, sibling-vs-map + conflict detection, and the duplicate-package-name check. """ warnings_sink: list[str] = [] - # -- owner -- - raw_owner = marketplace_dict.get("owner") - if raw_owner is None: - raise MarketplaceYmlError("'owner' is required") - owner = _parse_owner(raw_owner) - - # -- output selection -- - outputs, output_specs = _parse_outputs( - marketplace_dict.get("outputs"), warnings_sink=warnings_sink - ) - - # -- Claude output (default differs between legacy and new layouts) -- - # ``output`` remains as a backwards-compatible shorthand for - # ``claude.output``. The explicit block wins when both are present. - legacy_output = marketplace_dict.get("output") - output = default_output if legacy_output is None else legacy_output - if not isinstance(output, str) or not output.strip(): - raise MarketplaceYmlError("'output' must be a non-empty string") - output = output.strip() - - # Path-traversal guard -- reject output paths containing ".." segments. - try: - validate_path_segments(output, context="marketplace output") - except PathTraversalError as exc: - raise MarketplaceYmlError(str(exc)) from exc - - claude = _parse_claude(marketplace_dict.get("claude"), default_output=output) - output = claude.output - - # -- metadata (Anthropic pass-through, preserve verbatim) -- - metadata: dict[str, Any] = {} - raw_metadata = marketplace_dict.get("metadata") - if raw_metadata is not None: - if not isinstance(raw_metadata, dict): - raise MarketplaceYmlError("'metadata' must be a mapping") - metadata = dict(raw_metadata) + ( + owner, + outputs, + output_specs, + output, + claude, + metadata, + build, + versioning, + codex, + ) = _build_config_fields(marketplace_dict, default_output, warnings_sink) # S1: validate pluginRoot with path-safety checks if present. plugin_root = metadata.get("pluginRoot") @@ -1119,57 +356,20 @@ def _build_config( except PathTraversalError as exc: raise MarketplaceYmlError(str(exc)) from exc - # -- build -- - build = _parse_build(marketplace_dict.get("build")) - - # -- versioning (release-gate strategy) -- - versioning = _parse_versioning(marketplace_dict.get("versioning")) - - # -- codex output -- - codex = _parse_codex(marketplace_dict.get("codex")) - - # -- Sibling-vs-map conflict detection (A1: sibling wins) -- - # Only fire when the user EXPLICITLY set a sibling block AND the map - # also has an explicit path. Default/absent sibling is not a conflict. + # Sibling-vs-map conflict detection (A1: sibling wins). has_explicit_claude = marketplace_dict.get("claude") is not None has_explicit_codex = marketplace_dict.get("codex") is not None + output_specs = _resolve_output_spec_conflicts( + output_specs, + claude, + codex, + has_explicit_claude, + has_explicit_codex, + warnings_sink, + ) - final_specs_list = list(output_specs) - for i, spec in enumerate(final_specs_list): - if spec.path_explicit: - sibling_path: str | None = None - if spec.name == "claude" and has_explicit_claude and claude.output != spec.path: - sibling_path = claude.output - elif spec.name == "codex" and has_explicit_codex and codex.output != spec.path: - sibling_path = codex.output - if sibling_path is not None: - warnings_sink.append( - f"marketplace.outputs.{spec.name}.path ('{spec.path}') " - f"conflicts with marketplace.{spec.name}.output " - f"('{sibling_path}').\n" - f" Using marketplace.{spec.name}.output for backwards " - f"compatibility.\n\n" - f" To resolve: pick one source and remove the other.\n" - f" Keep map form (recommended):\n" - f" outputs:\n" - f" {spec.name}:\n" - f" path: {sibling_path}\n" - f" # remove the marketplace.{spec.name}: block\n\n" - f" The marketplace.{spec.name} sibling block becomes a " - f"schema error in v0.15." - ) - # Sibling wins: override the spec's path - final_specs_list[i] = MarketplaceOutputSpec( - name=spec.name, - path=sibling_path, - path_explicit=True, - ) - output_specs = tuple(final_specs_list) - - # -- packages -- - raw_packages = marketplace_dict.get("packages") - if raw_packages is None: - raw_packages = [] + # Packages + raw_packages = marketplace_dict.get("packages") or [] if not isinstance(raw_packages, list): raise MarketplaceYmlError("'packages' must be a list") @@ -1189,12 +389,11 @@ def _build_config( for output_name in outputs: profile = MARKETPLACE_OUTPUTS[output_name] for field_name in profile.required_package_fields: - missing = [entry.name for entry in entries if not getattr(entry, field_name)] + missing = [e.name for e in entries if not getattr(e, field_name)] if missing: - names = ", ".join(missing) raise MarketplaceYmlError( f"packages must define '{field_name}' when marketplace.outputs includes " - f"'{output_name}' (missing: {names})" + f"'{output_name}' (missing: {', '.join(missing)})" ) return MarketplaceConfig( @@ -1218,3 +417,46 @@ def _build_config( description_overridden=description_overridden, version_overridden=version_overridden, ) + + +def _resolve_output_spec_conflicts( + output_specs: tuple[MarketplaceOutputSpec, ...], + claude: MarketplaceClaudeConfig, + codex: MarketplaceCodexConfig, + has_explicit_claude: bool, + has_explicit_codex: bool, + warnings_sink: list[str], +) -> tuple[MarketplaceOutputSpec, ...]: + """Apply sibling-wins rule when outputs map and sibling block conflict.""" + final_specs_list = list(output_specs) + for i, spec in enumerate(final_specs_list): + if not spec.path_explicit: + continue + sibling_path: str | None = None + if spec.name == "claude" and has_explicit_claude and claude.output != spec.path: + sibling_path = claude.output + elif spec.name == "codex" and has_explicit_codex and codex.output != spec.path: + sibling_path = codex.output + if sibling_path is None: + continue + warnings_sink.append( + f"marketplace.outputs.{spec.name}.path ('{spec.path}') " + f"conflicts with marketplace.{spec.name}.output " + f"('{sibling_path}').\n" + f" Using marketplace.{spec.name}.output for backwards " + f"compatibility.\n\n" + f" To resolve: pick one source and remove the other.\n" + f" Keep map form (recommended):\n" + f" outputs:\n" + f" {spec.name}:\n" + f" path: {sibling_path}\n" + f" # remove the marketplace.{spec.name}: block\n\n" + f" The marketplace.{spec.name} sibling block becomes a " + f"schema error in v0.15." + ) + final_specs_list[i] = MarketplaceOutputSpec( + name=spec.name, + path=sibling_path, + path_explicit=True, + ) + return tuple(final_specs_list) diff --git a/src/apm_cli/policy/_constraint_pinning.py b/src/apm_cli/policy/_constraint_pinning.py index 432c0d959..3f25683ac 100644 --- a/src/apm_cli/policy/_constraint_pinning.py +++ b/src/apm_cli/policy/_constraint_pinning.py @@ -114,6 +114,24 @@ def _classify_range(spec: str) -> UnboundedReason | None: return None +def _classify_registry_ref( + ref: str | None, + is_semver_range, +) -> UnboundedReason | None: + """Classify an unbounded reason for a registry dependency's constraint. + + Extracted from :func:`classify_unbounded_reason` to reduce that + function's return count below the PLR0911 threshold. + """ + if ref is None or not ref.strip(): + return UnboundedReason.NO_REF + if is_semver_range(ref): + return _classify_range(ref) + # Registry resolver rejects non-semver refs at parse time, but + # defence-in-depth: treat anything else as a bare branch. + return UnboundedReason.BARE_BRANCH + + def classify_unbounded_reason(dep: DependencyReference) -> UnboundedReason | None: """Return ``None`` if *dep*'s constraint is pinned, otherwise the reason. @@ -139,14 +157,7 @@ def classify_unbounded_reason(dep: DependencyReference) -> UnboundedReason | Non # 2. Registry deps: the ref IS the semver range (or a single version). if source == "registry": - if ref is None or not ref.strip(): - # A registry dep without a constraint is itself unbounded. - return UnboundedReason.NO_REF - if is_semver_range(ref): - return _classify_range(ref) - # Registry resolver rejects non-semver refs at parse time, but - # defence-in-depth: treat anything else as a bare branch. - return UnboundedReason.BARE_BRANCH + return _classify_registry_ref(ref, is_semver_range) # 3. Empty / missing ref. if ref is None or not ref.strip(): diff --git a/src/apm_cli/policy/_discovery_cache.py b/src/apm_cli/policy/_discovery_cache.py new file mode 100644 index 000000000..6a728efc0 --- /dev/null +++ b/src/apm_cli/policy/_discovery_cache.py @@ -0,0 +1,553 @@ +"""Cache I/O and hash-pin verification helpers for policy discovery. + +Leaf module -- does NOT import ``discovery.py`` at module scope. +``PolicyFetchResult`` (defined in ``discovery.py``) is imported +function-locally inside the three helpers that return it so the +import graph stays acyclic. + +Symbols that tests patch via ``apm_cli.policy.discovery.`` +remain patchable because ``discovery.py`` re-exports every public +name from this module with the ``NAME as NAME`` redundant-alias form. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path + +import yaml + +from ..utils.path_security import PathTraversalError, ensure_path_within +from .parser import load_policy +from .project_config import ( + _DEFAULT_HASH_ALGORITHM, + _HASH_HEX_LEN, + _HEX_RE, + ALLOWED_HASH_ALGORITHMS, + ProjectPolicyConfigError, + compute_policy_hash, +) +from .schema import ApmPolicy + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Cache constants +# --------------------------------------------------------------------------- + +POLICY_CACHE_DIR = ".policy-cache" +DEFAULT_CACHE_TTL = 3600 # 1 hour +MAX_STALE_TTL = 7 * 24 * 3600 # 7 days -- stale cache usable on refresh failure +CACHE_SCHEMA_VERSION = "3" # Bump when cache format changes to auto-invalidate + + +# --------------------------------------------------------------------------- +# Hash-pin helpers +# --------------------------------------------------------------------------- + + +def _split_hash_pin(expected_hash: str) -> tuple[str, str]: + """Split an ``":"`` pin into (algorithm, lowercase_hex). + + Bare hex (no prefix) is interpreted as sha256 for backwards + compatibility -- callers that care about the algorithm should pass a + fully-qualified pin. Raises :class:`ProjectPolicyConfigError` on a + structurally invalid pin (unsupported algorithm, wrong length, non + hex). The discovery helpers translate that into a fail-closed + ``hash_mismatch`` outcome rather than crashing. + """ + raw = expected_hash.strip() + if ":" in raw: + algo, _, hex_part = raw.partition(":") + algo = algo.strip().lower() + else: + algo = _DEFAULT_HASH_ALGORITHM + hex_part = raw + hex_part = hex_part.strip().lower() + if algo not in ALLOWED_HASH_ALGORITHMS: + raise ProjectPolicyConfigError(f"Unsupported policy.hash algorithm '{algo}'") + expected_len = _HASH_HEX_LEN[algo] + if len(hex_part) != expected_len or not _HEX_RE.match(hex_part): + raise ProjectPolicyConfigError(f"policy.hash is not a valid {algo} digest") + return algo, hex_part + + +def _compute_hash_normalized(content: str, expected_hash: str | None) -> str: + """Compute the digest of *content* under the algorithm declared by + *expected_hash*, returning the canonical ``":"`` form. + + When *expected_hash* is ``None`` the default algorithm (sha256) is + used so the cache always carries a digest for later pin verification. + """ + algo = _DEFAULT_HASH_ALGORITHM + if expected_hash: + try: + algo, _ = _split_hash_pin(expected_hash) + except ProjectPolicyConfigError: + algo = _DEFAULT_HASH_ALGORITHM + digest = compute_policy_hash(content, algo) + return f"{algo}:{digest}" + + +def _verify_hash_pin( + content: object, + expected_hash: str | None, + source_label: str, +) -> object: # PolicyFetchResult | None + """Verify fetched policy bytes against the project's pin (#827). + + Returns ``None`` when there is no pin, or the digest matches. On + mismatch -- or on a structurally invalid pin, which is treated as a + mismatch to stay fail-closed -- returns a :class:`PolicyFetchResult` + with ``outcome="hash_mismatch"`` that callers must propagate. + """ + # Deferred import: PolicyFetchResult lives in discovery.py; importing it + # here at module scope would create a cycle. + from .discovery import PolicyFetchResult + + if expected_hash is None: + return None + + raw_bytes: bytes + if isinstance(content, bytes): + raw_bytes = content + elif isinstance(content, str): + raw_bytes = content.encode("utf-8") + else: + return PolicyFetchResult( + outcome="hash_mismatch", + source=source_label, + error=( + f"Policy hash mismatch from {source_label}: " + "no content available to verify against pin" + ), + expected_hash=expected_hash, + ) + + try: + algo, expected_hex = _split_hash_pin(expected_hash) + except ProjectPolicyConfigError as exc: + return PolicyFetchResult( + outcome="hash_mismatch", + source=source_label, + error=(f"Policy hash mismatch from {source_label}: invalid pin ({exc})"), + expected_hash=expected_hash, + ) + + digest = hashlib.new(algo) + digest.update(raw_bytes) + actual_hex = digest.hexdigest().lower() + if actual_hex == expected_hex: + return None + + expected_norm = f"{algo}:{expected_hex}" + actual_norm = f"{algo}:{actual_hex}" + return PolicyFetchResult( + outcome="hash_mismatch", + source=source_label, + error=( + f"Policy hash mismatch from {source_label}: expected {expected_norm}, got {actual_norm}" + ), + expected_hash=expected_norm, + raw_bytes_hash=actual_norm, + ) + + +# --------------------------------------------------------------------------- +# Cache entry dataclass +# --------------------------------------------------------------------------- + + +@dataclass +class _CacheEntry: + """Internal representation of a cached policy read.""" + + policy: ApmPolicy + source: str + age_seconds: int + stale: bool # True if past TTL (but within MAX_STALE_TTL) + chain_refs: list[str] = field(default_factory=list) + fingerprint: str = "" + raw_bytes_hash: str = "" # ":" of leaf bytes off wire (#827) + + +# --------------------------------------------------------------------------- +# Cache helpers +# --------------------------------------------------------------------------- + + +def _get_cache_dir(project_root: Path) -> Path: + """Get the policy cache directory. + + Path-security guard (#832): the resulting path is asserted to live + within ``project_root``. This catches the edge case where + ``apm_modules`` itself is a symlink that points outside the + project root -- a configuration that, while unusual, would let + cache reads/writes escape the project tree. + """ + # Resolve early so candidate inherits long-name form on Windows; + # without this, resolve() on a not-yet-existing candidate keeps + # 8.3 short names while the base resolves to long names (#886). + project_root = project_root.resolve() + base = project_root / "apm_modules" + candidate = base / POLICY_CACHE_DIR + # Resolve both ends and assert containment under ``project_root``, + # not under ``base`` -- otherwise a symlinked apm_modules pointing + # outside the project would resolve through the symlink on both + # sides and the check would silently pass. + try: + ensure_path_within(candidate, project_root) + except PathTraversalError: + raise PathTraversalError( # noqa: B904 + f"Policy cache path '{candidate}' resolves outside " + f"project root '{project_root}' -- refusing to read or " + "write the cache here." + ) + return candidate + + +def _cache_key(repo_ref: str) -> str: + """Generate a deterministic cache filename from repo ref.""" + return hashlib.sha256(repo_ref.encode()).hexdigest()[:16] + + +def _policy_to_dict(policy: ApmPolicy) -> dict: + """Serialize an ApmPolicy to a dict matching the YAML schema.""" + + def _opt_list(val: tuple[str, ...] | None) -> list | None: + return None if val is None else list(val) + + return { + "name": policy.name, + "version": policy.version, + "enforcement": policy.enforcement, + "fetch_failure": policy.fetch_failure, + "cache": {"ttl": policy.cache.ttl}, + "dependencies": { + "allow": _opt_list(policy.dependencies.allow), + "deny": _opt_list(policy.dependencies.deny), + "require": _opt_list(policy.dependencies.require), + "require_resolution": policy.dependencies.require_resolution, + "max_depth": policy.dependencies.max_depth, + }, + "mcp": { + "allow": _opt_list(policy.mcp.allow), + "deny": list(policy.mcp.deny), + "transport": { + "allow": _opt_list(policy.mcp.transport.allow), + }, + "self_defined": policy.mcp.self_defined, + "trust_transitive": policy.mcp.trust_transitive, + }, + "compilation": { + "target": { + "allow": _opt_list(policy.compilation.target.allow), + "enforce": policy.compilation.target.enforce, + }, + "strategy": { + "enforce": policy.compilation.strategy.enforce, + }, + "source_attribution": policy.compilation.source_attribution, + }, + "manifest": { + "required_fields": list(policy.manifest.required_fields), + "scripts": policy.manifest.scripts, + "content_types": policy.manifest.content_types, + }, + "unmanaged_files": { + "action": policy.unmanaged_files.action, + "directories": list(policy.unmanaged_files.directories or ()), + }, + } + + +def _serialize_policy(policy: ApmPolicy) -> str: + """Serialize an ApmPolicy to deterministic YAML for caching.""" + return yaml.dump( + _policy_to_dict(policy), default_flow_style=False, sort_keys=True + ) # yaml-io-exempt + + +def _policy_fingerprint(serialized: str) -> str: + """Compute a fingerprint of a serialized policy.""" + return hashlib.sha256(serialized.encode("utf-8")).hexdigest()[:32] + + +def _is_policy_empty(policy: ApmPolicy) -> bool: + """Return True if a policy has no actionable restrictions. + + An 'empty' policy is syntactically valid but imposes no constraints + beyond the permissive defaults. + """ + return ( + not policy.dependencies.effective_deny + and policy.dependencies.allow is None + and not policy.dependencies.effective_require + and not policy.mcp.deny + and policy.mcp.allow is None + and policy.mcp.transport.allow is None + and policy.compilation.target.allow is None + and not policy.manifest.required_fields + and policy.manifest.scripts == "allow" + and policy.manifest.content_types is None + and policy.unmanaged_files.effective_action == "ignore" + ) + + +def _stale_fallback_or_error( + cache_entry: _CacheEntry | None, + fetch_error_msg: str, + source_label: str, + outcome_on_miss: str, +) -> object: # PolicyFetchResult + """Return stale cache if available, otherwise error with given outcome.""" + from .discovery import PolicyFetchResult + + if cache_entry is not None: + return PolicyFetchResult( + policy=cache_entry.policy, + source=cache_entry.source, + cached=True, + cache_stale=True, + cache_age_seconds=cache_entry.age_seconds, + fetch_error=fetch_error_msg, + outcome="cached_stale", + ) + return PolicyFetchResult( + error=fetch_error_msg, + source=source_label, + fetch_error=fetch_error_msg, + outcome=outcome_on_miss, + ) + + +def _detect_garbage( + content: str | None, + identifier: str, + source_label: str, + cache_entry: _CacheEntry | None, +) -> object: # PolicyFetchResult | None + """Detect garbage responses (200 OK with non-YAML body). + + Returns a PolicyFetchResult if the content is garbage (stale fallback + or garbage_response outcome), or None if the content looks parseable. + """ + from .discovery import PolicyFetchResult + + if content is None: + return None + + try: + raw_data = yaml.safe_load(content) + except yaml.YAMLError: + msg = f"Response from {identifier} is not valid YAML" + if cache_entry is not None: + return PolicyFetchResult( + policy=cache_entry.policy, + source=cache_entry.source, + cached=True, + cache_stale=True, + cache_age_seconds=cache_entry.age_seconds, + fetch_error=msg, + outcome="cached_stale", + ) + return PolicyFetchResult( + error=msg + " (possible captive portal or redirect)", + source=source_label, + fetch_error=msg, + outcome="garbage_response", + ) + + if raw_data is not None and not isinstance(raw_data, dict): + msg = f"Response from {identifier} is not a YAML mapping" + if cache_entry is not None: + return PolicyFetchResult( + policy=cache_entry.policy, + source=cache_entry.source, + cached=True, + cache_stale=True, + cache_age_seconds=cache_entry.age_seconds, + fetch_error=msg, + outcome="cached_stale", + ) + return PolicyFetchResult( + error=msg, + source=source_label, + fetch_error=msg, + outcome="garbage_response", + ) + + return None # Not garbage -- proceed with normal parsing + + +def _read_cache_entry( + repo_ref: str, + project_root: Path, + ttl: int = DEFAULT_CACHE_TTL, + *, + expected_hash: str | None = None, +) -> _CacheEntry | None: + """Read cache entry with stale-awareness. + + Returns: + * ``_CacheEntry(stale=False)`` -- within TTL, ready for immediate use + * ``_CacheEntry(stale=True)`` -- past TTL but within MAX_STALE_TTL + * ``None`` -- no cache file, corrupt, past MAX_STALE_TTL, + or pin verification failure (#827). + + When *expected_hash* is provided the cached ``raw_bytes_hash`` is + compared against it; a mismatch invalidates the cache entry so the + caller falls through to a fresh fetch where the pin can be verified + against authoritative bytes off the wire. + """ + cache_dir = _get_cache_dir(project_root) + key = _cache_key(repo_ref) + policy_file = cache_dir / f"{key}.yml" + meta_file = cache_dir / f"{key}.meta.json" + + if not policy_file.exists() or not meta_file.exists(): + return None + + try: + meta = json.loads(meta_file.read_text(encoding="utf-8")) + + # Schema version check -- auto-invalidate on format change + if meta.get("schema_version") != CACHE_SCHEMA_VERSION: + return None + + cached_at = meta.get("cached_at", 0) + age = int(time.time() - cached_at) + + if age > MAX_STALE_TTL: + return None # Past MAX_STALE_TTL, unusable + + raw_bytes_hash = meta.get("raw_bytes_hash", "") or "" + + # Pin verification (#827): if the project pinned a hash and the + # cache was written without one (legacy entry) or with a different + # one, ignore the cache so the fetcher can verify the pin against + # fresh authoritative bytes. + if expected_hash is not None: + try: + exp_algo, exp_hex = _split_hash_pin(expected_hash) + expected_norm = f"{exp_algo}:{exp_hex}" + except ProjectPolicyConfigError: + return None + if raw_bytes_hash.lower() != expected_norm: + return None + + policy, _warnings = load_policy(policy_file) + + # Determine source label + if repo_ref.startswith("http://") or repo_ref.startswith("https://"): + source = f"url:{repo_ref}" + else: + source = f"org:{repo_ref}" + + return _CacheEntry( + policy=policy, + source=source, + age_seconds=age, + stale=age > ttl, + chain_refs=meta.get("chain_refs", [repo_ref]), + fingerprint=meta.get("fingerprint", ""), + raw_bytes_hash=raw_bytes_hash, + ) + except Exception: + return None + + +def _read_cache( + repo_ref: str, + project_root: Path, + ttl: int = DEFAULT_CACHE_TTL, +) -> object: # PolicyFetchResult | None + """Read policy from cache if still valid (within TTL). + + Legacy wrapper around ``_read_cache_entry`` for backward compatibility. + Returns None if cache miss, expired, or past MAX_STALE_TTL. + """ + from .discovery import PolicyFetchResult + + entry = _read_cache_entry(repo_ref, project_root, ttl=ttl) + if entry is None or entry.stale: + return None + outcome = "empty" if _is_policy_empty(entry.policy) else "found" + return PolicyFetchResult( + policy=entry.policy, + source=entry.source, + cached=True, + cache_age_seconds=entry.age_seconds, + outcome=outcome, + ) + + +def _write_cache( + repo_ref: str, + policy: ApmPolicy, + project_root: Path, + *, + chain_refs: list[str] | None = None, + raw_bytes_hash: str | None = None, +) -> None: + """Write merged effective policy and metadata to cache atomically. + + Uses temp file + ``os.replace()`` to prevent torn writes from parallel + installs. Both the policy file and metadata sidecar are written + atomically and independently. + + The optional ``raw_bytes_hash`` (canonical ``":"``) is the + digest of the leaf bytes off the wire and is persisted to the meta + sidecar so subsequent cached reads can verify against the project's + pin without re-fetching (#827). + """ + cache_dir = _get_cache_dir(project_root) + cache_dir.mkdir(parents=True, exist_ok=True) + + key = _cache_key(repo_ref) + policy_file = cache_dir / f"{key}.yml" + meta_file = cache_dir / f"{key}.meta.json" + + serialized = _serialize_policy(policy) + fingerprint = _policy_fingerprint(serialized) + + # Unique tmp suffix to avoid collisions from parallel writers + uid = f"{os.getpid()}.{threading.get_ident()}" + + # Atomic write: policy file + tmp_policy = cache_dir / f"{key}.{uid}.yml.tmp" + try: + tmp_policy.write_text(serialized, encoding="utf-8") + os.replace(str(tmp_policy), str(policy_file)) + except OSError: + # Best-effort cleanup + try: # noqa: SIM105 + tmp_policy.unlink(missing_ok=True) + except OSError: + pass + return + + # Atomic write: metadata sidecar + meta = { + "repo_ref": repo_ref, + "cached_at": time.time(), + "chain_refs": chain_refs if chain_refs is not None else [repo_ref], + "schema_version": CACHE_SCHEMA_VERSION, + "fingerprint": fingerprint, + "raw_bytes_hash": raw_bytes_hash or "", + } + tmp_meta = cache_dir / f"{key}.{uid}.meta.json.tmp" + try: + tmp_meta.write_text(json.dumps(meta), encoding="utf-8") + os.replace(str(tmp_meta), str(meta_file)) + except OSError: + try: # noqa: SIM105 + tmp_meta.unlink(missing_ok=True) + except OSError: + pass diff --git a/src/apm_cli/policy/_discovery_chain.py b/src/apm_cli/policy/_discovery_chain.py new file mode 100644 index 000000000..019b9264a --- /dev/null +++ b/src/apm_cli/policy/_discovery_chain.py @@ -0,0 +1,272 @@ +"""Chain-resolution and host-pin helpers for policy discovery. + +Leaf module -- does NOT import ``discovery.py`` at module scope. +Back-references to ``discovery`` symbols (``urlparse``, +``_extract_org_from_git_remote``, ``discover_policy``, ``_write_cache``) +use Rule-B function-local imports so that test patches applied via +``apm_cli.policy.discovery.`` are still honoured. + +Rule-B routing table (all inside ``from apm_cli.policy import discovery as _d``): +- ``_derive_leaf_host``: ``_d.urlparse`` (6 patches), ``_d._extract_org_from_git_remote`` (13 patches) +- ``_extract_extends_host``: ``_d.urlparse`` (6 patches) +- ``_resolve_and_persist_chain``: ``_d.discover_policy`` (20 patches), ``_d._write_cache`` (22 patches) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .schema import ApmPolicy + + +# --------------------------------------------------------------------------- +# Source-label helpers +# --------------------------------------------------------------------------- + + +def _strip_source_prefix(src: str) -> str: + """Strip 'org:' / 'url:' / 'file:' prefix from a PolicyFetchResult.source.""" + return src.removeprefix("org:").removeprefix("url:").removeprefix("file:") + + +def _derive_leaf_host(source: str, project_root: Path) -> str | None: + """Derive the origin host of the leaf policy. + + The leaf host pins which host an ``extends:`` reference may resolve + against (Security Finding F1 -- prevents credential leakage to + attacker-controlled hosts via cross-host extends chains). + + Returns the host in lowercase, or None if it cannot be determined. + + Source forms: + * ``url:https:///...`` -> ```` + * ``org://`` (3+ slash-segments) -> ```` + * ``org:/`` (2 slash-segments) -> ``github.com`` (default) + * ``file:`` -> fall back to git remote of *project_root* + """ + # Rule B: import discovery so urlparse and _extract_org_from_git_remote + # are looked up in the discovery module's namespace at call time, making + # test patches on apm_cli.policy.discovery.urlparse / ._extract_org_... + # visible here too. + from apm_cli.policy import discovery as _d + + if not source: # noqa: SIM108 + bare = "" + else: + bare = _strip_source_prefix(source) + + if source.startswith("url:") or bare.startswith("https://") or bare.startswith("http://"): + try: + parsed = _d.urlparse(bare) + if parsed.hostname: + return parsed.hostname.lower() + except Exception: + return None + return None + + if source.startswith("org:") or (bare and "://" not in bare and bare.count("/") >= 1): + parts = bare.split("/") + if len(parts) >= 3: + return parts[0].lower() + if len(parts) == 2: + # owner/repo shorthand defaults to github.com (matches + # _fetch_github_contents convention). + return "github.com" + + # File source (or unrecognized): fall back to project's git remote. + org_and_host = _d._extract_org_from_git_remote(project_root) + if org_and_host is not None: + _, host = org_and_host + if host: + return host.lower() + return None + + +def _extract_extends_host(ref: str) -> str | None: + """Return the host an ``extends:`` ref resolves against, if explicit. + + * Full URL -> URL host (lowercase) + * ``//`` (3+ slash-segments) -> ```` (lowercase) + * ``/`` shorthand -> None (intrinsically same-host) + * ```` shorthand (no slash) -> None (intrinsically same-host) + """ + # Rule B: use _d.urlparse so test patches on discovery.urlparse apply. + from apm_cli.policy import discovery as _d + + if not ref: + return None + if ref.startswith("http://") or ref.startswith("https://"): + try: + parsed = _d.urlparse(ref) + if parsed.hostname: + return parsed.hostname.lower() + except Exception: + return None + return None + if "/" not in ref: + return None + parts = ref.split("/") + if len(parts) >= 3: + return parts[0].lower() + return None + + +def _validate_extends_host(leaf_host: str | None, extends_ref: str) -> None: + """Reject ``extends:`` refs that point at a different host than the leaf. + + Raises :class:`PolicyInheritanceError` (imported lazily to avoid a + module-level cycle) when the ``extends:`` ref names a host that does + not match *leaf_host*. Pure shorthand refs (``owner/repo``, ``org``) + are intrinsically same-host and always pass. + + See Security Finding F1: a malicious org policy author setting + ``extends: "evil.example.com/org/.github"`` could otherwise route + ``git credential fill`` against an attacker-controlled host. + """ + from . import inheritance as _inheritance_mod + + extends_host = _extract_extends_host(extends_ref) + if extends_host is None: + return # shorthand: intrinsically same-host, allowed. + + if leaf_host is None: + raise _inheritance_mod.PolicyInheritanceError( + f"Policy extends: cross-host reference rejected " + f"(leaf host: , extends host: {extends_host}); " + f"cross-host policy chains are not allowed" + ) + + if extends_host != leaf_host.lower(): + raise _inheritance_mod.PolicyInheritanceError( + f"Policy extends: cross-host reference rejected " + f"(leaf host: {leaf_host}, extends host: {extends_host}); " + f"cross-host policy chains are not allowed" + ) + + +def _resolve_and_persist_chain( + fetch_result: object, # PolicyFetchResult + project_root: Path, +) -> None: + """Resolve inheritance chain and update cache with merged policy + chain_refs. + + Walks the ``extends:`` chain depth-first, fetching each parent via the + single-policy ``discover_policy`` (so each fetch still hits the + well-tested fetch path). Cycle detection on normalized ``extends:`` + refs and ``MAX_CHAIN_DEPTH`` enforcement protect against runaway or + self-referential chains. + + Partial-chain policy: if any parent fetch fails, emit a warning via + ``_rich_warning`` and merge whatever was resolved so far -- never + silently drop ancestors. + + Mutates *fetch_result*.policy in-place with the merged effective policy. + Called by :func:`discover_policy_with_chain` -- not intended for direct + use. + """ + # Rule B: discover_policy and _write_cache are patched via + # apm_cli.policy.discovery.* in tests; look them up via _d so patches apply. + from apm_cli.policy import discovery as _d + + from ..utils.console import _rich_warning + from . import inheritance as _inheritance_mod + + leaf_policy = fetch_result.policy + leaf_source = fetch_result.source + + # Host pin: extends: refs may only resolve against the leaf's origin + # host. Prevents credential leakage to attacker-controlled hosts via + # cross-host extends chains (Security Finding F1). + leaf_host = _derive_leaf_host(leaf_source, project_root) + + # Ordered ancestors collected as we walk parents. Built leaf-first + # for traversal convenience; reversed before merging. + chain_policies: list[ApmPolicy] = [leaf_policy] + chain_sources: list[str] = [leaf_source] + + # Track normalized refs we've already followed to break cycles. + # We seed with the leaf's source so an extends pointing back at the + # leaf is also detected. + visited: list[str] = [_strip_source_prefix(leaf_source)] if leaf_source else [] + + current = leaf_policy + partial_warning: tuple[str, int, int] | None = None + + while current.extends: + next_ref = current.extends + + # Host pin enforcement: must validate BEFORE any fetch so we never + # call git credential fill against an attacker-controlled host. + _validate_extends_host(leaf_host, next_ref) + + if _inheritance_mod.detect_cycle(visited, next_ref): + raise _inheritance_mod.PolicyInheritanceError( + f"Cycle detected in policy extends chain: {' -> '.join(visited)} -> {next_ref}" + ) + + # Depth check: chain_policies already has len() entries; next fetch + # would push us to len()+1. resolve_policy_chain enforces this + # afterwards, but failing here gives a clearer error. + if len(chain_policies) + 1 > _inheritance_mod.MAX_CHAIN_DEPTH: + raise _inheritance_mod.PolicyInheritanceError( + f"Policy chain depth exceeds maximum of " + f"{_inheritance_mod.MAX_CHAIN_DEPTH} " + f"(chain: {' -> '.join(visited)} -> {next_ref})" + ) + + parent_result = _d.discover_policy( + project_root, + policy_override=next_ref, + no_cache=False, + ) + + if parent_result.policy is None: + # Parent fetch failed -- merge what we have so far and warn. + attempted = len(chain_policies) + 1 + resolved = len(chain_policies) + partial_warning = (next_ref, resolved, attempted) + break + + chain_policies.append(parent_result.policy) + chain_sources.append(parent_result.source) + visited.append(next_ref) + current = parent_result.policy + + # No actual ancestors fetched -- nothing to merge or re-cache. + if len(chain_policies) == 1: + if partial_warning is not None: + ref, resolved, attempted = partial_warning + _rich_warning( + f"Policy chain incomplete: {ref} unreachable, " + f"using {resolved} of {attempted} policies", + symbol="warning", + ) + return + + # Merge in [root, ..., leaf] order. We collected leaf-first, so reverse. + ordered = list(reversed(chain_policies)) + ordered_sources = list(reversed(chain_sources)) + + try: + merged = _inheritance_mod.resolve_policy_chain(ordered) + except _inheritance_mod.PolicyInheritanceError: + # Re-raise depth errors from the canonical validator so callers + # see a single consistent error type. + raise + + chain_refs: list[str] = [_strip_source_prefix(src) for src in ordered_sources if src] + + cache_key = _strip_source_prefix(leaf_source) if leaf_source else "" + if cache_key: + _d._write_cache(cache_key, merged, project_root, chain_refs=chain_refs) + + fetch_result.policy = merged + + if partial_warning is not None: + ref, resolved, attempted = partial_warning + _rich_warning( + f"Policy chain incomplete: {ref} unreachable, using {resolved} of {attempted} policies", + symbol="warning", + ) diff --git a/src/apm_cli/policy/_policy_checks_mcp.py b/src/apm_cli/policy/_policy_checks_mcp.py new file mode 100644 index 000000000..cc137b326 --- /dev/null +++ b/src/apm_cli/policy/_policy_checks_mcp.py @@ -0,0 +1,465 @@ +"""MCP, compilation, manifest, and unmanaged-files policy checks. + +Leaf module -- does NOT import ``policy_checks.py`` at module scope. +All symbols that tests import from ``apm_cli.policy.policy_checks`` +are re-exported from there with the ``NAME as NAME`` redundant-alias form. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from .models import CheckResult + +if TYPE_CHECKING: + from .schema import ( + CompilationPolicy, + ManifestPolicy, + McpPolicy, + ) + +_logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Raw manifest loader +# --------------------------------------------------------------------------- + + +def _load_raw_apm_yml(project_root: Path) -> dict | None: + """Load raw apm.yml as a dict for policy checks that inspect raw fields. + + This helper is called **after** :pymethod:`APMPackage.from_apm_yml` has + already succeeded in :func:`run_policy_checks`. The primary security + gate is ``from_apm_yml()`` -- if it fails, the audit aborts with a + ``manifest-parse`` check result and this function is never reached. + + Returning ``None`` here is therefore **defence-in-depth**: it covers + edge cases (TOCTOU race, transient I/O error) where the file becomes + unreadable between the two calls. Callers that receive ``None`` + gracefully skip supplementary raw-field checks (e.g. + ``compilation-target``, ``extensions-present``) rather than hard-failing. + + Returns ``None`` when the file is absent, unreadable, malformed YAML, + or not a mapping -- but logs a warning so the failure is visible + rather than silently swallowed. + """ + import yaml + + apm_yml_path = project_root / "apm.yml" + if not apm_yml_path.exists(): + return None + try: + with open(apm_yml_path, encoding="utf-8") as f: + data = yaml.safe_load(f) + except FileNotFoundError: + # TOCTOU: file disappeared between exists() check and open(); normal condition. + return None + except yaml.YAMLError as exc: + _logger.warning("Malformed YAML in %s: %s", apm_yml_path, exc) + return None + except OSError as exc: + _logger.warning("Cannot read %s: %s", apm_yml_path, exc) + return None + except UnicodeDecodeError as exc: + _logger.warning("Cannot decode %s as UTF-8: %s", apm_yml_path, exc) + return None + if not isinstance(data, dict): + _logger.warning( + "apm.yml is not a YAML mapping (got %s) -- skipping raw-field checks", + type(data).__name__, + ) + return None + return data + + +# --------------------------------------------------------------------------- +# MCP checks (7-10) +# --------------------------------------------------------------------------- + + +def _check_mcp_allowlist( + mcp_deps: list, + policy: McpPolicy, +) -> CheckResult: + """Check 7: MCP server names match allow list.""" + from .matcher import check_mcp_allowed + + if policy.allow is None: + return CheckResult( + name="mcp-allowlist", + passed=True, + message="No MCP allow list configured", + ) + + violations: list[str] = [] + for mcp in mcp_deps: + allowed, reason = check_mcp_allowed(mcp.name, policy) + if not allowed and "not in allowed" in reason: + violations.append(f"{mcp.name}: {reason}") + + if not violations: + return CheckResult( + name="mcp-allowlist", + passed=True, + message="All MCP servers match allow list", + ) + return CheckResult( + name="mcp-allowlist", + passed=False, + message=f"{len(violations)} MCP server(s) not in allow list", + details=violations, + ) + + +def _check_mcp_denylist( + mcp_deps: list, + policy: McpPolicy, +) -> CheckResult: + """Check 8: no MCP server matches deny list.""" + from .matcher import check_mcp_allowed + + if not policy.deny: + return CheckResult( + name="mcp-denylist", + passed=True, + message="No MCP deny list configured", + ) + + violations: list[str] = [] + for mcp in mcp_deps: + allowed, reason = check_mcp_allowed(mcp.name, policy) + if not allowed and "denied by pattern" in reason: + violations.append(f"{mcp.name}: {reason}") + + if not violations: + return CheckResult( + name="mcp-denylist", + passed=True, + message="No MCP servers match deny list", + ) + return CheckResult( + name="mcp-denylist", + passed=False, + message=f"{len(violations)} MCP server(s) match deny list", + details=violations, + ) + + +def _check_mcp_transport( + mcp_deps: list, + policy: McpPolicy, +) -> CheckResult: + """Check 9: MCP transport values match policy allow list.""" + allowed_transports = policy.transport.allow + if allowed_transports is None: + return CheckResult( + name="mcp-transport", + passed=True, + message="No MCP transport restrictions configured", + ) + + violations: list[str] = [] + for mcp in mcp_deps: + if mcp.transport and mcp.transport not in allowed_transports: + violations.append( + f"{mcp.name}: transport '{mcp.transport}' not in allowed {allowed_transports}" + ) + + if not violations: + return CheckResult( + name="mcp-transport", + passed=True, + message="All MCP transports comply with policy", + ) + return CheckResult( + name="mcp-transport", + passed=False, + message=f"{len(violations)} MCP transport violation(s)", + details=violations, + ) + + +def _check_mcp_self_defined( + mcp_deps: list, + policy: McpPolicy, +) -> CheckResult: + """Check 10: self-defined MCP servers comply with policy.""" + self_defined_policy = policy.self_defined + if self_defined_policy == "allow": + return CheckResult( + name="mcp-self-defined", + passed=True, + message="Self-defined MCP servers allowed", + ) + + self_defined = [m for m in mcp_deps if m.registry is False] + if not self_defined: + return CheckResult( + name="mcp-self-defined", + passed=True, + message="No self-defined MCP servers found", + ) + + details = [f"{m.name}: self-defined server" for m in self_defined] + if self_defined_policy == "deny": + return CheckResult( + name="mcp-self-defined", + passed=False, + message=f"{len(self_defined)} self-defined MCP server(s) denied by policy", + details=details, + ) + # warn -- pass but with details + return CheckResult( + name="mcp-self-defined", + passed=True, + message=f"{len(self_defined)} self-defined MCP server(s) (warn)", + details=details, + ) + + +# --------------------------------------------------------------------------- +# Compilation checks (11-13) +# --------------------------------------------------------------------------- + + +def _check_compilation_target( + raw_yml: dict | None, + policy: CompilationPolicy, +) -> CheckResult: + """Check 11: compilation target matches policy.""" + enforce = policy.target.enforce + allow = policy.target.allow + + if not enforce and allow is None: + return CheckResult( + name="compilation-target", + passed=True, + message="No compilation target restrictions configured", + ) + + target = (raw_yml or {}).get("target") + if not target: + return CheckResult( + name="compilation-target", + passed=True, + message="No compilation target set in manifest", + ) + + # Normalize target to a list for uniform checking + target_list = target if isinstance(target, list) else [target] + + if enforce: + if enforce not in target_list: + return CheckResult( + name="compilation-target", + passed=False, + message=f"Enforced target '{enforce}' not present in {target_list}", + details=[f"target: {target}, enforced: {enforce}"], + ) + elif allow is not None: + allow_set = set(allow) if isinstance(allow, (list, tuple)) else {allow} + disallowed = [t for t in target_list if t not in allow_set] + if disallowed: + return CheckResult( + name="compilation-target", + passed=False, + message=f"Target(s) {disallowed} not in allowed list {sorted(allow_set)}", + details=[f"target: {target}, allowed: {sorted(allow_set)}"], + ) + + return CheckResult( + name="compilation-target", + passed=True, + message="Compilation target compliant", + ) + + +def _check_compilation_strategy( + raw_yml: dict | None, + policy: CompilationPolicy, +) -> CheckResult: + """Check 12: compilation strategy matches policy.""" + enforce = policy.strategy.enforce + if not enforce: + return CheckResult( + name="compilation-strategy", + passed=True, + message="No compilation strategy enforced", + ) + + compilation = (raw_yml or {}).get("compilation", {}) + strategy = compilation.get("strategy") if isinstance(compilation, dict) else None + if not strategy: + return CheckResult( + name="compilation-strategy", + passed=True, + message="No compilation strategy set in manifest", + ) + + if strategy != enforce: + return CheckResult( + name="compilation-strategy", + passed=False, + message=f"Strategy '{strategy}' does not match enforced '{enforce}'", + details=[f"strategy: {strategy}, enforced: {enforce}"], + ) + return CheckResult( + name="compilation-strategy", + passed=True, + message="Compilation strategy compliant", + ) + + +def _check_source_attribution( + raw_yml: dict | None, + policy: CompilationPolicy, +) -> CheckResult: + """Check 13: source attribution enabled if policy requires.""" + if not policy.source_attribution: + return CheckResult( + name="source-attribution", + passed=True, + message="Source attribution not required by policy", + ) + + compilation = (raw_yml or {}).get("compilation", {}) + attribution = compilation.get("source_attribution") if isinstance(compilation, dict) else None + if attribution is True: + return CheckResult( + name="source-attribution", + passed=True, + message="Source attribution enabled", + ) + return CheckResult( + name="source-attribution", + passed=False, + message="Source attribution required by policy but not enabled in manifest", + details=["Set compilation.source_attribution: true in apm.yml"], + ) + + +# --------------------------------------------------------------------------- +# Manifest checks (14-15 + explicit-includes) +# --------------------------------------------------------------------------- + + +def _check_required_manifest_fields( + raw_yml: dict | None, + policy: ManifestPolicy, +) -> CheckResult: + """Check 14: all required fields are present with non-empty values.""" + if not policy.required_fields: + return CheckResult( + name="required-manifest-fields", + passed=True, + message="No required manifest fields configured", + ) + + data = raw_yml or {} + missing: list[str] = [] + for field_name in policy.required_fields: + value = data.get(field_name) + if not value: # None, empty string, missing + missing.append(field_name) + + if not missing: + return CheckResult( + name="required-manifest-fields", + passed=True, + message="All required manifest fields present", + ) + return CheckResult( + name="required-manifest-fields", + passed=False, + message=f"{len(missing)} required manifest field(s) missing", + details=missing, + ) + + +def _check_includes_explicit( + manifest_includes, + policy: ManifestPolicy, +) -> CheckResult: + """Check: manifest declares an explicit ``includes:`` list when policy requires it. + + ``manifest_includes`` is the parsed value of the manifest's ``includes:`` + field as exposed by :class:`APMPackage` -- one of ``None`` (field + absent), the literal string ``"auto"``, or a list of repo-relative + path strings. + + Violation when ``policy.require_explicit_includes`` is True and + ``manifest_includes`` is ``None`` or ``"auto"``. + """ + if not policy.require_explicit_includes: + return CheckResult( + name="explicit-includes", + passed=True, + message="Explicit includes not required by policy", + ) + + if manifest_includes is None: + return CheckResult( + name="explicit-includes", + passed=False, + message=( + "Policy requires explicit 'includes:' paths but none are " + "declared. Add 'includes: [, ...]' to apm.yml with " + "the paths you intend to publish." + ), + details=[ + "includes: , require_explicit_includes: true", + ], + ) + + if manifest_includes == "auto": + return CheckResult( + name="explicit-includes", + passed=False, + message=( + "Policy requires explicit 'includes:' paths but manifest " + "uses 'includes: auto'. Replace with an explicit list of " + "paths." + ), + details=[ + "includes: 'auto', require_explicit_includes: true", + ], + ) + + return CheckResult( + name="explicit-includes", + passed=True, + message="Manifest declares explicit includes paths", + ) + + +def _check_scripts_policy( + raw_yml: dict | None, + policy: ManifestPolicy, +) -> CheckResult: + """Check 15: scripts section absent if policy denies it.""" + if policy.scripts != "deny": + return CheckResult( + name="scripts-policy", + passed=True, + message="Scripts allowed by policy", + ) + + scripts = (raw_yml or {}).get("scripts") + if scripts: + return CheckResult( + name="scripts-policy", + passed=False, + message="Scripts section present but denied by policy", + details=list(scripts.keys()) if isinstance(scripts, dict) else ["scripts"], + ) + return CheckResult( + name="scripts-policy", + passed=True, + message="No scripts section (compliant with deny policy)", + ) + + +# End of _policy_checks_mcp.py diff --git a/src/apm_cli/policy/ci_checks.py b/src/apm_cli/policy/ci_checks.py index 93ad7d7ff..ead8811b8 100644 --- a/src/apm_cli/policy/ci_checks.py +++ b/src/apm_cli/policy/ci_checks.py @@ -595,29 +595,19 @@ def _run(check: CheckResult) -> bool: result.checks.append(check) return fail_fast and not check.passed - # Check 2: Ref consistency - if _run(_check_ref_consistency(manifest, lock)): - return result - - # Check 3: Deployed files present - if _run(_check_deployed_files_present(project_root, lock)): - return result - - # Check 4: No orphaned packages - if _run(_check_no_orphans(manifest, lock)): - return result - - # Check 4.5: Skill subset consistency (manifest vs lockfile) - if _run(_check_skill_subset_consistency(manifest, lock)): - return result - - # Check 5: Config consistency (MCP) - if _run(_check_config_consistency(manifest, lock)): - return result - - # Check 6: Content integrity - if _run(_check_content_integrity(project_root, lock)): - return result + # Checks 2-6: ordered sequence; stop on first failure when fail_fast. + # Lambdas ensure lazy evaluation so expensive checks (content integrity) + # are skipped when an earlier check fails in fail_fast mode. + for _check_fn in [ + lambda: _check_ref_consistency(manifest, lock), # 2 + lambda: _check_deployed_files_present(project_root, lock), # 3 + lambda: _check_no_orphans(manifest, lock), # 4 + lambda: _check_skill_subset_consistency(manifest, lock), # 4.5 + lambda: _check_config_consistency(manifest, lock), # 5 + lambda: _check_content_integrity(project_root, lock), # 6 + ]: + if _run(_check_fn()): + return result # Check 7: Includes consent (advisory; never hard-fails) _run(_check_includes_consent(manifest, lock)) diff --git a/src/apm_cli/policy/discovery.py b/src/apm_cli/policy/discovery.py index 97e52f29c..0b9df1142 100644 --- a/src/apm_cli/policy/discovery.py +++ b/src/apm_cli/policy/discovery.py @@ -18,30 +18,92 @@ from __future__ import annotations import base64 -import hashlib -import json import logging import os import subprocess -import threading -import time -from dataclasses import dataclass, field +import threading # noqa: F401 +from dataclasses import dataclass from pathlib import Path from urllib.parse import urlparse import requests -import yaml from ..cache.url_normalize import SCP_LIKE_RE -from ..utils.path_security import PathTraversalError, ensure_path_within +from ._discovery_cache import ( + CACHE_SCHEMA_VERSION as CACHE_SCHEMA_VERSION, +) +from ._discovery_cache import ( + DEFAULT_CACHE_TTL as DEFAULT_CACHE_TTL, +) +from ._discovery_cache import ( + MAX_STALE_TTL as MAX_STALE_TTL, +) +from ._discovery_cache import ( + POLICY_CACHE_DIR as POLICY_CACHE_DIR, +) +from ._discovery_cache import ( + _cache_key as _cache_key, +) +from ._discovery_cache import ( + _CacheEntry as _CacheEntry, +) +from ._discovery_cache import ( + _compute_hash_normalized as _compute_hash_normalized, +) +from ._discovery_cache import ( + _detect_garbage as _detect_garbage, +) +from ._discovery_cache import ( + _get_cache_dir as _get_cache_dir, +) +from ._discovery_cache import ( + _is_policy_empty as _is_policy_empty, +) +from ._discovery_cache import ( + _policy_fingerprint as _policy_fingerprint, +) +from ._discovery_cache import ( + _policy_to_dict as _policy_to_dict, +) +from ._discovery_cache import ( + _read_cache as _read_cache, +) +from ._discovery_cache import ( + _read_cache_entry as _read_cache_entry, +) +from ._discovery_cache import ( + _serialize_policy as _serialize_policy, +) +from ._discovery_cache import ( + _split_hash_pin as _split_hash_pin, +) +from ._discovery_cache import ( + _stale_fallback_or_error as _stale_fallback_or_error, +) +from ._discovery_cache import ( + _verify_hash_pin as _verify_hash_pin, +) +from ._discovery_cache import ( + _write_cache as _write_cache, +) +from ._discovery_chain import ( + _derive_leaf_host as _derive_leaf_host, +) +from ._discovery_chain import ( + _extract_extends_host as _extract_extends_host, +) +from ._discovery_chain import ( + _resolve_and_persist_chain as _resolve_and_persist_chain, +) +from ._discovery_chain import ( + _strip_source_prefix as _strip_source_prefix, +) +from ._discovery_chain import ( + _validate_extends_host as _validate_extends_host, +) from .parser import PolicyValidationError, load_policy from .project_config import ( - _DEFAULT_HASH_ALGORITHM, - _HASH_HEX_LEN, - _HEX_RE, - ALLOWED_HASH_ALGORITHMS, ProjectPolicyConfigError, - compute_policy_hash, read_project_policy_hash_pin, ) from .schema import ApmPolicy @@ -49,119 +111,6 @@ logger = logging.getLogger(__name__) -def _split_hash_pin(expected_hash: str) -> tuple[str, str]: - """Split an ``":"`` pin into (algorithm, lowercase_hex). - - Bare hex (no prefix) is interpreted as sha256 for backwards - compatibility -- callers that care about the algorithm should pass a - fully-qualified pin. Raises :class:`ProjectPolicyConfigError` on a - structurally invalid pin (unsupported algorithm, wrong length, non - hex). The discovery helpers translate that into a fail-closed - ``hash_mismatch`` outcome rather than crashing. - """ - raw = expected_hash.strip() - if ":" in raw: - algo, _, hex_part = raw.partition(":") - algo = algo.strip().lower() - else: - algo = _DEFAULT_HASH_ALGORITHM - hex_part = raw - hex_part = hex_part.strip().lower() - if algo not in ALLOWED_HASH_ALGORITHMS: - raise ProjectPolicyConfigError(f"Unsupported policy.hash algorithm '{algo}'") - expected_len = _HASH_HEX_LEN[algo] - if len(hex_part) != expected_len or not _HEX_RE.match(hex_part): - raise ProjectPolicyConfigError(f"policy.hash is not a valid {algo} digest") - return algo, hex_part - - -def _compute_hash_normalized(content: str, expected_hash: str | None) -> str: - """Compute the digest of *content* under the algorithm declared by - *expected_hash*, returning the canonical ``":"`` form. - - When *expected_hash* is ``None`` the default algorithm (sha256) is - used so the cache always carries a digest for later pin verification. - """ - algo = _DEFAULT_HASH_ALGORITHM - if expected_hash: - try: - algo, _ = _split_hash_pin(expected_hash) - except ProjectPolicyConfigError: - algo = _DEFAULT_HASH_ALGORITHM - digest = compute_policy_hash(content, algo) - return f"{algo}:{digest}" - - -def _verify_hash_pin( - content: object, - expected_hash: str | None, - source_label: str, -) -> PolicyFetchResult | None: - """Verify fetched policy bytes against the project's pin (#827). - - Returns ``None`` when there is no pin, or the digest matches. On - mismatch -- or on a structurally invalid pin, which is treated as a - mismatch to stay fail-closed -- returns a :class:`PolicyFetchResult` - with ``outcome="hash_mismatch"`` that callers must propagate. The - hash is computed on the raw UTF-8 bytes that get parsed (matching - ``yaml.safe_load`` semantics) so a malicious mirror cannot bypass the - check by re-serializing semantically-equivalent YAML. - """ - if expected_hash is None: - return None - - raw_bytes: bytes - if isinstance(content, bytes): - raw_bytes = content - elif isinstance(content, str): - raw_bytes = content.encode("utf-8") - else: - return PolicyFetchResult( - outcome="hash_mismatch", - source=source_label, - error=( - f"Policy hash mismatch from {source_label}: " - "no content available to verify against pin" - ), - expected_hash=expected_hash, - ) - - try: - algo, expected_hex = _split_hash_pin(expected_hash) - except ProjectPolicyConfigError as exc: - return PolicyFetchResult( - outcome="hash_mismatch", - source=source_label, - error=(f"Policy hash mismatch from {source_label}: invalid pin ({exc})"), - expected_hash=expected_hash, - ) - - digest = hashlib.new(algo) - digest.update(raw_bytes) - actual_hex = digest.hexdigest().lower() - if actual_hex == expected_hex: - return None - - expected_norm = f"{algo}:{expected_hex}" - actual_norm = f"{algo}:{actual_hex}" - return PolicyFetchResult( - outcome="hash_mismatch", - source=source_label, - error=( - f"Policy hash mismatch from {source_label}: expected {expected_norm}, got {actual_norm}" - ), - expected_hash=expected_norm, - raw_bytes_hash=actual_norm, - ) - - -# Cache location: apm_modules/.policy-cache/.yml + .meta.json -POLICY_CACHE_DIR = ".policy-cache" -DEFAULT_CACHE_TTL = 3600 # 1 hour -MAX_STALE_TTL = 7 * 24 * 3600 # 7 days -- stale cache usable on refresh failure -CACHE_SCHEMA_VERSION = "3" # Bump when cache format changes to auto-invalidate - - @dataclass class PolicyFetchResult: """Result of a policy fetch attempt. @@ -245,17 +194,10 @@ def discover_policy_with_chain( is present. Outcome follows the 9-outcome matrix (section B). """ # -- Escape hatch (defence-in-depth) ------------------------------- - # The CLI's --no-policy flag is handled by callers; this env-var - # check stays so third-party use of the API still respects the - # global disable switch. if os.environ.get("APM_POLICY_DISABLE") == "1": return PolicyFetchResult(outcome="disabled") # -- Resolve project-side hash pin (#827) -------------------------- - # An explicit *expected_hash* argument always wins (test seam, future - # CLI override). Otherwise fall back to ``policy.hash`` in the - # project's apm.yml. A malformed pin surfaces as ``hash_mismatch`` - # rather than a crash so install fails closed with a clear error. if expected_hash is None: try: pin = read_project_policy_hash_pin(project_root) @@ -283,239 +225,6 @@ def discover_policy_with_chain( return fetch_result -def _strip_source_prefix(src: str) -> str: - """Strip 'org:' / 'url:' / 'file:' prefix from a PolicyFetchResult.source.""" - return src.removeprefix("org:").removeprefix("url:").removeprefix("file:") - - -def _derive_leaf_host(source: str, project_root: Path) -> str | None: - """Derive the origin host of the leaf policy. - - The leaf host pins which host an ``extends:`` reference may resolve - against (Security Finding F1 -- prevents credential leakage to - attacker-controlled hosts via cross-host extends chains). - - Returns the host in lowercase, or None if it cannot be determined. - - Source forms: - * ``url:https:///...`` -> ```` - * ``org://`` (3+ slash-segments) -> ```` - * ``org:/`` (2 slash-segments) -> ``github.com`` (default) - * ``file:`` -> fall back to git remote of *project_root* - """ - if not source: # noqa: SIM108 - bare = "" - else: - bare = _strip_source_prefix(source) - - if source.startswith("url:") or bare.startswith("https://") or bare.startswith("http://"): - try: - parsed = urlparse(bare) - if parsed.hostname: - return parsed.hostname.lower() - except Exception: - return None - return None - - if source.startswith("org:") or (bare and "://" not in bare and bare.count("/") >= 1): - parts = bare.split("/") - if len(parts) >= 3: - return parts[0].lower() - if len(parts) == 2: - # owner/repo shorthand defaults to github.com (matches - # _fetch_github_contents convention). - return "github.com" - - # File source (or unrecognized): fall back to project's git remote. - org_and_host = _extract_org_from_git_remote(project_root) - if org_and_host is not None: - _, host = org_and_host - if host: - return host.lower() - return None - - -def _extract_extends_host(ref: str) -> str | None: - """Return the host an ``extends:`` ref resolves against, if explicit. - - * Full URL -> URL host (lowercase) - * ``//`` (3+ slash-segments) -> ```` (lowercase) - * ``/`` shorthand -> None (intrinsically same-host) - * ```` shorthand (no slash) -> None (intrinsically same-host) - """ - if not ref: - return None - if ref.startswith("http://") or ref.startswith("https://"): - try: - parsed = urlparse(ref) - if parsed.hostname: - return parsed.hostname.lower() - except Exception: - return None - return None - if "/" not in ref: - return None - parts = ref.split("/") - if len(parts) >= 3: - return parts[0].lower() - return None - - -def _validate_extends_host(leaf_host: str | None, extends_ref: str) -> None: - """Reject ``extends:`` refs that point at a different host than the leaf. - - Raises :class:`PolicyInheritanceError` (imported lazily to avoid a - module-level cycle) when the ``extends:`` ref names a host that does - not match *leaf_host*. Pure shorthand refs (``owner/repo``, ``org``) - are intrinsically same-host and always pass. - - See Security Finding F1: a malicious org policy author setting - ``extends: "evil.example.com/org/.github"`` could otherwise route - ``git credential fill`` against an attacker-controlled host. - """ - from . import inheritance as _inheritance_mod - - extends_host = _extract_extends_host(extends_ref) - if extends_host is None: - return # shorthand: intrinsically same-host, allowed. - - if leaf_host is None: - raise _inheritance_mod.PolicyInheritanceError( - f"Policy extends: cross-host reference rejected " - f"(leaf host: , extends host: {extends_host}); " - f"cross-host policy chains are not allowed" - ) - - if extends_host != leaf_host.lower(): - raise _inheritance_mod.PolicyInheritanceError( - f"Policy extends: cross-host reference rejected " - f"(leaf host: {leaf_host}, extends host: {extends_host}); " - f"cross-host policy chains are not allowed" - ) - - -def _resolve_and_persist_chain( - fetch_result: PolicyFetchResult, - project_root: Path, -) -> None: - """Resolve inheritance chain and update cache with merged policy + chain_refs. - - Walks the ``extends:`` chain depth-first, fetching each parent via the - single-policy ``discover_policy`` (so each fetch still hits the - well-tested fetch path). Cycle detection on normalized ``extends:`` - refs and ``MAX_CHAIN_DEPTH`` enforcement protect against runaway or - self-referential chains. - - Partial-chain policy: if any parent fetch fails, emit a warning via - ``_rich_warning`` and merge whatever was resolved so far -- never - silently drop ancestors. - - Mutates *fetch_result*.policy in-place with the merged effective policy. - Called by :func:`discover_policy_with_chain` -- not intended for direct - use. - """ - from ..utils.console import _rich_warning - from . import inheritance as _inheritance_mod - - leaf_policy = fetch_result.policy - leaf_source = fetch_result.source - - # Host pin: extends: refs may only resolve against the leaf's origin - # host. Prevents credential leakage to attacker-controlled hosts via - # cross-host extends chains (Security Finding F1). - leaf_host = _derive_leaf_host(leaf_source, project_root) - - # Ordered ancestors collected as we walk parents. Built leaf-first - # for traversal convenience; reversed before merging. - chain_policies: list[ApmPolicy] = [leaf_policy] - chain_sources: list[str] = [leaf_source] - - # Track normalized refs we've already followed to break cycles. - # We seed with the leaf's source so an extends pointing back at the - # leaf is also detected. - visited: list[str] = [_strip_source_prefix(leaf_source)] if leaf_source else [] - - current = leaf_policy - partial_warning: tuple[str, int, int] | None = None - - while current.extends: - next_ref = current.extends - - # Host pin enforcement: must validate BEFORE any fetch so we never - # call git credential fill against an attacker-controlled host. - _validate_extends_host(leaf_host, next_ref) - - if _inheritance_mod.detect_cycle(visited, next_ref): - raise _inheritance_mod.PolicyInheritanceError( - f"Cycle detected in policy extends chain: {' -> '.join(visited)} -> {next_ref}" - ) - - # Depth check: chain_policies already has len() entries; next fetch - # would push us to len()+1. resolve_policy_chain enforces this - # afterwards, but failing here gives a clearer error. - if len(chain_policies) + 1 > _inheritance_mod.MAX_CHAIN_DEPTH: - raise _inheritance_mod.PolicyInheritanceError( - f"Policy chain depth exceeds maximum of " - f"{_inheritance_mod.MAX_CHAIN_DEPTH} " - f"(chain: {' -> '.join(visited)} -> {next_ref})" - ) - - parent_result = discover_policy( - project_root, - policy_override=next_ref, - no_cache=False, - ) - - if parent_result.policy is None: - # Parent fetch failed -- merge what we have so far and warn. - attempted = len(chain_policies) + 1 - resolved = len(chain_policies) - partial_warning = (next_ref, resolved, attempted) - break - - chain_policies.append(parent_result.policy) - chain_sources.append(parent_result.source) - visited.append(next_ref) - current = parent_result.policy - - # No actual ancestors fetched -- nothing to merge or re-cache. - if len(chain_policies) == 1: - if partial_warning is not None: - ref, resolved, attempted = partial_warning - _rich_warning( - f"Policy chain incomplete: {ref} unreachable, " - f"using {resolved} of {attempted} policies", - symbol="warning", - ) - return - - # Merge in [root, ..., leaf] order. We collected leaf-first, so reverse. - ordered = list(reversed(chain_policies)) - ordered_sources = list(reversed(chain_sources)) - - try: - merged = _inheritance_mod.resolve_policy_chain(ordered) - except _inheritance_mod.PolicyInheritanceError: - # Re-raise depth errors from the canonical validator so callers - # see a single consistent error type. - raise - - chain_refs: list[str] = [_strip_source_prefix(src) for src in ordered_sources if src] - - cache_key = _strip_source_prefix(leaf_source) if leaf_source else "" - if cache_key: - _write_cache(cache_key, merged, project_root, chain_refs=chain_refs) - - fetch_result.policy = merged - - if partial_warning is not None: - ref, resolved, attempted = partial_warning - _rich_warning( - f"Policy chain incomplete: {ref} unreachable, using {resolved} of {attempted} policies", - symbol="warning", - ) - - def discover_policy( project_root: Path, *, @@ -534,11 +243,6 @@ def discover_policy( -> fetch from that repo via GitHub Contents API 5. If policy_override is None -> auto-discover from project's git remote - The user-facing forms are documented in - ``apm_cli.policy._help_text.POLICY_SOURCE_FORMS_HELP``; that constant - is the single source of truth shared by ``apm audit --policy`` and - ``apm policy status --policy-source``. - The optional ``expected_hash`` (``":"``) pins the leaf policy bytes; mismatches return ``outcome="hash_mismatch"`` and must always be treated fail-closed by callers. @@ -575,9 +279,6 @@ def discover_policy( def _load_from_file(path: Path, *, expected_hash: str | None = None) -> PolicyFetchResult: """Load policy from a local file.""" try: - # Read raw bytes ourselves so we can verify the pin against the - # exact bytes that get parsed (matches the on-the-wire semantics - # used by the URL/repo fetchers). content = path.read_text(encoding="utf-8") except Exception as e: return PolicyFetchResult( @@ -613,12 +314,7 @@ def _auto_discover( no_cache: bool = False, expected_hash: str | None = None, ) -> PolicyFetchResult: - """Auto-discover policy from org's .github repo. - - 1. Run git remote get-url origin - 2. Parse org from URL - 3. Fetch /.github/apm-policy.yml - """ + """Auto-discover policy from org's .github repo.""" org_and_host = _extract_org_from_git_remote(project_root) if org_and_host is None: return PolicyFetchResult( @@ -664,18 +360,14 @@ def _parse_remote_url(url: str) -> tuple[str, str] | None: """Parse a git remote URL into (org, host). Accepts SCP-style SSH URLs with any username (not just ``git@``), so - EMU/GHE deployments that use a non-``git`` SSH user - (e.g. ``enterprise-user@ghe.corp.com:org/repo.git``) parse correctly. - Also handles Azure DevOps SSH URLs which carry an extra ``v3/`` - path prefix (``git@ssh.dev.azure.com:v3///``). + EMU/GHE deployments that use a non-``git`` SSH user parse correctly. + Also handles Azure DevOps SSH URLs (``v3/`` path prefix). Returns None if URL can't be parsed. """ if not url: return None - # SCP-like SSH: @: -- any user, not just `git`. - # Closes #1159 for non-`git` SSH users (EMU, custom GHE accounts). scp_match = SCP_LIKE_RE.match(url) if scp_match: host = scp_match.group("host") @@ -685,16 +377,12 @@ def _parse_remote_url(url: str) -> tuple[str, str] | None: parts = [p for p in parts if p] if not parts: return None - # Azure DevOps SSH carries a leading 'v3/' segment that is - # NOT the org. The org is the second segment. if host == "ssh.dev.azure.com" and parts[0] == "v3" and len(parts) >= 2: return (parts[1], host) return (parts[0], host) except (ValueError, IndexError): return None - # HTTPS: https://github.com/owner/repo.git - # ADO: https://dev.azure.com/org/project/_git/repo if "://" in url: try: parsed = urlparse(url) @@ -719,7 +407,6 @@ def _fetch_from_url( source_label = f"url:{url}" cache_entry: _CacheEntry | None = None - # Use URL as cache key if not no_cache: cache_entry = _read_cache_entry(url, project_root, expected_hash=expected_hash) if cache_entry is not None and not cache_entry.stale: @@ -746,9 +433,6 @@ def _fetch_from_url( outcome="absent", ) if 300 <= resp.status_code < 400: - # Redirects are refused: a malicious or compromised origin - # could otherwise bounce us to an attacker-controlled host - # (SSRF / Referer leakage). Treat as fetch failure. location = resp.headers.get("Location", "") fetch_error = f"Refusing HTTP redirect ({resp.status_code}) from {url} to {location}" elif resp.status_code != 200: @@ -767,15 +451,10 @@ def _fetch_from_url( cache_entry, fetch_error, source_label, "cache_miss_fetch_fail" ) - # Garbage-response detection: body must be valid YAML mapping garbage_result = _detect_garbage(content, url, source_label, cache_entry) if garbage_result is not None: return garbage_result - # Hash pin verification (#827) -- BEFORE parse, on raw bytes off wire. - # A mismatch is a hard failure regardless of cache_entry availability: - # falling back to a "good" cache when the pin doesn't match would mask - # exactly the compromise this pin is designed to catch. mismatch = _verify_hash_pin(content, expected_hash, source_label) if mismatch is not None: return mismatch @@ -791,13 +470,7 @@ def _fetch_from_url( chain_refs = [url] actual_hash = _compute_hash_normalized(content, expected_hash) - _write_cache( - url, - policy, - project_root, - chain_refs=chain_refs, - raw_bytes_hash=actual_hash, - ) + _write_cache(url, policy, project_root, chain_refs=chain_refs, raw_bytes_hash=actual_hash) outcome = "empty" if _is_policy_empty(policy) else "found" return PolicyFetchResult( policy=policy, @@ -839,21 +512,17 @@ def _fetch_from_repo( content, error = _fetch_github_contents(repo_ref, "apm-policy.yml") if error: - # 404 = no policy, not an error if "404" in error: return PolicyFetchResult(source=source_label, outcome="absent") - # Fetch failed -- try stale cache fallback return _stale_fallback_or_error(cache_entry, error, source_label, "cache_miss_fetch_fail") if content is None: return PolicyFetchResult(source=source_label, outcome="absent") - # Garbage-response detection garbage_result = _detect_garbage(content, repo_ref, source_label, cache_entry) if garbage_result is not None: return garbage_result - # Hash pin verification (#827) -- BEFORE parse, on raw bytes off wire. mismatch = _verify_hash_pin(content, expected_hash, source_label) if mismatch is not None: return mismatch @@ -869,13 +538,7 @@ def _fetch_from_repo( chain_refs = [repo_ref] actual_hash = _compute_hash_normalized(content, expected_hash) - _write_cache( - repo_ref, - policy, - project_root, - chain_refs=chain_refs, - raw_bytes_hash=actual_hash, - ) + _write_cache(repo_ref, policy, project_root, chain_refs=chain_refs, raw_bytes_hash=actual_hash) outcome = "empty" if _is_policy_empty(policy) else "found" return PolicyFetchResult( policy=policy, @@ -886,6 +549,58 @@ def _fetch_from_repo( ) +# --------------------------------------------------------------------------- +# GitHub API helpers -- decomposed to keep _fetch_github_contents <= 8 returns +# --------------------------------------------------------------------------- + + +def _parse_github_repo_ref(repo_ref: str) -> tuple[str, str, str] | None: + """Parse repo_ref into (host, owner, repo_path), or None if invalid.""" + parts = repo_ref.split("/") + if len(parts) == 2: + return ("github.com", parts[0], parts[1]) + if len(parts) >= 3: + return (parts[0], parts[1], "/".join(parts[2:])) + return None + + +def _decode_github_content(data: dict, repo_ref: str) -> tuple[str | None, str | None]: + """Decode GitHub API response body to (content_str, error_str).""" + if data.get("encoding") == "base64" and data.get("content"): + content = base64.b64decode(data["content"]).decode("utf-8") + return content, None + if data.get("content"): + return data["content"], None + return None, f"Unexpected response format from {repo_ref}" + + +def _call_github_api( + api_url: str, + headers: dict, + repo_ref: str, +) -> tuple[str | None, str | None]: + """Call GitHub Contents API and return (content_str, error_str).""" + try: + resp = requests.get(api_url, headers=headers, timeout=10, allow_redirects=False) + except requests.exceptions.Timeout: + return None, f"Timeout fetching policy from {repo_ref}" + except requests.exceptions.ConnectionError: + return None, f"Connection error fetching policy from {repo_ref}" + except Exception as e: + return None, f"Error fetching policy from {repo_ref}: {e}" + + if resp.status_code == 404: + return None, "404: Policy file not found" + if resp.status_code == 403: + return None, f"403: Access denied to {repo_ref}" + if 300 <= resp.status_code < 400: + location = resp.headers.get("Location", "") + return None, (f"Refusing HTTP redirect ({resp.status_code}) from {api_url} to {location}") + if resp.status_code != 200: + return None, f"HTTP {resp.status_code} fetching policy from {repo_ref}" + return _decode_github_content(resp.json(), repo_ref) + + def _fetch_github_contents( repo_ref: str, file_path: str, @@ -894,20 +609,11 @@ def _fetch_github_contents( Returns (content_string, error_string). One will be None. """ - - # Parse repo_ref: "owner/repo" or "host/owner/repo" - parts = repo_ref.split("/") - if len(parts) == 2: - host = "github.com" - owner, repo = parts - elif len(parts) >= 3: - host = parts[0] - owner = parts[1] - repo = "/".join(parts[2:]) - else: + parsed = _parse_github_repo_ref(repo_ref) + if parsed is None: return None, f"Invalid repo reference: {repo_ref}" - # Build API URL + host, owner, repo = parsed if host == "github.com": api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{file_path}" else: @@ -918,34 +624,7 @@ def _fetch_github_contents( if token: headers["Authorization"] = f"token {token}" - try: - resp = requests.get(api_url, headers=headers, timeout=10, allow_redirects=False) - if resp.status_code == 404: - return None, "404: Policy file not found" - if resp.status_code == 403: - return None, f"403: Access denied to {repo_ref}" - if 300 <= resp.status_code < 400: - location = resp.headers.get("Location", "") - return None, ( - f"Refusing HTTP redirect ({resp.status_code}) from {api_url} to {location}" - ) - if resp.status_code != 200: - return None, f"HTTP {resp.status_code} fetching policy from {repo_ref}" - - data = resp.json() - if data.get("encoding") == "base64" and data.get("content"): - content = base64.b64decode(data["content"]).decode("utf-8") - return content, None - elif data.get("content"): - return data["content"], None - else: - return None, f"Unexpected response format from {repo_ref}" - except requests.exceptions.Timeout: - return None, f"Timeout fetching policy from {repo_ref}" - except requests.exceptions.ConnectionError: - return None, f"Connection error fetching policy from {repo_ref}" - except Exception as e: - return None, f"Error fetching policy from {repo_ref}: {e}" + return _call_github_api(api_url, headers, repo_ref) def _is_github_host(host: str) -> bool: @@ -981,384 +660,3 @@ def _get_token_for_host(host: str) -> str | None: or os.environ.get("GH_TOKEN") ) return None - - -# -- Cache ---------------------------------------------------------- - - -@dataclass -class _CacheEntry: - """Internal representation of a cached policy read.""" - - policy: ApmPolicy - source: str - age_seconds: int - stale: bool # True if past TTL (but within MAX_STALE_TTL) - chain_refs: list[str] = field(default_factory=list) - fingerprint: str = "" - raw_bytes_hash: str = "" # ":" of leaf bytes off wire (#827) - - -def _get_cache_dir(project_root: Path) -> Path: - """Get the policy cache directory. - - Path-security guard (#832): the resulting path is asserted to live - within ``project_root``. This catches the edge case where - ``apm_modules`` itself is a symlink that points outside the - project root -- a configuration that, while unusual, would let - cache reads/writes escape the project tree. - """ - # Resolve early so candidate inherits long-name form on Windows; - # without this, resolve() on a not-yet-existing candidate keeps - # 8.3 short names while the base resolves to long names (#886). - project_root = project_root.resolve() - base = project_root / "apm_modules" - candidate = base / POLICY_CACHE_DIR - # Resolve both ends and assert containment under ``project_root``, - # not under ``base`` -- otherwise a symlinked apm_modules pointing - # outside the project would resolve through the symlink on both - # sides and the check would silently pass. - try: - ensure_path_within(candidate, project_root) - except PathTraversalError: - raise PathTraversalError( # noqa: B904 - f"Policy cache path '{candidate}' resolves outside " - f"project root '{project_root}' -- refusing to read or " - "write the cache here." - ) - return candidate - - -def _cache_key(repo_ref: str) -> str: - """Generate a deterministic cache filename from repo ref.""" - return hashlib.sha256(repo_ref.encode()).hexdigest()[:16] - - -def _policy_to_dict(policy: ApmPolicy) -> dict: - """Serialize an ApmPolicy to a dict matching the YAML schema.""" - - def _opt_list(val: tuple[str, ...] | None) -> list | None: - return None if val is None else list(val) - - return { - "name": policy.name, - "version": policy.version, - "enforcement": policy.enforcement, - "fetch_failure": policy.fetch_failure, - "cache": {"ttl": policy.cache.ttl}, - "dependencies": { - "allow": _opt_list(policy.dependencies.allow), - "deny": _opt_list(policy.dependencies.deny), - "require": _opt_list(policy.dependencies.require), - "require_resolution": policy.dependencies.require_resolution, - "max_depth": policy.dependencies.max_depth, - }, - "mcp": { - "allow": _opt_list(policy.mcp.allow), - "deny": list(policy.mcp.deny), - "transport": { - "allow": _opt_list(policy.mcp.transport.allow), - }, - "self_defined": policy.mcp.self_defined, - "trust_transitive": policy.mcp.trust_transitive, - }, - "compilation": { - "target": { - "allow": _opt_list(policy.compilation.target.allow), - "enforce": policy.compilation.target.enforce, - }, - "strategy": { - "enforce": policy.compilation.strategy.enforce, - }, - "source_attribution": policy.compilation.source_attribution, - }, - "manifest": { - "required_fields": list(policy.manifest.required_fields), - "scripts": policy.manifest.scripts, - "content_types": policy.manifest.content_types, - }, - "unmanaged_files": { - "action": policy.unmanaged_files.action, - "directories": list(policy.unmanaged_files.directories or ()), - }, - } - - -def _serialize_policy(policy: ApmPolicy) -> str: - """Serialize an ApmPolicy to deterministic YAML for caching.""" - return yaml.dump( - _policy_to_dict(policy), default_flow_style=False, sort_keys=True - ) # yaml-io-exempt - - -def _policy_fingerprint(serialized: str) -> str: - """Compute a fingerprint of a serialized policy.""" - return hashlib.sha256(serialized.encode("utf-8")).hexdigest()[:32] - - -def _is_policy_empty(policy: ApmPolicy) -> bool: - """Return True if a policy has no actionable restrictions. - - An 'empty' policy is syntactically valid but imposes no constraints - beyond the permissive defaults. - """ - return ( - not policy.dependencies.effective_deny - and policy.dependencies.allow is None - and not policy.dependencies.effective_require - and not policy.mcp.deny - and policy.mcp.allow is None - and policy.mcp.transport.allow is None - and policy.compilation.target.allow is None - and not policy.manifest.required_fields - and policy.manifest.scripts == "allow" - and policy.manifest.content_types is None - and policy.unmanaged_files.effective_action == "ignore" - ) - - -def _stale_fallback_or_error( - cache_entry: _CacheEntry | None, - fetch_error_msg: str, - source_label: str, - outcome_on_miss: str, -) -> PolicyFetchResult: - """Return stale cache if available, otherwise error with given outcome.""" - if cache_entry is not None: - return PolicyFetchResult( - policy=cache_entry.policy, - source=cache_entry.source, - cached=True, - cache_stale=True, - cache_age_seconds=cache_entry.age_seconds, - fetch_error=fetch_error_msg, - outcome="cached_stale", - ) - return PolicyFetchResult( - error=fetch_error_msg, - source=source_label, - fetch_error=fetch_error_msg, - outcome=outcome_on_miss, - ) - - -def _detect_garbage( - content: str | None, - identifier: str, - source_label: str, - cache_entry: _CacheEntry | None, -) -> PolicyFetchResult | None: - """Detect garbage responses (200 OK with non-YAML body). - - Returns a PolicyFetchResult if the content is garbage (stale fallback - or garbage_response outcome), or None if the content looks parseable. - """ - if content is None: - return None - - try: - raw_data = yaml.safe_load(content) - except yaml.YAMLError: - msg = f"Response from {identifier} is not valid YAML" - if cache_entry is not None: - return PolicyFetchResult( - policy=cache_entry.policy, - source=cache_entry.source, - cached=True, - cache_stale=True, - cache_age_seconds=cache_entry.age_seconds, - fetch_error=msg, - outcome="cached_stale", - ) - return PolicyFetchResult( - error=msg + " (possible captive portal or redirect)", - source=source_label, - fetch_error=msg, - outcome="garbage_response", - ) - - if raw_data is not None and not isinstance(raw_data, dict): - msg = f"Response from {identifier} is not a YAML mapping" - if cache_entry is not None: - return PolicyFetchResult( - policy=cache_entry.policy, - source=cache_entry.source, - cached=True, - cache_stale=True, - cache_age_seconds=cache_entry.age_seconds, - fetch_error=msg, - outcome="cached_stale", - ) - return PolicyFetchResult( - error=msg, - source=source_label, - fetch_error=msg, - outcome="garbage_response", - ) - - return None # Not garbage -- proceed with normal parsing - - -def _read_cache_entry( - repo_ref: str, - project_root: Path, - ttl: int = DEFAULT_CACHE_TTL, - *, - expected_hash: str | None = None, -) -> _CacheEntry | None: - """Read cache entry with stale-awareness. - - Returns: - * ``_CacheEntry(stale=False)`` -- within TTL, ready for immediate use - * ``_CacheEntry(stale=True)`` -- past TTL but within MAX_STALE_TTL - * ``None`` -- no cache file, corrupt, past MAX_STALE_TTL, - or pin verification failure (#827). - - When *expected_hash* is provided the cached ``raw_bytes_hash`` is - compared against it; a mismatch invalidates the cache entry so the - caller falls through to a fresh fetch where the pin can be verified - against authoritative bytes off the wire. - """ - cache_dir = _get_cache_dir(project_root) - key = _cache_key(repo_ref) - policy_file = cache_dir / f"{key}.yml" - meta_file = cache_dir / f"{key}.meta.json" - - if not policy_file.exists() or not meta_file.exists(): - return None - - try: - meta = json.loads(meta_file.read_text(encoding="utf-8")) - - # Schema version check -- auto-invalidate on format change - if meta.get("schema_version") != CACHE_SCHEMA_VERSION: - return None - - cached_at = meta.get("cached_at", 0) - age = int(time.time() - cached_at) - - if age > MAX_STALE_TTL: - return None # Past MAX_STALE_TTL, unusable - - raw_bytes_hash = meta.get("raw_bytes_hash", "") or "" - - # Pin verification (#827): if the project pinned a hash and the - # cache was written without one (legacy entry) or with a different - # one, ignore the cache so the fetcher can verify the pin against - # fresh authoritative bytes. - if expected_hash is not None: - try: - exp_algo, exp_hex = _split_hash_pin(expected_hash) - expected_norm = f"{exp_algo}:{exp_hex}" - except ProjectPolicyConfigError: - return None - if raw_bytes_hash.lower() != expected_norm: - return None - - policy, _warnings = load_policy(policy_file) - - # Determine source label - if repo_ref.startswith("http://") or repo_ref.startswith("https://"): - source = f"url:{repo_ref}" - else: - source = f"org:{repo_ref}" - - return _CacheEntry( - policy=policy, - source=source, - age_seconds=age, - stale=age > ttl, - chain_refs=meta.get("chain_refs", [repo_ref]), - fingerprint=meta.get("fingerprint", ""), - raw_bytes_hash=raw_bytes_hash, - ) - except Exception: - return None - - -def _read_cache( - repo_ref: str, - project_root: Path, - ttl: int = DEFAULT_CACHE_TTL, -) -> PolicyFetchResult | None: - """Read policy from cache if still valid (within TTL). - - Legacy wrapper around ``_read_cache_entry`` for backward compatibility. - Returns None if cache miss, expired, or past MAX_STALE_TTL. - """ - entry = _read_cache_entry(repo_ref, project_root, ttl=ttl) - if entry is None or entry.stale: - return None - outcome = "empty" if _is_policy_empty(entry.policy) else "found" - return PolicyFetchResult( - policy=entry.policy, - source=entry.source, - cached=True, - cache_age_seconds=entry.age_seconds, - outcome=outcome, - ) - - -def _write_cache( - repo_ref: str, - policy: ApmPolicy, - project_root: Path, - *, - chain_refs: list[str] | None = None, - raw_bytes_hash: str | None = None, -) -> None: - """Write merged effective policy and metadata to cache atomically. - - Uses temp file + ``os.replace()`` to prevent torn writes from parallel - installs. Both the policy file and metadata sidecar are written - atomically and independently. - - The optional ``raw_bytes_hash`` (canonical ``":"``) is the - digest of the leaf bytes off the wire and is persisted to the meta - sidecar so subsequent cached reads can verify against the project's - pin without re-fetching (#827). - """ - cache_dir = _get_cache_dir(project_root) - cache_dir.mkdir(parents=True, exist_ok=True) - - key = _cache_key(repo_ref) - policy_file = cache_dir / f"{key}.yml" - meta_file = cache_dir / f"{key}.meta.json" - - serialized = _serialize_policy(policy) - fingerprint = _policy_fingerprint(serialized) - - # Unique tmp suffix to avoid collisions from parallel writers - uid = f"{os.getpid()}.{threading.get_ident()}" - - # Atomic write: policy file - tmp_policy = cache_dir / f"{key}.{uid}.yml.tmp" - try: - tmp_policy.write_text(serialized, encoding="utf-8") - os.replace(str(tmp_policy), str(policy_file)) - except OSError: - # Best-effort cleanup - try: # noqa: SIM105 - tmp_policy.unlink(missing_ok=True) - except OSError: - pass - return - - # Atomic write: metadata sidecar - meta = { - "repo_ref": repo_ref, - "cached_at": time.time(), - "chain_refs": chain_refs if chain_refs is not None else [repo_ref], - "schema_version": CACHE_SCHEMA_VERSION, - "fingerprint": fingerprint, - "raw_bytes_hash": raw_bytes_hash or "", - } - tmp_meta = cache_dir / f"{key}.{uid}.meta.json.tmp" - try: - tmp_meta.write_text(json.dumps(meta), encoding="utf-8") - os.replace(str(tmp_meta), str(meta_file)) - except OSError: - try: # noqa: SIM105 - tmp_meta.unlink(missing_ok=True) - except OSError: - pass diff --git a/src/apm_cli/policy/policy_checks.py b/src/apm_cli/policy/policy_checks.py index 2b9554fb3..2f0d8a32b 100644 --- a/src/apm_cli/policy/policy_checks.py +++ b/src/apm_cli/policy/policy_checks.py @@ -9,63 +9,71 @@ import logging from pathlib import Path - +from typing import TYPE_CHECKING + +from ._policy_checks_mcp import ( + _check_compilation_strategy as _check_compilation_strategy, +) +from ._policy_checks_mcp import ( + _check_compilation_target as _check_compilation_target, +) +from ._policy_checks_mcp import ( + _check_includes_explicit as _check_includes_explicit, +) +from ._policy_checks_mcp import ( + _check_mcp_allowlist as _check_mcp_allowlist, +) +from ._policy_checks_mcp import ( + _check_mcp_denylist as _check_mcp_denylist, +) +from ._policy_checks_mcp import ( + _check_mcp_self_defined as _check_mcp_self_defined, +) +from ._policy_checks_mcp import ( + _check_mcp_transport as _check_mcp_transport, +) +from ._policy_checks_mcp import ( + _check_required_manifest_fields as _check_required_manifest_fields, +) +from ._policy_checks_mcp import ( + _check_scripts_policy as _check_scripts_policy, +) +from ._policy_checks_mcp import ( + _check_source_attribution as _check_source_attribution, +) +from ._policy_checks_mcp import ( + _load_raw_apm_yml as _load_raw_apm_yml, +) from .models import CheckResult, CIAuditResult -_logger = logging.getLogger(__name__) - - -# -- Helpers ------------------------------------------------------- - - -def _load_raw_apm_yml(project_root: Path) -> dict | None: - """Load raw apm.yml as a dict for policy checks that inspect raw fields. +if TYPE_CHECKING: + from ..deps.lockfile import LockFile + from .schema import ( + ApmPolicy, + DependencyPolicy, + DependencyReference, + RegistrySourcePolicy, + UnmanagedFilesPolicy, + ) - This helper is called **after** :pymethod:`APMPackage.from_apm_yml` has - already succeeded in :func:`run_policy_checks`. The primary security - gate is ``from_apm_yml()`` -- if it fails, the audit aborts with a - ``manifest-parse`` check result and this function is never reached. +_logger = logging.getLogger(__name__) - Returning ``None`` here is therefore **defence-in-depth**: it covers - edge cases (TOCTOU race, transient I/O error) where the file becomes - unreadable between the two calls. Callers that receive ``None`` - gracefully skip supplementary raw-field checks (e.g. - ``compilation-target``, ``extensions-present``) rather than hard-failing. +# -- Sentinel for "manifest_includes not provided" in run_dependency_policy_checks -- +_INCLUDES_NOT_PROVIDED = object() - Returns ``None`` when the file is absent, unreadable, malformed YAML, - or not a mapping -- but logs a warning so the failure is visible - rather than silently swallowed. - """ - import yaml +_DEFAULT_GOVERNANCE_DIRS = [ + ".github/agents", + ".github/instructions", + ".github/hooks", + ".cursor/rules", + ".claude", + ".opencode", +] - apm_yml_path = project_root / "apm.yml" - if not apm_yml_path.exists(): - return None - try: - with open(apm_yml_path, encoding="utf-8") as f: - data = yaml.safe_load(f) - except FileNotFoundError: - # TOCTOU: file disappeared between exists() check and open(); normal condition. - return None - except yaml.YAMLError as exc: - _logger.warning("Malformed YAML in %s: %s", apm_yml_path, exc) - return None - except OSError as exc: - _logger.warning("Cannot read %s: %s", apm_yml_path, exc) - return None - except UnicodeDecodeError as exc: - _logger.warning("Cannot decode %s as UTF-8: %s", apm_yml_path, exc) - return None - if not isinstance(data, dict): - _logger.warning( - "apm.yml is not a YAML mapping (got %s) -- skipping raw-field checks", - type(data).__name__, - ) - return None - return data +_MAX_UNMANAGED_SCAN_FILES = 10_000 -# -- Individual policy checks -------------------------------------- +# -- Individual policy checks (dependency cluster) ------------------------- def _check_dependency_allowlist( @@ -192,7 +200,6 @@ def _check_required_packages_deployed( if pkg_name not in dep_names: continue # not in manifest -- check 3 handles this - # Find in lockfile by exact key match locked = lock_by_name.get(pkg_name) if not locked or not locked.deployed_files: not_deployed.append(pkg_name) @@ -297,394 +304,124 @@ def _check_transitive_depth( ) -def _check_mcp_allowlist( - mcp_deps: list, - policy: McpPolicy, -) -> CheckResult: - """Check 7: MCP server names match allow list.""" - from .matcher import check_mcp_allowed - - if policy.allow is None: - return CheckResult( - name="mcp-allowlist", - passed=True, - message="No MCP allow list configured", - ) - - violations: list[str] = [] - for mcp in mcp_deps: - allowed, reason = check_mcp_allowed(mcp.name, policy) - if not allowed and "not in allowed" in reason: - violations.append(f"{mcp.name}: {reason}") - - if not violations: - return CheckResult( - name="mcp-allowlist", - passed=True, - message="All MCP servers match allow list", - ) - return CheckResult( - name="mcp-allowlist", - passed=False, - message=f"{len(violations)} MCP server(s) not in allow list", - details=violations, - ) - - -def _check_mcp_denylist( - mcp_deps: list, - policy: McpPolicy, +def _check_registry_source( + deps: list[DependencyReference], + policy: RegistrySourcePolicy, + registries_map: dict[str, str] | None, ) -> CheckResult: - """Check 8: no MCP server matches deny list.""" - from .matcher import check_mcp_allowed + """Check registry source policy (require / allow_non_registry). - if not policy.deny: - return CheckResult( - name="mcp-denylist", - passed=True, - message="No MCP deny list configured", - ) + Fail-closed when a required registry name has no URL configured in + *registries_map* -- that means the registry source is unreachable by + definition and the install must not proceed. + """ + check_name = "registry-source" + no_op = not policy.require and policy.allow_non_registry + if no_op: + return CheckResult(name=check_name, passed=True, message="No registry source policy") violations: list[str] = [] - for mcp in mcp_deps: - allowed, reason = check_mcp_allowed(mcp.name, policy) - if not allowed and "denied by pattern" in reason: - violations.append(f"{mcp.name}: {reason}") - if not violations: - return CheckResult( - name="mcp-denylist", - passed=True, - message="No MCP servers match deny list", - ) - return CheckResult( - name="mcp-denylist", - passed=False, - message=f"{len(violations)} MCP server(s) match deny list", - details=violations, - ) - - -def _check_mcp_transport( - mcp_deps: list, - policy: McpPolicy, -) -> CheckResult: - """Check 9: MCP transport values match policy allow list.""" - allowed_transports = policy.transport.allow - if allowed_transports is None: - return CheckResult( - name="mcp-transport", - passed=True, - message="No MCP transport restrictions configured", - ) - - violations: list[str] = [] - for mcp in mcp_deps: - if mcp.transport and mcp.transport not in allowed_transports: + # Fail-closed: required registry names must be configured. + for req_name in policy.require: + if not registries_map or req_name not in registries_map: violations.append( - f"{mcp.name}: transport '{mcp.transport}' not in allowed {allowed_transports}" + f"required registry '{req_name}' is not configured -- " + "add it to the 'registries:' block or via 'apm config set registry." + f"{req_name}.url '" ) - if not violations: - return CheckResult( - name="mcp-transport", - passed=True, - message="All MCP transports comply with policy", - ) - return CheckResult( - name="mcp-transport", - passed=False, - message=f"{len(violations)} MCP transport violation(s)", - details=violations, - ) - - -def _check_mcp_self_defined( - mcp_deps: list, - policy: McpPolicy, -) -> CheckResult: - """Check 10: self-defined MCP servers comply with policy.""" - self_defined_policy = policy.self_defined - if self_defined_policy == "allow": - return CheckResult( - name="mcp-self-defined", - passed=True, - message="Self-defined MCP servers allowed", - ) - - self_defined = [m for m in mcp_deps if m.registry is False] - if not self_defined: - return CheckResult( - name="mcp-self-defined", - passed=True, - message="No self-defined MCP servers found", - ) - - details = [f"{m.name}: self-defined server" for m in self_defined] - if self_defined_policy == "deny": - return CheckResult( - name="mcp-self-defined", - passed=False, - message=f"{len(self_defined)} self-defined MCP server(s) denied by policy", - details=details, - ) - # warn -- pass but with details - return CheckResult( - name="mcp-self-defined", - passed=True, - message=f"{len(self_defined)} self-defined MCP server(s) (warn)", - details=details, - ) - - -def _check_compilation_target( - raw_yml: dict | None, - policy: CompilationPolicy, -) -> CheckResult: - """Check 11: compilation target matches policy.""" - enforce = policy.target.enforce - allow = policy.target.allow - - if not enforce and allow is None: - return CheckResult( - name="compilation-target", - passed=True, - message="No compilation target restrictions configured", - ) - - target = (raw_yml or {}).get("target") - if not target: - return CheckResult( - name="compilation-target", - passed=True, - message="No compilation target set in manifest", - ) - - # Normalize target to a list for uniform checking - target_list = target if isinstance(target, list) else [target] + for dep in deps: + key = dep.get_canonical_dependency_string() + is_registry = getattr(dep, "source", None) == "registry" + registry_name = getattr(dep, "registry_name", None) - if enforce: - if enforce not in target_list: - return CheckResult( - name="compilation-target", - passed=False, - message=f"Enforced target '{enforce}' not present in {target_list}", - details=[f"target: {target}, enforced: {enforce}"], - ) - elif allow is not None: - allow_set = set(allow) if isinstance(allow, (list, tuple)) else {allow} - disallowed = [t for t in target_list if t not in allow_set] - if disallowed: - return CheckResult( - name="compilation-target", - passed=False, - message=f"Target(s) {disallowed} not in allowed list {sorted(allow_set)}", - details=[f"target: {target}, allowed: {sorted(allow_set)}"], + if not policy.allow_non_registry and not is_registry: + violations.append( + f"{key}: non-registry source not permitted (policy requires registry sources only)" ) + continue - return CheckResult( - name="compilation-target", - passed=True, - message="Compilation target compliant", - ) - - -def _check_compilation_strategy( - raw_yml: dict | None, - policy: CompilationPolicy, -) -> CheckResult: - """Check 12: compilation strategy matches policy.""" - enforce = policy.strategy.enforce - if not enforce: - return CheckResult( - name="compilation-strategy", - passed=True, - message="No compilation strategy enforced", - ) - - compilation = (raw_yml or {}).get("compilation", {}) - strategy = compilation.get("strategy") if isinstance(compilation, dict) else None - if not strategy: - return CheckResult( - name="compilation-strategy", - passed=True, - message="No compilation strategy set in manifest", - ) + if policy.require and is_registry and registry_name not in policy.require: + violations.append( + f"{key}: sourced from registry '{registry_name}' " + f"but policy requires one of {sorted(policy.require)}" + ) - if strategy != enforce: + if violations: return CheckResult( - name="compilation-strategy", + name=check_name, passed=False, - message=f"Strategy '{strategy}' does not match enforced '{enforce}'", - details=[f"strategy: {strategy}, enforced: {enforce}"], + message=f"{len(violations)} registry source violation(s)", + details=violations, ) return CheckResult( - name="compilation-strategy", + name=check_name, passed=True, - message="Compilation strategy compliant", - ) - - -def _check_source_attribution( - raw_yml: dict | None, - policy: CompilationPolicy, -) -> CheckResult: - """Check 13: source attribution enabled if policy requires.""" - if not policy.source_attribution: - return CheckResult( - name="source-attribution", - passed=True, - message="Source attribution not required by policy", - ) - - compilation = (raw_yml or {}).get("compilation", {}) - attribution = compilation.get("source_attribution") if isinstance(compilation, dict) else None - if attribution is True: - return CheckResult( - name="source-attribution", - passed=True, - message="Source attribution enabled", - ) - return CheckResult( - name="source-attribution", - passed=False, - message="Source attribution required by policy but not enabled in manifest", - details=["Set compilation.source_attribution: true in apm.yml"], + message="All dependencies satisfy registry source policy", ) -def _check_required_manifest_fields( - raw_yml: dict | None, - policy: ManifestPolicy, +def _check_pinned_constraints( + deps: list[DependencyReference], + policy: DependencyPolicy, + direct_dep_keys: set[str] | None = None, ) -> CheckResult: - """Check 14: all required fields are present with non-empty values.""" - if not policy.required_fields: - return CheckResult( - name="required-manifest-fields", - passed=True, - message="No required manifest fields configured", - ) - - data = raw_yml or {} - missing: list[str] = [] - for field_name in policy.required_fields: - value = data.get(field_name) - if not value: # None, empty string, missing - missing.append(field_name) - - if not missing: - return CheckResult( - name="required-manifest-fields", - passed=True, - message="All required manifest fields present", - ) - return CheckResult( - name="required-manifest-fields", - passed=False, - message=f"{len(missing)} required manifest field(s) missing", - details=missing, - ) - - -_INCLUDES_NOT_PROVIDED = object() + """Check: every direct dep declares a bounded constraint. + Skipped (passes vacuously) when + ``policy.require_pinned_constraint`` is ``False`` -- the default. -def _check_includes_explicit( - manifest_includes, - policy: ManifestPolicy, -) -> CheckResult: - """Check: manifest declares an explicit ``includes:`` list when policy requires it. + Operates on the **declared** constraint (``dep.reference``), not + the resolved one, so authors learn before the install completes + that a moving ref slipped past review. - ``manifest_includes`` is the parsed value of the manifest's ``includes:`` - field as exposed by :class:`APMPackage` -- one of ``None`` (field - absent), the literal string ``"auto"``, or a list of repo-relative - path strings. + When ``direct_dep_keys`` is provided, the check is restricted to + direct dependencies only -- transitives are excluded, since the + consumer cannot rewrite a constraint declared in a transitive + package's own manifest. - Violation when ``policy.require_explicit_includes`` is True and - ``manifest_includes`` is ``None`` or ``"auto"``. + See ``_constraint_pinning.py`` for classification rules. """ - if not policy.require_explicit_includes: - return CheckResult( - name="explicit-includes", - passed=True, - message="Explicit includes not required by policy", - ) - - if manifest_includes is None: - return CheckResult( - name="explicit-includes", - passed=False, - message=( - "Policy requires explicit 'includes:' paths but none are " - "declared. Add 'includes: [, ...]' to apm.yml with " - "the paths you intend to publish." - ), - details=[ - "includes: , require_explicit_includes: true", - ], - ) + from ._constraint_pinning import classify_unbounded_reason, humanize_reason - if manifest_includes == "auto": + check_name = "dependency-pinned-constraint" + if not policy.require_pinned_constraint: return CheckResult( - name="explicit-includes", - passed=False, - message=( - "Policy requires explicit 'includes:' paths but manifest " - "uses 'includes: auto'. Replace with an explicit list of " - "paths." - ), - details=[ - "includes: 'auto', require_explicit_includes: true", - ], + name=check_name, + passed=True, + message="Pinned-constraint requirement disabled", ) - return CheckResult( - name="explicit-includes", - passed=True, - message="Manifest declares explicit includes paths", - ) - + violations: list[str] = [] + for dep in deps: + if direct_dep_keys is not None and dep.get_unique_key() not in direct_dep_keys: + continue + reason = classify_unbounded_reason(dep) + if reason is None: + continue + key = dep.get_canonical_dependency_string() + hint = humanize_reason(reason, dep) + violations.append(f"{key}: {hint}") -def _check_scripts_policy( - raw_yml: dict | None, - policy: ManifestPolicy, -) -> CheckResult: - """Check 15: scripts section absent if policy denies it.""" - if policy.scripts != "deny": + if not violations: return CheckResult( - name="scripts-policy", + name=check_name, passed=True, - message="Scripts allowed by policy", + message="All dependencies use pinned constraints", ) - scripts = (raw_yml or {}).get("scripts") - if scripts: - return CheckResult( - name="scripts-policy", - passed=False, - message="Scripts section present but denied by policy", - details=list(scripts.keys()) if isinstance(scripts, dict) else ["scripts"], - ) return CheckResult( - name="scripts-policy", - passed=True, - message="No scripts section (compliant with deny policy)", + name=check_name, + passed=False, + message=( + f"{len(violations)} dependency(ies) use unbounded constraints " + "(hint: pin to a semver range, literal tag, or SHA)" + ), + details=violations, ) -_DEFAULT_GOVERNANCE_DIRS = [ - ".github/agents", - ".github/instructions", - ".github/hooks", - ".cursor/rules", - ".claude", - ".opencode", -] - - -_MAX_UNMANAGED_SCAN_FILES = 10_000 - - def _check_unmanaged_files( project_root: Path, lock: LockFile | None, @@ -700,7 +437,6 @@ def _check_unmanaged_files( dirs = policy.directories if policy.directories else _DEFAULT_GOVERNANCE_DIRS - # Build set of deployed files AND directory prefixes from lockfile deployed: set = set() deployed_dir_prefixes: list = [] if lock: @@ -763,7 +499,6 @@ def _check_unmanaged_files( details=unmanaged, ) - # action == "deny" return CheckResult( name="unmanaged-files", passed=False, @@ -772,129 +507,6 @@ def _check_unmanaged_files( ) -def _check_registry_source( - deps: list[DependencyReference], - policy: RegistrySourcePolicy, - registries_map: dict[str, str] | None, -) -> CheckResult: - """Check registry source policy (require / allow_non_registry). - - Fail-closed when a required registry name has no URL configured in - *registries_map* — that means the registry source is unreachable by - definition and the install must not proceed. - """ - check_name = "registry-source" - no_op = not policy.require and policy.allow_non_registry - if no_op: - return CheckResult(name=check_name, passed=True, message="No registry source policy") - - violations: list[str] = [] - - # Fail-closed: required registry names must be configured. - for req_name in policy.require: - if not registries_map or req_name not in registries_map: - violations.append( - f"required registry '{req_name}' is not configured — " - "add it to the 'registries:' block or via 'apm config set registry." - f"{req_name}.url '" - ) - - for dep in deps: - key = dep.get_canonical_dependency_string() - is_registry = getattr(dep, "source", None) == "registry" - registry_name = getattr(dep, "registry_name", None) - - if not policy.allow_non_registry and not is_registry: - violations.append( - f"{key}: non-registry source not permitted (policy requires registry sources only)" - ) - continue - - if policy.require and is_registry and registry_name not in policy.require: - violations.append( - f"{key}: sourced from registry '{registry_name}' " - f"but policy requires one of {sorted(policy.require)}" - ) - - if violations: - return CheckResult( - name=check_name, - passed=False, - message=f"{len(violations)} registry source violation(s)", - details=violations, - ) - return CheckResult( - name=check_name, - passed=True, - message="All dependencies satisfy registry source policy", - ) - - -def _check_pinned_constraints( - deps: list[DependencyReference], - policy: DependencyPolicy, - direct_dep_keys: set[str] | None = None, -) -> CheckResult: - """Check: every direct dep declares a bounded constraint. - - Skipped (passes vacuously) when - ``policy.require_pinned_constraint`` is ``False`` -- the default. - - Operates on the **declared** constraint (``dep.reference``), not - the resolved one, so authors learn before the install completes - that a moving ref slipped past review. - - When ``direct_dep_keys`` is provided, the check is restricted to - direct dependencies only -- transitives are excluded, since the - consumer cannot rewrite a constraint declared in a transitive - package's own manifest. Callers that have direct-vs-transitive - context (the install pipeline gate, the target-aware re-check, - and the install preflight) should always pass it. When ``None`` - (legacy dep-only seam, or the audit wrapper that already iterates - direct-only manifest deps) the check falls back to evaluating - every dep in ``deps``. - - See ``_constraint_pinning.py`` for classification rules. - """ - from ._constraint_pinning import classify_unbounded_reason, humanize_reason - - check_name = "dependency-pinned-constraint" - if not policy.require_pinned_constraint: - return CheckResult( - name=check_name, - passed=True, - message="Pinned-constraint requirement disabled", - ) - - violations: list[str] = [] - for dep in deps: - if direct_dep_keys is not None and dep.get_unique_key() not in direct_dep_keys: - continue - reason = classify_unbounded_reason(dep) - if reason is None: - continue - key = dep.get_canonical_dependency_string() - hint = humanize_reason(reason, dep) - violations.append(f"{key}: {hint}") - - if not violations: - return CheckResult( - name=check_name, - passed=True, - message="All dependencies use pinned constraints", - ) - - return CheckResult( - name=check_name, - passed=False, - message=( - f"{len(violations)} dependency(ies) use unbounded constraints " - "(hint: pin to a semver range, literal tag, or SHA)" - ), - details=violations, - ) - - # -- Aggregate runners --------------------------------------------- @@ -934,28 +546,22 @@ def run_dependency_policy_checks( ``policy.mcp``. effective_target: The post-targets-phase compilation target string, or ``None``. - When ``None`` target/compilation checks are **skipped** (they - belong to the separate W2-target-aware call). + When ``None`` target/compilation checks are **skipped**. fetch_outcome: - Human-readable label for diagnostic context (e.g. - ``"cached"``, ``"fetched"``). Currently informational only. + Human-readable label for diagnostic context. Currently + informational only. fail_fast: Stop after the first failing check (default ``True``). manifest_includes: The parsed value of the manifest's ``includes:`` field (``None``, ``"auto"``, or a list of paths). When omitted, - the ``explicit-includes`` check is skipped -- callers that - do not have manifest information available (e.g. dep-only - seams) can leave it unset. + the ``explicit-includes`` check is skipped. direct_dep_keys: Optional set of ``DependencyReference.get_unique_key()`` for the direct (manifest-declared) deps. When supplied, the ``require_pinned_constraint`` check only evaluates direct - deps -- transitive entries are excluded because the consumer - cannot rewrite a constraint declared inside a transitive - package's own manifest. When ``None`` (legacy dep-only seam - and the audit wrapper that already iterates direct-only - manifest deps) every dep in ``deps_to_install`` is evaluated. + deps -- transitive entries are excluded. When ``None`` every + dep in ``deps_to_install`` is evaluated. Returns ------- @@ -963,17 +569,6 @@ def run_dependency_policy_checks( Contains individual :class:`CheckResult` entries. The caller decides how to map ``enforcement`` level (block vs warn) onto these results. - - Notes - ----- - ``require_resolution: project-wins`` semantics (rubber-duck I7): - version-pin mismatches are downgraded to warnings; missing required - packages still block; inherited org deny still wins. This is - handled inside ``_check_required_package_version`` which already - reads ``policy.dependencies.require_resolution``. - - Does **not** load ``apm.yml`` from disk -- the caller supplies the - resolved dep set directly. """ result = CIAuditResult() deps_list = list(deps_to_install) @@ -1003,16 +598,8 @@ def _run(check: CheckResult) -> bool: # stays within the max-returns threshold. remaining_checks: list[CheckResult] = [ _check_registry_source(deps_list, policy.registry_source, registries), - # Pinned-constraint: property check on declared refs. Cheap - # (O(N) string classification, no I/O) so it always runs. - # When direct_dep_keys is supplied, restrict to direct deps -- a - # transitive package with an unbounded ref in its own manifest is - # not actionable by the consumer (see Copilot review on #1494). _check_pinned_constraints(deps_list, policy.dependencies, direct_dep_keys), ] - # MCP checks -- when mcp_deps is None (not provided), skip entirely. - # When mcp_deps is an empty list (provided but no MCP deps), still - # run MCP checks so they report "no X configured" for completeness. if mcp_deps is not None: remaining_checks += [ _check_mcp_allowlist(mcp_list, policy.mcp), @@ -1020,7 +607,6 @@ def _run(check: CheckResult) -> bool: _check_mcp_transport(mcp_list, policy.mcp), _check_mcp_self_defined(mcp_list, policy.mcp), ] - # Target / compilation + manifest tail checks. if effective_target is not None: synthetic_yml = {"target": effective_target} remaining_checks.append(_check_compilation_target(synthetic_yml, policy.compilation)) @@ -1030,13 +616,6 @@ def _run(check: CheckResult) -> bool: if _run(check): return result - # NOTE: compilation strategy, source attribution, manifest fields, - # scripts policy, and unmanaged files are disk-level / manifest-level - # concerns. They are NOT included in the resolved-dep seam because - # the install pipeline does not have the raw manifest at this point - # and they are already covered by the full ``run_policy_checks`` - # wrapper that ``apm audit --ci`` calls. - return result @@ -1064,7 +643,6 @@ def run_policy_checks( result = CIAuditResult() - # Load manifest apm_yml_path = project_root / "apm.yml" if not apm_yml_path.exists(): return result @@ -1073,62 +651,43 @@ def run_policy_checks( if manifest is None: return result - # Load lockfile (optional -- some checks work without it) lockfile_path = get_lockfile_path(project_root) lock = LockFile.read(lockfile_path) if lockfile_path.exists() else None - - # Load raw YAML for field-level checks raw_yml = _load_raw_apm_yml(project_root) - # Get dependencies from manifest (disk view) apm_deps = manifest.get_apm_dependencies() mcp_deps = manifest.get_mcp_dependencies() - # Read effective target from raw manifest for the full-project path - # NOTE: the wrapper does NOT pass effective_target to the dep seam. - # Target checks run as disk-level checks below (reading raw_yml), - # because the wrapper has the on-disk manifest. The install pipeline - # will pass effective_target directly (W2-target-aware). - - # -- Delegate dependency + MCP checks to shared seam --------------- dep_result = run_dependency_policy_checks( apm_deps, lockfile=lock, policy=policy, mcp_deps=mcp_deps, - # effective_target=None: target checks handled below from raw_yml fail_fast=fail_fast, manifest_includes=manifest.includes, registries=getattr(manifest, "registries", None), ) result.checks.extend(dep_result.checks) - # Early exit if dep checks already failed in fail-fast mode if fail_fast and not dep_result.passed: return result def _run(check: CheckResult) -> bool: - """Append check and return True if fail-fast should stop.""" result.checks.append(check) return fail_fast and not check.passed - # -- Disk-level checks that only apply to full-project audits -- - - # Compilation checks (11-13) -- all run from raw_yml in wrapper - if _run(_check_compilation_target(raw_yml, policy.compilation)): - return result - if _run(_check_compilation_strategy(raw_yml, policy.compilation)): - return result - if _run(_check_source_attribution(raw_yml, policy.compilation)): - return result - - # Manifest checks (14-15) - if _run(_check_required_manifest_fields(raw_yml, policy.manifest)): - return result - if _run(_check_scripts_policy(raw_yml, policy.manifest)): - return result + # Disk-level checks: compilation (11-13), manifest (14-15), unmanaged (16). + # Eager evaluation is safe -- these check functions read dict/policy only, + # no side effects except _check_unmanaged_files which must stay last. + for check in [ + _check_compilation_target(raw_yml, policy.compilation), + _check_compilation_strategy(raw_yml, policy.compilation), + _check_source_attribution(raw_yml, policy.compilation), + _check_required_manifest_fields(raw_yml, policy.manifest), + _check_scripts_policy(raw_yml, policy.manifest), + ]: + if _run(check): + return result - # Unmanaged files check (16) _run(_check_unmanaged_files(project_root, lock, policy.unmanaged_files)) - return result From 3434eb3f850a38dc408b971e3da47d5c1ee1dc54 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Sat, 6 Jun 2026 08:43:54 +0200 Subject: [PATCH 20/21] refactor: enforce Stage 2 complexity + file-length thresholds (#1078) Final Strangler Stage 2 commit: flip the guardrails now that all violations are fixed (commits for integration, deps, commands, models, core/adapters, policy, marketplace landed earlier in this PR). pyproject.toml [tool.ruff.lint]: tighten to Stage 2 targets max-statements 200->120, max-args 15->12, max-branches 60->40, max-returns 12->8, mccabe max-complexity 50->35; refresh the stale Stage-1 roadmap comments. .github/workflows/ci.yml: file-length guardrail MAX_LINES 2100->800; refresh the stale comment. CHANGELOG.md: record the threshold + guardrail tightening under Unreleased. tests/integration/test_deps_resolver_{phase3b,resolution}.py: the new max-args=12 surfaced a 15-arg `_make_dep_ref` test factory in both files. Dropped the never-passed `explicit_scheme` param and grouped the three cohesive Azure DevOps coordinates into a single `ado` triple (15->12 args) - a parameter-object simplification, not an arg-count dodge. Verification (CI-mirror, all green): ruff check src/ tests/; ruff format --check; file-length guardrail (no src file >800); R0801 10.00/10; auth-signals; YAML and relative_to guards; full unit+acceptance 16605 passed; targeted integration green. Whole-src >800 backlog is 0. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- CHANGELOG.md | 5 +++++ pyproject.toml | 16 ++++++++-------- .../integration/test_deps_resolver_phase3b.py | 18 ++++++++---------- .../test_deps_resolver_resolution.py | 18 ++++++++---------- 5 files changed, 30 insertions(+), 29 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9981343a3..6fd4ac918 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: run: | # Ruff has no max-module-lines rule. This check prevents new files from # exceeding the current worst case. Tighten the threshold over time. - MAX_LINES=2100 # Stage 1 (was 2450); target 1400 deferred to Stage 2 + MAX_LINES=800 # Stage 2 (issue #1078); Stage 1 was 2100 VIOLATIONS=$(find src/ -name '*.py' -print0 | xargs -0 -I{} awk -v max="$MAX_LINES" \ 'END { if (NR > max) printf "%s: %d lines (max %d)\n", FILENAME, NR, max }' {}) if [ -n "$VIOLATIONS" ]; then diff --git a/CHANGELOG.md b/CHANGELOG.md index 795425092..585980c2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 making token misconfiguration diagnosable without adding noise to normal output. Enable with ``apm --verbose`` or ``APM_LOG_LEVEL=DEBUG``. (by @danielmeppiel, closes #935, #1664) +- Tightened Stage 2 code-complexity thresholds (`max-statements` 120, + `max-branches` 40, `max-complexity` 35, `max-args` 12, `max-returns` 8) and + lowered the source file-length guardrail to 800 lines, splitting the remaining + oversized `policy/` and `marketplace/` modules into focused submodules with no + behaviour change. (#1681) ## [0.18.0] - 2026-06-04 diff --git a/pyproject.toml b/pyproject.toml index 0d0fc3d8d..9d1d4aa3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,16 +94,16 @@ ignore = [ # High initial thresholds set just above current codebase maximums. # Prevents new code from exceeding the worst existing violations. # Tighten these over time via dedicated refactoring PRs. -# Stage 1 thresholds (PR #1464, issue #1077). -# Roadmap: Stage 2 targets max-complexity<=20, max-branches<=25 (McCabe standard). -max-statements = 200 # Stage 1 (was 275) -max-args = 15 # Stage 1 (was 18) -max-branches = 60 # Stage 1 (was 115) -max-returns = 12 # Stage 1 (was 18) +# Stage 2 thresholds (issue #1078). Stage 1 (#1077/#1464) values shown for history. +# A future stage may tighten further toward the McCabe standard (<=20 complexity). +max-statements = 120 # Stage 2 (Stage 1 was 200) +max-args = 12 # Stage 2 (Stage 1 was 15) +max-branches = 40 # Stage 2 (Stage 1 was 60) +max-returns = 8 # Stage 2 (Stage 1 was 12) [tool.ruff.lint.mccabe] -# Stage 1 (was 100). Stage 2 target: <=20 (McCabe industry standard). -max-complexity = 50 +# Stage 2 (issue #1078). Stage 1 was 50. A future stage may target <=20. +max-complexity = 35 [tool.ruff.lint.per-file-ignores] # Subprocess calls are intentional in a CLI tool diff --git a/tests/integration/test_deps_resolver_phase3b.py b/tests/integration/test_deps_resolver_phase3b.py index 228c1c9bd..80bd0a16b 100644 --- a/tests/integration/test_deps_resolver_phase3b.py +++ b/tests/integration/test_deps_resolver_phase3b.py @@ -42,16 +42,17 @@ def _make_dep_ref( is_local: bool = False, local_path: str | None = None, is_insecure: bool = False, - ado_organization: str | None = None, - ado_project: str | None = None, - ado_repo: str | None = None, alias: str | None = None, + ado: tuple[str | None, str | None, str | None] = (None, None, None), is_parent_repo_inheritance: bool = False, - explicit_scheme: str | None = None, ) -> Any: - """Build a DependencyReference instance without network calls.""" + """Build a DependencyReference instance without network calls. + + ``ado`` groups the Azure DevOps (organization, project, repo) triple. + """ from apm_cli.models.dependency.reference import DependencyReference + ado_organization, ado_project, ado_repo = ado return DependencyReference( repo_url=repo_url, host=host, @@ -62,12 +63,11 @@ def _make_dep_ref( is_local=is_local, local_path=local_path, is_insecure=is_insecure, + alias=alias, ado_organization=ado_organization, ado_project=ado_project, ado_repo=ado_repo, - alias=alias, is_parent_repo_inheritance=is_parent_repo_inheritance, - explicit_scheme=explicit_scheme, ) @@ -176,9 +176,7 @@ def test_azure_devops_returns_false(self) -> None: dep = _make_dep_ref( host="dev.azure.com", repo_url="org/project/repo", - ado_organization="org", - ado_project="project", - ado_repo="repo", + ado=("org", "project", "repo"), ) assert dl._is_generic_dependency_host(dep) is False diff --git a/tests/integration/test_deps_resolver_resolution.py b/tests/integration/test_deps_resolver_resolution.py index 228c1c9bd..80bd0a16b 100644 --- a/tests/integration/test_deps_resolver_resolution.py +++ b/tests/integration/test_deps_resolver_resolution.py @@ -42,16 +42,17 @@ def _make_dep_ref( is_local: bool = False, local_path: str | None = None, is_insecure: bool = False, - ado_organization: str | None = None, - ado_project: str | None = None, - ado_repo: str | None = None, alias: str | None = None, + ado: tuple[str | None, str | None, str | None] = (None, None, None), is_parent_repo_inheritance: bool = False, - explicit_scheme: str | None = None, ) -> Any: - """Build a DependencyReference instance without network calls.""" + """Build a DependencyReference instance without network calls. + + ``ado`` groups the Azure DevOps (organization, project, repo) triple. + """ from apm_cli.models.dependency.reference import DependencyReference + ado_organization, ado_project, ado_repo = ado return DependencyReference( repo_url=repo_url, host=host, @@ -62,12 +63,11 @@ def _make_dep_ref( is_local=is_local, local_path=local_path, is_insecure=is_insecure, + alias=alias, ado_organization=ado_organization, ado_project=ado_project, ado_repo=ado_repo, - alias=alias, is_parent_repo_inheritance=is_parent_repo_inheritance, - explicit_scheme=explicit_scheme, ) @@ -176,9 +176,7 @@ def test_azure_devops_returns_false(self) -> None: dep = _make_dep_ref( host="dev.azure.com", repo_url="org/project/repo", - ado_organization="org", - ado_project="project", - ado_repo="repo", + ado=("org", "project", "repo"), ) assert dl._is_generic_dependency_host(dep) is False From 0f29213bd4505fa2f8de6379570dcc1566b1fd73 Mon Sep 17 00:00:00 2001 From: Sergio Sisternes Date: Tue, 9 Jun 2026 15:41:59 +0100 Subject: [PATCH 21/21] fix(lint): dedupe _skill_subset_name_filter to clear R0801 The #1709 port left two copies of the subset-normalization helper: a module function in skill_deploy.py and an identical SkillIntegrator staticmethod. CI's pylint R0801 guard (evaluated on the merge commit) flagged the 15-line duplication. Make the staticmethod delegate to the skill_deploy module function, matching the file's delegator pattern. R0801 now exits 0; ruff/format clean; 378 targeted tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/apm_cli/integration/skill_integrator.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/apm_cli/integration/skill_integrator.py b/src/apm_cli/integration/skill_integrator.py index 870256329..d5340d3b0 100644 --- a/src/apm_cli/integration/skill_integrator.py +++ b/src/apm_cli/integration/skill_integrator.py @@ -158,21 +158,7 @@ def _copy_promoted_skill_tree(sub_skill_path: Path, target: Path) -> None: @staticmethod def _skill_subset_name_filter(skill_subset: tuple[str, ...] | None) -> set[str] | None: """Return promotion filter tokens for --skill subset values.""" - if not skill_subset: - return None - - name_filter: set[str] = set() - for skill_name in skill_subset: - raw_name = str(skill_name).strip() - if not raw_name: - continue - normalized_path = raw_name.replace("\\", "/") - leaf_name = Path(normalized_path).name - name_filter.add(raw_name) - name_filter.add(normalized_path) - if leaf_name: - name_filter.add(leaf_name) - return name_filter or None + return _skill_deploy._skill_subset_name_filter(skill_subset) @staticmethod def _promote_sub_skills(