From 18996f0a28828e06588583d1d303516d536ea50f Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 22 May 2026 21:31:14 +0900 Subject: [PATCH 1/3] Inline Einsum OFG in xtensor lower_dot to avoid ShapeFeature blow-up `pt.einsum` wraps its output in an `Einsum` `OpFromGraph`. The OFG is only inlined by `inline_optimized_einsum` during `specialize`, but while it is alive `ShapeFeature.on_import` calls `OpFromGraph.infer_shape` on every node import during canonicalize, and `infer_shape` re-walks the OFG's inner graph each time. When several xtensor dots are composed (e.g. multi-layer attention), this becomes super-linear and dominates compile time. Inlining the OFG immediately after `einsum` removes it before any shape-using pass ever sees it. The 2-operand case `lower_dot` produces has no path optimisation to defer, so inlining is safe and behaviour- preserving. Effect on the toy multi-head attention reproducer (block_size=32, n_embd=64, n_head=4, with grad): n_layer plain xtensor (before) xtensor (after) 1 0.94s 3.04s 1.12s 2 2.03s 72.50s 4.07s Adds a structural test that locks in the post-`lower_xtensor` invariant "no OpFromGraph nodes left in the lowered graph". All existing xtensor / tensordot / einsum tests still pass. Co-authored-by: Cursor --- pytensor/xtensor/rewriting/math.py | 10 ++++++++++ tests/xtensor/test_math.py | 26 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index c767ec490e..bdbebb26e8 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -2,6 +2,8 @@ from pytensor.graph import node_rewriter from pytensor.tensor import einsum +from pytensor.tensor.einsum import Einsum +from pytensor.tensor.rewriting.ofg import inline_ofg_node from pytensor.tensor.shape import specify_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.math import Dot @@ -41,6 +43,14 @@ def lower_dot(fgraph, node): # Perform the einsum operation out_tensor = einsum(einsum_str, x_tensor, y_tensor) + # Inline the Einsum OFG eagerly. `inline_optimized_einsum` only fires + # during `specialize`, but while the OFG is alive `ShapeFeature` calls + # `OpFromGraph.infer_shape` on every import, re-walking the inner graph + # each time. With many composed xtensor dots that dominates compile + # time. The 2-operand case has no path optimisation to defer. + if out_tensor.owner is not None and isinstance(out_tensor.owner.op, Einsum): + [out_tensor] = inline_ofg_node(out_tensor.owner) + # Reshape to match the output shape out_tensor = specify_shape(out_tensor, out.type.shape) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index bb187a0fac..c095c8ca78 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -316,6 +316,32 @@ def test_dot(): xr_assert_allclose(z_test, expected) +def test_dot_lowering_inlines_einsum_ofg(): + """``lower_dot`` must inline the ``Einsum`` OFG that ``pt.einsum`` wraps. + + Leaving the OFG in place lets ``ShapeFeature.on_import`` call + ``OpFromGraph.infer_shape`` on every node import during canonicalize, + which re-walks the inner graph and dominates compile time once several + xtensor dots are composed (e.g. multi-layer attention). + """ + from pytensor.compile.builders import OpFromGraph + from pytensor.graph.rewriting.utils import rewrite_graph + from pytensor.graph.traversal import io_toposort + + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + z = x.dot(y) + + lowered = rewrite_graph(z.values, include=("lower_xtensor",)) + ofg_nodes = [ + n for n in io_toposort([], [lowered]) if isinstance(n.op, OpFromGraph) + ] + assert ofg_nodes == [], ( + "lower_dot should inline the Einsum OpFromGraph eagerly; got: " + f"{[type(n.op).__name__ for n in ofg_nodes]}" + ) + + def test_dot_errors(): # No matching dimensions x = xtensor("x", dims=("a", "b"), shape=(2, 3)) From 11236f26fcbf59604b6aca16d17584d3f1d64904 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 22 May 2026 22:20:29 +0900 Subject: [PATCH 2/3] style: ruff format test_dot_lowering_inlines_einsum_ofg Co-authored-by: Cursor --- tests/xtensor/test_math.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index c095c8ca78..fbc938db04 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -333,9 +333,7 @@ def test_dot_lowering_inlines_einsum_ofg(): z = x.dot(y) lowered = rewrite_graph(z.values, include=("lower_xtensor",)) - ofg_nodes = [ - n for n in io_toposort([], [lowered]) if isinstance(n.op, OpFromGraph) - ] + ofg_nodes = [n for n in io_toposort([], [lowered]) if isinstance(n.op, OpFromGraph)] assert ofg_nodes == [], ( "lower_dot should inline the Einsum OpFromGraph eagerly; got: " f"{[type(n.op).__name__ for n in ofg_nodes]}" From 48b5a5a89884ae151efb9a1440c6de2eb2bb6841 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Fri, 22 May 2026 22:28:37 +0900 Subject: [PATCH 3/3] tests: widen test_dot_errors match for inlined Dot's runtime shape check After inlining the Einsum OFG in lower_dot, the runtime shape mismatch is now raised by the inlined Dot directly (`Shape mismatch: x has ...`) instead of by np.einsum/np.dot inside the OFG's wrapper. Add that message to the regex so the test passes on all backends. Co-authored-by: Cursor --- tests/xtensor/test_math.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index fbc938db04..df69a7a22d 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -363,10 +363,15 @@ def test_dot_errors(): fn = xr_function([x, y], z) x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) - # Doesn't fail until the rewrite + # Doesn't fail until the rewrite. The exact message depends on which op + # raises (np.einsum vs np.dot vs the inlined Dot's runtime shape check). with pytest.raises( ValueError, - match=r"(Input operand 1 has a mismatch in its core dimension 0|incompatible array sizes for np.dot)", + match=( + r"Input operand 1 has a mismatch in its core dimension 0" + r"|incompatible array sizes for np.dot" + r"|Shape mismatch: x has" + ), ): fn(x_test, y_test)