From 7de911ddf9979fcb18455f7bf83f5be99b574c9c Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 21:13:34 -0400 Subject: [PATCH 1/2] Propagate triangularity through matrix transpose Replace the dead _dimshuffle_left_expand_dims rule (shadowed by the default DimShuffle handler) with cross-key inference: a matrix-transpose DimShuffle now flips lower<->upper, so downstream rewrites can see triangular structure through .T / .mT. --- pytensor/assumptions/triangular.py | 33 +++++++++++++------- tests/assumptions/test_dimshuffle.py | 9 ------ tests/assumptions/test_triangular.py | 45 ++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 19 deletions(-) diff --git a/pytensor/assumptions/triangular.py b/pytensor/assumptions/triangular.py index af73c75e5c..b8f9ed6395 100644 --- a/pytensor/assumptions/triangular.py +++ b/pytensor/assumptions/triangular.py @@ -136,21 +136,34 @@ def _lu_upper(key, op, feature, fgraph, node, input_states): return states -def _dimshuffle_left_expand_dims(key, op, feature, fgraph, node, input_states): - """Triangularity survives left-expand-dims (batch broadcast). - - Matrix transpose swaps lower<->upper, so it is *not* propagated under - the same key. +@register_assumption(LOWER_TRIANGULAR, DimShuffle) +def _lower_dimshuffle(key, op, feature, fgraph, node, input_states): + """Lower-triangularity survives left-expand-dims and arises from the matrix + transpose of an upper-triangular operand (since transposing swaps the two + triangles). """ - if input_states[0] is not FactState.TRUE: - return [FactState.UNKNOWN] - if left_expand_dims_propagates_matrix_property(op): + if op.is_matrix_transpose: + return true_if(feature.check(node.inputs[0], UPPER_TRIANGULAR)) + if input_states[ + 0 + ] is FactState.TRUE and left_expand_dims_propagates_matrix_property(op): return [FactState.TRUE] return [FactState.UNKNOWN] -register_assumption(LOWER_TRIANGULAR, DimShuffle)(_dimshuffle_left_expand_dims) -register_assumption(UPPER_TRIANGULAR, DimShuffle)(_dimshuffle_left_expand_dims) +@register_assumption(UPPER_TRIANGULAR, DimShuffle) +def _upper_dimshuffle(key, op, feature, fgraph, node, input_states): + """Upper-triangularity survives left-expand-dims and arises from the matrix + transpose of a lower-triangular operand (since transposing swaps the two + triangles). + """ + if op.is_matrix_transpose: + return true_if(feature.check(node.inputs[0], LOWER_TRIANGULAR)) + if input_states[ + 0 + ] is FactState.TRUE and left_expand_dims_propagates_matrix_property(op): + return [FactState.TRUE] + return [FactState.UNKNOWN] @register_assumption(LOWER_TRIANGULAR, Elemwise) diff --git a/tests/assumptions/test_dimshuffle.py b/tests/assumptions/test_dimshuffle.py index f2ec187c94..e65e01b1dc 100644 --- a/tests/assumptions/test_dimshuffle.py +++ b/tests/assumptions/test_dimshuffle.py @@ -38,15 +38,6 @@ def test_multiple_left_axes(self, key): _, af = make_fgraph(y) assert af.check(y, key) - def test_triangular_transpose_does_not_propagate(self): - """Transpose swaps the triangle, so a lower-triangular matrix is no longer - known to be lower-triangular.""" - x = pt.matrix("x", shape=(4, 4)) - x_lower = assume(x, lower_triangular=True) - y = x_lower.T - _, af = make_fgraph(y) - assert af.get(y, LOWER_TRIANGULAR) == FactState.UNKNOWN - def test_right_expand_dims_does_not_propagate(self): """Adding a broadcast dim on the right shifts the matrix axes; the property is lost.""" x = pt.matrix("x", shape=(4, 4)) diff --git a/tests/assumptions/test_triangular.py b/tests/assumptions/test_triangular.py index 7a2e4569c7..c589639c81 100644 --- a/tests/assumptions/test_triangular.py +++ b/tests/assumptions/test_triangular.py @@ -166,3 +166,48 @@ def test_lu_permute_l_pl_is_not_lower_triangular(): _, af = make_fgraph(PL, U) assert af.get(PL, LOWER_TRIANGULAR) == FactState.UNKNOWN assert af.check(U, UPPER_TRIANGULAR) + + +class TestMatrixTransposeFlipsTriangle: + """Matrix transpose maps lower-triangular to upper-triangular and vice versa.""" + + @pytest.mark.parametrize( + "asserted, flipped", + [ + (LOWER_TRIANGULAR, UPPER_TRIANGULAR), + (UPPER_TRIANGULAR, LOWER_TRIANGULAR), + ], + ) + def test_transpose_flips_triangle(self, asserted, flipped): + x = pt.matrix("x", shape=(4, 4)) + x_tagged = assume(x, **{asserted.name: True}) + y = x_tagged.T + _, af = make_fgraph(y) + assert af.check(y, flipped) + assert af.get(y, asserted) == FactState.UNKNOWN + + @pytest.mark.parametrize( + "asserted, flipped", + [ + (LOWER_TRIANGULAR, UPPER_TRIANGULAR), + (UPPER_TRIANGULAR, LOWER_TRIANGULAR), + ], + ) + def test_batched_transpose_flips_triangle(self, asserted, flipped): + x = pt.tensor("x", shape=(2, 3, 4, 4)) + x_tagged = assume(x, **{asserted.name: True}) + y = x_tagged.mT + _, af = make_fgraph(y) + assert af.check(y, flipped) + assert af.get(y, asserted) == FactState.UNKNOWN + + @pytest.mark.parametrize("key", [LOWER_TRIANGULAR, UPPER_TRIANGULAR]) + def test_double_transpose_recovers_triangle(self, key): + # Exercises the cross-key inference chain: the first transpose flips + # via a feature.check on the input, and the second transpose flips + # back via a feature.check on the freshly-inferred intermediate. + x = pt.matrix("x", shape=(4, 4)) + x_tagged = assume(x, **{key.name: True}) + y = x_tagged.T.T + _, af = make_fgraph(y) + assert af.check(y, key) From 65611276f8784eb62aec9456cdb82236b9f93a3f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 29 May 2026 23:09:03 -0500 Subject: [PATCH 2/2] propagate False facts --- pytensor/assumptions/triangular.py | 30 +++++++++++++--------------- tests/assumptions/test_triangular.py | 26 ++++++++++++++++++------ 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/pytensor/assumptions/triangular.py b/pytensor/assumptions/triangular.py index b8f9ed6395..fad2505341 100644 --- a/pytensor/assumptions/triangular.py +++ b/pytensor/assumptions/triangular.py @@ -138,31 +138,29 @@ def _lu_upper(key, op, feature, fgraph, node, input_states): @register_assumption(LOWER_TRIANGULAR, DimShuffle) def _lower_dimshuffle(key, op, feature, fgraph, node, input_states): - """Lower-triangularity survives left-expand-dims and arises from the matrix - transpose of an upper-triangular operand (since transposing swaps the two - triangles). + """Mirror lower-triangularity through left-expand-dims and matrix transpose. + + Left-expand-dims preserves the matrix exactly and transpose swaps the two + triangles, swapping their fact states. """ if op.is_matrix_transpose: - return true_if(feature.check(node.inputs[0], UPPER_TRIANGULAR)) - if input_states[ - 0 - ] is FactState.TRUE and left_expand_dims_propagates_matrix_property(op): - return [FactState.TRUE] + return [feature.get(node.inputs[0], UPPER_TRIANGULAR)] + if left_expand_dims_propagates_matrix_property(op): + return [input_states[0]] return [FactState.UNKNOWN] @register_assumption(UPPER_TRIANGULAR, DimShuffle) def _upper_dimshuffle(key, op, feature, fgraph, node, input_states): - """Upper-triangularity survives left-expand-dims and arises from the matrix - transpose of a lower-triangular operand (since transposing swaps the two - triangles). + """Mirror upper-triangularity through left-expand-dims and matrix transpose. + + Left-expand-dims preserves the matrix exactly and transpose swaps the two + triangles, swapping their fact states. """ if op.is_matrix_transpose: - return true_if(feature.check(node.inputs[0], LOWER_TRIANGULAR)) - if input_states[ - 0 - ] is FactState.TRUE and left_expand_dims_propagates_matrix_property(op): - return [FactState.TRUE] + return [feature.get(node.inputs[0], LOWER_TRIANGULAR)] + if left_expand_dims_propagates_matrix_property(op): + return [input_states[0]] return [FactState.UNKNOWN] diff --git a/tests/assumptions/test_triangular.py b/tests/assumptions/test_triangular.py index c589639c81..9a09ea1c5b 100644 --- a/tests/assumptions/test_triangular.py +++ b/tests/assumptions/test_triangular.py @@ -171,6 +171,11 @@ def test_lu_permute_l_pl_is_not_lower_triangular(): class TestMatrixTransposeFlipsTriangle: """Matrix transpose maps lower-triangular to upper-triangular and vice versa.""" + @pytest.mark.parametrize( + "asserted_value, expected", + [(True, FactState.TRUE), (False, FactState.FALSE)], + ids=["true", "false"], + ) @pytest.mark.parametrize( "asserted, flipped", [ @@ -178,14 +183,21 @@ class TestMatrixTransposeFlipsTriangle: (UPPER_TRIANGULAR, LOWER_TRIANGULAR), ], ) - def test_transpose_flips_triangle(self, asserted, flipped): + def test_transpose_flips_triangle( + self, asserted, flipped, asserted_value, expected + ): x = pt.matrix("x", shape=(4, 4)) - x_tagged = assume(x, **{asserted.name: True}) + x_tagged = assume(x, **{asserted.name: asserted_value}) y = x_tagged.T _, af = make_fgraph(y) - assert af.check(y, flipped) + assert af.get(y, flipped) == expected assert af.get(y, asserted) == FactState.UNKNOWN + @pytest.mark.parametrize( + "asserted_value, expected", + [(True, FactState.TRUE), (False, FactState.FALSE)], + ids=["true", "false"], + ) @pytest.mark.parametrize( "asserted, flipped", [ @@ -193,12 +205,14 @@ def test_transpose_flips_triangle(self, asserted, flipped): (UPPER_TRIANGULAR, LOWER_TRIANGULAR), ], ) - def test_batched_transpose_flips_triangle(self, asserted, flipped): + def test_batched_transpose_flips_triangle( + self, asserted, flipped, asserted_value, expected + ): x = pt.tensor("x", shape=(2, 3, 4, 4)) - x_tagged = assume(x, **{asserted.name: True}) + x_tagged = assume(x, **{asserted.name: asserted_value}) y = x_tagged.mT _, af = make_fgraph(y) - assert af.check(y, flipped) + assert af.get(y, flipped) == expected assert af.get(y, asserted) == FactState.UNKNOWN @pytest.mark.parametrize("key", [LOWER_TRIANGULAR, UPPER_TRIANGULAR])