Description
PyTensor currently doesn't simplify a*b + a*c to a*(b + c), nor does it cancel a*b - b*a to 0. Both are missing for the same underlying reason: Mul arguments aren't canonicalized into a deterministic order, and there's no distributive factorization rewrite across Add/Sub. Concrete cases that all currently miss: a*b ± a*c, b*a ± c*a, mixed positions like a*b + c*a, the n-ary form a*b + a*c + a*d, and the unit-factor degenerate cases a ± a*c. The fix is two small changes:
(1) sort Mul arguments by a stable key during canonicalize, so a*b and b*a become the same node — this alone makes a*b - b*a → 0 fall out via the existing x - x → 0 rewrite;
(2) add a rewrite that recognizes Add(Mul(a, X_1), Mul(a, X_2), ...) and rewrites to Mul(a, Add(X_1, X_2, ...)). The same canonical-order principle applies to other commutative variadic ops (Add, Maximum, Minimum, bitwise/logical And/Or/Xor, and the axis tuple of commutative reductions); Mul is the immediate case here but the canonicalization step is worth doing uniformly.
The direct saving is one elementwise multiply per factored group: a*b + a*c does two multiplies and one add, a*(b + c) does one multiply and one add. For pure-scalar/pure-elementwise graphs Composite fusion compiles both forms to one kernel with the same per-element work, so the rewrite is mostly a no-op there.
The analogous rewrite for Dot (A@B + A@C → A@(B+C)) saves an entire matmul and is a separate, higher-priority follow-up; it would be natural to implement once the Mul version exists as a template.
Description
PyTensor currently doesn't simplify
a*b + a*ctoa*(b + c), nor does it cancela*b - b*ato 0. Both are missing for the same underlying reason: Mul arguments aren't canonicalized into a deterministic order, and there's no distributive factorization rewrite across Add/Sub. Concrete cases that all currently miss:a*b ± a*c,b*a ± c*a, mixed positions likea*b + c*a, the n-ary forma*b + a*c + a*d, and the unit-factor degenerate casesa ± a*c. The fix is two small changes:(1) sort Mul arguments by a stable key during canonicalize, so
a*bandb*abecome the same node — this alone makesa*b - b*a → 0fall out via the existingx - x → 0rewrite;(2) add a rewrite that recognizes Add(Mul(a, X_1), Mul(a, X_2), ...) and rewrites to Mul(a, Add(X_1, X_2, ...)). The same canonical-order principle applies to other commutative variadic ops (Add, Maximum, Minimum, bitwise/logical And/Or/Xor, and the axis tuple of commutative reductions); Mul is the immediate case here but the canonicalization step is worth doing uniformly.
The direct saving is one elementwise multiply per factored group:
a*b + a*cdoes two multiplies and one add,a*(b + c)does one multiply and one add. For pure-scalar/pure-elementwise graphs Composite fusion compiles both forms to one kernel with the same per-element work, so the rewrite is mostly a no-op there.The analogous rewrite for Dot (A@B + A@C → A@(B+C)) saves an entire matmul and is a separate, higher-priority follow-up; it would be natural to implement once the Mul version exists as a template.