diff --git a/CLAUDE.md b/CLAUDE.md index a6ecda6464..7dfeb2baa4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -172,6 +172,10 @@ CLI flags for these fields are defined in `booster/tools/booster/Server.hs` in - Traverses bottom-up, applies function/simplification equations. - Concrete sub-terms are sent to the LLVM backend for evaluation (top-down first to maximise batch size). +- Passes restart from the top until a fixed point (bounded by + `--equation-max-iterations`); `--equation-max-local-steps N` additionally + re-simplifies rewritten sub-terms in place, up to N chained equation + applications (default 0 = restart-only). **Term representation** (`Booster/Pattern/Base.hs`): - `Term` is a tagged union with a `TermAttributes` field tracking `isConstructorLike`, diff --git a/booster/library/Booster/CLOptions.hs b/booster/library/Booster/CLOptions.hs index 122dfaabab..9cdd1ecb72 100644 --- a/booster/library/Booster/CLOptions.hs +++ b/booster/library/Booster/CLOptions.hs @@ -400,7 +400,7 @@ parseSMTOptions = parseEquationOptions :: Parser EquationOptions parseEquationOptions = - (\x y -> EquationOptions (Bound x) (Bound y)) + (\x y z -> EquationOptions (Bound x) (Bound y) (Bound z)) <$> option nonnegativeInt ( metavar "ITERATION_LIMIT" @@ -417,9 +417,22 @@ parseEquationOptions = <> value defaultMaxRecursion <> showDefault ) + <*> option + nonnegativeInt + ( metavar "LOCAL_STEP_LIMIT" + <> long "equation-max-local-steps" + <> help + "Number of equations applied in place at a rewritten subterm \ + \(per chain of in-place rewrites) before restarting the \ + \traversal from the top (0, the default, is restart-only \ + \evaluation)" + <> value defaultMaxLocalSteps + <> showDefault + ) where defaultMaxIterations = 100 defaultMaxRecursion = 5 + defaultMaxLocalSteps = 0 parseRewriteOptions :: Parser RewriteOptions parseRewriteOptions = diff --git a/booster/library/Booster/GlobalState.hs b/booster/library/Booster/GlobalState.hs index b40dc4026e..1a95ecf349 100644 --- a/booster/library/Booster/GlobalState.hs +++ b/booster/library/Booster/GlobalState.hs @@ -13,6 +13,11 @@ import Booster.Util (Bound (..)) data EquationOptions = EquationOptions { maxIterations :: Bound "Iterations" , maxRecursion :: Bound "Recursion" + , maxLocalSteps :: Bound "LocalSteps" + -- ^ how many equations may be applied in place at a rewritten + -- subterm (per chain of in-place rewrites) before falling back + -- to restarting the traversal from the top (0, the default, is + -- restart-only evaluation) } deriving stock (Show, Eq) @@ -20,7 +25,11 @@ data EquationOptions = EquationOptions globalEquationOptions :: IORef EquationOptions globalEquationOptions = unsafePerformIO . newIORef $ - EquationOptions{maxIterations = 100, maxRecursion = 5} + EquationOptions + { maxIterations = 100 + , maxRecursion = 5 + , maxLocalSteps = 0 + } readGlobalEquationOptions :: IO EquationOptions readGlobalEquationOptions = readIORef globalEquationOptions diff --git a/booster/library/Booster/Pattern/ApplyEquations.hs b/booster/library/Booster/Pattern/ApplyEquations.hs index dd4a573185..928c941afe 100644 --- a/booster/library/Booster/Pattern/ApplyEquations.hs +++ b/booster/library/Booster/Pattern/ApplyEquations.hs @@ -149,6 +149,7 @@ data EquationConfig = EquationConfig , smtSolver :: SMT.SMTContext , maxRecursion :: Bound "Recursion" , maxIterations :: Bound "Iterations" + , maxLocalSteps :: Bound "LocalSteps" , logger :: Logger LogMessage , prettyModifiers :: ModifiersRep } @@ -156,6 +157,14 @@ data EquationConfig = EquationConfig data EquationState = EquationState { termStack :: Seq Term , recursionStack :: [Term] + , localSteps :: [Term] + -- ^ chain of locally-rewritten node values on the current + -- traversal path, for loop detection of in-place rewriting + -- (path-scoped: saved and restored around each local recursion) + , localStepCount :: Int + -- ^ length of 'localSteps' (path-scoped like the chain itself), + -- kept separately to bound in-place rewriting by 'maxLocalSteps' + -- without computing the length on every step , changed :: Bool , predicates :: Set Predicate , cache :: SimplifierCache @@ -193,6 +202,8 @@ startState cache known = EquationState { termStack = mempty , recursionStack = [] + , localSteps = [] + , localStepCount = 0 , changed = False , predicates = known , -- replacements from predicates are rebuilt from the path conditions every time @@ -353,6 +364,7 @@ runEquationT definition llvmApi smtSolver sCache known (EquationT m) = do , smtSolver , maxIterations = globalEquationOptions.maxIterations , maxRecursion = globalEquationOptions.maxRecursion + , maxLocalSteps = globalEquationOptions.maxLocalSteps , logger , prettyModifiers } @@ -406,13 +418,80 @@ iterateEquations direction preference startTerm = do in simp llvmResult -- evaluate functions and simplify (recursively at each level) newTerm <- - let simp = cached Equations $ traverseTerm direction simp (applyHooksAndEquations preference) + let onEval + | direction == BottomUp = localFixpointEval simp + | otherwise = applyHooksAndEquations preference + simp = cached Equations $ traverseTerm direction simp onEval in simp replacedTerm changeFlag <- getChanged if changeFlag then checkForLoop newTerm >> resetChanged >> go newTerm else pure llvmResult + {- Local-fixpoint evaluation (BottomUp mode): when a node was + rewritten, run the LLVM pass on the result (preserving the + LLVM-before-equations ordering the global loop provides) and + re-enter the cached bottom-up traversal on it, normalizing + everything the rewrite produced in place instead of restarting + the whole-term traversal (the rewrite builds a new subterm + from the rule's RHS, which needs full evaluation in all its + arguments; the cache cuts the descent short at substituted + subject parts that are already normal). Ancestors then see + children in final form and the global loop converges in a few + passes instead of one per causal chain step. + + The in-place rewriting effort is bounded: both 'localSteps' + and 'localStepCount' are path-scoped (saved and restored + around the recursion, maintaining the invariant + localStepCount == length localSteps), so each chain of + in-place rewrites is at most 'maxLocalSteps' deep. Once the + budget is exhausted, rewritten nodes are returned without + recursion, which is exactly the restart-only strategy (the + changed flag is already set, so the global loop picks the + node up on the next pass). Each chain therefore advances by + at most 'maxLocalSteps' plus one application per pass, total + work stays bounded by 'maxIterations' passes, and + 'TooManyIterations' with its partial result is reached as + before. A budget of 0 restores the restart-only strategy + entirely. + + Loop detection is two-layered: the in-place rewrite chain in + 'localSteps' is checked per step, catching oscillations + shorter than the budget immediately; cycles that survive a + pass boundary recur in the whole-term snapshots within at most + one cycle period of passes and are caught by 'checkForLoop' + (a cycle period dividing the budget repeats the snapshot on + the very next pass). + -} + localFixpointEval :: LoggerMIO io => (Term -> EquationT io Term) -> Term -> EquationT io Term + localFixpointEval recurse t = do + t' <- applyHooksAndEquations preference t + if t' == t + then pure t + else do + config <- getConfig + st <- getState + if coerce st.localStepCount >= config.maxLocalSteps + then pure t' -- budget exhausted: defer to the global loop + else do + when (t' `elem` st.localSteps) $ do + withContext CtxAbort $ do + logWarn "Equation loop detected (local fixpoint)." + throw . EquationLoop . reverse $ t' : st.localSteps + let !newCount = st.localStepCount + 1 + eqState . put $ + st + { localSteps = t' : st.localSteps + , localStepCount = newCount + } + result <- llvmSimplify t' >>= recurse + eqState . modify $ \s -> + s + { localSteps = st.localSteps + , localStepCount = st.localStepCount + } + pure result + llvmSimplify :: forall io. LoggerMIO io => Term -> EquationT io Term llvmSimplify term = do config <- getConfig @@ -1201,9 +1280,15 @@ simplifyConstraint' recurseIntoEvalBool = \case evalBool :: LoggerMIO io => Term -> EquationT io Term evalBool t = withTermContext t $ do prior <- getState -- save prior state so we can revert - eqState $ put prior{termStack = mempty, changed = False} + eqState $ put prior{termStack = mempty, changed = False, localSteps = [], localStepCount = 0} result <- iterateEquations BottomUp PreferFunctions t - -- reset change flag and term stack to prior values + -- reset change flag, term stack, and local steps to prior values -- (keep the updated cache and added predicates, if any) - eqState $ modify $ \s -> s{changed = prior.changed, termStack = prior.termStack} + eqState $ modify $ \s -> + s + { changed = prior.changed + , termStack = prior.termStack + , localSteps = prior.localSteps + , localStepCount = prior.localStepCount + } pure result diff --git a/booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs b/booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs index b2353bae83..d8cb3ed9b2 100644 --- a/booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs +++ b/booster/unit-tests/Test/Booster/Pattern/ApplyEquations.hs @@ -12,9 +12,11 @@ module Test.Booster.Pattern.ApplyEquations ( test_simplify, test_simplifyPattern, test_simplifyConstraint, + test_localFixpoint, test_errors, ) where +import Control.Exception (finally) import Control.Monad.Logger (runNoLoggingT) import Data.ByteString (ByteString) import Data.Map (Map) @@ -25,6 +27,11 @@ import Test.Tasty.HUnit import Booster.Definition.Attributes.Base import Booster.Definition.Base +import Booster.GlobalState ( + EquationOptions (..), + readGlobalEquationOptions, + writeGlobalEquationOptions, + ) import Booster.Pattern.ApplyEquations import Booster.Pattern.Base import Booster.Pattern.Bool @@ -69,7 +76,9 @@ test_evaluateFunction = n `times` f = foldr (.) id (replicate n $ apply f) -- top-down evaluation: a single iteration is enough eval TopDown (subj 101) @?>>= Right (101 `times` con2 $ a) - -- bottom-up evaluation: `depth` many iterations + -- bottom-up evaluation: with the default in-place budget + -- of 0, each pass advances the chain by one application, + -- so the iteration limit caps the chain depth eval BottomUp (subj 100) @?>>= Right (100 `times` con2 $ a) isTooManyIterations =<< eval BottomUp (subj 101) , -- con3(f1(con2(a)), f1(con1(con2(b)))) => con3(con2(a), con2(con2(b))) @@ -226,6 +235,69 @@ test_simplifyConstraint = ns <- noSolver runNoLoggingT $ fst <$> simplifyConstraint testDefinition Nothing ns mempty mempty t +test_localFixpoint :: TestTree +test_localFixpoint = + -- must run after the iteration-limit test ("Recursive evaluation"), + -- whose outcome the temporary budget window below would change + -- if the two ran concurrently (the test binary runs tests in + -- parallel) + after AllFinish "Recursive evaluation" $ + testCase "In-place rewriting: deeper chains, loop detection, and bounded effort" $ do + -- at the default budget of 0 (restart-only evaluation), + -- node-level oscillation is detected by the whole-term + -- snapshots of the global passes + isLoop =<< evalWith loopDef (app f1 [app con1 [a]]) + -- explicit construction instead of record update: the + -- field names are shared with EquationConfig, making an + -- update ambiguous under DuplicateRecordFields + defaults <- readGlobalEquationOptions + writeGlobalEquationOptions + EquationOptions + { maxIterations = defaults.maxIterations + , maxRecursion = defaults.maxRecursion + , maxLocalSteps = 20 + } + budgetChecks defaults `finally` writeGlobalEquationOptions defaults + where + budgetChecks :: EquationOptions -> IO () + budgetChecks defaults = do + -- each pass advances a chain by the budget plus one + -- application, so depths beyond the restart-only limit of + -- maxIterations complete now (the chain rule produces its + -- redex inside the RHS, which the in-place recursion follows) + evalWith funDef (subj 101) >>= (@?= Right (101 `times` con2 $ start)) + -- node-level oscillation is detected per local step (cycle + -- shorter than the budget) + isLoop =<< evalWith loopDef (app f1 [app con1 [a]]) + -- the combined bound (passes times the per-chain budget plus + -- one application) still terminates evaluation with a partial + -- result: with 10 passes and a budget of 20, a chain of depth + -- 300 cannot finish + writeGlobalEquationOptions + EquationOptions + { maxIterations = 10 + , maxRecursion = defaults.maxRecursion + , maxLocalSteps = 20 + } + isTooMany =<< evalWith funDef (subj 300) + + subj depth = app f1 [iterate (apply con1) start !! depth] + start = app con2 [a] + n `times` f = foldr (.) id (replicate n $ apply f) + + a = var "A" someSort + apply f = app f . (: []) + + isLoop (Left (EquationLoop _)) = pure () + isLoop other = assertFailure $ "Expected an equation loop, got " <> show other + + isTooMany (Left (TooManyIterations _ _ _)) = pure () + isTooMany other = assertFailure $ "Expected an iteration-limit abort, got " <> show other + + evalWith def t = do + ns <- noSolver + runNoLoggingT $ fst <$> evaluateTerm BottomUp def Nothing ns mempty mempty t + test_errors :: TestTree test_errors = testGroup diff --git a/docs/2024-10-18-booster-description.md b/docs/2024-10-18-booster-description.md index c49a9794f1..d6d8e33ac7 100644 --- a/docs/2024-10-18-booster-description.md +++ b/docs/2024-10-18-booster-description.md @@ -229,11 +229,13 @@ The simplification code path is used at two different points of the execution, a Concrete function evaluation is handled by the LLVM backend and thus requires the semantics to be written in such a way, so as to be able to build both the kore definition used by the haskell backend, as well as the LLVM kore definition. The booster relies on the LLVM version of a semantics, compiled as a dynamic library, which is loaded when the server starts. During simplification, the term is traversed bottom up and any concrete sub-terms are sent to he LLVM backend to be evaluated. -The symbolic parts of a term are handled directly by the booster. Similarly to rewrite rules, function rules may also have side conditions. As a result, the simplifier may have to recurse into evaluating whether the side-condition of a function/simplification rule evaluates to true/false before successfully rewriting the term. At the moment, the evaluation strategy is hard-coded in the booster, and it is generally is the following: +The symbolic parts of a term are handled directly by the booster. Similarly to rewrite rules, function rules may also have side conditions. As a result, the simplifier may have to recurse into evaluating whether the side-condition of a function/simplification rule evaluates to true/false before successfully rewriting the term. The evaluation strategy is generally the following: - traverse the term top-down once and apply LLVM simplifications to the concrete sub-terms. It is essential to discover the concrete terms top-down and thus track the concreteness of sub-terms with attributes. By doing that, we make sure that we send a few big terms to the LLVM backend and not many small terms, thus minimising the overhead. -- traverser the term bottom-up, applying equations at every level until we make progress with at least one equation; +- traverse the term bottom-up, applying equations at every level until we make progress with at least one equation; - when applying equations, prefer functions, and only apply simplifications when function do not produce a result anymore, i.e. no functions apply. +These traversal passes restart from the top of the term after every pass that changed it, until a fixed point is reached or the pass limit (`--equation-max-iterations`, default 100) aborts evaluation. By default, a rewrite chain at a single position therefore advances by one equation application per pass. The `--equation-max-local-steps N` option generalises this: when an equation rewrites a sub-term, the result is re-simplified in place for up to `N` chained equation applications before deferring back to the restart loop, so chains of side-condition evaluations can complete at the sub-term where they arise instead of paying one global pass per step. `N = 0` (the default) is exactly the restart-only behaviour. The in-place step re-enters the bottom-up traversal of the rewritten result, fully evaluating the new subterm the rule's RHS produced (the descent stops early at sub-terms that are already cached as evaluated). Total effort remains bounded by the pass limit times the per-chain budget, and equation loops are detected both per local step (cycles shorter than the budget) and by whole-term snapshots across passes (cycles longer than the budget). + **TODO**: discuss the abort conditions of function vs. simplifications. In short, simplifications are optional, and functions are mandatory, i.e. we abort if a function equation produces an indeterminate match or a function condition is indeterminate. **TODO**: discuss the process of applying a single equation. **TODO**: discuss caching and how it's implemented in Booster.