diff --git a/pyk/src/pyk/cli/pyk.py b/pyk/src/pyk/cli/pyk.py index 8a23469ed8..75a9ecc1ed 100644 --- a/pyk/src/pyk/cli/pyk.py +++ b/pyk/src/pyk/cli/pyk.py @@ -316,6 +316,7 @@ class ProveOptions(LoggingOptions, SpecOptions, SaveDirOptions): kore_rpc_command: str | Iterable[str] | None max_depth: int | None max_iterations: int | None + step_timeout: int | None assume_defined: bool show_kcfg: bool haskell_logging: bool @@ -329,6 +330,7 @@ def default() -> dict[str, Any]: 'kore_rpc_command': None, 'max_depth': None, 'max_iterations': None, + 'step_timeout': None, 'assume_defined': False, 'show_kcfg': False, 'haskell_logging': False, @@ -515,6 +517,17 @@ def create_argument_parser() -> ArgumentParser: type=int, help='Maximum number of KCFG explorations to take in attempting to discharge proof.', ) + prove_args.add_argument( + '--step-timeout', + dest='step_timeout', + type=int, + default=None, + help=( + 'Per-step wall-clock budget in whole seconds (floored at 1). When a symbolic-execution step' + ' exceeds it, the step is interrupted, its execution depth is halved, and it is retried;' + ' proving stops once the depth cannot be reduced further. Omit to disable the timeout.' + ), + ) prove_args.add_argument( '--kore-rpc-command', dest='kore_rpc_command', diff --git a/pyk/src/pyk/cterm/symbolic.py b/pyk/src/pyk/cterm/symbolic.py index af8f9a4e36..66d178dcb3 100644 --- a/pyk/src/pyk/cterm/symbolic.py +++ b/pyk/src/pyk/cterm/symbolic.py @@ -127,6 +127,10 @@ def kast_to_kore(self, kinner: KInner) -> Pattern: def kore_to_kast(self, pattern: Pattern) -> KInner: return kore_to_kast(self._definition, pattern) + def interrupt(self) -> None: + """Abort a backend request currently in flight on another thread; see `KoreClient.interrupt`.""" + self._kore_client.interrupt() + def _haskell_logging_request(self, haskell_logging: bool | None) -> tuple[str, ...] | None: """Resolve the per-call on/off flag to the list of log entries to request. diff --git a/pyk/src/pyk/kcfg/explore.py b/pyk/src/pyk/kcfg/explore.py index 1eef51588a..81716fdd4b 100644 --- a/pyk/src/pyk/kcfg/explore.py +++ b/pyk/src/pyk/kcfg/explore.py @@ -59,6 +59,10 @@ def _pretty_printer(self) -> PrettyPrinter: def pretty_print(self, kinner: KInner) -> str: return self._pretty_printer.print(kinner) + def interrupt(self) -> None: + """Abort a backend request currently in flight on another thread; see `KoreClient.interrupt`.""" + self.cterm_symbolic.interrupt() + def _extract_rule_labels(self, _logs: tuple[LogEntry, ...]) -> list[str]: _rule_lines = [] for node_log in _logs: diff --git a/pyk/src/pyk/kore/rpc.py b/pyk/src/pyk/kore/rpc.py index e4e0355883..099c86839f 100644 --- a/pyk/src/pyk/kore/rpc.py +++ b/pyk/src/pyk/kore/rpc.py @@ -85,6 +85,15 @@ def __exit__(self, *args: Any) -> None: @abstractmethod def close(self) -> None: ... + def send_interrupt(self, data: str) -> None: + """Send `data` on the live connection without waiting for a reply. + + Used to deliver a `cancel` to a connection whose reply another thread is already + awaiting. Default: no-op. Override only for connections that can be written to + while a request is in flight. + """ + ... + @abstractmethod def _request(self, req: str) -> str: ... @@ -101,6 +110,7 @@ class TransportType(Enum): class SingleSocketTransport(Transport): _host: str _port: int + _timeout: int | None _sock: socket.socket _file: IO[str] @@ -113,6 +123,7 @@ def __init__( ): self._host = host self._port = port + self._timeout = timeout self._sock = self._create_connection(host, port, timeout) self._file = self._sock.makefile('r') @@ -141,6 +152,12 @@ def close(self) -> None: self._file.close() self._sock.close() + def send_interrupt(self, data: str) -> None: + # Write the cancel to the socket but don't read the reply: the thread already blocked in + # `_request`'s `readline` will read the server's "cancelled" reply. The socket stays open. + # The leading newline separates the cancel from the request bytes already on the stream. + self._sock.sendall(b'\n' + data.encode()) + def _request(self, req: str) -> str: self._sock.sendall(req.encode()) server_addr = self._description() @@ -252,6 +269,12 @@ def last_request_id(self) -> str | None: """The JSON-RPC id of the most recent request issued through this facade (``None`` if none yet).""" return self._last_request_id + def interrupt(self) -> None: + self._default_client.interrupt() + for clients in self._clients.values(): + for client in clients: + client.interrupt() + def request(self, method: str, **params: Any) -> dict[str, Any]: if method in self._clients: for client in self._clients[method]: @@ -316,6 +339,19 @@ def __exit__(self, *args: Any) -> None: def close(self) -> None: self._transport.close() + def interrupt(self) -> None: + # Send a `cancel` so the server aborts the in-flight request. The cancel gets no reply of its + # own; the thread awaiting that request gets a "cancelled" error instead. The id is only for + # logs, so we derive it from the last request rather than touching the requester's `_req_id`. + cancel_id = f'{self._last_request_id}-cancel' if self._last_request_id is not None else 'cancel' + payload = { + 'jsonrpc': self._JSON_RPC_VERSION, + 'id': cancel_id, + 'method': 'cancel', + 'params': {}, + } + self._transport.send_interrupt(json.dumps(payload)) + def request(self, method: str, **params: Any) -> dict[str, Any]: label = client_label.get() prefix = label if label is not None else str(id(self)) @@ -1014,6 +1050,15 @@ def last_request_id(self) -> str | None: """ return self._client.last_request_id + def interrupt(self) -> None: + """Abort an `execute`/`simplify`/… request running on another thread. + + Sends a `cancel` so the server stops computing; the interrupted call raises a + "cancelled" error and the connection stays usable. Works on the single-socket + transport only; a no-op for HTTP (one connection per request, nothing to cancel). + """ + self._client.interrupt() + def _request(self, method: str, **params: Any) -> dict[str, Any]: try: return self._client.request(method, **params) diff --git a/pyk/src/pyk/proof/proof.py b/pyk/src/pyk/proof/proof.py index 0d1eef04d2..3edc8d45d3 100644 --- a/pyk/src/pyk/proof/proof.py +++ b/pyk/src/pyk/proof/proof.py @@ -3,7 +3,9 @@ import json import logging from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor, wait +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError +from concurrent.futures import wait from dataclasses import dataclass from enum import Enum from itertools import chain @@ -499,6 +501,33 @@ def init_proof(self, proof: P) -> None: """ ... + #: Per-step wall-clock budget in whole seconds (minimum 1). When set, `advance_proof` runs each + #: step under this budget and, on timeout, interrupts it, calls `shrink_step`, and retries. + #: `None` (the default) disables the policy, so steps run synchronously with no time limit. A + #: prover that can do less work per step (e.g. `APRProver`, by lowering its execution depth) + #: should pair this with `shrink_step`. + step_timeout: int | None = None + + def shrink_step(self) -> bool: + """Reduce the amount of work a single `step_proof` does, after a step timed out. + + Return `True` if the step was made smaller (so it is worth retrying), or `False` if the + step is already at its minimum size (so `advance_proof` should stop). The default is a + no-op that returns `False`; see `step_timeout`. + """ + return False + + def interrupt(self) -> None: + """Abort a `step_proof` call currently running on another thread, as quickly as possible. + + Used by the timeout-and-shrink policy in `advance_proof` to abandon a step that has + exhausted its `step_timeout` budget. After this returns, a thread blocked in `step_proof` + must raise promptly and the prover must remain usable for subsequent steps. The default + implementation does nothing; provers backed by an interruptible resource (e.g. a Kore RPC + connection) should override it. + """ + ... + def advance_proof( self, proof: P, @@ -509,7 +538,8 @@ def advance_proof( ) -> None: """Advance a proof. - Performs loop `Proof.get_steps()` -> `Prover.step_proof()` -> `Proof.commit()`. + Performs loop `Proof.get_steps()` -> `Prover.step_proof()` (within `step_timeout`, else + `interrupt` + `shrink_step` and retry, or stop if it cannot shrink) -> `Proof.commit()`. Args: proof: proof to advance. @@ -522,25 +552,59 @@ def advance_proof( iterations = 0 _LOGGER.info(f'Initializing proof: {proof.id}') self.init_proof(proof) - while True: - steps = list(proof.get_steps()) - _LOGGER.info(f'Found {len(steps)} next steps for proof: {proof.id}') - if len(steps) == 0: - break - for step in steps: - if fail_fast and proof.failed: - _LOGGER.warning(f'Terminating proof early because fail_fast is set: {proof.id}') - proof.failure_info = self.failure_info(proof) - return - if max_iterations is not None and max_iterations <= iterations: - return - iterations += 1 - results = self.step_proof(step) - for result in results: - proof.commit(result) - if iterations % maintenance_rate == 0: - proof.write_proof_data() - callback(proof) + + timed = self.step_timeout is not None + executor = ThreadPoolExecutor(max_workers=1) if timed else None + try: + while True: + steps = list(proof.get_steps()) + _LOGGER.info(f'Found {len(steps)} next steps for proof: {proof.id}') + if len(steps) == 0: + break + shrank_step = False + for step in steps: + if fail_fast and proof.failed: + _LOGGER.warning(f'Terminating proof early because fail_fast is set: {proof.id}') + proof.failure_info = self.failure_info(proof) + return + if max_iterations is not None and max_iterations <= iterations: + return + if timed: + assert executor is not None + budget = self.step_timeout + assert budget is not None + future = executor.submit(self.step_proof, step) + try: + results = future.result(timeout=budget) + except FuturesTimeoutError: + # The step exhausted its budget: interrupt it, ask the prover to do less + # work per step, and re-fetch steps so the same node is retried smaller. + self.interrupt() + wait([future]) + if not self.shrink_step(): + _LOGGER.warning( + f'Proof {proof.id}: step exhausted {budget}s budget and cannot be ' + f'shrunk further; stopping.' + ) + return + _LOGGER.warning( + f'Proof {proof.id}: step exhausted {budget}s budget; shrinking and retrying.' + ) + shrank_step = True + break + else: + results = self.step_proof(step) + iterations += 1 + for result in results: + proof.commit(result) + if iterations % maintenance_rate == 0: + proof.write_proof_data() + callback(proof) + if shrank_step: + continue + finally: + if executor is not None: + executor.shutdown(wait=False) if proof.failed: proof.failure_info = self.failure_info(proof) diff --git a/pyk/src/pyk/proof/prove_rpc.py b/pyk/src/pyk/proof/prove_rpc.py index df01ad30b4..9b852daaa6 100644 --- a/pyk/src/pyk/proof/prove_rpc.py +++ b/pyk/src/pyk/proof/prove_rpc.py @@ -52,6 +52,7 @@ def prove_rpc(self, options: ProveOptions) -> list[Proof]: max_depth=options.max_depth, save_directory=options.save_directory, max_iterations=options.max_iterations, + step_timeout=options.step_timeout, ) for claim in all_claims ] @@ -63,6 +64,7 @@ def _prove_claim_rpc( max_depth: int | None = None, save_directory: Path | None = None, max_iterations: int | None = None, + step_timeout: int | None = None, ) -> Proof: definition = self._kprove.definition @@ -90,7 +92,12 @@ def _prove_claim_rpc( prover = ImpliesProver(proof, kcfg_explore, assume_defined=assume_defined) else: assert type(proof) is APRProof - prover = APRProver(kcfg_explore, execute_depth=max_depth, assume_defined=assume_defined) + prover = APRProver( + kcfg_explore, + execute_depth=max_depth, + assume_defined=assume_defined, + step_timeout=step_timeout, + ) prover.advance_proof(proof, max_iterations=max_iterations) # type: ignore [arg-type] if proof.passed: diff --git a/pyk/src/pyk/proof/reachability.py b/pyk/src/pyk/proof/reachability.py index 0b4a66d2b3..b7c01d913d 100644 --- a/pyk/src/pyk/proof/reachability.py +++ b/pyk/src/pyk/proof/reachability.py @@ -725,6 +725,7 @@ class APRProver(Prover[APRProof, APRProofStep, APRProofResult]): kcfg_explore: KCFGExplore extra_module: KFlatModule | None optimize_kcfg: bool + step_timeout: int | None def __init__( self, @@ -738,6 +739,7 @@ def __init__( assume_defined: bool = False, extra_module: KFlatModule | None = None, optimize_kcfg: bool = False, + step_timeout: int | None = None, ) -> None: self.kcfg_explore = kcfg_explore @@ -751,10 +753,23 @@ def __init__( self.assume_defined = assume_defined self.extra_module = extra_module self.optimize_kcfg = optimize_kcfg + # Whole seconds, floored at 1; None disables the per-step timeout/shrink policy entirely. + self.step_timeout = max(1, step_timeout) if step_timeout is not None else None def close(self) -> None: self.kcfg_explore.cterm_symbolic._kore_client.close() + def shrink_step(self) -> bool: + # On step timeout, halve the execution depth (floor 1) so the next attempt does less work + # per `extend_cterm`. Returns False once `execute_depth` is unset or already at the minimum. + if self.execute_depth is None or self.execute_depth <= 1: + return False + self.execute_depth = max(1, self.execute_depth // 2) + return True + + def interrupt(self) -> None: + self.kcfg_explore.interrupt() + def init_proof(self, proof: APRProof) -> None: # Stamp proof.id on every subsequent kore-RPC request from this thread so # booster's `{request: ...}` log lines self-identify the originating diff --git a/pyk/src/tests/integration/kore/test_interrupt.py b/pyk/src/tests/integration/kore/test_interrupt.py new file mode 100644 index 0000000000..1a1b3d2e64 --- /dev/null +++ b/pyk/src/tests/integration/kore/test_interrupt.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import threading +import time +from string import Template +from typing import TYPE_CHECKING + +from pyk.kore.parser import KoreParser +from pyk.kore.rpc import DefaultError +from pyk.testing import KoreClientTest + +if TYPE_CHECKING: + from pyk.kore.rpc import KoreClient + from pyk.kore.syntax import Pattern + + +def term(n: int) -> Pattern: + template = Template( + r""" + Lbl'-LT-'generatedTop'-GT-'{}( + Lbl'-LT-'k'-GT-'{}( + kseq{}( + inj{SortInt{}, SortKItem{}}(\dv{SortInt{}}("$n")), + K:SortK{} + ) + ), + GCC:SortGeneratedCounterCell{} + ) + """ + ) + parser = KoreParser(template.substitute(n=n)) + pattern = parser.pattern() + assert parser.eof + return pattern + + +class TestInterrupt(KoreClientTest): + # The interrupt mechanism (cancel over the single-socket transport) is what `advance_proof`'s + # step-timeout policy relies on. `inc` never terminates, so an `execute` only ever returns by + # being interrupted -- which is exactly what this test asserts. + DISABLE_BOOSTER = True # exercise the legacy (pure haskell) kore-rpc server + + KOMPILE_DEFINITION = """ + module INTERRUPT-TEST + imports INT + rule [inc]: I:Int => I +Int 1 + endmodule + """ + KOMPILE_MAIN_MODULE = 'INTERRUPT-TEST' + KOMPILE_ARGS = {'syntax_module': 'INTERRUPT-TEST'} + + def test_interrupt_aborts_in_flight_request_and_keeps_connection(self, kore_client: KoreClient) -> None: + # Given: a non-terminating `execute` running on another thread. + box: dict = {} + + def run() -> None: + try: + kore_client.execute(term(0), max_depth=1_000_000_000) + except BaseException as e: # noqa: B036 - record whatever the interrupted call raises + box['exc'] = e + + thread = threading.Thread(target=run, daemon=True) + thread.start() + time.sleep(2.0) # let the step get well underway + assert thread.is_alive() # sanity: it is genuinely long-running, not terminating on its own + + # When: the in-flight request is interrupted. + kore_client.interrupt() + + # Then: the call is aborted promptly (rather than running ~1e9 steps to completion)... + thread.join(timeout=10.0) + assert not thread.is_alive(), 'execute() was not aborted by interrupt() within 10s' + exc = box.get('exc') + assert isinstance(exc, DefaultError) + assert exc.message == 'Request cancelled' + + # ...and the connection survives the cancel: a fresh request still succeeds on it. + result = kore_client.execute(term(0), max_depth=1) + assert result.depth == 1 diff --git a/pyk/src/tests/unit/test_advance_proof.py b/pyk/src/tests/unit/test_advance_proof.py new file mode 100644 index 0000000000..fe955f78e9 --- /dev/null +++ b/pyk/src/tests/unit/test_advance_proof.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from threading import Event +from typing import TYPE_CHECKING + +from pyk.proof.proof import Proof, ProofStatus, Prover + +if TYPE_CHECKING: + from collections.abc import Mapping + from pathlib import Path + from typing import Any + + +class _StepInterrupted(Exception): + """Raised inside `step_proof` when the prover is interrupted, mimicking a backend abort.""" + + +class CountingProof(Proof[int, int]): + """Minimal proof that needs `target` committed steps to pass.""" + + target: int + committed: int + + def __init__(self, id: str, target: int) -> None: + super().__init__(id) + self.target = target + self.committed = 0 + + def commit(self, result: int) -> None: + self.committed += result + + @property + def own_status(self) -> ProofStatus: + return ProofStatus.PASSED if self.committed >= self.target else ProofStatus.PENDING + + @property + def can_progress(self) -> bool: + return self.committed < self.target + + @classmethod + def from_dict(cls: type[CountingProof], dct: Mapping[str, Any], proof_dir: Path | None = None) -> CountingProof: + raise NotImplementedError + + def write_proof_data(self) -> None: ... + + def get_steps(self) -> list[int]: + return [self.committed] if self.can_progress else [] + + +class CountingProver(Prover[CountingProof, int, int]): + """Prover whose `step_proof` is "slow" (blocks for up to `slow_step_secs`) while `depth` exceeds `quick_at_depth`. + + Mirrors `APRProver`: a fixed `step_timeout` budgets each step, and `shrink_step` halves the depth. + A slow step that is interrupted before its budget elapses raises (mimicking a backend abort); a slow + step that is never interrupted finishes its work on its own. Tracks the number of interruptions so + tests can assert how many times a step was shrunk. + """ + + depth: int + quick_at_depth: int + step_timeout: int | None + slow_step_secs: float + interrupt_count: int + _interrupt_event: Event + + def __init__( + self, depth: int, quick_at_depth: int, step_timeout: int | None = 1, slow_step_secs: float = 10.0 + ) -> None: + self.depth = depth + self.quick_at_depth = quick_at_depth + self.step_timeout = step_timeout + self.slow_step_secs = slow_step_secs + self.interrupt_count = 0 + self._interrupt_event = Event() + + def close(self) -> None: ... + + def failure_info(self, proof: CountingProof) -> Any: + return None + + def init_proof(self, proof: CountingProof) -> None: ... + + def shrink_step(self) -> bool: + if self.depth <= 1: + return False + self.depth = max(1, self.depth // 2) + return True + + def interrupt(self) -> None: + self.interrupt_count += 1 + self._interrupt_event.set() + + def step_proof(self, step: int) -> list[int]: + self._interrupt_event.clear() + if self.depth > self.quick_at_depth: + # A "slow" step: block for up to `slow_step_secs`. If `advance_proof` interrupts us first + # (because the step budget elapsed) abort like a real backend; otherwise the step finishes + # its work on its own and commits normally. + if self._interrupt_event.wait(timeout=self.slow_step_secs): + raise _StepInterrupted() + return [1] + + +def test_advance_proof_shrinks_until_progress() -> None: + # Given: depth 4 stalls, but a step completes once depth drops to <= 2. + proof = CountingProof('counting', target=1) + prover = CountingProver(depth=4, quick_at_depth=2) + + # When + prover.advance_proof(proof) + + # Then: one timeout shrinks 4 -> 2, then a step commits and the proof passes. + assert proof.status == ProofStatus.PASSED + assert prover.depth == 2 + assert prover.interrupt_count == 1 + + +def test_advance_proof_stops_when_cannot_shrink_further() -> None: + # Given: every step stalls regardless of depth. + proof = CountingProof('counting', target=1) + prover = CountingProver(depth=2, quick_at_depth=0) + + # When + prover.advance_proof(proof) + + # Then: depth shrinks 2 -> 1, then stops at the floor; the proof stays pending. + assert proof.status == ProofStatus.PENDING + assert proof.committed == 0 + assert prover.depth == 1 + assert prover.interrupt_count == 2 + + +def test_advance_proof_no_shrink_when_steps_are_fast() -> None: + # Given: step_timeout set but steps always complete in time. + proof = CountingProof('counting', target=3) + prover = CountingProver(depth=2, quick_at_depth=2) + + # When + prover.advance_proof(proof) + + # Then: no interruptions, depth untouched, proof passes. + assert proof.status == ProofStatus.PASSED + assert prover.depth == 2 + assert prover.interrupt_count == 0 + + +def test_advance_proof_without_step_timeout_is_unaffected() -> None: + # Given: step_timeout is None -> classic in-loop behavior, no watchdog thread. Each step is "slow" + # (depth 8 > quick_at_depth 4), but without a budget nothing interrupts it, so the step runs to + # completion synchronously instead of being aborted and shrunk. + proof = CountingProof('counting', target=2) + prover = CountingProver(depth=8, quick_at_depth=4, step_timeout=None, slow_step_secs=0.05) + + # When + prover.advance_proof(proof) + + # Then: both slow steps complete on their own; nothing is interrupted or shrunk. + assert proof.status == ProofStatus.PASSED + assert prover.depth == 8 + assert prover.interrupt_count == 0