diff --git a/pytensor/assumptions/__init__.py b/pytensor/assumptions/__init__.py index 38ca208867..103de78a89 100644 --- a/pytensor/assumptions/__init__.py +++ b/pytensor/assumptions/__init__.py @@ -9,6 +9,7 @@ import pytensor.assumptions.permutation import pytensor.assumptions.positive_definite import pytensor.assumptions.reshape +import pytensor.assumptions.scan import pytensor.assumptions.selection import pytensor.assumptions.shape import pytensor.assumptions.subtensor diff --git a/pytensor/assumptions/scan.py b/pytensor/assumptions/scan.py new file mode 100644 index 0000000000..76e0950f58 --- /dev/null +++ b/pytensor/assumptions/scan.py @@ -0,0 +1,181 @@ +from pytensor.assumptions.core import ( + ALL_KEYS, + AssumptionFeature, + FactState, + check_assumption, + register_assumption, +) +from pytensor.assumptions.specify import SpecifyAssumptions +from pytensor.graph.replace import clone_replace +from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter +from pytensor.scan.op import Scan +from pytensor.scan.rewriting import scan_seqopt1 +from pytensor.tensor.subtensor import IncSubtensor + + +def _recurrent_init_fact(buffer_var, key, feature, fallback): + """Return the *key* fact of a recurrence's initial value. + + Scan stores a ``sit-sot`` recurrence's initial value as + ``SetSubtensor{:n_taps}(AllocEmpty(...), init)``. The buffer's own fact is + UNKNOWN -- the rows scan has yet to fill are uninitialised -- so read the + fact of the written ``init`` instead. Return *fallback* when the buffer is + not that shape. + """ + owner = buffer_var.owner + if ( + owner is not None + and isinstance(owner.op, IncSubtensor) + and owner.op.set_instead_of_inc + ): + return feature.get(owner.inputs[1], key) + return fallback + + +def scan_delegate(key, op, feature, fgraph, node, input_states): + """Infer *key* for a :class:`Scan`'s outer outputs by delegating into its inner graph. + + The outer-input facts seed the matching inner inputs; the inner graph is then + inferred and the inner-output facts are mapped back onto the outer outputs. + + For a non-recurrent (``nit-sot``) output the per-step inner output is stacked + along a new leading axis, so the fact carries straight through. For a + recurrent (``sit-sot``) output the carried state is seeded from the + recurrence's initial value, and the fact is kept only when the loop body + reproduces it -- a one-step fixpoint, exact because the per-key lattice + (``UNKNOWN`` < ``TRUE``) leaves no room to iterate. Multi-output recurrences + (``mit-mot``) are left UNKNOWN. + """ + mappings = op.get_oinp_iinp_iout_oout_mappings() + inner_inputs = op.inner_inputs + inner_outputs = op.inner_outputs + + # Inner inputs that carry recurrent state. Their outer input is the sit-sot + # buffer; the buffer's own fact is UNKNOWN, so seed from the initial value. + recurrent_iinps = { + iidx + for iidxs in mappings["inner_inp_from_outer_out"].values() + for iidx in iidxs + } + + inner_feature = AssumptionFeature() + op.fgraph.attach_feature(inner_feature) + try: + # Seed each inner input with the fact of the outer input feeding it. + # The seed is written straight into the cache: an inner input is a + # graph leaf, so this is the only way to inject a fact onto it. + for iinp_idx, iinp in enumerate(inner_inputs): + outer_iidx = mappings["outer_inp_from_inner_inp"][iinp_idx] + seed = input_states[outer_iidx] + if iinp_idx in recurrent_iinps: + seed = _recurrent_init_fact(node.inputs[outer_iidx], key, feature, seed) + if seed is not FactState.UNKNOWN: + inner_feature.cache[(iinp, key)] = seed + inner_feature._var_to_keys.setdefault(iinp, set()).add(key) + + out_states = [FactState.UNKNOWN] * len(node.outputs) + for outer_oidx in range(len(node.outputs)): + inner_oidxs = mappings["inner_out_from_outer_out"].get(outer_oidx, []) + if len(inner_oidxs) != 1: + # Multi-output (mit-mot) recurrences are left UNKNOWN for now. + continue + fact = inner_feature.get(inner_outputs[inner_oidxs[0]], key) + + if mappings["inner_inp_from_outer_out"].get(outer_oidx): + # Recurrent: the fact survives only if the loop body reproduces + # the initial value's fact. + outer_iidx = mappings["outer_inp_from_outer_out"][outer_oidx] + init_fact = _recurrent_init_fact( + node.inputs[outer_iidx], key, feature, input_states[outer_iidx] + ) + if fact is not init_fact: + fact = FactState.UNKNOWN + out_states[outer_oidx] = fact + return out_states + finally: + op.fgraph.remove_feature(inner_feature) + + +for _key in ALL_KEYS: + register_assumption(_key, Scan)(scan_delegate) + + +@node_rewriter([Scan]) +def push_assumptions_into_scan(fgraph, node): + """Push structural assumptions from a Scan's sequence and non-sequence inputs + onto the matching inner inputs. + + An inner input is a bare leaf, so an ``assume`` on the outer variable is + invisible to rewrites of the inner graph. This re-asserts it with a + :class:`SpecifyAssumptions` node inside, so those rewrites can fire -- e.g. + ``inv(X) @ y`` of a positive-definite :math:`X` specializes to a Cholesky + solve within the loop body. Matrix properties are invariant to batch axes, + so the assertion is valid for every per-step slice. + + Recurrent inner inputs are excluded: the loop body need not preserve the + initial value's properties, so the carried state cannot be assumed to keep + them past the first step. + """ + scan_op = node.op + inner_inputs = scan_op.inner_inputs + non_recurrent = set(scan_op.inner_seqs(inner_inputs)) + non_recurrent.update(scan_op.inner_non_seqs(inner_inputs)) + outer_from_inner = scan_op.get_oinp_iinp_iout_oout_mappings()[ + "outer_inp_from_inner_inp" + ] + + new_facts = {} + for inner_idx, inner_inp in enumerate(inner_inputs): + if inner_inp not in non_recurrent: + continue + clients = scan_op.fgraph.clients.get(inner_inp, ()) + if any( + not isinstance(client, str) and isinstance(client.op, SpecifyAssumptions) + for client, _ in clients + ): + # Already carries an inner assertion -- skip to avoid re-firing. + continue + outer_inp = node.inputs[outer_from_inner[inner_idx]] + facts = { + key.name: FactState.TRUE + for key in ALL_KEYS + if check_assumption(fgraph, outer_inp, key) + } + if facts: + new_facts[inner_inp] = facts + + if not new_facts: + return None + + # Rebuild the inner graph over fresh leaves, splicing the assertions on. + replace = {} + new_inner_inputs = [] + for inner_inp in inner_inputs: + dummy = inner_inp.type() + new_inner_inputs.append(dummy) + facts = new_facts.get(inner_inp) + replace[inner_inp] = SpecifyAssumptions(facts)(dummy) if facts else dummy + new_inner_outputs = clone_replace(scan_op.inner_outputs, replace=replace) + + new_scan_op = Scan( + new_inner_inputs, + new_inner_outputs, + scan_op.info, + mode=scan_op.mode, + profile=scan_op.profile, + truncate_gradient=scan_op.truncate_gradient, + name=scan_op.name, + allow_gc=scan_op.allow_gc, + ) + new_outs = new_scan_op.make_node(*node.inputs).outputs + copy_stack_trace(node.outputs, new_outs) + return new_outs + + +scan_seqopt1.register( + push_assumptions_into_scan.__name__, + dfs_rewriter(push_assumptions_into_scan, ignore_newtrees=True), + "fast_run", + "scan", + position=1, +) diff --git a/pytensor/scan/rewriting/push_out.py b/pytensor/scan/rewriting/push_out.py index 63450f5871..a3b5e808cc 100644 --- a/pytensor/scan/rewriting/push_out.py +++ b/pytensor/scan/rewriting/push_out.py @@ -18,7 +18,7 @@ import pytensor.scalar as ps import pytensor.tensor as pt -from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.replace import clone_replace @@ -91,10 +91,9 @@ def add_to_replace(y): ) for x in nd.inputs ) - # We can (supposedly) do this because the assumption is that a - # `ViewOp` or `DeepCopyOp` will be just at the end of the - # function and not somewhere in the middle - and not isinstance(nd.op, ViewOp) + # Marker ops carry no computation; hoisting them may strip an + # inner-graph hint a later rewrite needs. + and not isinstance(nd.op, TypeCastingOp) and not isinstance(nd.op, DeepCopyOp) ): # We have a candidate node to remove from the inner-graph diff --git a/tests/assumptions/test_scan.py b/tests/assumptions/test_scan.py new file mode 100644 index 0000000000..6f1a017261 --- /dev/null +++ b/tests/assumptions/test_scan.py @@ -0,0 +1,194 @@ +import pytest + +import pytensor.tensor as pt +from pytensor import function +from pytensor.assumptions import DIAGONAL, FactState +from pytensor.assumptions.specify import assume +from pytensor.configdefaults import config +from pytensor.scan.basic import scan +from pytensor.scan.op import Scan +from tests.assumptions.conftest import make_fgraph + + +def test_map_preserving_body_forwards_property(): + """map (nit-sot): ``s @ s`` of a diagonal ``s`` is diagonal, so the stack is.""" + x = pt.tensor3("seq") + seq = assume(x, diagonal=True) + out = scan(lambda s: s @ s, sequences=[seq], return_updates=False) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_map_breaking_body_is_unknown(): + """map: ``exp`` does not preserve the zero pattern, so the property is lost.""" + x = pt.tensor3("seq") + seq = assume(x, diagonal=True) + out = scan(lambda s: pt.exp(s), sequences=[seq], return_updates=False) + _, af = make_fgraph(out, inputs=[x]) + assert af.get(out, DIAGONAL) == FactState.UNKNOWN + + +def test_non_sequence_tagged_input_forwards_property(): + """The delegate seeds non-sequence inner inputs; transposing a diagonal + non-sequence keeps it diagonal.""" + x = pt.matrix("m", shape=(4, 4)) + m = assume(x, diagonal=True) + out = scan(lambda mat: mat.T, non_sequences=[m], n_steps=5, return_updates=False) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_recurrence_preserving_body_forwards_property(): + """recurrence (sit-sot): ``2 * prev`` keeps a diagonal carried state diagonal + at every step.""" + x = pt.matrix("init", shape=(3, 3)) + init = assume(x, diagonal=True) + out = scan( + lambda prev: 2.0 * prev, + outputs_info=[init], + n_steps=5, + return_updates=False, + ) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_recurrence_breaking_body_is_unknown(): + """recurrence: adding a non-diagonal term each step breaks the property.""" + x = pt.matrix("init", shape=(3, 3)) + init = assume(x, diagonal=True) + out = scan( + lambda prev: prev + pt.ones_like(prev), + outputs_info=[init], + n_steps=5, + return_updates=False, + ) + _, af = make_fgraph(out, inputs=[x]) + assert af.get(out, DIAGONAL) == FactState.UNKNOWN + + +def test_mit_sot_multi_tap_recurrence_forwards_property(): + """2-tap recurrence: ``p1 + p2`` needs *both* taps diagonal, pinning that the + delegate seeds both from the multi-tap init buffer.""" + x = pt.tensor("init", shape=(2, 3, 3)) + init = assume(x, diagonal=True) + out = scan( + lambda p2, p1: p1 + p2, + outputs_info=[dict(initial=init, taps=[-2, -1])], + n_steps=5, + return_updates=False, + ) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +def test_map_and_recurrence_combined_forwards_property(): + """A scan with both a sequence and a recurrence -- ``prev @ s`` needs both the + carried state and the sequence element diagonal, pinning that the delegate + seeds the recurrent and sequence inner inputs from the right outer inputs.""" + sx = pt.tensor("seq", shape=(5, 3, 3)) + seq = assume(sx, diagonal=True) + ix = pt.matrix("init", shape=(3, 3)) + init = assume(ix, diagonal=True) + out = scan( + lambda s, prev: prev @ s, + sequences=[seq], + outputs_info=[init], + return_updates=False, + ) + _, af = make_fgraph(out, inputs=[sx, ix]) + assert af.check(out, DIAGONAL) + + +def test_multi_output_scan_maps_outputs_independently(): + """A scan with two outputs -- one diagonal-preserving, one breaking -- the + delegate maps each inner output to its own outer output independently.""" + x = pt.tensor3("seq") + seq = assume(x, diagonal=True) + doubled, exponential = scan( + lambda s: [s + s, pt.exp(s)], sequences=[seq], return_updates=False + ) + _, af = make_fgraph(doubled, exponential, inputs=[x]) + assert af.check(doubled, DIAGONAL) + assert af.get(exponential, DIAGONAL) == FactState.UNKNOWN + + +def test_nested_scan_forwards_property(): + """A scan whose body runs an inner scan -- the delegate recurses into the + inner Scan while still inferring the outer one.""" + x = pt.tensor3("seq") + seq = assume(x, diagonal=True) + + def body(s): + inner = scan( + lambda m: 2.0 * m, non_sequences=[s], n_steps=1, return_updates=False + ) + return inner[-1] + + out = scan(body, sequences=[seq], return_updates=False) + _, af = make_fgraph(out, inputs=[x]) + assert af.check(out, DIAGONAL) + + +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="assumption push-in and the inner-graph solve specialization are fast_run rewrites", +) +def test_outer_assumption_pushed_into_scan_inner_graph(): + """`push_assumptions_into_scan` re-asserts a sequence's assumption on the + inner input, so a rewrite of the inner graph sees it: ``inv(X) @ y`` of a + positive-definite ``X`` specializes to a Cholesky solve in the loop body.""" + Xs = pt.tensor("Xs", shape=(4, 3, 3)) + ys = pt.tensor("ys", shape=(4, 3, 3)) + out = scan( + lambda Xt, yt: pt.linalg.inv(Xt) @ yt, + sequences=[assume(Xs, positive_definite=True), ys], + return_updates=False, + ) + fn = function([Xs, ys], out) + [scan_node] = [n for n in fn.maker.fgraph.toposort() if isinstance(n.op, Scan)] + inner_ops = {type(n.op).__name__ for n in scan_node.op.fgraph.toposort()} + assert "MatrixInverse" not in inner_ops + assert "CholeskySolve" in inner_ops + + +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="assumption push-in and the inner-graph solve specialization are fast_run rewrites", +) +def test_non_sequence_assumption_pushed_into_scan_inner_graph(): + """Mirror of the sequence test but for a non-sequence: a positive-definite + ``X`` passed via ``non_sequences`` is pushed into the inner graph too, so the + per-step ``inv(X) @ y_t`` still specializes to a Cholesky solve.""" + X = pt.matrix("X", shape=(3, 3)) + ys = pt.tensor("ys", shape=(4, 3, 3)) + out = scan( + lambda yt, X: pt.linalg.inv(X) @ yt, + sequences=[ys], + non_sequences=[assume(X, positive_definite=True)], + return_updates=False, + ) + fn = function([X, ys], out) + [scan_node] = [n for n in fn.maker.fgraph.toposort() if isinstance(n.op, Scan)] + inner_ops = {type(n.op).__name__ for n in scan_node.op.fgraph.toposort()} + assert "MatrixInverse" not in inner_ops + assert "CholeskySolve" in inner_ops + + +def test_scan_grad_compiles_with_recurrence_assumption(): + """Backward AD flips a sit-sot forward output into a sequence input of the + backward scan. With ``diagonal=True`` on the recurrence's init the + gradient must still compile and stay numerically correct.""" + import numpy as np + + init_raw = pt.matrix("init", shape=(3, 3)) + init = assume(init_raw, diagonal=True) + out = scan( + lambda prev: 2.0 * prev, + outputs_info=[init], + n_steps=4, + return_updates=False, + ) + g = pt.grad(out[-1].sum(), init_raw) + fn = function([init_raw], g) + np.testing.assert_allclose(fn(np.eye(3)), 16.0 * np.ones((3, 3)))