Description
Sum{axes=None}(ExpandDims{axis=0}(x)) produces incorrect results when compiled with the Numba linker. The py and cvm linkers give the correct result.
Reproducer
import numpy as np
import pytensor
import pytensor.tensor as pt
x = pt.matrix("x")
out = pt.sum(x[None]) # Sum{axes=None}(ExpandDims{axis=0}(x))
x_val = np.array([[1, 0], [3, 4]], dtype="float64") # sum = 8
for linker in ("py", "cvm", "numba"):
mode = pytensor.compile.mode.get_mode("FAST_RUN").__class__(
linker, pytensor.compile.mode.get_mode("FAST_RUN").optimizer
)
f = pytensor.function([x], out, mode=mode)
print(f"{linker}: {f(x_val)}")
# Expected: 8.0 for all linkers
# Actual (numba): 1.0
Expected behavior
All three linkers should return 8.0 (the sum of all elements in the 2x2 matrix).
Actual behavior
- py: 8.0 ✅
- cvm: 8.0 ✅
- numba: 1.0 ❌
Environment
- PyTensor HEAD (commit dcfe621)
- Python 3.13.2
- macOS
Notes
Exposed by the local_careduce_join rewrite (#2130) which pushes CAReduce(Join(...)) through the join, producing Sum(ExpandDims(x)) in some cases.
Description
Sum{axes=None}(ExpandDims{axis=0}(x))produces incorrect results when compiled with the Numba linker. Thepyandcvmlinkers give the correct result.Reproducer
Expected behavior
All three linkers should return 8.0 (the sum of all elements in the 2x2 matrix).
Actual behavior
Environment
Notes
Exposed by the
local_careduce_joinrewrite (#2130) which pushesCAReduce(Join(...))through the join, producingSum(ExpandDims(x))in some cases.