From d550064c8e464f4ea9ed2b4892e889ad3c46f809 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 17 May 2026 10:02:51 -0500 Subject: [PATCH 1/7] Add Scan support to the assumption system scan_delegate propagates matrix-property facts through a Scan node during inference. lift_assumptions_into_scan is a scan_seqopt1 rewrite that re-asserts a Scan's outer-input assumptions on the inner graph, so assumption-driven rewrites of the loop body can fire. Wired in via assumptions/__init__.py. --- pytensor/assumptions/__init__.py | 1 + pytensor/assumptions/scan.py | 176 +++++++++++++++++++++++++++++++ tests/assumptions/test_scan.py | 145 +++++++++++++++++++++++++ 3 files changed, 322 insertions(+) create mode 100644 pytensor/assumptions/scan.py create mode 100644 tests/assumptions/test_scan.py diff --git a/pytensor/assumptions/__init__.py b/pytensor/assumptions/__init__.py index 38ca208867..103de78a89 100644 --- a/pytensor/assumptions/__init__.py +++ b/pytensor/assumptions/__init__.py @@ -9,6 +9,7 @@ import pytensor.assumptions.permutation import pytensor.assumptions.positive_definite import pytensor.assumptions.reshape +import pytensor.assumptions.scan import pytensor.assumptions.selection import pytensor.assumptions.shape import pytensor.assumptions.subtensor diff --git a/pytensor/assumptions/scan.py b/pytensor/assumptions/scan.py new file mode 100644 index 0000000000..b3c79ee443 --- /dev/null +++ b/pytensor/assumptions/scan.py @@ -0,0 +1,176 @@ +from copy import copy + +from pytensor.assumptions.core import ( + ALL_KEYS, + AssumptionFeature, + FactState, + check_assumption, + register_assumption, +) +from pytensor.assumptions.specify import SpecifyAssumptions +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import clone_replace +from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter +from pytensor.scan.op import Scan +from pytensor.scan.rewriting import scan_seqopt1 +from pytensor.tensor.subtensor import IncSubtensor + + +def _recurrent_init_fact(buffer_var, key, feature, fallback): + """Return the *key* fact of a recurrence's initial value. + + Scan stores a ``sit-sot`` recurrence's initial value as + ``SetSubtensor{:n_taps}(AllocEmpty(...), init)``. The buffer's own fact is + UNKNOWN -- the rows scan has yet to fill are uninitialised -- so read the + fact of the written ``init`` instead. Return *fallback* when the buffer is + not that shape. + """ + owner = buffer_var.owner + if ( + owner is not None + and isinstance(owner.op, IncSubtensor) + and owner.op.set_instead_of_inc + ): + return feature.get(owner.inputs[1], key) + return fallback + + +def scan_delegate(key, op, feature, fgraph, node, input_states): + """Infer *key* for a :class:`Scan`'s outer outputs by delegating into its inner graph. + + The outer-input facts seed the matching inner inputs; the inner graph is then + inferred and the inner-output facts are mapped back onto the outer outputs. + + For a non-recurrent (``nit-sot``) output the per-step inner output is stacked + along a new leading axis, so the fact carries straight through. For a + recurrent (``sit-sot``) output the carried state is seeded from the + recurrence's initial value, and the fact is kept only when the loop body + reproduces it -- a one-step fixpoint, exact because the per-key lattice + (``UNKNOWN`` < ``TRUE``) leaves no room to iterate. Multi-output recurrences + (``mit-mot``) are left UNKNOWN. + """ + mappings = op.get_oinp_iinp_iout_oout_mappings() + inner_inputs = op.inner_inputs + inner_outputs = op.inner_outputs + + # Inner inputs that carry recurrent state. Their outer input is the sit-sot + # buffer; the buffer's own fact is UNKNOWN, so seed from the initial value. + recurrent_iinps = { + iidx + for iidxs in mappings["inner_inp_from_outer_out"].values() + for iidx in iidxs + } + + inner_feature = AssumptionFeature() + op.fgraph.attach_feature(inner_feature) + try: + # Seed each inner input with the fact of the outer input feeding it. + # The seed is written straight into the cache: an inner input is a + # graph leaf, so this is the only way to inject a fact onto it. + for iinp_idx, iinp in enumerate(inner_inputs): + outer_iidx = mappings["outer_inp_from_inner_inp"][iinp_idx] + seed = input_states[outer_iidx] + if iinp_idx in recurrent_iinps: + seed = _recurrent_init_fact(node.inputs[outer_iidx], key, feature, seed) + if seed is not FactState.UNKNOWN: + inner_feature.cache[(iinp, key)] = seed + inner_feature._var_to_keys.setdefault(iinp, set()).add(key) + + out_states = [FactState.UNKNOWN] * len(node.outputs) + for outer_oidx in range(len(node.outputs)): + inner_oidxs = mappings["inner_out_from_outer_out"].get(outer_oidx, []) + if len(inner_oidxs) != 1: + # Multi-output (mit-mot) recurrences are left UNKNOWN for now. + continue + fact = inner_feature.get(inner_outputs[inner_oidxs[0]], key) + + if mappings["inner_inp_from_outer_out"].get(outer_oidx): + # Recurrent: the fact survives only if the loop body reproduces + # the initial value's fact. + outer_iidx = mappings["outer_inp_from_outer_out"][outer_oidx] + init_fact = _recurrent_init_fact( + node.inputs[outer_iidx], key, feature, input_states[outer_iidx] + ) + if fact is not init_fact: + fact = FactState.UNKNOWN + out_states[outer_oidx] = fact + return out_states + finally: + op.fgraph.remove_feature(inner_feature) + + +for _key in ALL_KEYS: + register_assumption(_key, Scan)(scan_delegate) + + +@node_rewriter([Scan]) +def lift_assumptions_into_scan(fgraph, node): + """Lift structural assumptions from a Scan's sequence and non-sequence inputs + onto the matching inner inputs. + + An inner input is a bare leaf, so an ``assume`` on the outer variable is + invisible to rewrites of the inner graph. This re-asserts it with a + :class:`SpecifyAssumptions` node inside, so those rewrites can fire -- e.g. + ``inv(X) @ y`` of a positive-definite :math:`X` specializes to a Cholesky + solve within the loop body. Matrix properties are invariant to batch axes, + so the assertion is valid for every per-step slice. + + Recurrent inner inputs are excluded: the loop body need not preserve the + initial value's properties, so the carried state cannot be assumed to keep + them past the first step. + """ + scan_op = node.op + inner_inputs = scan_op.inner_inputs + non_recurrent = set(scan_op.inner_seqs(inner_inputs)) + non_recurrent.update(scan_op.inner_non_seqs(inner_inputs)) + outer_from_inner = scan_op.get_oinp_iinp_iout_oout_mappings()[ + "outer_inp_from_inner_inp" + ] + + new_facts = {} + for inner_idx, inner_inp in enumerate(inner_inputs): + if inner_inp not in non_recurrent: + continue + clients = scan_op.fgraph.clients.get(inner_inp, ()) + if any( + not isinstance(client, str) and isinstance(client.op, SpecifyAssumptions) + for client, _ in clients + ): + # Already carries an inner assertion -- skip to avoid re-firing. + continue + outer_inp = node.inputs[outer_from_inner[inner_idx]] + facts = { + key.name: FactState.TRUE + for key in ALL_KEYS + if check_assumption(fgraph, outer_inp, key) + } + if facts: + new_facts[inner_inp] = facts + + if not new_facts: + return None + + # Rebuild the inner graph over fresh leaves, splicing the assertions on. + replace = {} + input_clones = [] + for inner_inp in inner_inputs: + clone = inner_inp.clone() + input_clones.append(clone) + facts = new_facts.get(inner_inp) + replace[inner_inp] = SpecifyAssumptions(facts)(clone) if facts else clone + new_inner_outputs = clone_replace(scan_op.inner_outputs, replace=replace) + + new_scan_op = copy(scan_op) + new_scan_op.fgraph = FunctionGraph(input_clones, new_inner_outputs, clone=False) + new_outs = new_scan_op.make_node(*node.inputs).outputs + copy_stack_trace(node.outputs, new_outs) + return new_outs + + +scan_seqopt1.register( + lift_assumptions_into_scan.__name__, + dfs_rewriter(lift_assumptions_into_scan, ignore_newtrees=True), + "fast_run", + "scan", + position=1, +) diff --git a/tests/assumptions/test_scan.py b/tests/assumptions/test_scan.py new file mode 100644 index 0000000000..80b6e2d796 --- /dev/null +++ b/tests/assumptions/test_scan.py @@ -0,0 +1,145 @@ +import pytensor.tensor as pt +from pytensor import function +from pytensor.assumptions import DIAGONAL, FactState +from pytensor.assumptions.specify import assume +from pytensor.scan.basic import scan +from pytensor.scan.op import Scan +from tests.assumptions.conftest import make_fgraph + + +def test_map_preserving_body_forwards_property(): + """map (nit-sot): ``s @ s`` of a diagonal ``s`` is diagonal, so the stack is.""" + x = pt.tensor3("seq") + seq = assume(x, diagonal=True) + out = scan(lambda s: s @ s, sequences=[seq], return_updates=False) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_map_breaking_body_is_unknown(): + """map: ``exp`` does not preserve the zero pattern, so the property is lost.""" + x = pt.tensor3("seq") + seq = assume(x, diagonal=True) + out = scan(lambda s: pt.exp(s), sequences=[seq], return_updates=False) + _, af = make_fgraph(out, inputs=[x]) + assert af.get(out, DIAGONAL) == FactState.UNKNOWN + + +def test_non_sequence_tagged_input_forwards_property(): + """The delegate seeds non-sequence inner inputs; transposing a diagonal + non-sequence keeps it diagonal.""" + x = pt.matrix("m", shape=(4, 4)) + m = assume(x, diagonal=True) + out = scan(lambda mat: mat.T, non_sequences=[m], n_steps=5, return_updates=False) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_recurrence_preserving_body_forwards_property(): + """recurrence (sit-sot): ``2 * prev`` keeps a diagonal carried state diagonal + at every step.""" + x = pt.matrix("init", shape=(3, 3)) + init = assume(x, diagonal=True) + out = scan( + lambda prev: 2.0 * prev, + outputs_info=[init], + n_steps=5, + return_updates=False, + ) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_recurrence_breaking_body_is_unknown(): + """recurrence: adding a non-diagonal term each step breaks the property.""" + x = pt.matrix("init", shape=(3, 3)) + init = assume(x, diagonal=True) + out = scan( + lambda prev: prev + pt.ones_like(prev), + outputs_info=[init], + n_steps=5, + return_updates=False, + ) + _, af = make_fgraph(out, inputs=[x]) + assert af.get(out, DIAGONAL) == FactState.UNKNOWN + + +def test_mit_sot_multi_tap_recurrence_forwards_property(): + """2-tap recurrence: ``p1 + p2`` needs *both* taps diagonal, pinning that the + delegate seeds both from the multi-tap init buffer.""" + x = pt.tensor("init", shape=(2, 3, 3)) + init = assume(x, diagonal=True) + out = scan( + lambda p2, p1: p1 + p2, + outputs_info=[dict(initial=init, taps=[-2, -1])], + n_steps=5, + return_updates=False, + ) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_map_and_recurrence_combined_forwards_property(): + """A scan with both a sequence and a recurrence -- ``prev @ s`` needs both the + carried state and the sequence element diagonal, pinning that the delegate + seeds the recurrent and sequence inner inputs from the right outer inputs.""" + sx = pt.tensor("seq", shape=(5, 3, 3)) + seq = assume(sx, diagonal=True) + ix = pt.matrix("init", shape=(3, 3)) + init = assume(ix, diagonal=True) + out = scan( + lambda s, prev: prev @ s, + sequences=[seq], + outputs_info=[init], + return_updates=False, + ) + _, af = make_fgraph(out, inputs=[sx, ix]) + assert af.check(out, DIAGONAL) + + +def test_multi_output_scan_maps_outputs_independently(): + """A scan with two outputs -- one diagonal-preserving, one breaking -- the + delegate maps each inner output to its own outer output independently.""" + x = pt.tensor3("seq") + seq = assume(x, diagonal=True) + doubled, exponential = scan( + lambda s: [s + s, pt.exp(s)], sequences=[seq], return_updates=False + ) + _, af = make_fgraph(doubled, exponential, inputs=[x]) + assert af.check(doubled, DIAGONAL) + assert af.get(exponential, DIAGONAL) == FactState.UNKNOWN + + +def test_nested_scan_forwards_property(): + """A scan whose body runs an inner scan -- the delegate recurses into the + inner Scan while still inferring the outer one.""" + x = pt.tensor3("seq") + seq = assume(x, diagonal=True) + + def body(s): + inner = scan( + lambda m: 2.0 * m, non_sequences=[s], n_steps=1, return_updates=False + ) + return inner[-1] + + out = scan(body, sequences=[seq], return_updates=False) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_outer_assumption_lifts_into_scan_inner_graph(): + """`lift_assumptions_into_scan` re-asserts a sequence's assumption on the + inner input, so a rewrite of the inner graph sees it: ``inv(X) @ y`` of a + positive-definite ``X`` specializes to a Cholesky solve in the loop body.""" + Xs = pt.tensor("Xs", shape=(4, 3, 3)) + ys = pt.tensor("ys", shape=(4, 3, 3)) + out = scan( + lambda Xt, yt: pt.linalg.inv(Xt) @ yt, + sequences=[assume(Xs, positive_definite=True), ys], + return_updates=False, + ) + fn = function([Xs, ys], out) + [scan_node] = [n for n in fn.maker.fgraph.toposort() if isinstance(n.op, Scan)] + inner_ops = {type(n.op).__name__ for n in scan_node.op.fgraph.toposort()} + assert "MatrixInverse" not in inner_ops + assert "CholeskySolve" in inner_ops From c2a69263e60fe77734077853b9d0a75df2fc9c59 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 11:08:24 -0400 Subject: [PATCH 2/7] Skip TypeCastingOp in scan_push_out_non_seq --- pytensor/scan/rewriting/push_out.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytensor/scan/rewriting/push_out.py b/pytensor/scan/rewriting/push_out.py index 63450f5871..a3b5e808cc 100644 --- a/pytensor/scan/rewriting/push_out.py +++ b/pytensor/scan/rewriting/push_out.py @@ -18,7 +18,7 @@ import pytensor.scalar as ps import pytensor.tensor as pt -from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.replace import clone_replace @@ -91,10 +91,9 @@ def add_to_replace(y): ) for x in nd.inputs ) - # We can (supposedly) do this because the assumption is that a - # `ViewOp` or `DeepCopyOp` will be just at the end of the - # function and not somewhere in the middle - and not isinstance(nd.op, ViewOp) + # Marker ops carry no computation; hoisting them may strip an + # inner-graph hint a later rewrite needs. + and not isinstance(nd.op, TypeCastingOp) and not isinstance(nd.op, DeepCopyOp) ): # We have a candidate node to remove from the inner-graph From 4c8c981f7585855ee51614d6d7d9c440c34af1d8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 11:08:37 -0400 Subject: [PATCH 3/7] Rebuild Scan via constructor in lift_assumptions_into_scan --- pytensor/assumptions/scan.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/pytensor/assumptions/scan.py b/pytensor/assumptions/scan.py index b3c79ee443..4ef09a68e0 100644 --- a/pytensor/assumptions/scan.py +++ b/pytensor/assumptions/scan.py @@ -1,5 +1,3 @@ -from copy import copy - from pytensor.assumptions.core import ( ALL_KEYS, AssumptionFeature, @@ -8,7 +6,6 @@ register_assumption, ) from pytensor.assumptions.specify import SpecifyAssumptions -from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter from pytensor.scan.op import Scan @@ -152,16 +149,24 @@ def lift_assumptions_into_scan(fgraph, node): # Rebuild the inner graph over fresh leaves, splicing the assertions on. replace = {} - input_clones = [] + new_inner_inputs = [] for inner_inp in inner_inputs: - clone = inner_inp.clone() - input_clones.append(clone) + dummy = inner_inp.type() + new_inner_inputs.append(dummy) facts = new_facts.get(inner_inp) - replace[inner_inp] = SpecifyAssumptions(facts)(clone) if facts else clone + replace[inner_inp] = SpecifyAssumptions(facts)(dummy) if facts else dummy new_inner_outputs = clone_replace(scan_op.inner_outputs, replace=replace) - new_scan_op = copy(scan_op) - new_scan_op.fgraph = FunctionGraph(input_clones, new_inner_outputs, clone=False) + new_scan_op = Scan( + new_inner_inputs, + new_inner_outputs, + scan_op.info, + mode=scan_op.mode, + profile=scan_op.profile, + truncate_gradient=scan_op.truncate_gradient, + name=scan_op.name, + allow_gc=scan_op.allow_gc, + ) new_outs = new_scan_op.make_node(*node.inputs).outputs copy_stack_trace(node.outputs, new_outs) return new_outs From feff12b57aeb70a8282e636064170e44ea442fc4 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 11:08:55 -0400 Subject: [PATCH 4/7] Test lift_assumptions_into_scan on a non-sequence input --- tests/assumptions/test_scan.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/assumptions/test_scan.py b/tests/assumptions/test_scan.py index 80b6e2d796..c897c99be2 100644 --- a/tests/assumptions/test_scan.py +++ b/tests/assumptions/test_scan.py @@ -143,3 +143,22 @@ def test_outer_assumption_lifts_into_scan_inner_graph(): inner_ops = {type(n.op).__name__ for n in scan_node.op.fgraph.toposort()} assert "MatrixInverse" not in inner_ops assert "CholeskySolve" in inner_ops + + +def test_non_sequence_assumption_lifts_into_scan_inner_graph(): + """Mirror of the sequence test but for a non-sequence: a positive-definite + ``X`` passed via ``non_sequences`` lifts into the inner graph too, so the + per-step ``inv(X) @ y_t`` still specializes to a Cholesky solve.""" + X = pt.matrix("X", shape=(3, 3)) + ys = pt.tensor("ys", shape=(4, 3, 3)) + out = scan( + lambda yt, X: pt.linalg.inv(X) @ yt, + sequences=[ys], + non_sequences=[assume(X, positive_definite=True)], + return_updates=False, + ) + fn = function([X, ys], out) + [scan_node] = [n for n in fn.maker.fgraph.toposort() if isinstance(n.op, Scan)] + inner_ops = {type(n.op).__name__ for n in scan_node.op.fgraph.toposort()} + assert "MatrixInverse" not in inner_ops + assert "CholeskySolve" in inner_ops From 697e6585220dcf57fae7945f107d70080a4fd7c8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 13:42:18 -0400 Subject: [PATCH 5/7] gradient test --- tests/assumptions/test_scan.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/assumptions/test_scan.py b/tests/assumptions/test_scan.py index c897c99be2..c974926723 100644 --- a/tests/assumptions/test_scan.py +++ b/tests/assumptions/test_scan.py @@ -162,3 +162,22 @@ def test_non_sequence_assumption_lifts_into_scan_inner_graph(): inner_ops = {type(n.op).__name__ for n in scan_node.op.fgraph.toposort()} assert "MatrixInverse" not in inner_ops assert "CholeskySolve" in inner_ops + + +def test_scan_grad_compiles_with_recurrence_assumption(): + """Backward AD flips a sit-sot forward output into a sequence input of the + backward scan. With ``diagonal=True`` on the recurrence's init the + gradient must still compile and stay numerically correct.""" + import numpy as np + + init_raw = pt.matrix("init", shape=(3, 3)) + init = assume(init_raw, diagonal=True) + out = scan( + lambda prev: 2.0 * prev, + outputs_info=[init], + n_steps=4, + return_updates=False, + ) + g = pt.grad(out[-1].sum(), init_raw) + fn = function([init_raw], g) + np.testing.assert_allclose(fn(np.eye(3)), 16.0 * np.ones((3, 3))) From 64c5f74e902f8ef20a8aa661be7b7e955f652f5a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 27 May 2026 22:00:50 -0500 Subject: [PATCH 6/7] rename lift -> push --- pytensor/assumptions/scan.py | 8 ++++---- tests/assumptions/test_scan.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytensor/assumptions/scan.py b/pytensor/assumptions/scan.py index 4ef09a68e0..76e0950f58 100644 --- a/pytensor/assumptions/scan.py +++ b/pytensor/assumptions/scan.py @@ -101,8 +101,8 @@ def scan_delegate(key, op, feature, fgraph, node, input_states): @node_rewriter([Scan]) -def lift_assumptions_into_scan(fgraph, node): - """Lift structural assumptions from a Scan's sequence and non-sequence inputs +def push_assumptions_into_scan(fgraph, node): + """Push structural assumptions from a Scan's sequence and non-sequence inputs onto the matching inner inputs. An inner input is a bare leaf, so an ``assume`` on the outer variable is @@ -173,8 +173,8 @@ def lift_assumptions_into_scan(fgraph, node): scan_seqopt1.register( - lift_assumptions_into_scan.__name__, - dfs_rewriter(lift_assumptions_into_scan, ignore_newtrees=True), + push_assumptions_into_scan.__name__, + dfs_rewriter(push_assumptions_into_scan, ignore_newtrees=True), "fast_run", "scan", position=1, diff --git a/tests/assumptions/test_scan.py b/tests/assumptions/test_scan.py index c974926723..24e699e282 100644 --- a/tests/assumptions/test_scan.py +++ b/tests/assumptions/test_scan.py @@ -127,8 +127,8 @@ def body(s): assert af.check(out, DIAGONAL) -def test_outer_assumption_lifts_into_scan_inner_graph(): - """`lift_assumptions_into_scan` re-asserts a sequence's assumption on the +def test_outer_assumption_pushed_into_scan_inner_graph(): + """`push_assumptions_into_scan` re-asserts a sequence's assumption on the inner input, so a rewrite of the inner graph sees it: ``inv(X) @ y`` of a positive-definite ``X`` specializes to a Cholesky solve in the loop body.""" Xs = pt.tensor("Xs", shape=(4, 3, 3)) @@ -145,9 +145,9 @@ def test_outer_assumption_lifts_into_scan_inner_graph(): assert "CholeskySolve" in inner_ops -def test_non_sequence_assumption_lifts_into_scan_inner_graph(): +def test_non_sequence_assumption_pushed_into_scan_inner_graph(): """Mirror of the sequence test but for a non-sequence: a positive-definite - ``X`` passed via ``non_sequences`` lifts into the inner graph too, so the + ``X`` passed via ``non_sequences`` is pushed into the inner graph too, so the per-step ``inv(X) @ y_t`` still specializes to a Cholesky solve.""" X = pt.matrix("X", shape=(3, 3)) ys = pt.tensor("ys", shape=(4, 3, 3)) From 6f576e7358a9d7cf4461a477ba993d96981850c5 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 27 May 2026 22:16:15 -0500 Subject: [PATCH 7/7] skip rewrite test when mode = FAST_COMPILE --- tests/assumptions/test_scan.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/assumptions/test_scan.py b/tests/assumptions/test_scan.py index 24e699e282..6f1a017261 100644 --- a/tests/assumptions/test_scan.py +++ b/tests/assumptions/test_scan.py @@ -1,7 +1,10 @@ +import pytest + import pytensor.tensor as pt from pytensor import function from pytensor.assumptions import DIAGONAL, FactState from pytensor.assumptions.specify import assume +from pytensor.configdefaults import config from pytensor.scan.basic import scan from pytensor.scan.op import Scan from tests.assumptions.conftest import make_fgraph @@ -127,6 +130,10 @@ def body(s): assert af.check(out, DIAGONAL) +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="assumption push-in and the inner-graph solve specialization are fast_run rewrites", +) def test_outer_assumption_pushed_into_scan_inner_graph(): """`push_assumptions_into_scan` re-asserts a sequence's assumption on the inner input, so a rewrite of the inner graph sees it: ``inv(X) @ y`` of a @@ -145,6 +152,10 @@ def test_outer_assumption_pushed_into_scan_inner_graph(): assert "CholeskySolve" in inner_ops +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="assumption push-in and the inner-graph solve specialization are fast_run rewrites", +) def test_non_sequence_assumption_pushed_into_scan_inner_graph(): """Mirror of the sequence test but for a non-sequence: a positive-definite ``X`` passed via ``non_sequences`` is pushed into the inner graph too, so the