Skip to content

Canonicalize associative Op argument order and add distributive factorization for Add/Sub #2140

@ricardoV94

Description

@ricardoV94

Description

PyTensor currently doesn't simplify a*b + a*c to a*(b + c), nor does it cancel a*b - b*a to 0. Both are missing for the same underlying reason: Mul arguments aren't canonicalized into a deterministic order, and there's no distributive factorization rewrite across Add/Sub. Concrete cases that all currently miss: a*b ± a*c, b*a ± c*a, mixed positions like a*b + c*a, the n-ary form a*b + a*c + a*d, and the unit-factor degenerate cases a ± a*c. The fix is two small changes:

(1) sort Mul arguments by a stable key during canonicalize, so a*b and b*a become the same node — this alone makes a*b - b*a → 0 fall out via the existing x - x → 0 rewrite;

(2) add a rewrite that recognizes Add(Mul(a, X_1), Mul(a, X_2), ...) and rewrites to Mul(a, Add(X_1, X_2, ...)). The same canonical-order principle applies to other commutative variadic ops (Add, Maximum, Minimum, bitwise/logical And/Or/Xor, and the axis tuple of commutative reductions); Mul is the immediate case here but the canonicalization step is worth doing uniformly.

The direct saving is one elementwise multiply per factored group: a*b + a*c does two multiplies and one add, a*(b + c) does one multiply and one add. For pure-scalar/pure-elementwise graphs Composite fusion compiles both forms to one kernel with the same per-element work, so the rewrite is mostly a no-op there.

The analogous rewrite for Dot (A@B + A@C → A@(B+C)) saves an entire matmul and is a separate, higher-priority follow-up; it would be natural to implement once the Mul version exists as a template.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions