From 0299fa6b5429e126b465edc64a003dfb3e872d09 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 11 May 2026 13:57:04 +0200 Subject: [PATCH 1/2] Replace unification-based PatternNodeRewriter with specialized matcher Drops logical-unification, kanren, etuples and cons as required deps; they move to the `pytensor[kanren]` extra (still pulled in by `pytensor[tests]`). Importing pytensor.graph.rewriting.kanren registers PatternVar via Var.register so external isinstance(x, unification.Var) / isvar(x) still hold. 10-20x faster match in microbenchmarks. --- doc/extending/graph_rewriting.rst | 229 +--------- doc/extending/index.rst | 1 + doc/extending/unification_kanren.rst | 250 +++++++++++ pyproject.toml | 13 +- pytensor/graph/rewriting/basic.py | 60 ++- pytensor/graph/rewriting/kanren.py | 244 ++++++++++- pytensor/graph/rewriting/unify.py | 552 +++++++++++-------------- tests/benchmarks/test_pattern_match.py | 91 ++++ tests/graph/rewriting/test_kanren.py | 31 +- tests/graph/rewriting/test_unify.py | 445 +++++++------------- 10 files changed, 1025 insertions(+), 891 deletions(-) create mode 100644 doc/extending/unification_kanren.rst create mode 100644 tests/benchmarks/test_pattern_match.py diff --git a/doc/extending/graph_rewriting.rst b/doc/extending/graph_rewriting.rst index 0bb4c9fa7f..0ab6c8d477 100644 --- a/doc/extending/graph_rewriting.rst +++ b/doc/extending/graph_rewriting.rst @@ -324,229 +324,14 @@ or :class:`PatternNodeRewriter`, it is highly recommended to use them. .. _unification: -Unification and reification -=========================== - -The :class:`PatternNodeRewriter` class uses `unification and reification -`_ to implement a -more succinct and reusable form of "pattern matching and replacement". -In general, *use of the unification and reification tools is preferable when -a rewrite's matching and replacement are non-trivial*, so we will briefly explain -them in the following. - -PyTensor's unification and reification tools are provided by the -`logical-unification `_ package. -The basic tools are :func:`unify`, :func:`reify`, and :class:`var`. The class :class:`var` -construct *logic variables*, which represent the elements to be unified/matched, :func:`unify` -performs the "matching", and :func:`reify` performs the "replacements". - -See :mod:`unification`'s documentation for an introduction to unification and reification. - -In order to use :func:`unify` and :func:`reify` with PyTensor graphs, we need an intermediate -structure that will allow us to represent PyTensor graphs that contain :class:`var`\s, because -PyTensor :class:`Op`\s and :class:`Apply` nodes will not accept these foreign objects as inputs. - -:class:`PatternNodeRewriter` uses Python ``tuple``\s to effectively represent :class:`Apply` nodes and -``str``\s to represent logic variables (i.e. :class:`var`\s in the :mod:`unification` library). -Behind the scenes, these ``tuple``\s are converted to a ``tuple`` subclass called :class:`ExpressionTuple`\s, -which behave just like normal ``tuple``\s except for some special caching features that allow for easy -evaluation and caching. These :class:`ExpressionTuple`\s are provided by the -`etuples `_ library. - -Here is an illustration of all the above components used together: - ->>> from unification import unify, reify, var ->>> from etuples import etuple ->>> y_lv = var() # Create a logic variable ->>> y_lv -~_1 ->>> s = unify(add(x, y), etuple(add, x, y_lv)) ->>> s -{~_1: y} - -In this example, :func:`unify` matched the PyTensor graph in the first argument with the "pattern" -given by the :func:`etuple` in the second. The result is a ``dict`` mapping logic variables to -the objects to which they were successfully unified. When a :func:`unify` doesn't succeed, it will -return ``False``. - -:func:`reify` uses ``dict``\s like the kind produced by :func:`unify` to replace -logic variables within structures: - ->>> res = reify(etuple(add, y_lv, y_lv), s) ->>> res -e(, y, y) - -Since :class:`ExpressionTuple`\s can be evaluated, we can produce a complete PyTensor graph from these -results as follows: - ->>> res.evaled_obj -add.0 ->>> pytensor.dprint(res.evaled_obj) -add [id A] '' - |y [id B] - |y [id B] - - -Because :class:`ExpressionTuple`\s effectively model `S-expressions -`_, they can be used with the `cons -`_ package to unify and reify -graphs structurally. - -Let's say we want to match graphs that use the :class:`add`\ :class:`Op` but could have a -varying number of arguments: - ->>> from cons import cons ->>> op_lv = var() ->>> args_lv = var() ->>> s = unify(cons(op_lv, args_lv), add(x, y)) ->>> s -{~_2: , ~_3: e(x, y)} ->>> s = unify(cons(op_lv, args_lv), add(x, y, z)) ->>> s -{~_2: , ~_3: e(x, y, z)} - -From here, we can check ``s[op_lv] == add`` to confirm that we have the correct :class:`Op` and -proceed with our rewrite. - ->>> res = reify(cons(mul, args_lv), s) ->>> res -e(, x, y, z) ->>> pytensor.dprint(res.evaled_obj) -mul [id A] '' - |x [id B] - |y [id C] - |z [id D] - - -.. _miniKanren_rewrites: - -miniKanren -========== - -Given that unification and reification are fully implemented for PyTensor objects via the :mod:`unificiation` package, -the `kanren `_ package can be used with PyTensor graphs, as well. -:mod:`kanren` implements the `miniKanren `_ domain-specific language for relational programming. - -Refer to the links above for a proper introduction to miniKanren, but suffice it to say that -miniKanren orchestrates the unification and reification operations described in :ref:`unification`, and -it does so in the context of relational operators (e.g. equations like :math:`x + x = 2 x`). -This means that a relation that--say--represents :math:`x + x = 2 x` can be -utilized in both directions. - -Currently, the node rewriter :class:`KanrenRelationSub` provides a means of -turning :mod:`kanren` relations into :class:`NodeRewriter`\s; however, -:mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so -:class:`KanrenRelationSub` is not necessary. - -The following is an example that distributes dot products across additions. - -.. code:: - - import pytensor - import pytensor.tensor as pt - from pytensor.graph.rewriting.kanren import KanrenRelationSub - from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter - from pytensor.graph.rewriting.utils import rewrite_graph - from pytensor.tensor.math import _dot - from etuples import etuple - from kanren import conso, eq, fact, heado, tailo - from kanren.assoccomm import assoc_flatten, associative - from kanren.core import lall - from kanren.graph import mapo - from unification import vars as lvars - - - # Make the graph pretty printing results a little more readable - pytensor.pprint.assign( - _dot, pytensor.printing.OperatorPrinter("@", -1, "left") - ) - - # Tell `kanren` that `add` is associative - fact(associative, pt.add) - - - def dot_distributeo(in_lv, out_lv): - """A `kanren` goal constructor relation for the relation ``A.dot(a + b ...) == A.dot(a) + A.dot(b) ...``.""" - A_lv, add_term_lv, add_cdr_lv, dot_cdr_lv, add_flat_lv = lvars(5) - - return lall( - # Make sure the input is a `_dot` - eq(in_lv, etuple(_dot, A_lv, add_term_lv)), - # Make sure the term being `_dot`ed is an `add` - heado(pt.add, add_term_lv), - # Flatten the associative pairings of `add` operations - assoc_flatten(add_term_lv, add_flat_lv), - # Get the flattened `add` arguments - tailo(add_cdr_lv, add_flat_lv), - # Add all the `_dot`ed arguments and set the output - conso(pt.add, dot_cdr_lv, out_lv), - # Apply the `_dot` to all the flattened `add` arguments - mapo(lambda x, y: conso(_dot, etuple(A_lv, x), y), add_cdr_lv, dot_cdr_lv), - ) - - - dot_distribute_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(dot_distributeo)], max_use_ratio=10) - - -Below, we apply `dot_distribute_rewrite` to a few example graphs. First we create simple test graph: - ->>> x_at = pt.vector("x") ->>> y_at = pt.vector("y") ->>> A_at = pt.matrix("A") ->>> test_at = A_pt.dot(x_at + y_at) ->>> print(pytensor.pprint(test_at)) -(A @ (x + y)) - -Next we apply the rewrite to the graph: - ->>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) ->>> print(pytensor.pprint(res)) -((A @ x) + (A @ y)) - -We see that the dot product has been distributed, as desired. Now, let's try a -few more test cases: - ->>> z_at = pt.vector("z") ->>> w_at = pt.vector("w") ->>> test_at = A_pt.dot((x_at + y_at) + (z_at + w_at)) ->>> print(pytensor.pprint(test_at)) -(A @ ((x + y) + (z + w))) ->>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) ->>> print(pytensor.pprint(res)) -(((A @ x) + (A @ y)) + ((A @ z) + (A @ w))) - ->>> B_at = pt.matrix("B") ->>> w_at = pt.vector("w") ->>> test_at = A_pt.dot(x_at + (y_at + B_pt.dot(z_at + w_at))) ->>> print(pytensor.pprint(test_at)) -(A @ (x + (y + ((B @ z) + (B @ w))))) ->>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) ->>> print(pytensor.pprint(res)) -((A @ x) + ((A @ y) + ((A @ (B @ z)) + (A @ (B @ w))))) - - -This example demonstrates how non-trivial matching and replacement logic can -be neatly expressed in miniKanren's DSL, but it doesn't quite demonstrate miniKanren's -relational properties. - -To do that, we will create another :class:`Rewriter` that simply reverses the arguments -to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``: - ->>> dot_gather_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10) ->>> rev_res = rewrite_graph(res, include=[], custom_rewrite=dot_gather_rewrite, clone=False) ->>> print(pytensor.pprint(rev_res)) -(A @ (x + (y + (B @ (z + w))))) - -As we can see, the :mod:`kanren` relation works both ways, just like the underlying -mathematical relation does. - -miniKanren relations can be used to explore rewrites of graphs in sophisticated -ways. It also provides a framework that more directly maps to the mathematical -identities that drive graph rewrites. For some simple examples of relational graph rewriting -in :mod:`kanren` see `here `_. For a -high-level overview of miniKanren's use as a tool for symbolic computation see -`"miniKanren as a Tool for Symbolic Computation in Python" `_. +Unification, reification and miniKanren +======================================= +For non-trivial pattern matching and replacement, PyTensor optionally exposes +:func:`unify` / :func:`reify` over :mod:`logical-unification` and miniKanren +relations via :class:`KanrenRelationSub`. These rely on a set of optional +packages (:mod:`logical-unification`, :mod:`kanren`, :mod:`etuples`, +:mod:`cons`) and are documented separately in :ref:`unification_kanren`. .. _optdb: diff --git a/doc/extending/index.rst b/doc/extending/index.rst index 3f9e40ae17..cf9657d52f 100644 --- a/doc/extending/index.rst +++ b/doc/extending/index.rst @@ -34,6 +34,7 @@ with PyTensor itself. graphstructures graph_rewriting + unification_kanren op creating_an_op creating_a_c_op diff --git a/doc/extending/unification_kanren.rst b/doc/extending/unification_kanren.rst new file mode 100644 index 0000000000..91e8575a8f --- /dev/null +++ b/doc/extending/unification_kanren.rst @@ -0,0 +1,250 @@ +.. _unification_kanren: + +Unification, reification and miniKanren (optional) +================================================== + +.. note:: + + The :mod:`logical-unification`, :mod:`kanren`, :mod:`etuples` and :mod:`cons` + packages are **optional** dependencies. PyTensor's built-in + :class:`PatternNodeRewriter` ships its own specialized matcher and does not + require any of them. Install the extra explicitly to use the tools described + on this page:: + + pip install pytensor[kanren] + + or:: + + pip install logical-unification kanren etuples cons + + Importing :mod:`pytensor.graph.rewriting.kanren` registers the dispatchers + that let :func:`unification.unify` / :func:`unification.reify` and miniKanren + relations walk PyTensor :class:`Apply` nodes, :class:`Op`\s and + :class:`Type`\s. It also registers :class:`PatternVar` with the + :class:`unification.Var` ABC, so ``isinstance(x, unification.Var)`` and + :func:`isvar(x)` keep returning ``True`` for PyTensor pattern variables. + Make sure that import happens once before calling :func:`unify`, + :func:`reify` or :func:`kanren.run` on PyTensor graphs:: + + import pytensor.graph.rewriting.kanren # noqa: F401 -- registers dispatchers + +.. _unification: + +Unification and reification +--------------------------- + +`Unification and reification +`_ implement a +more succinct and reusable form of "pattern matching and replacement". +*Use of the unification and reification tools is preferable when +a rewrite's matching and replacement are non-trivial*, so we will briefly explain +them in the following. + +PyTensor's unification and reification tools are provided by the +`logical-unification `_ package. +The basic tools are :func:`unify`, :func:`reify`, and :class:`var`. The class :class:`var` +construct *logic variables*, which represent the elements to be unified/matched, :func:`unify` +performs the "matching", and :func:`reify` performs the "replacements". + +See :mod:`unification`'s documentation for an introduction to unification and reification. + +In order to use :func:`unify` and :func:`reify` with PyTensor graphs, we need an intermediate +structure that will allow us to represent PyTensor graphs that contain :class:`var`\s, because +PyTensor :class:`Op`\s and :class:`Apply` nodes will not accept these foreign objects as inputs. +The `etuples `_ library provides the +:class:`ExpressionTuple` (tuple-like, with caching for evaluation) that fills this role. + +Here is an illustration of all the above components used together: + +>>> import pytensor.graph.rewriting.kanren # noqa: F401 -- registers dispatchers +>>> from unification import unify, reify, var +>>> from etuples import etuple +>>> y_lv = var() # Create a logic variable +>>> y_lv +~_1 +>>> s = unify(add(x, y), etuple(add, x, y_lv)) +>>> s +{~_1: y} + +In this example, :func:`unify` matched the PyTensor graph in the first argument with the "pattern" +given by the :func:`etuple` in the second. The result is a ``dict`` mapping logic variables to +the objects to which they were successfully unified. When a :func:`unify` doesn't succeed, it will +return ``False``. + +:func:`reify` uses ``dict``\s like the kind produced by :func:`unify` to replace +logic variables within structures: + +>>> res = reify(etuple(add, y_lv, y_lv), s) +>>> res +e(, y, y) + +Since :class:`ExpressionTuple`\s can be evaluated, we can produce a complete PyTensor graph from these +results as follows: + +>>> res.evaled_obj +add.0 +>>> pytensor.dprint(res.evaled_obj) +add [id A] '' + |y [id B] + |y [id B] + + +Because :class:`ExpressionTuple`\s effectively model `S-expressions +`_, they can be used with the `cons +`_ package to unify and reify +graphs structurally. + +Let's say we want to match graphs that use the :class:`add`\ :class:`Op` but could have a +varying number of arguments: + +>>> from cons import cons +>>> op_lv = var() +>>> args_lv = var() +>>> s = unify(cons(op_lv, args_lv), add(x, y)) +>>> s +{~_2: , ~_3: e(x, y)} +>>> s = unify(cons(op_lv, args_lv), add(x, y, z)) +>>> s +{~_2: , ~_3: e(x, y, z)} + +From here, we can check ``s[op_lv] == add`` to confirm that we have the correct :class:`Op` and +proceed with our rewrite. + +>>> res = reify(cons(mul, args_lv), s) +>>> res +e(, x, y, z) +>>> pytensor.dprint(res.evaled_obj) +mul [id A] '' + |x [id B] + |y [id C] + |z [id D] + + +.. _miniKanren_rewrites: + +miniKanren +---------- + +Given that unification and reification are fully implemented for PyTensor objects via the :mod:`unification` package, +the `kanren `_ package can be used with PyTensor graphs, as well. +:mod:`kanren` implements the `miniKanren `_ domain-specific language for relational programming. + +Refer to the links above for a proper introduction to miniKanren, but suffice it to say that +miniKanren orchestrates the unification and reification operations described above, and +it does so in the context of relational operators (e.g. equations like :math:`x + x = 2 x`). +This means that a relation that--say--represents :math:`x + x = 2 x` can be +utilized in both directions. + +Currently, the node rewriter :class:`KanrenRelationSub` provides a means of +turning :mod:`kanren` relations into :class:`NodeRewriter`\s; however, +:mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so +:class:`KanrenRelationSub` is not necessary. + +The following is an example that distributes dot products across additions. + +.. code:: + + import pytensor + import pytensor.tensor as pt + from pytensor.graph.rewriting.kanren import KanrenRelationSub + from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter + from pytensor.graph.rewriting.utils import rewrite_graph + from pytensor.tensor.math import _dot + from etuples import etuple + from kanren import conso, eq, fact, heado, tailo + from kanren.assoccomm import assoc_flatten, associative + from kanren.core import lall + from kanren.graph import mapo + from unification import vars as lvars + + + # Make the graph pretty printing results a little more readable + pytensor.pprint.assign( + _dot, pytensor.printing.OperatorPrinter("@", -1, "left") + ) + + # Tell `kanren` that `add` is associative + fact(associative, pt.add) + + + def dot_distributeo(in_lv, out_lv): + """A `kanren` goal constructor relation for the relation ``A.dot(a + b ...) == A.dot(a) + A.dot(b) ...``.""" + A_lv, add_term_lv, add_cdr_lv, dot_cdr_lv, add_flat_lv = lvars(5) + + return lall( + # Make sure the input is a `_dot` + eq(in_lv, etuple(_dot, A_lv, add_term_lv)), + # Make sure the term being `_dot`ed is an `add` + heado(pt.add, add_term_lv), + # Flatten the associative pairings of `add` operations + assoc_flatten(add_term_lv, add_flat_lv), + # Get the flattened `add` arguments + tailo(add_cdr_lv, add_flat_lv), + # Add all the `_dot`ed arguments and set the output + conso(pt.add, dot_cdr_lv, out_lv), + # Apply the `_dot` to all the flattened `add` arguments + mapo(lambda x, y: conso(_dot, etuple(A_lv, x), y), add_cdr_lv, dot_cdr_lv), + ) + + + dot_distribute_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(dot_distributeo)], max_use_ratio=10) + + +Below, we apply `dot_distribute_rewrite` to a few example graphs. First we create simple test graph: + +>>> x_at = pt.vector("x") +>>> y_at = pt.vector("y") +>>> A_at = pt.matrix("A") +>>> test_at = A_pt.dot(x_at + y_at) +>>> print(pytensor.pprint(test_at)) +(A @ (x + y)) + +Next we apply the rewrite to the graph: + +>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) +>>> print(pytensor.pprint(res)) +((A @ x) + (A @ y)) + +We see that the dot product has been distributed, as desired. Now, let's try a +few more test cases: + +>>> z_at = pt.vector("z") +>>> w_at = pt.vector("w") +>>> test_at = A_pt.dot((x_at + y_at) + (z_at + w_at)) +>>> print(pytensor.pprint(test_at)) +(A @ ((x + y) + (z + w))) +>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) +>>> print(pytensor.pprint(res)) +(((A @ x) + (A @ y)) + ((A @ z) + (A @ w))) + +>>> B_at = pt.matrix("B") +>>> w_at = pt.vector("w") +>>> test_at = A_pt.dot(x_at + (y_at + B_pt.dot(z_at + w_at))) +>>> print(pytensor.pprint(test_at)) +(A @ (x + (y + ((B @ z) + (B @ w))))) +>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) +>>> print(pytensor.pprint(res)) +((A @ x) + ((A @ y) + ((A @ (B @ z)) + (A @ (B @ w))))) + + +This example demonstrates how non-trivial matching and replacement logic can +be neatly expressed in miniKanren's DSL, but it doesn't quite demonstrate miniKanren's +relational properties. + +To do that, we will create another :class:`Rewriter` that simply reverses the arguments +to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``: + +>>> dot_gather_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10) +>>> rev_res = rewrite_graph(res, include=[], custom_rewrite=dot_gather_rewrite, clone=False) +>>> print(pytensor.pprint(rev_res)) +(A @ (x + (y + (B @ (z + w))))) + +As we can see, the :mod:`kanren` relation works both ways, just like the underlying +mathematical relation does. + +miniKanren relations can be used to explore rewrites of graphs in sophisticated +ways. It also provides a framework that more directly maps to the mathematical +identities that drive graph rewrites. For some simple examples of relational graph rewriting +in :mod:`kanren` see `here `_. For a +high-level overview of miniKanren's use as a tool for symbolic computation see +`"miniKanren as a Tool for Symbolic Computation in Python" `_. diff --git a/pyproject.toml b/pyproject.toml index 1edd5034ec..3d3d2170c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,10 +52,6 @@ dependencies = [ "numpy>=2.0", "numba>=0.58,<=0.65.1", "filelock>=3.15", - "etuples", - "logical-unification", - "miniKanren", - "cons", ] [project.urls] @@ -78,10 +74,17 @@ tests = [ "pytest-benchmark", "pytest-mock", "pytest-sphinx", + "pytensor[kanren]", ] rtd = ["sphinx>=5.1.0,<6", "pygments", "pydot"] jax = ["jax", "jaxlib"] numba = ["numba>=0.58,<=0.65.1", "llvmlite"] +kanren = [ + "etuples", + "logical-unification", + "miniKanren", + "cons", +] [tool.setuptools.packages.find] include = ["pytensor*"] @@ -119,7 +122,7 @@ versionfile_build = "pytensor/_version.py" tag_prefix = "rel-" [tool.pytest.ini_options] -addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/ipython.py" +addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/ipython.py --ignore=pytensor/graph/rewriting/kanren.py" testpaths = ["pytensor/", "tests/"] xfail_strict = true filterwarnings =[ diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index e39465f416..df44c1587c 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -26,12 +26,18 @@ from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op -from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars +from pytensor.graph.rewriting.unify import ( + OpPattern, + PatternNode, + PatternVar, + convert_strs_to_vars, + match_pattern, + reify_pattern, +) from pytensor.graph.traversal import ( apply_ancestors, applys_between, toposort, - vars_between, ) from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.misc.ordered_set import OrderedSet @@ -1551,9 +1557,14 @@ def __init__( frequent `Op`, which will prevent the rewrite from being tried as often. """ - var_map: dict[str, Var] = {} + var_map: dict[str, PatternVar] = {} self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map) self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) + if not isinstance(self.in_pattern, PatternNode): + raise TypeError( + "The in_pattern must be a tuple starting with an Op or OpPattern; " + f"got {in_pattern!r} of type {type(in_pattern)}" + ) self.values_eq_approx = values_eq_approx self.allow_cast = allow_cast self.allow_multiple_clients = allow_multiple_clients @@ -1565,27 +1576,15 @@ def __init__( raise ValueError("Custom `tracks` requires `get_nodes` to be provided.") self._tracks = tracks else: - if isinstance(in_pattern, list | tuple): - op = self.in_pattern[0] - elif isinstance(in_pattern, dict): - op = self.in_pattern["pattern"][0] - else: - raise TypeError( - f"The in_pattern must be a sequence or a dict, but got {in_pattern} of type {type(in_pattern)}" - ) + op = self.in_pattern.op_match if isinstance(op, Op): self._tracks = [op] - elif isinstance(op, type) and issubclass(op, Op): - raise ValueError( - f"The in_pattern starts with an Op class {op}, not an instance.\n" - "You can use pytensor.graph.unify.OpPattern instead if you want to match instances of a class." - ) elif isinstance(op, OpPattern): self._tracks = [op.op_type] else: raise ValueError( f"The in_pattern must start with a specific Op or an OpPattern instance. " - f"Got {op}, with type {type(op)}." + f"Got {op!r}, with type {type(op)}." ) def tracks(self): @@ -1597,9 +1596,6 @@ def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True): If it does, it constructs ``out_pattern`` and performs the replacement. """ - from etuples.core import ExpressionTuple - from unification import reify, unify - if get_nodes and self.get_nodes is not None: for real_node in self.get_nodes(fgraph, node): ret = self.transform(fgraph, real_node, get_nodes=False) @@ -1610,21 +1606,15 @@ def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True): # PatternNodeRewriter doesn't support replacing multi-output nodes return False - s = unify(self.in_pattern, node.out, {}) - - if s is False: + fgraph_clients = None if self.allow_multiple_clients else fgraph.clients + s = match_pattern( + self.in_pattern, + node, + fgraph_clients=fgraph_clients, + ) + if s is None: return False - if not self.allow_multiple_clients: - input_vars = set(s.values()) - clients = fgraph.clients - if any( - len(clients[v]) > 1 - for v in vars_between(input_vars, node.inputs) - if v not in input_vars - ): - return False - if callable(self.out_pattern): # token is the variable name used in the original pattern ret = self.out_pattern(fgraph, node, {k.token: v for k, v in s.items()}) @@ -1636,9 +1626,7 @@ def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True): f"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}." ) else: - ret = reify(self.out_pattern, s) - if isinstance(ret, ExpressionTuple): - ret = ret.evaled_obj + ret = reify_pattern(self.out_pattern, s) if self.values_eq_approx: ret.tag.values_eq_approx = self.values_eq_approx diff --git a/pytensor/graph/rewriting/kanren.py b/pytensor/graph/rewriting/kanren.py index 8b45d85da8..f4a5335873 100644 --- a/pytensor/graph/rewriting/kanren.py +++ b/pytensor/graph/rewriting/kanren.py @@ -1,13 +1,243 @@ -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterator, Mapping -from etuples.core import ExpressionTuple -from kanren import run -from unification import var -from unification.variable import Var -from pytensor.graph.basic import Apply, Variable +try: + import numpy as np + from cons.core import ConsError, _car, _cdr + from etuples import apply, etuple, etuplize + from etuples.core import ExpressionTuple + from kanren import run + from unification import var + from unification.core import _unify, assoc + from unification.utils import transitive_get as walk + from unification.variable import Var, isvar +except ImportError as _err: + raise ImportError( + "pytensor.graph.rewriting.kanren requires the optional packages " + "'logical-unification', 'kanren', 'etuples' and 'cons'. " + "Install them with `pip install logical-unification kanren etuples cons`." + ) from _err + +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import NodeRewriter -from pytensor.graph.rewriting.unify import eval_if_etuple +from pytensor.graph.rewriting.unify import ( + ConstrainedVar, + OpPattern, + PatternVar, +) +from pytensor.graph.type import Type + + +Var.register(PatternVar) + + +def eval_if_etuple(x): + if isinstance(x, ExpressionTuple): + return x.evaled_obj + return x + + +def car_Variable(x): + if x.owner: + return x.owner.op + else: + raise ConsError("Not a cons pair.") + + +_car.add((Variable,), car_Variable) + + +def cdr_Variable(x): + if x.owner: + x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) + else: + raise ConsError("Not a cons pair.") + + return x_e[1:] + + +_cdr.add((Variable,), cdr_Variable) + + +def car_Op(x): + if hasattr(x, "__props__"): + return type(x) + + raise ConsError("Not a cons pair.") + + +_car.add((Op,), car_Op) + + +def cdr_Op(x): + if not hasattr(x, "__props__"): + raise ConsError("Not a cons pair.") + + x_e = etuple( + _car(x), + *[getattr(x, p) for p in getattr(x, "__props__", ())], + evaled_obj=x, + ) + return x_e[1:] + + +_cdr.add((Op,), cdr_Op) + + +def car_Type(x): + return type(x) + + +_car.add((Type,), car_Type) + + +def cdr_Type(x): + x_e = etuple( + _car(x), *[getattr(x, p) for p in getattr(x, "__props__", ())], evaled_obj=x + ) + return x_e[1:] + + +_cdr.add((Type,), cdr_Type) + + +def apply_Op_ExpressionTuple(op, etuple_arg): + res = op.make_node(*etuple_arg) + + try: + return res.default_output() + except ValueError: + return res.outputs + + +apply.add((Op, ExpressionTuple), apply_Op_ExpressionTuple) + + +def _unify_etuplize_first_arg(u, v, s): + try: + u_et = etuplize(u, shallow=True) + yield _unify(u_et, v, s) + except TypeError: + yield False + return + + +_unify.add((Op, ExpressionTuple, Mapping), _unify_etuplize_first_arg) +_unify.add( + (ExpressionTuple, Op, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s) +) + +_unify.add((Type, ExpressionTuple, Mapping), _unify_etuplize_first_arg) +_unify.add( + (ExpressionTuple, Type, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s) +) + + +def _unify_Variable_Variable(u, v, s): + # Avoid converting to `etuple`s, when possible + if u == v: + yield s + return + + if u.owner is None and v.owner is None: + yield False + return + + yield _unify( + etuplize(u, shallow=True) if u.owner else u, + etuplize(v, shallow=True) if v.owner else v, + s, + ) + + +_unify.add((Variable, Variable, Mapping), _unify_Variable_Variable) + + +def _unify_Constant_Constant(u, v, s): + # XXX: This ignores shape and type differences. It's only implemented this + # way for backward compatibility + if np.array_equiv(u.data, v.data): + yield s + else: + yield False + + +_unify.add((Constant, Constant, Mapping), _unify_Constant_Constant) + + +def _unify_Variable_ExpressionTuple(u, v, s): + # `Constant`s are "atomic" + if u.owner is None: + yield False + return + + yield _unify(etuplize(u, shallow=True), v, s) + + +_unify.add( + (Variable, ExpressionTuple, Mapping), + _unify_Variable_ExpressionTuple, +) +_unify.add( + (ExpressionTuple, Variable, Mapping), + lambda u, v, s: _unify_Variable_ExpressionTuple(v, u, s), +) + + +@_unify.register(ConstrainedVar, (ConstrainedVar, Var, object), Mapping) +def _unify_ConstrainedVar_object(u, v, s): + u_w = walk(u, s) + + if isvar(v): + v_w = walk(v, s) + else: + v_w = v + + if u_w == v_w: + yield s + elif isvar(u_w): + if ( + not isvar(v_w) + and isinstance(u_w, ConstrainedVar) + and u_w.constraint is not None + and not u_w.constraint(eval_if_etuple(v_w)) + ): + yield False + return + yield assoc(s, u_w, v_w) + elif isvar(v_w): + if ( + not isvar(u_w) + and isinstance(v_w, ConstrainedVar) + and v_w.constraint is not None + and not v_w.constraint(eval_if_etuple(u_w)) + ): + yield False + return + yield assoc(s, v_w, u_w) + else: + yield _unify(u_w, v_w, s) + + +_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object) + + +def _unify_parametrized_op(v, u, s): + if not isinstance(v, u.op_type): + yield False + return + for parameter_key, parameter_pattern in u.parameters: + parameter_value = getattr(v, parameter_key) + new_s = yield _unify(parameter_value, parameter_pattern, s) + if new_s is False: + yield False + return + s = new_s + yield s + + +_unify.add((Op, OpPattern, Mapping), _unify_parametrized_op) class KanrenRelationSub(NodeRewriter): diff --git a/pytensor/graph/rewriting/unify.py b/pytensor/graph/rewriting/unify.py index 195dd18564..c6f8d96da6 100644 --- a/pytensor/graph/rewriting/unify.py +++ b/pytensor/graph/rewriting/unify.py @@ -1,260 +1,16 @@ -""" -If you have two expressions containing unification variables, these expressions -can be "unified" if there exists an assignment to all unification variables -such that the two expressions are equal. - -For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9, -yielding [5, 5, 9]. -[5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A -that satisfies the constraints. That's useful for pattern matching. - -""" - -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from numbers import Number from types import UnionType from typing import Any, TypeAlias import numpy as np -from cons.core import ConsError, _car, _cdr -from etuples import apply, etuple, etuplize -from etuples.core import ExpressionTuple -from unification.core import _unify, assoc -from unification.utils import transitive_get as walk -from unification.variable import Var, isvar, var from pytensor.graph.basic import Constant, Variable from pytensor.graph.op import Op -from pytensor.graph.type import Type - - -def eval_if_etuple(x): - if isinstance(x, ExpressionTuple): - return x.evaled_obj - return x - - -class ConstrainedVar(Var): - """A logical variable with a constraint. - - These will unify with other `Var`s regardless of the constraints. - """ - - __slots__ = ("constraint",) - - def __new__(cls, constraint, token=None, prefix=""): - if token is None: - token = f"{prefix}_{Var._id}" - Var._id += 1 - - key = (token, constraint) - obj = cls._refs.get(key, None) - - if obj is None: - obj = object.__new__(cls) - obj.token = token - obj.constraint = constraint - cls._refs[key] = obj - - return obj - - def __eq__(self, other): - if type(self) is type(other): - return self.token is other.token and self.constraint == other.constraint - return NotImplemented - - def __hash__(self): - return hash((type(self), self.token, self.constraint)) - - def __str__(self): - return f"~{self.token} [{self.constraint}]" - - def __repr__(self): - return f"{type(self).__name__}({self.constraint!r}, {self.token})" - - -def car_Variable(x): - if x.owner: - return x.owner.op - else: - raise ConsError("Not a cons pair.") - - -_car.add((Variable,), car_Variable) - - -def cdr_Variable(x): - if x.owner: - x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) - else: - raise ConsError("Not a cons pair.") - - return x_e[1:] - - -_cdr.add((Variable,), cdr_Variable) - - -def car_Op(x): - if hasattr(x, "__props__"): - return type(x) - - raise ConsError("Not a cons pair.") - - -_car.add((Op,), car_Op) - - -def cdr_Op(x): - if not hasattr(x, "__props__"): - raise ConsError("Not a cons pair.") - - x_e = etuple( - _car(x), - *[getattr(x, p) for p in getattr(x, "__props__", ())], - evaled_obj=x, - ) - return x_e[1:] - - -_cdr.add((Op,), cdr_Op) - - -def car_Type(x): - return type(x) - - -_car.add((Type,), car_Type) - - -def cdr_Type(x): - x_e = etuple( - _car(x), *[getattr(x, p) for p in getattr(x, "__props__", ())], evaled_obj=x - ) - return x_e[1:] - - -_cdr.add((Type,), cdr_Type) - - -def apply_Op_ExpressionTuple(op, etuple_arg): - res = op.make_node(*etuple_arg) - - try: - return res.default_output() - except ValueError: - return res.outputs - - -apply.add((Op, ExpressionTuple), apply_Op_ExpressionTuple) - - -def _unify_etuplize_first_arg(u, v, s): - try: - u_et = etuplize(u, shallow=True) - yield _unify(u_et, v, s) - except TypeError: - yield False - return - - -_unify.add((Op, ExpressionTuple, Mapping), _unify_etuplize_first_arg) -_unify.add( - (ExpressionTuple, Op, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s) -) - -_unify.add((Type, ExpressionTuple, Mapping), _unify_etuplize_first_arg) -_unify.add( - (ExpressionTuple, Type, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s) -) - - -def _unify_Variable_Variable(u, v, s): - # Avoid converting to `etuple`s, when possible - if u == v: - yield s - return - - if u.owner is None and v.owner is None: - yield False - return - - yield _unify( - etuplize(u, shallow=True) if u.owner else u, - etuplize(v, shallow=True) if v.owner else v, - s, - ) - - -_unify.add((Variable, Variable, Mapping), _unify_Variable_Variable) - - -def _unify_Constant_Constant(u, v, s): - # XXX: This ignores shape and type differences. It's only implemented this - # way for backward compatibility - if np.array_equiv(u.data, v.data): - yield s - else: - yield False - - -_unify.add((Constant, Constant, Mapping), _unify_Constant_Constant) - - -def _unify_Variable_ExpressionTuple(u, v, s): - # `Constant`s are "atomic" - if u.owner is None: - yield False - return - - yield _unify(etuplize(u, shallow=True), v, s) - - -_unify.add( - (Variable, ExpressionTuple, Mapping), - _unify_Variable_ExpressionTuple, -) -_unify.add( - (ExpressionTuple, Variable, Mapping), - lambda u, v, s: _unify_Variable_ExpressionTuple(v, u, s), -) - -@_unify.register(ConstrainedVar, (ConstrainedVar, Var, object), Mapping) -def _unify_ConstrainedVar_object(u, v, s): - u_w = walk(u, s) - if isvar(v): - v_w = walk(v, s) - else: - v_w = v - - if u_w == v_w: - yield s - elif isvar(u_w): - if ( - not isvar(v_w) - and isinstance(u_w, ConstrainedVar) - and not u_w.constraint(eval_if_etuple(v_w)) - ): - yield False - return - yield assoc(s, u_w, v_w) - elif isvar(v_w): - if ( - not isvar(u_w) - and isinstance(v_w, ConstrainedVar) - and not v_w.constraint(eval_if_etuple(u_w)) - ): - yield False - return - yield assoc(s, v_w, u_w) - else: - yield _unify(u_w, v_w, s) - - -_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object) +OpPatternOpTypeType: TypeAlias = type[Op] | tuple[type[Op], ...] | UnionType @dataclass(frozen=True) @@ -262,9 +18,6 @@ class LiteralString: value: str -OpPatternOpTypeType: TypeAlias = type[Op] | tuple[type[Op], ...] | UnionType - - @dataclass(unsafe_hash=True) class OpPattern: """Class that can be unified with Op instances of a given type (or instance) and parameters. @@ -340,50 +93,10 @@ def output_fn(fgraph, node, s): out_pattern=(OpPattern(CAReduce, scalar_op="scalar_op", axis=None), "x"), ) - - OpPattern can also be used with `unification.unify` to match Ops with specific parameters. - This is used by PatternNodeRewriter but can also be used directly. - - .. testcode:: - - from unification import var, unify - from etuples import etuple - - import pytensor.tensor as pt - from pytensor.graph.rewriting.unify import OpPattern - from pytensor.tensor.blockwise import Blockwise - from pytensor.tensor.linalg.solvers.general import Solve - - A = var("A") - b = var("b") - pattern = etuple( - OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a="gen")), - A, - b, - ) - - A_pt = pt.tensor3("A") - b_pt = pt.tensor3("b") - out1 = pt.linalg.solve(A_pt, b_pt) - out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos") - - assert unify(pattern, out1) == {A: A_pt, b: b_pt} - assert unify(pattern, out2) is False - - assume_a = var("assume_a") - pattern = etuple( - OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a=assume_a)), - A, - b, - ) - assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"} - assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"} - - """ op_type: OpPatternOpTypeType - parameters: tuple[tuple[str, Any]] + parameters: tuple[tuple[str, Any], ...] def __init__( self, @@ -406,12 +119,12 @@ def __init__( self.op_type = op_type self.parameters = parameters # type: ignore[assignment] - def match_op(self, op: Op): + def match_op(self, op: Op) -> bool: if not isinstance(op, self.op_type): return False return self.match_parameters(op) - def match_parameters(self, op): + def match_parameters(self, op: Op) -> bool: # This is used by methods that already check the op_type is satisfied # Some methods may index on the op_type and know in advance the op is matched # Also recursive calls to OpPattern.match_parameters do the op check outside to exit early (see below) @@ -426,6 +139,8 @@ def match_parameters(self, op): # Skip if there are no parameters if param.parameters and not param.match_parameters(sub_op): return False + elif isinstance(param, PatternVar): + continue elif getattr(op, key) != param: return False return True @@ -434,27 +149,53 @@ def __str__(self): return f"OpPattern({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})" -def _unify_parametrized_op(v: Op, u: OpPattern, s: Mapping): - if not isinstance(v, u.op_type): - yield False - return - for parameter_key, parameter_pattern in u.parameters: - parameter_value = getattr(v, parameter_key) - new_s = yield _unify(parameter_value, parameter_pattern, s) - if new_s is False: - yield False - return - s = new_s - yield s +class PatternVar: + __slots__ = ("constraint", "token") + + def __init__(self, token: str, constraint: Callable[[Any], bool] | None = None): + self.token = token + self.constraint = constraint + def __repr__(self) -> str: + if self.constraint is not None: + return f"PatternVar({self.token!r}, constraint=...)" + return f"PatternVar({self.token!r})" -_unify.add((Op, OpPattern, Mapping), _unify_parametrized_op) + +Var: TypeAlias = PatternVar + + +class ConstrainedVar(PatternVar): + """A logical variable with a constraint.""" + + def __new__(cls, constraint, token: str | None = None, prefix: str = ""): + return object.__new__(cls) + + def __init__(self, constraint, token: str | None = None, prefix: str = ""): + if token is None: + token = f"{prefix}_constrained" + super().__init__(token=token, constraint=constraint) + + +class PatternNode: + __slots__ = ("inputs", "op_match") + + def __init__(self, op_match, inputs: tuple): + self.op_match = op_match + self.inputs = inputs + + def __repr__(self) -> str: + return f"PatternNode({self.op_match!r}, {self.inputs!r})" + + +PatternElement: TypeAlias = PatternVar | PatternNode | Variable | Any def convert_strs_to_vars( - x: tuple | str | dict, var_map: dict[str, Var] | None = None -) -> ExpressionTuple | Var: - r"""Convert tuples and strings to `etuple`\s and logic variables, respectively. + x: tuple | str | dict | OpPattern, + var_map: dict[str, PatternVar] | None = None, +): + r"""Convert tuples and strings to pattern trees and logic variables, respectively. Constrained logic variables are specified via `dict`s with the keys `"pattern"`, which specifies the logic variable as a string, and @@ -463,35 +204,202 @@ def convert_strs_to_vars( if var_map is None: var_map = {} - def _convert(y, op_prop=False): + def _convert(y, op_prop: bool = False): if isinstance(y, str): - v = var_map.get(y, var(y)) - var_map[y] = v + v = var_map.get(y) + if v is None: + v = PatternVar(token=y) + var_map[y] = v return v if isinstance(y, LiteralString): return y.value - elif isinstance(y, dict): + if isinstance(y, dict): pattern = y["pattern"] if not isinstance(pattern, str): raise TypeError( "Constraints can only be assigned to logic variables (i.e. strings)" ) constraint = y["constraint"] - v = var_map.get(pattern, ConstrainedVar(constraint, pattern)) - var_map[pattern] = v + v = var_map.get(pattern) + if v is None: + v = PatternVar(token=pattern, constraint=constraint) + var_map[pattern] = v + elif v.constraint is None: + v.constraint = constraint return v - elif isinstance(y, tuple): - return etuple(*(_convert(e, op_prop=op_prop) for e in y)) - elif isinstance(y, OpPattern): - return OpPattern( - y.op_type, - {k: _convert(v, op_prop=True) for k, v in y.parameters}, + if isinstance(y, OpPattern): + new_params = tuple((k, _convert(v, op_prop=True)) for k, v in y.parameters) + return OpPattern(y.op_type, new_params) + if isinstance(y, tuple): + head, *rest = y + head_converted = ( + _convert(head, op_prop=True) if isinstance(head, OpPattern) else head ) - elif (not op_prop) and isinstance(y, Number | np.ndarray): + if not isinstance(head_converted, Op | OpPattern): + raise TypeError( + "Pattern tuples must start with an Op instance or OpPattern; " + f"got {head!r} of type {type(head)}" + ) + children = tuple(_convert(e) for e in rest) + return PatternNode(head_converted, children) + if (not op_prop) and isinstance(y, Number | np.ndarray): # If we are converting an Op property, we don't want to convert numbers to PyTensor constants from pytensor.tensor import as_tensor_variable - return as_tensor_variable(y) + return as_tensor_variable(y) # type: ignore[arg-type] return y return _convert(x) + + +def match_pattern( + pattern: PatternNode, + node, + subs: dict[PatternVar, Any] | None = None, + *, + fgraph_clients=None, +): + if subs is None: + subs = {} + if _match_node(pattern, node, subs, fgraph_clients): + return subs + return None + + +def _match_node( + pattern: PatternNode, + node, + subs: dict[PatternVar, Any], + fgraph_clients, +) -> bool: + op_match = pattern.op_match + node_op = node.op + if isinstance(op_match, OpPattern): + if not isinstance(node_op, op_match.op_type): + return False + if not _match_op_parameters(op_match, node_op, subs): + return False + else: + if node_op != op_match: + return False + + if len(pattern.inputs) != len(node.inputs): + return False + + for sub_pat, sub_var in zip(pattern.inputs, node.inputs): + if not _match_element(sub_pat, sub_var, subs, fgraph_clients): + return False + return True + + +def _match_element( + pattern, + var, + subs: dict[PatternVar, Any], + fgraph_clients, +) -> bool: + if isinstance(pattern, PatternVar): + return _bind_var(pattern, var, subs) + + if isinstance(pattern, PatternNode): + if var.owner is None: + return False + if fgraph_clients is not None and len(fgraph_clients[var]) > 1: + return False + return _match_node(pattern, var.owner, subs, fgraph_clients) + + if isinstance(pattern, Variable): + if isinstance(pattern, Constant) and isinstance(var, Constant): + return np.array_equiv(pattern.data, var.data) + return pattern is var + + if isinstance(var, Constant): + try: + return bool(np.array_equiv(pattern, var.data)) + except Exception: + return False + return bool(pattern == var) + + +_MISSING = object() + + +def _bind_var(pat_var: PatternVar, value, subs: dict[PatternVar, Any]) -> bool: + existing = subs.get(pat_var, _MISSING) + if existing is not _MISSING: + return _values_equal(existing, value) + if pat_var.constraint is not None and not pat_var.constraint(value): + return False + subs[pat_var] = value + return True + + +def _values_equal(a, b) -> bool: + if a is b: + return True + if isinstance(a, Variable) and isinstance(b, Variable): + if isinstance(a, Constant) and isinstance(b, Constant): + return bool(np.array_equiv(a.data, b.data)) + return bool(a == b) + try: + return bool(a == b) + except (ValueError, TypeError): + return False + + +def _match_op_parameters( + op_pat: OpPattern, op: Op, subs: dict[PatternVar, Any] +) -> bool: + for key, param in op_pat.parameters: + op_val = getattr(op, key) + if isinstance(param, PatternVar): + if not _bind_var(param, op_val, subs): + return False + elif isinstance(param, OpPattern): + if not isinstance(op_val, param.op_type): + return False + if not _match_op_parameters(param, op_val, subs): + return False + else: + if op_val != param: + return False + return True + + +def reify_pattern(pattern, subs: Mapping[PatternVar, Any]): + if isinstance(pattern, PatternVar): + try: + return subs[pattern] + except KeyError: + raise ValueError( + f"Output pattern references unbound variable {pattern.token!r}" + ) + + if isinstance(pattern, PatternNode): + op = _reify_op_match(pattern.op_match, subs) + inputs = [reify_pattern(p, subs) for p in pattern.inputs] + return op.make_node(*inputs).default_output() + + if isinstance(pattern, OpPattern): + return _reify_op_match(pattern, subs) + + return pattern + + +def _reify_op_match(op_match, subs: Mapping[PatternVar, Any]): + if isinstance(op_match, OpPattern): + op_type = op_match.op_type + if not isinstance(op_type, type): + raise ValueError( + f"Cannot instantiate OpPattern with non-type op_type {op_type!r}" + ) + params = {} + for key, val in op_match.parameters: + if isinstance(val, PatternVar): + params[key] = subs[val] + elif isinstance(val, OpPattern): + params[key] = _reify_op_match(val, subs) + else: + params[key] = val + return op_type(**params) + return op_match diff --git a/tests/benchmarks/test_pattern_match.py b/tests/benchmarks/test_pattern_match.py new file mode 100644 index 0000000000..223d41821c --- /dev/null +++ b/tests/benchmarks/test_pattern_match.py @@ -0,0 +1,91 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import PatternNodeRewriter +from pytensor.graph.rewriting.unify import ( + convert_strs_to_vars, + match_pattern, +) +from pytensor.tensor import exp, log, mul, sub +from pytensor.tensor.math import erf + + +unification = pytest.importorskip("unification") +etuples = pytest.importorskip("etuples") +import pytensor.graph.rewriting.kanren # noqa: E402, F401 + + +def _old_pattern_from_tuple(pat_tuple): + from etuples import etuple + from unification import var + + from pytensor.graph.rewriting.unify import ConstrainedVar + + if isinstance(pat_tuple, str): + return var(pat_tuple) + if isinstance(pat_tuple, dict): + return ConstrainedVar(pat_tuple["constraint"], pat_tuple["pattern"]) + if isinstance(pat_tuple, tuple): + return etuple(*[_old_pattern_from_tuple(p) for p in pat_tuple]) + if isinstance(pat_tuple, int | float | np.ndarray): + return pt.as_tensor_variable(pat_tuple) + return pat_tuple + + +def _make_case(case_id): + x = pt.vector("x") + if case_id == "shallow_match": + return x, (log, (exp, "x")), log(exp(x)) + if case_id == "deep_match": + return x, (log, (exp, (log, (exp, "x")))), log(exp(log(exp(x)))) + if case_id == "repeated_var": + return x, (mul, "x", "x"), mul(x, x) + if case_id == "with_constant": + return x, (sub, 1.0, (erf, "x")), sub(pt.as_tensor_variable(1.0), erf(x)) + if case_id == "no_match": + return x, (log, (exp, "x")), log(x) + raise ValueError(case_id) + + +@pytest.mark.parametrize( + "case_id", + ["shallow_match", "deep_match", "repeated_var", "with_constant", "no_match"], +) +def test_match_pattern_benchmark(benchmark, case_id): + _x, pat_tuple, out = _make_case(case_id) + new_pat = convert_strs_to_vars(pat_tuple) + + def run(): + return match_pattern(new_pat, out.owner) + + benchmark(run) + + +@pytest.mark.parametrize( + "case_id", + ["shallow_match", "deep_match", "repeated_var", "with_constant", "no_match"], +) +def test_unification_unify_benchmark(benchmark, case_id): + _x, pat_tuple, out = _make_case(case_id) + old_pat = _old_pattern_from_tuple(pat_tuple) + from unification import unify + + def run(): + return unify(old_pat, out, {}) + + benchmark(run) + + +@pytest.mark.parametrize("case_id", ["shallow_match", "deep_match", "no_match"]) +def test_pattern_rewriter_transform_benchmark(benchmark, case_id): + x, pat_tuple, out = _make_case(case_id) + fg = FunctionGraph([x], [out], clone=False) + rw = PatternNodeRewriter(pat_tuple, "x", allow_multiple_clients=True) + node = out.owner + + def run(): + rw.transform(fg, node) + + benchmark(run) diff --git a/tests/graph/rewriting/test_kanren.py b/tests/graph/rewriting/test_kanren.py index 7cb66a4ba0..cbe7e034b5 100644 --- a/tests/graph/rewriting/test_kanren.py +++ b/tests/graph/rewriting/test_kanren.py @@ -2,19 +2,28 @@ import numpy as np import pytest + + +pytest.importorskip("kanren") +pytest.importorskip("unification") +pytest.importorskip("etuples") +pytest.importorskip("cons") + +from cons import car, cdr from etuples import etuple from kanren import eq, fact, run from kanren.assoccomm import associative, commutative, eq_assoccomm from kanren.core import lall -from unification import var, vars +from unification import unify, var, vars +from unification.variable import isvar import pytensor.tensor as pt from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter -from pytensor.graph.rewriting.kanren import KanrenRelationSub -from pytensor.graph.rewriting.unify import eval_if_etuple +from pytensor.graph.rewriting.kanren import KanrenRelationSub, eval_if_etuple +from pytensor.graph.rewriting.unify import ConstrainedVar, PatternVar from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor.math import Dot, _dot from tests.graph.utils import MyType, MyVariable @@ -35,6 +44,22 @@ def clear_assoccomm(): associative.facts = old_associative_facts +def test_pytensor_unification_dispatchers(): + x_pt = pt.vector("x") + y_pt = pt.vector("y") + z_pt = x_pt + y_pt + + assert car(z_pt) == z_pt.owner.op + assert cdr(z_pt) == [x_pt, y_pt] + + op_lv = var("op") + s = unify(etuple(op_lv, x_pt, y_pt), z_pt, {}) + assert s[op_lv] == z_pt.owner.op + + assert isvar(PatternVar("v")) + assert isvar(ConstrainedVar(lambda v: True, "c")) + + def test_kanren_basic(): A_pt = pt.matrix("A") B_pt = pt.matrix("B") diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 5ce8d04105..8467c25270 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -1,23 +1,20 @@ +"""Tests for the specialized pattern matcher in :mod:`pytensor.graph.rewriting.unify`.""" + import numpy as np -import pytest -from cons import car, cdr -from cons.core import ConsError -from etuples import apply, etuple, etuplize -from etuples.core import ExpressionTuple -from unification import reify, unify, var -from unification.variable import Var - -import pytensor.scalar as ps + import pytensor.tensor as pt -from pytensor.graph.basic import Apply, Constant, equal_computations +from pytensor.graph.basic import Apply, Constant from pytensor.graph.op import Op from pytensor.graph.rewriting.unify import ( ConstrainedVar, + LiteralString, OpPattern, + PatternNode, + PatternVar, convert_strs_to_vars, + match_pattern, + reify_pattern, ) -from pytensor.tensor.type import TensorType -from tests.graph.utils import MyType class CustomOp(Op): @@ -33,344 +30,200 @@ def perform(self, node, inputs, outputs): raise NotImplementedError() -class CustomOpNoPropsNoEq(Op): +class CustomOpNoProps(Op): def __init__(self, a): self.a = a - def make_node(self, *inputs): - return Apply(self, list(inputs), [pt.vector()]) - - def perform(self, node, inputs, outputs): - raise NotImplementedError() - - -class CustomOpNoProps(CustomOpNoPropsNoEq): def __eq__(self, other): return type(self) is type(other) and self.a == other.a def __hash__(self): return hash((type(self), self.a)) + def make_node(self, *inputs): + return Apply(self, list(inputs), [pt.vector()]) -def test_cons(): - x_pt = pt.vector("x") - y_pt = pt.vector("y") - - z_pt = x_pt + y_pt - - res = car(z_pt) - assert res == z_pt.owner.op - - res = cdr(z_pt) - assert res == [x_pt, y_pt] - - with pytest.raises(ConsError): - car(x_pt) - - with pytest.raises(ConsError): - cdr(x_pt) - - op1 = CustomOp(1) - - assert car(op1) == CustomOp - assert cdr(op1) == (1,) - - tt1 = TensorType("float32", shape=(1, None)) - - assert car(tt1) == TensorType - assert cdr(tt1) == ("float32", (1, None)) - - op1_np = CustomOpNoProps(1) - - with pytest.raises(ConsError): - car(op1_np) - - with pytest.raises(ConsError): - cdr(op1_np) - - atype_pt = ps.float64 - car_res = car(atype_pt) - cdr_res = cdr(atype_pt) - assert car_res is type(atype_pt) - assert cdr_res == [atype_pt.dtype] - - atype_pt = pt.lvector - car_res = car(atype_pt) - cdr_res = cdr(atype_pt) - assert car_res is type(atype_pt) - assert cdr_res == [atype_pt.dtype, atype_pt.shape] - - -def test_etuples(): - x_pt = pt.vector("x") - y_pt = pt.vector("y") - - z_pt = etuple(x_pt, y_pt) - - res = apply(pt.add, z_pt) - - assert res.owner.op == pt.add - assert res.owner.inputs == [x_pt, y_pt] - - w_pt = etuple(pt.add, x_pt, y_pt) - - res = w_pt.evaled_obj - assert res.owner.op == pt.add - assert res.owner.inputs == [x_pt, y_pt] - - # This `Op` doesn't expand into an `etuple` (i.e. it's "atomic") - op1_np = CustomOpNoProps(1) - - res = apply(op1_np, z_pt) - assert res.owner.op == op1_np - - q_pt = op1_np(x_pt, y_pt) - res = etuplize(q_pt) - assert res[0] == op1_np - - with pytest.raises(TypeError): - etuplize(op1_np) - - class MyMultiOutOp(Op): - def make_node(self, *inputs): - outputs = [MyType()(), MyType()()] - return Apply(self, list(inputs), outputs) - - def perform(self, node, inputs, outputs): - outputs[0] = np.array(inputs[0]) - outputs[1] = np.array(inputs[0]) - - x_pt = pt.vector("x") - op1_np = MyMultiOutOp() - res = apply(op1_np, etuple(x_pt)) - assert len(res) == 2 - assert res[0].owner.op == op1_np - assert res[1].owner.op == op1_np - - -def test_unify_Variable(): - x_pt = pt.vector("x") - y_pt = pt.vector("y") - - z_pt = x_pt + y_pt - - # `Variable`, `Variable` - s = unify(z_pt, z_pt) - assert s == {} - - # These `Variable`s have no owners - v1 = MyType()() - v2 = MyType()() - - assert v1 != v2 - - s = unify(v1, v2) - assert s is False - - op_lv = var() - z_ppt_et = etuple(op_lv, x_pt, y_pt) - - # `Variable`, `ExpressionTuple` - s = unify(z_pt, z_ppt_et, {}) - - assert op_lv in s - assert s[op_lv] == z_pt.owner.op - - res = reify(z_ppt_et, s) - - assert isinstance(res, ExpressionTuple) - assert equal_computations([res.evaled_obj], [z_pt]) - - z_et = etuple(pt.add, x_pt, y_pt) - - # `ExpressionTuple`, `ExpressionTuple` - s = unify(z_et, z_ppt_et, {}) - - assert op_lv in s - assert s[op_lv] == z_et[0] - - res = reify(z_ppt_et, s) - - assert isinstance(res, ExpressionTuple) - assert equal_computations([res.evaled_obj], [z_et.evaled_obj]) + def perform(self, node, inputs, outputs): + raise NotImplementedError() - # `ExpressionTuple`, `Variable` - s = unify(z_et, x_pt, {}) - assert s is False - # This `Op` doesn't expand into an `ExpressionTuple` - op1_np = CustomOpNoProps(1) +def test_convert_strs_to_vars_strings(): + res = convert_strs_to_vars("a") + assert isinstance(res, PatternVar) + assert res.token == "a" - q_pt = op1_np(x_pt, y_pt) - a_lv = var() - b_lv = var() - # `Variable`, `ExpressionTuple` - s = unify(q_pt, etuple(op1_np, a_lv, b_lv)) +def test_convert_strs_to_vars_shared_var_map(): + var_map = {} + res = convert_strs_to_vars((pt.add, "x", "x"), var_map=var_map) + assert isinstance(res, PatternNode) + assert res.inputs[0] is res.inputs[1] + assert res.inputs[0].token == "x" - assert s[a_lv] == x_pt - assert s[b_lv] == y_pt +def test_convert_strs_to_vars_constraint(): + def is_int(v): + return isinstance(v, int) -def test_unify_Op(): - # These `Op`s expand into `ExpressionTuple`s - op1 = CustomOp(1) - op2 = CustomOp(1) + var_map = {} + res = convert_strs_to_vars( + (pt.add, {"pattern": "x", "constraint": is_int}, "x"), + var_map=var_map, + ) + # Both references to "x" are the same PatternVar + assert res.inputs[0] is res.inputs[1] + assert res.inputs[0].constraint is is_int - # `Op`, `Op` - s = unify(op1, op2) - assert s == {} - # `ExpressionTuple`, `Op` - s = unify(etuplize(op1), op2) - assert s == {} +def test_convert_strs_to_vars_literal_string(): + res = convert_strs_to_vars(LiteralString("hello")) + assert res == "hello" - # These `Op`s don't expand into `ExpressionTuple`s - op1_np = CustomOpNoProps(1) - op2_np = CustomOpNoProps(1) - s = unify(op1_np, op2_np) - assert s == {} +def test_convert_strs_to_vars_numeric_constant(): + val = np.r_[1, 2] + res = convert_strs_to_vars((pt.add, val, "x")) + assert isinstance(res.inputs[0], Constant) + assert np.array_equal(res.inputs[0].data, val) - # Same, but this one also doesn't implement `__eq__` - op1_np_neq = CustomOpNoPropsNoEq(1) - s = unify(op1_np_neq, etuplize(op1)) - assert s is False +def test_match_simple_chain(): + x = pt.vector("x") + out = pt.log(pt.exp(x)) + pat = convert_strs_to_vars((pt.log, (pt.exp, "x"))) + subs = match_pattern(pat, out.owner) + assert subs is not None + [(_, bound)] = subs.items() + assert bound is x -def test_unify_Constant(): - # Make sure `Constant` unification works - c1_pt = pt.as_tensor(np.r_[1, 2]) - c2_pt = pt.as_tensor(np.r_[1, 2]) - # `Constant`, `Constant` - s = unify(c1_pt, c2_pt) - assert s == {} +def test_match_repeated_var_same_value(): + x = pt.vector("x") + out = pt.mul(x, x) + pat = convert_strs_to_vars((pt.mul, "z", "z")) + subs = match_pattern(pat, out.owner) + assert subs is not None + [(pat_var, bound)] = subs.items() + assert pat_var.token == "z" + assert bound is x -def test_unify_Type(): - t1 = TensorType(np.float64, shape=(1, None)) - t2 = TensorType(np.float64, shape=(1, None)) +def test_match_repeated_var_different_value_fails(): + x = pt.vector("x") + y = pt.vector("y") + out = pt.mul(x, y) + pat = convert_strs_to_vars((pt.mul, "z", "z")) + assert match_pattern(pat, out.owner) is None - # `Type`, `Type` - s = unify(t1, t2) - assert s == {} - # `Type`, `ExpressionTuple` - s = unify(t1, etuple(TensorType, "float64", (1, None))) - assert s == {} +def test_match_with_constraint(): + def is_scalar(v): + return all(s == 1 for s in v.type.shape) - from pytensor.scalar.basic import ScalarType + x = pt.vector("x") + s = pt.scalar("s") + out = pt.sub(s, x) + pat = convert_strs_to_vars((pt.sub, {"pattern": "a", "constraint": is_scalar}, "x")) + assert match_pattern(pat, out.owner) is not None - st1 = ScalarType(np.float64) - st2 = ScalarType(np.float64) + y = pt.vector("y") + out2 = pt.sub(y, x) + assert match_pattern(pat, out2.owner) is None - s = unify(st1, st2) - assert s == {} +def test_match_op_pattern_isinstance(): + x = pt.matrix("x") + out = pt.sum(x) -def test_ConstrainedVar(): - cvar = ConstrainedVar(lambda x: isinstance(x, str)) + from pytensor.tensor.elemwise import CAReduce - assert repr(cvar).startswith("ConstrainedVar(") - assert repr(cvar).endswith(f", {cvar.token})") + pat = convert_strs_to_vars((OpPattern(CAReduce, axis=None), "x")) + subs = match_pattern(pat, out.owner) + assert subs is not None + [(_, bound)] = subs.items() + assert bound is x - s = unify(cvar, 1) - assert s is False - s = unify(1, cvar) - assert s is False +def test_match_op_pattern_with_var_param(): + from pytensor.tensor.elemwise import CAReduce - s = unify(cvar, "hi") - assert s[cvar] == "hi" + x = pt.matrix("x") + out = pt.sum(x) + pat = convert_strs_to_vars((OpPattern(CAReduce, scalar_op="sop", axis=None), "x")) + subs = match_pattern(pat, out.owner) + assert subs is not None + by_name = {p.token: v for p, v in subs.items()} + assert by_name["x"] is x + assert by_name["sop"] is out.owner.op.scalar_op - s = unify("hi", cvar) - assert s[cvar] == "hi" - x_lv = var() - s = unify(cvar, x_lv) - assert s == {cvar: x_lv} +def test_match_op_pattern_param_mismatch(): + from pytensor.tensor.elemwise import CAReduce - s = unify(cvar, x_lv, {x_lv: "hi"}) - assert s[cvar] == "hi" + x = pt.matrix("x") + out = pt.sum(x, axis=0) # axis=(0,), not None + pat = convert_strs_to_vars((OpPattern(CAReduce, axis=None), "x")) + assert match_pattern(pat, out.owner) is None - s = unify(x_lv, cvar, {x_lv: "hi"}) - assert s[cvar] == "hi" - s_orig = {cvar: "hi", x_lv: "hi"} - s = unify(x_lv, cvar, s_orig) - assert s == s_orig +def test_match_nested_op_pattern(): + from pytensor.tensor.blockwise import Blockwise - s_orig = {cvar: "hi", x_lv: "bye"} - s = unify(x_lv, cvar, s_orig) - assert s is False + A = pt.tensor3("A") + b = pt.tensor3("b") + out = pt.linalg.solve(A, b) - x_pt = pt.vector("x") - y_pt = pt.vector("y") - op1_np = CustomOpNoProps(1) - r_pt = etuple(op1_np, x_pt, y_pt) + from pytensor.tensor.slinalg import Solve - def constraint(x): - return isinstance(x, tuple) + pat = convert_strs_to_vars( + ( + OpPattern( + Blockwise, + core_op=OpPattern(Solve, assume_a=LiteralString("gen")), + ), + "A", + "b", + ) + ) + res = match_pattern(pat, out.owner) + assert res is not None - a_lv = ConstrainedVar(constraint) - res = reify(etuple(op1_np, a_lv), {a_lv: r_pt}) - assert res[1] == r_pt +def test_match_literal_int_constant(): + x = pt.scalar("x") + out = pt.add(x, pt.as_tensor_variable(2.0)) + pat = convert_strs_to_vars((pt.add, "x", 2.0)) + assert match_pattern(pat, out.owner) is not None -def test_convert_strs_to_vars(): - res = convert_strs_to_vars("a") - assert isinstance(res, Var) - assert res.token == "a" +def test_reify_simple(): + x = pt.vector("x") + pv = PatternVar("x") + pat = PatternNode(pt.log, (pv,)) + out = reify_pattern(pat, {pv: x}) + assert out.owner is not None + assert out.owner.op == pt.log - x_pt = pt.vector() - y_pt = pt.vector() - res = convert_strs_to_vars((("a", x_pt), y_pt)) - assert res == etuple(etuple(var("a"), x_pt), y_pt) - def constraint(x): - return isinstance(x, str) +def test_reify_op_pattern_to_op(): + import pytensor.scalar as ps + from pytensor.tensor.elemwise import CAReduce - res = convert_strs_to_vars( - (({"pattern": "a", "constraint": constraint}, x_pt), y_pt) + pv_x = PatternVar("x") + pv_scalar = PatternVar("sop") + pat = PatternNode( + OpPattern(CAReduce, parameters=(("axis", None), ("scalar_op", pv_scalar))), + (pv_x,), ) - assert res == etuple(etuple(ConstrainedVar(constraint, "a"), x_pt), y_pt) + x = pt.matrix("x") + out = reify_pattern(pat, {pv_x: x, pv_scalar: ps.add}) + assert out.owner is not None + assert out.owner.op.axis is None - # Make sure constrained logic variables are the same across distinct uses - # of their string names - res = convert_strs_to_vars(({"pattern": "a", "constraint": constraint}, "a")) - assert res[0] is res[1] - var_map = {"a": var("a")} - res = convert_strs_to_vars(("a",), var_map=var_map) - assert res[0] is var_map["a"] - - # Make sure numbers and NumPy arrays are converted - val = np.r_[1, 2] - res = convert_strs_to_vars((val,)) - assert isinstance(res[0], Constant) - assert np.array_equal(res[0].data, val) - - -def test_unify_OpPattern(): - x_pt = MyType()("x_pt") - y_pt = MyType()("y_pt") - out1 = CustomOp(a=1)(x_pt, y_pt) - out2 = CustomOp(a=2)(x_pt, y_pt) - - x = var("x") - y = var("y") - pattern = etuple(OpPattern(CustomOp), x, y) - assert unify(pattern, out1) == {x: x_pt, y: y_pt} - assert unify(pattern, out2) == {x: x_pt, y: y_pt} - - pattern = etuple(OpPattern(CustomOp, a=1), x, y) - assert unify(pattern, out1) == {x: x_pt, y: y_pt} - assert unify(pattern, out2) is False - - a = var("a") - pattern = etuple(OpPattern(CustomOp, a=a), x, y) - assert unify(pattern, out1) == {x: x_pt, y: y_pt, a: 1} - assert unify(pattern, out2) == {x: x_pt, y: y_pt, a: 2} +def test_constrained_var_back_compat(): + cvar = ConstrainedVar(lambda v: isinstance(v, int), "tok") + assert isinstance(cvar, PatternVar) + assert cvar.constraint(1) + assert not cvar.constraint("x") From 2bae088bfb2ec14ca0c8abbaeb22028a4befbe27 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 11 May 2026 13:57:24 +0200 Subject: [PATCH 2/2] Add variadic Asterisk and auto-detected commutative match Commutativity is detected from op.commutative (or scalar_op.commutative for Elemwise). The matcher tries the original input order in place first; only on failure does it copy subs and backtrack through permutations. Asterisk preserves input order in both paths. --- pytensor/graph/rewriting/unify.py | 180 ++++++++++++++++++++++++++-- tests/graph/rewriting/test_unify.py | 88 ++++++++++++++ 2 files changed, 257 insertions(+), 11 deletions(-) diff --git a/pytensor/graph/rewriting/unify.py b/pytensor/graph/rewriting/unify.py index c6f8d96da6..73642cdee5 100644 --- a/pytensor/graph/rewriting/unify.py +++ b/pytensor/graph/rewriting/unify.py @@ -162,6 +162,23 @@ def __repr__(self) -> str: return f"PatternVar({self.token!r})" +class Asterisk: + """Pattern element that captures the remaining inputs of a node as a tuple. + + Must appear in last position of a pattern node's children. The captured + tuple is reified (during ``reify_pattern``) by splatting the bound tuple + back into the output op's input list. + """ + + __slots__ = ("token",) + + def __init__(self, token: str): + self.token = token + + def __repr__(self) -> str: + return f"Asterisk({self.token!r})" + + Var: TypeAlias = PatternVar @@ -181,6 +198,12 @@ class PatternNode: __slots__ = ("inputs", "op_match") def __init__(self, op_match, inputs: tuple): + for i, c in enumerate(inputs): + if isinstance(c, Asterisk) and i != len(inputs) - 1: + raise TypeError( + "Asterisk must appear in last position of a pattern's children; " + f"got it at index {i} of {inputs}" + ) self.op_match = op_match self.inputs = inputs @@ -193,7 +216,7 @@ def __repr__(self) -> str: def convert_strs_to_vars( x: tuple | str | dict | OpPattern, - var_map: dict[str, PatternVar] | None = None, + var_map: dict[str, PatternVar | Asterisk] | None = None, ): r"""Convert tuples and strings to pattern trees and logic variables, respectively. @@ -210,9 +233,25 @@ def _convert(y, op_prop: bool = False): if v is None: v = PatternVar(token=y) var_map[y] = v + elif isinstance(v, Asterisk): + raise TypeError( + f"Token {y!r} is already bound to an Asterisk; " + "cannot reuse as PatternVar" + ) return v if isinstance(y, LiteralString): return y.value + if isinstance(y, Asterisk): + existing = var_map.get(y.token) + if existing is None: + var_map[y.token] = y + return y + if not isinstance(existing, Asterisk): + raise TypeError( + f"Token {y.token!r} is already bound to a PatternVar; " + "cannot reuse as Asterisk" + ) + return existing if isinstance(y, dict): pattern = y["pattern"] if not isinstance(pattern, str): @@ -224,6 +263,11 @@ def _convert(y, op_prop: bool = False): if v is None: v = PatternVar(token=pattern, constraint=constraint) var_map[pattern] = v + elif isinstance(v, Asterisk): + raise TypeError( + f"Token {pattern!r} is already bound to an Asterisk; " + "cannot reuse as PatternVar" + ) elif v.constraint is None: v.constraint = constraint return v @@ -252,10 +296,17 @@ def _convert(y, op_prop: bool = False): return _convert(x) +def _is_commutative_op(op) -> bool: + if getattr(op, "commutative", False): + return True + scalar_op = getattr(op, "scalar_op", None) + return bool(getattr(scalar_op, "commutative", False)) + + def match_pattern( pattern: PatternNode, node, - subs: dict[PatternVar, Any] | None = None, + subs: dict[PatternVar | Asterisk, Any] | None = None, *, fgraph_clients=None, ): @@ -269,7 +320,7 @@ def match_pattern( def _match_node( pattern: PatternNode, node, - subs: dict[PatternVar, Any], + subs: dict[PatternVar | Asterisk, Any], fgraph_clients, ) -> bool: op_match = pattern.op_match @@ -283,19 +334,113 @@ def _match_node( if node_op != op_match: return False - if len(pattern.inputs) != len(node.inputs): + return _match_children(pattern.inputs, node.inputs, subs, fgraph_clients, node_op) + + +def _match_children( + pattern_inputs: tuple, + node_inputs: list, + subs: dict[PatternVar | Asterisk, Any], + fgraph_clients, + node_op, +) -> bool: + has_asterisk = bool(pattern_inputs) and isinstance(pattern_inputs[-1], Asterisk) + n_fixed = len(pattern_inputs) - 1 if has_asterisk else len(pattern_inputs) + + if has_asterisk: + if len(node_inputs) < n_fixed: + return False + elif len(pattern_inputs) != len(node_inputs): return False - for sub_pat, sub_var in zip(pattern.inputs, node.inputs): + fixed_pats = pattern_inputs[:n_fixed] + asterisk_pat: Asterisk | None = pattern_inputs[-1] if has_asterisk else None + + # n_fixed == 1 without asterisk has only one possible assignment, so + # commutativity doesn't add any choice — skip the backup/backtrack. + needs_commutative = ( + _is_commutative_op(node_op) and n_fixed >= 1 and (n_fixed > 1 or has_asterisk) + ) + saved_subs = dict(subs) if needs_commutative else None + ok = True + for sub_pat, sub_var in zip(fixed_pats, node_inputs[:n_fixed]): if not _match_element(sub_pat, sub_var, subs, fgraph_clients): + ok = False + break + if ok: + if asterisk_pat is not None: + return _bind_asterisk(asterisk_pat, tuple(node_inputs[n_fixed:]), subs) + return True + + if saved_subs is None: + return False + + subs.clear() + subs.update(saved_subs) + return _commutative_backtrack( + fixed_pats, + node_inputs, + [False] * len(node_inputs), + 0, + subs, + fgraph_clients, + has_asterisk, + asterisk_pat, + ) + + +def _commutative_backtrack( + fixed_pats, + node_inputs, + used, + idx, + subs, + fgraph_clients, + has_asterisk, + asterisk_pat, +) -> bool: + if idx == len(fixed_pats): + if has_asterisk: + remainder = tuple(v for v, u in zip(node_inputs, used) if not u) + return _bind_asterisk(asterisk_pat, remainder, subs) + return True + pat = fixed_pats[idx] + for j, var in enumerate(node_inputs): + if used[j]: + continue + saved_subs = dict(subs) + used[j] = True + if _match_element(pat, var, subs, fgraph_clients) and _commutative_backtrack( + fixed_pats, + node_inputs, + used, + idx + 1, + subs, + fgraph_clients, + has_asterisk, + asterisk_pat, + ): + return True + used[j] = False + subs.clear() + subs.update(saved_subs) + return False + + +def _bind_asterisk(asterisk_pat: Asterisk, value: tuple, subs: dict) -> bool: + existing = subs.get(asterisk_pat, _MISSING) + if existing is not _MISSING: + if len(existing) != len(value): return False + return all(_values_equal(a, b) for a, b in zip(existing, value)) + subs[asterisk_pat] = value return True def _match_element( pattern, var, - subs: dict[PatternVar, Any], + subs: dict[PatternVar | Asterisk, Any], fgraph_clients, ) -> bool: if isinstance(pattern, PatternVar): @@ -324,7 +469,9 @@ def _match_element( _MISSING = object() -def _bind_var(pat_var: PatternVar, value, subs: dict[PatternVar, Any]) -> bool: +def _bind_var( + pat_var: PatternVar, value, subs: dict[PatternVar | Asterisk, Any] +) -> bool: existing = subs.get(pat_var, _MISSING) if existing is not _MISSING: return _values_equal(existing, value) @@ -348,7 +495,7 @@ def _values_equal(a, b) -> bool: def _match_op_parameters( - op_pat: OpPattern, op: Op, subs: dict[PatternVar, Any] + op_pat: OpPattern, op: Op, subs: dict[PatternVar | Asterisk, Any] ) -> bool: for key, param in op_pat.parameters: op_val = getattr(op, key) @@ -366,7 +513,7 @@ def _match_op_parameters( return True -def reify_pattern(pattern, subs: Mapping[PatternVar, Any]): +def reify_pattern(pattern, subs: Mapping[PatternVar | Asterisk, Any]): if isinstance(pattern, PatternVar): try: return subs[pattern] @@ -377,7 +524,18 @@ def reify_pattern(pattern, subs: Mapping[PatternVar, Any]): if isinstance(pattern, PatternNode): op = _reify_op_match(pattern.op_match, subs) - inputs = [reify_pattern(p, subs) for p in pattern.inputs] + inputs = [] + for p in pattern.inputs: + if isinstance(p, Asterisk): + try: + captured = subs[p] + except KeyError: + raise ValueError( + f"Output pattern references unbound asterisk {p.token!r}" + ) + inputs.extend(captured) + else: + inputs.append(reify_pattern(p, subs)) return op.make_node(*inputs).default_output() if isinstance(pattern, OpPattern): @@ -386,7 +544,7 @@ def reify_pattern(pattern, subs: Mapping[PatternVar, Any]): return pattern -def _reify_op_match(op_match, subs: Mapping[PatternVar, Any]): +def _reify_op_match(op_match, subs: Mapping[PatternVar | Asterisk, Any]): if isinstance(op_match, OpPattern): op_type = op_match.op_type if not isinstance(op_type, type): diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 8467c25270..0a446d7518 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -6,6 +6,7 @@ from pytensor.graph.basic import Apply, Constant from pytensor.graph.op import Op from pytensor.graph.rewriting.unify import ( + Asterisk, ConstrainedVar, LiteralString, OpPattern, @@ -197,6 +198,93 @@ def test_match_literal_int_constant(): assert match_pattern(pat, out.owner) is not None +def test_match_commutative_add_swapped(): + x = pt.scalar("x") + out_left = pt.add(pt.as_tensor_variable(1.0), x) + out_right = pt.add(x, pt.as_tensor_variable(1.0)) + pat = convert_strs_to_vars((pt.add, 1.0, "x")) + + s1 = match_pattern(pat, out_left.owner) + s2 = match_pattern(pat, out_right.owner) + assert s1 is not None + assert s2 is not None + [(_, b1)] = s1.items() + [(_, b2)] = s2.items() + assert b1 is x + assert b2 is x + + +def test_match_commutative_does_not_apply_to_sub(): + x = pt.scalar("x") + out = pt.sub(pt.as_tensor_variable(1.0), x) + pat_swapped = convert_strs_to_vars((pt.sub, "x", 1.0)) + assert match_pattern(pat_swapped, out.owner) is None + + +def test_match_variadic_asterisk(): + a = pt.vector("a") + b = pt.vector("b") + c = pt.vector("c") + out = pt.add(a, b, c) + pat = convert_strs_to_vars((pt.add, "first", Asterisk("rest"))) + + subs = match_pattern(pat, out.owner) + assert subs is not None + by_name = { + (p.token if isinstance(p, (PatternVar, Asterisk)) else p): v + for p, v in subs.items() + } + assert by_name["first"] is a + assert by_name["rest"] == (b, c) + + +def test_match_variadic_too_few_inputs(): + a = pt.vector("a") + out = pt.exp(a) + pat = convert_strs_to_vars((pt.exp, "x", "y", Asterisk("rest"))) + assert match_pattern(pat, out.owner) is None + + +def test_match_variadic_preserves_input_order_under_commutative_backtrack(): + a = pt.vector("a", dtype="float32") + b = pt.vector("b", dtype="float64") + c = pt.vector("c", dtype="float32") + + pat = convert_strs_to_vars( + ( + pt.add, + {"pattern": "first", "constraint": lambda v: v.dtype == "float64"}, + Asterisk("rest"), + ) + ) + + s1 = match_pattern(pat, pt.add(a, b, c).owner) + n1 = {k.token: v for k, v in s1.items()} + assert n1["first"] is b + assert n1["rest"] == (a, c) + + s2 = match_pattern(pat, pt.add(c, a, b).owner) + n2 = {k.token: v for k, v in s2.items()} + assert n2["first"] is b + assert n2["rest"] == (c, a) + + +def test_reify_variadic(): + a = pt.vector("a") + b = pt.vector("b") + c = pt.vector("c") + rest = Asterisk("rest") + var_map = {} + in_pat = convert_strs_to_vars((pt.add, "first", rest), var_map=var_map) + subs = match_pattern(in_pat, pt.add(a, b, c).owner) + assert subs is not None + + out_pat = convert_strs_to_vars((pt.mul, "first", rest), var_map=var_map) + result = reify_pattern(out_pat, subs) + assert result.owner.op == pt.mul + assert result.owner.inputs == [a, b, c] + + def test_reify_simple(): x = pt.vector("x") pv = PatternVar("x")