From 21e568fda1aa749ec4ace28408a9c117ecccbae8 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 20 May 2026 13:55:19 +0200 Subject: [PATCH] Numba sparse dot: Use final precision in intermediate computations Numba underpromotes relative to numpy/scipy mixed scalar * array dtypes --- pytensor/link/numba/dispatch/sparse/math.py | 8 +++-- tests/link/numba/sparse/test_math.py | 35 +++++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/sparse/math.py b/pytensor/link/numba/dispatch/sparse/math.py index 828f19c7fd..9b74e433c0 100644 --- a/pytensor/link/numba/dispatch/sparse/math.py +++ b/pytensor/link/numba/dispatch/sparse/math.py @@ -136,7 +136,9 @@ def numba_funcify_SparseDot(op, node, **kwargs): x_format = x.type.format if x_is_sparse else None y_format = y.type.format if y_is_sparse else None - cache_version = 2 + out_type = np.dtype(out_dtype).type + + cache_version = 3 cache_key = sha256( str( ( @@ -371,7 +373,7 @@ def spmdm_csr(x, y): for row_idx in range(n): for idx in range(x_ptr[row_idx], x_ptr[row_idx + 1]): col_idx = x_ind[idx] - value = x_data[idx] + value = out_type(x_data[idx]) z[row_idx] += value * y[col_idx] return z @@ -390,7 +392,7 @@ def spmdm_csc(x, y): for col_idx in range(p): for idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]): row_idx = x_ind[idx] - value = x_data[idx] + value = out_type(x_data[idx]) z[row_idx] += value * y[col_idx] return z diff --git a/tests/link/numba/sparse/test_math.py b/tests/link/numba/sparse/test_math.py index 4bc34518e3..3f9bbceff6 100644 --- a/tests/link/numba/sparse/test_math.py +++ b/tests/link/numba/sparse/test_math.py @@ -97,6 +97,41 @@ def test_sparse_spmv(sp_format): compare_numba_and_py_sparse([x, y], z, [x_test, y_test]) +@pytest.mark.parametrize( + "x_dtype, y_dtype", + [ + ("int64", "complex64"), + ("int64", "float32"), + ], +) +def test_structured_dot_upcast(x_dtype, y_dtype): + """Numba scalar-array mul keeps the array dtype; numpy upcasts to a wider type.""" + x = ps.matrix(format="csc", name="x", dtype=x_dtype, shape=(4, 3)) + y = pt.matrix("y", dtype=y_dtype, shape=(3, 5)) + z = ps.structured_dot(x, y) + + x_test = scipy.sparse.csc_matrix( + np.array([[97, 0, 0], [0, 83, 0], [0, 0, 71], [42, 0, 0]], dtype=x_dtype) + ) + y_test = np.array( + [ + [9.12345, -3.98765, 7.55555, 1.23456, -5.67890], + [2.34567, 8.76543, -4.32109, 6.54321, 0.98765], + [-1.11111, 3.33333, 9.99999, -7.77777, 2.22222], + ], + dtype=y_dtype, + ) + + def strict_assert(a, b): + if scipy.sparse.issparse(a): + a = a.toarray() + if scipy.sparse.issparse(b): + b = b.toarray() + np.testing.assert_allclose(a, b, rtol=1e-14, atol=0, strict=True) + + compare_numba_and_py_sparse([x, y], z, [x_test, y_test], assert_fn=strict_assert) + + @pytest.mark.parametrize("x_format", ["csr", "csc"]) @pytest.mark.parametrize("y_format", ["csr", "csc", "dense"]) @pytest.mark.parametrize("x_shape, y_shape", DOT_SHAPES)