From aca8c8386206e440cf08bcab50ff484e7d494101 Mon Sep 17 00:00:00 2001 From: Stevengre Date: Mon, 1 Jun 2026 18:24:34 +0800 Subject: [PATCH] Add progressive depth-halving policy to Prover.advance_proof Implements runtimeverification/k#4924: factor kontrol's `--per-depth-timeout` mechanism out of tool-specific code and into a generic policy in `Prover.advance_proof`. When `per_depth_timeout` is set, each step gets a stall window of `current_depth * per_depth_timeout` seconds to commit. A step that exceeds its window is interrupted, the step depth is halved (floor 1), and the node is retried at the shallower depth; at the minimum depth the proof stops and stays pending. Provers that do not expose a tunable depth are unaffected, and the path without `per_depth_timeout` is unchanged. - Prover: generic no-op hooks get_step_depth/set_step_depth/interrupt - APRProver: overrides them (execute_depth + kcfg_explore.interrupt) - KCFGExplore/CTermSymbolic/KoreClient/JsonRpc*/Transport: interrupt() that force-unblocks an in-flight single-socket request and reconnects - Unit tests covering halving, stop-at-floor, fast-path, and disabled --- pyk/src/pyk/cterm/symbolic.py | 4 + pyk/src/pyk/kcfg/explore.py | 4 + pyk/src/pyk/kore/rpc.py | 46 +++++++ pyk/src/pyk/proof/proof.py | 112 ++++++++++++++--- pyk/src/pyk/proof/reachability.py | 9 ++ pyk/src/tests/unit/test_advance_proof.py | 151 +++++++++++++++++++++++ 6 files changed, 306 insertions(+), 20 deletions(-) create mode 100644 pyk/src/tests/unit/test_advance_proof.py diff --git a/pyk/src/pyk/cterm/symbolic.py b/pyk/src/pyk/cterm/symbolic.py index cef9c9f8dbf..5929af07811 100644 --- a/pyk/src/pyk/cterm/symbolic.py +++ b/pyk/src/pyk/cterm/symbolic.py @@ -93,6 +93,10 @@ def kast_to_kore(self, kinner: KInner) -> Pattern: def kore_to_kast(self, pattern: Pattern) -> KInner: return kore_to_kast(self._definition, pattern) + def interrupt(self) -> None: + """Abort a backend request currently in flight on another thread; see `KoreClient.interrupt`.""" + self._kore_client.interrupt() + def execute( self, cterm: CTerm, diff --git a/pyk/src/pyk/kcfg/explore.py b/pyk/src/pyk/kcfg/explore.py index 5562132e3e2..9cae78177a5 100644 --- a/pyk/src/pyk/kcfg/explore.py +++ b/pyk/src/pyk/kcfg/explore.py @@ -59,6 +59,10 @@ def _pretty_printer(self) -> PrettyPrinter: def pretty_print(self, kinner: KInner) -> str: return self._pretty_printer.print(kinner) + def interrupt(self) -> None: + """Abort a backend request currently in flight on another thread; see `KoreClient.interrupt`.""" + self.cterm_symbolic.interrupt() + def _extract_rule_labels(self, _logs: tuple[LogEntry, ...]) -> list[str]: _rule_lines = [] for node_log in _logs: diff --git a/pyk/src/pyk/kore/rpc.py b/pyk/src/pyk/kore/rpc.py index 77ef641a8f3..0aba2d6f9c1 100644 --- a/pyk/src/pyk/kore/rpc.py +++ b/pyk/src/pyk/kore/rpc.py @@ -76,6 +76,15 @@ def __exit__(self, *args: Any) -> None: @abstractmethod def close(self) -> None: ... + def interrupt(self) -> None: + """Abort a request that is currently in flight on another thread. + + After `interrupt()` returns, a thread blocked in `request()` must raise promptly and + the transport must remain usable for subsequent requests. The default implementation + does nothing; transports backed by an interruptible connection should override it. + """ + ... + @abstractmethod def _request(self, req: str) -> str: ... @@ -92,6 +101,7 @@ class TransportType(Enum): class SingleSocketTransport(Transport): _host: str _port: int + _timeout: int | None _sock: socket.socket _file: IO[str] @@ -104,6 +114,7 @@ def __init__( ): self._host = host self._port = port + self._timeout = timeout self._sock = self._create_connection(host, port, timeout) self._file = self._sock.makefile('r') @@ -131,6 +142,23 @@ def close(self) -> None: self._file.close() self._sock.close() + def interrupt(self) -> None: + # Shutting down the socket unblocks a thread currently blocked in `readline`, causing + # its read to raise. We then reconnect so the transport stays usable for later requests. + # The old socket is closed; the old file object is left to be reclaimed by the garbage + # collector to avoid racing a `close()` against the unwinding reader thread. + old_sock = self._sock + try: + old_sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + self._sock = self._create_connection(self._host, self._port, self._timeout) + self._file = self._sock.makefile('r') + try: + old_sock.close() + except OSError: + pass + def _request(self, req: str) -> str: self._sock.sendall(req.encode()) server_addr = self._description() @@ -235,6 +263,12 @@ def close(self) -> None: for client in clients: client.close() + def interrupt(self) -> None: + self._default_client.interrupt() + for clients in self._clients.values(): + for client in clients: + client.interrupt() + def request(self, method: str, **params: Any) -> dict[str, Any]: if method in self._clients: for client in self._clients[method]: @@ -289,6 +323,9 @@ def __exit__(self, *args: Any) -> None: def close(self) -> None: self._transport.close() + def interrupt(self) -> None: + self._transport.interrupt() + def request(self, method: str, **params: Any) -> dict[str, Any]: req_id = f'{id(self)}-{self._req_id:03}' self._req_id += 1 @@ -918,6 +955,15 @@ def __exit__(self, *args: Any) -> None: def close(self) -> None: self._client.close() + def interrupt(self) -> None: + """Abort an `execute`/`simplify`/… request currently in flight on another thread. + + After this returns the interrupted call raises and the client stays usable. Only + effective for the single-socket transport; a no-op for transports that cannot abort + an in-flight request. + """ + self._client.interrupt() + def _request(self, method: str, **params: Any) -> dict[str, Any]: try: return self._client.request(method, **params) diff --git a/pyk/src/pyk/proof/proof.py b/pyk/src/pyk/proof/proof.py index 0d1eef04d2e..b39cad5fe0d 100644 --- a/pyk/src/pyk/proof/proof.py +++ b/pyk/src/pyk/proof/proof.py @@ -3,7 +3,9 @@ import json import logging from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor, wait +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError +from concurrent.futures import wait from dataclasses import dataclass from enum import Enum from itertools import chain @@ -499,6 +501,30 @@ def init_proof(self, proof: P) -> None: """ ... + def get_step_depth(self) -> int | None: + """Return the current per-step exploration depth, if this prover exposes a tunable one. + + Returning `None` (the default) opts the prover out of the progressive depth-halving + policy in `advance_proof`. Provers with a tunable execution depth (e.g. `APRProver`) + should override this together with `set_step_depth`. + """ + return None + + def set_step_depth(self, depth: int) -> None: + """Set the per-step exploration depth. No-op by default; see `get_step_depth`.""" + ... + + def interrupt(self) -> None: + """Abort a `step_proof` call currently running on another thread, as quickly as possible. + + Used by the progressive depth-halving policy in `advance_proof` to abandon a step that + has exhausted its stall window. After this returns, a thread blocked in `step_proof` must + raise promptly and the prover must remain usable for subsequent steps. The default + implementation does nothing; provers backed by an interruptible resource (e.g. a Kore RPC + connection) should override it. + """ + ... + def advance_proof( self, proof: P, @@ -506,6 +532,7 @@ def advance_proof( fail_fast: bool = False, callback: Callable[[P], None] = (lambda x: None), maintenance_rate: int = 1, + per_depth_timeout: float | None = None, ) -> None: """Advance a proof. @@ -518,29 +545,74 @@ def advance_proof( halt execution even if there are still available steps. callback: Callable to run in between each completed step, useful for getting real-time information about the proof. maintenance_rate: Number of iterations between proof maintenance (writing to disk and executing callback). + per_depth_timeout (optional): Enables progressive depth halving when set to a positive value. + Each step is given a stall window of `current_depth * per_depth_timeout` seconds (where + `current_depth` is the prover's `get_step_depth()`) to commit its result. If a step does not + finish within its window, it is interrupted, the step depth is halved (down to a floor of 1), + and the step is retried at the shallower depth. Has no effect for provers whose + `get_step_depth()` returns `None`. """ iterations = 0 _LOGGER.info(f'Initializing proof: {proof.id}') self.init_proof(proof) - while True: - steps = list(proof.get_steps()) - _LOGGER.info(f'Found {len(steps)} next steps for proof: {proof.id}') - if len(steps) == 0: - break - for step in steps: - if fail_fast and proof.failed: - _LOGGER.warning(f'Terminating proof early because fail_fast is set: {proof.id}') - proof.failure_info = self.failure_info(proof) - return - if max_iterations is not None and max_iterations <= iterations: - return - iterations += 1 - results = self.step_proof(step) - for result in results: - proof.commit(result) - if iterations % maintenance_rate == 0: - proof.write_proof_data() - callback(proof) + + progressive = per_depth_timeout is not None and per_depth_timeout > 0 and self.get_step_depth() is not None + executor = ThreadPoolExecutor(max_workers=1) if progressive else None + try: + while True: + steps = list(proof.get_steps()) + _LOGGER.info(f'Found {len(steps)} next steps for proof: {proof.id}') + if len(steps) == 0: + break + halved_depth = False + for step in steps: + if fail_fast and proof.failed: + _LOGGER.warning(f'Terminating proof early because fail_fast is set: {proof.id}') + proof.failure_info = self.failure_info(proof) + return + if max_iterations is not None and max_iterations <= iterations: + return + if progressive: + assert executor is not None + assert per_depth_timeout is not None + depth = self.get_step_depth() + assert depth is not None + window = max(depth, 1) * per_depth_timeout + future = executor.submit(self.step_proof, step) + try: + results = future.result(timeout=window) + except FuturesTimeoutError: + # The step exhausted its stall window: interrupt it, halve the depth, + # and re-fetch steps so the same node is retried at the shallower depth. + self.interrupt() + wait([future]) + new_depth = max(1, depth // 2) + if new_depth >= depth: + _LOGGER.warning( + f'Proof {proof.id}: step exhausted {window}s stall window at minimum ' + f'depth {depth}; stopping.' + ) + return + _LOGGER.warning( + f'Proof {proof.id}: step exhausted {window}s stall window at depth {depth}; ' + f'halving to {new_depth}.' + ) + self.set_step_depth(new_depth) + halved_depth = True + break + else: + results = self.step_proof(step) + iterations += 1 + for result in results: + proof.commit(result) + if iterations % maintenance_rate == 0: + proof.write_proof_data() + callback(proof) + if halved_depth: + continue + finally: + if executor is not None: + executor.shutdown(wait=False) if proof.failed: proof.failure_info = self.failure_info(proof) diff --git a/pyk/src/pyk/proof/reachability.py b/pyk/src/pyk/proof/reachability.py index 928fe5c25a1..4437feb8376 100644 --- a/pyk/src/pyk/proof/reachability.py +++ b/pyk/src/pyk/proof/reachability.py @@ -755,6 +755,15 @@ def __init__( def close(self) -> None: self.kcfg_explore.cterm_symbolic._kore_client.close() + def get_step_depth(self) -> int | None: + return self.execute_depth + + def set_step_depth(self, depth: int) -> None: + self.execute_depth = depth + + def interrupt(self) -> None: + self.kcfg_explore.interrupt() + def init_proof(self, proof: APRProof) -> None: main_module_name = self.main_module_name if self.extra_module: diff --git a/pyk/src/tests/unit/test_advance_proof.py b/pyk/src/tests/unit/test_advance_proof.py new file mode 100644 index 00000000000..bfdcbf02f3c --- /dev/null +++ b/pyk/src/tests/unit/test_advance_proof.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from threading import Event +from typing import TYPE_CHECKING + +from pyk.proof.proof import Proof, ProofStatus, Prover + +if TYPE_CHECKING: + from collections.abc import Mapping + from pathlib import Path + from typing import Any + + +class _StepInterrupted(Exception): + """Raised inside `step_proof` when the prover is interrupted, mimicking a backend abort.""" + + +class CountingProof(Proof[int, int]): + """Minimal proof that needs `target` committed steps to pass.""" + + target: int + committed: int + + def __init__(self, id: str, target: int) -> None: + super().__init__(id) + self.target = target + self.committed = 0 + + def commit(self, result: int) -> None: + self.committed += result + + @property + def own_status(self) -> ProofStatus: + return ProofStatus.PASSED if self.committed >= self.target else ProofStatus.PENDING + + @property + def can_progress(self) -> bool: + return self.committed < self.target + + @classmethod + def from_dict(cls: type[CountingProof], dct: Mapping[str, Any], proof_dir: Path | None = None) -> CountingProof: + raise NotImplementedError + + def write_proof_data(self) -> None: ... + + def get_steps(self) -> list[int]: + return [self.committed] if self.can_progress else [] + + +class CountingProver(Prover[CountingProof, int, int]): + """Prover whose `step_proof` stalls (blocks until interrupted) while `depth` exceeds `quick_at_depth`. + + Tracks the number of interruptions so tests can assert how many times the depth was halved. + """ + + depth: int + quick_at_depth: int + interrupt_count: int + _interrupt_event: Event + + def __init__(self, depth: int, quick_at_depth: int) -> None: + self.depth = depth + self.quick_at_depth = quick_at_depth + self.interrupt_count = 0 + self._interrupt_event = Event() + + def close(self) -> None: ... + + def failure_info(self, proof: CountingProof) -> Any: + return None + + def init_proof(self, proof: CountingProof) -> None: ... + + def get_step_depth(self) -> int | None: + return self.depth + + def set_step_depth(self, depth: int) -> None: + self.depth = depth + + def interrupt(self) -> None: + self.interrupt_count += 1 + self._interrupt_event.set() + + def step_proof(self, step: int) -> list[int]: + self._interrupt_event.clear() + if self.depth > self.quick_at_depth: + # Stall until `advance_proof` interrupts us once the stall window elapses. + if self._interrupt_event.wait(timeout=10.0): + raise _StepInterrupted() + raise AssertionError('step_proof was not interrupted within 10s') + return [1] + + +PER_DEPTH_TIMEOUT = 0.02 + + +def test_advance_proof_halves_depth_until_progress() -> None: + # Given: depth 8 stalls, but a step completes once depth drops to <= 2. + proof = CountingProof('counting', target=1) + prover = CountingProver(depth=8, quick_at_depth=2) + + # When + prover.advance_proof(proof, per_depth_timeout=PER_DEPTH_TIMEOUT) + + # Then: 8 -> 4 -> 2 (two halvings, two interrupts), then a step commits and the proof passes. + assert proof.status == ProofStatus.PASSED + assert prover.depth == 2 + assert prover.interrupt_count == 2 + + +def test_advance_proof_stops_at_minimum_depth_when_never_progressing() -> None: + # Given: every step stalls regardless of depth. + proof = CountingProof('counting', target=1) + prover = CountingProver(depth=4, quick_at_depth=0) + + # When + prover.advance_proof(proof, per_depth_timeout=PER_DEPTH_TIMEOUT) + + # Then: depth halves 4 -> 2 -> 1, then stops at the floor; the proof stays pending. + assert proof.status == ProofStatus.PENDING + assert proof.committed == 0 + assert prover.depth == 1 + assert prover.interrupt_count == 3 + + +def test_advance_proof_no_halving_when_steps_are_fast() -> None: + # Given: progressive policy enabled but steps always complete in time. + proof = CountingProof('counting', target=3) + prover = CountingProver(depth=8, quick_at_depth=8) + + # When + prover.advance_proof(proof, per_depth_timeout=PER_DEPTH_TIMEOUT) + + # Then: no interruptions, depth untouched, proof passes. + assert proof.status == ProofStatus.PASSED + assert prover.depth == 8 + assert prover.interrupt_count == 0 + + +def test_advance_proof_without_per_depth_timeout_is_unaffected() -> None: + # Given: no per_depth_timeout -> classic in-loop behavior, no watchdog thread. + proof = CountingProof('counting', target=2) + prover = CountingProver(depth=8, quick_at_depth=8) + + # When + prover.advance_proof(proof) + + # Then + assert proof.status == ProofStatus.PASSED + assert prover.depth == 8 + assert prover.interrupt_count == 0