diff --git a/pytensor/assumptions/triangular.py b/pytensor/assumptions/triangular.py index af73c75e5c..fad2505341 100644 --- a/pytensor/assumptions/triangular.py +++ b/pytensor/assumptions/triangular.py @@ -136,21 +136,32 @@ def _lu_upper(key, op, feature, fgraph, node, input_states): return states -def _dimshuffle_left_expand_dims(key, op, feature, fgraph, node, input_states): - """Triangularity survives left-expand-dims (batch broadcast). +@register_assumption(LOWER_TRIANGULAR, DimShuffle) +def _lower_dimshuffle(key, op, feature, fgraph, node, input_states): + """Mirror lower-triangularity through left-expand-dims and matrix transpose. - Matrix transpose swaps lower<->upper, so it is *not* propagated under - the same key. + Left-expand-dims preserves the matrix exactly and transpose swaps the two + triangles, swapping their fact states. """ - if input_states[0] is not FactState.TRUE: - return [FactState.UNKNOWN] + if op.is_matrix_transpose: + return [feature.get(node.inputs[0], UPPER_TRIANGULAR)] if left_expand_dims_propagates_matrix_property(op): - return [FactState.TRUE] + return [input_states[0]] return [FactState.UNKNOWN] -register_assumption(LOWER_TRIANGULAR, DimShuffle)(_dimshuffle_left_expand_dims) -register_assumption(UPPER_TRIANGULAR, DimShuffle)(_dimshuffle_left_expand_dims) +@register_assumption(UPPER_TRIANGULAR, DimShuffle) +def _upper_dimshuffle(key, op, feature, fgraph, node, input_states): + """Mirror upper-triangularity through left-expand-dims and matrix transpose. + + Left-expand-dims preserves the matrix exactly and transpose swaps the two + triangles, swapping their fact states. + """ + if op.is_matrix_transpose: + return [feature.get(node.inputs[0], LOWER_TRIANGULAR)] + if left_expand_dims_propagates_matrix_property(op): + return [input_states[0]] + return [FactState.UNKNOWN] @register_assumption(LOWER_TRIANGULAR, Elemwise) diff --git a/tests/assumptions/test_dimshuffle.py b/tests/assumptions/test_dimshuffle.py index f2ec187c94..e65e01b1dc 100644 --- a/tests/assumptions/test_dimshuffle.py +++ b/tests/assumptions/test_dimshuffle.py @@ -38,15 +38,6 @@ def test_multiple_left_axes(self, key): _, af = make_fgraph(y) assert af.check(y, key) - def test_triangular_transpose_does_not_propagate(self): - """Transpose swaps the triangle, so a lower-triangular matrix is no longer - known to be lower-triangular.""" - x = pt.matrix("x", shape=(4, 4)) - x_lower = assume(x, lower_triangular=True) - y = x_lower.T - _, af = make_fgraph(y) - assert af.get(y, LOWER_TRIANGULAR) == FactState.UNKNOWN - def test_right_expand_dims_does_not_propagate(self): """Adding a broadcast dim on the right shifts the matrix axes; the property is lost.""" x = pt.matrix("x", shape=(4, 4)) diff --git a/tests/assumptions/test_triangular.py b/tests/assumptions/test_triangular.py index 7a2e4569c7..9a09ea1c5b 100644 --- a/tests/assumptions/test_triangular.py +++ b/tests/assumptions/test_triangular.py @@ -166,3 +166,62 @@ def test_lu_permute_l_pl_is_not_lower_triangular(): _, af = make_fgraph(PL, U) assert af.get(PL, LOWER_TRIANGULAR) == FactState.UNKNOWN assert af.check(U, UPPER_TRIANGULAR) + + +class TestMatrixTransposeFlipsTriangle: + """Matrix transpose maps lower-triangular to upper-triangular and vice versa.""" + + @pytest.mark.parametrize( + "asserted_value, expected", + [(True, FactState.TRUE), (False, FactState.FALSE)], + ids=["true", "false"], + ) + @pytest.mark.parametrize( + "asserted, flipped", + [ + (LOWER_TRIANGULAR, UPPER_TRIANGULAR), + (UPPER_TRIANGULAR, LOWER_TRIANGULAR), + ], + ) + def test_transpose_flips_triangle( + self, asserted, flipped, asserted_value, expected + ): + x = pt.matrix("x", shape=(4, 4)) + x_tagged = assume(x, **{asserted.name: asserted_value}) + y = x_tagged.T + _, af = make_fgraph(y) + assert af.get(y, flipped) == expected + assert af.get(y, asserted) == FactState.UNKNOWN + + @pytest.mark.parametrize( + "asserted_value, expected", + [(True, FactState.TRUE), (False, FactState.FALSE)], + ids=["true", "false"], + ) + @pytest.mark.parametrize( + "asserted, flipped", + [ + (LOWER_TRIANGULAR, UPPER_TRIANGULAR), + (UPPER_TRIANGULAR, LOWER_TRIANGULAR), + ], + ) + def test_batched_transpose_flips_triangle( + self, asserted, flipped, asserted_value, expected + ): + x = pt.tensor("x", shape=(2, 3, 4, 4)) + x_tagged = assume(x, **{asserted.name: asserted_value}) + y = x_tagged.mT + _, af = make_fgraph(y) + assert af.get(y, flipped) == expected + assert af.get(y, asserted) == FactState.UNKNOWN + + @pytest.mark.parametrize("key", [LOWER_TRIANGULAR, UPPER_TRIANGULAR]) + def test_double_transpose_recovers_triangle(self, key): + # Exercises the cross-key inference chain: the first transpose flips + # via a feature.check on the input, and the second transpose flips + # back via a feature.check on the freshly-inferred intermediate. + x = pt.matrix("x", shape=(4, 4)) + x_tagged = assume(x, **{key.name: True}) + y = x_tagged.T.T + _, af = make_fgraph(y) + assert af.check(y, key)