Skip to content

Numba: Sum(ExpandDims(x)) returns incorrect result #2131

@williambdean

Description

@williambdean

Description

Sum{axes=None}(ExpandDims{axis=0}(x)) produces incorrect results when compiled with the Numba linker. The py and cvm linkers give the correct result.

Reproducer

import numpy as np
import pytensor
import pytensor.tensor as pt

x = pt.matrix("x")
out = pt.sum(x[None])  # Sum{axes=None}(ExpandDims{axis=0}(x))

x_val = np.array([[1, 0], [3, 4]], dtype="float64")  # sum = 8

for linker in ("py", "cvm", "numba"):
    mode = pytensor.compile.mode.get_mode("FAST_RUN").__class__(
        linker, pytensor.compile.mode.get_mode("FAST_RUN").optimizer
    )
    f = pytensor.function([x], out, mode=mode)
    print(f"{linker}: {f(x_val)}")

# Expected: 8.0 for all linkers
# Actual (numba): 1.0

Expected behavior

All three linkers should return 8.0 (the sum of all elements in the 2x2 matrix).

Actual behavior

  • py: 8.0 ✅
  • cvm: 8.0 ✅
  • numba: 1.0 ❌

Environment

  • PyTensor HEAD (commit dcfe621)
  • Python 3.13.2
  • macOS

Notes

Exposed by the local_careduce_join rewrite (#2130) which pushes CAReduce(Join(...)) through the join, producing Sum(ExpandDims(x)) in some cases.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions