Description
Pattern
For a unary Elemwise f and any SetSubtensor variant (basic, advanced, inc):
# Before
f(x[idx].set(c))
# After
f(x)[idx].set(f(c))
Motivating case: LKJCorr logp
The LKJCorr Cholesky factor L is built by filling a zero matrix:
L = zeros(n, n)
L = L[tril_indices].set(params) # fill lower triangle
L = L[diag_indices].set(1) # set diagonal to 1
The logp computes Sqr(L). Currently Sqr operates on the full (n, n) matrix. Lifting it through the SetSubtensors gives:
└─ AdvancedSetSubtensor [diag=Sqr(1)=1]
└─ AdvancedSetSubtensor [tril=Sqr(params)]
└─ Alloc(Sqr(0)=0, n, n)
Now Sqr only applies to the params vector (length n*(n-1)/2) instead of the full matrix. The scalar applications Sqr(0) and Sqr(1) stay scalar / constant-fold.
When to apply
The rewrite is always valid for set mode, but only profitable when the base or the set-value is scalar/broadcast. In that case f on that piece stays a scalar op — cheap. If both base and set-value are full-sized arrays, lifting splits one f into two full-sized ones with no benefit.
Guard: apply when all(base.type.broadcastable) OR all(set_value.type.broadcastable).
Set mode only — for inc mode, f(x + inc_at_idx) ≠ f(x)[idx].set(f(inc)) in general. As always inc on zeros is semantically the same as set if there are no repeats (shouldn't we canonicalize into set then?).
Scope
- All SetSubtensor variants:
SetSubtensor, AdvancedSetSubtensor, AdvancedSetSubtensor1
- Unary Elemwise always qualifies. Multi-input Elemwise qualifies when the other inputs are scalar/broadcast (e.g.,
Add(x[idx].set(c), 1) → Add(x, 1)[idx].set(Add(c, 1))). If another input has full-sized dimensions, lifting can't be done without splitting the other inputs, which isn't trivial
Description
Pattern
For a unary Elemwise
fand any SetSubtensor variant (basic, advanced, inc):Motivating case: LKJCorr logp
The LKJCorr Cholesky factor
Lis built by filling a zero matrix:The logp computes
Sqr(L). CurrentlySqroperates on the full(n, n)matrix. Lifting it through the SetSubtensors gives:Now
Sqronly applies to theparamsvector (lengthn*(n-1)/2) instead of the full matrix. The scalar applicationsSqr(0)andSqr(1)stay scalar / constant-fold.When to apply
The rewrite is always valid for set mode, but only profitable when the base or the set-value is scalar/broadcast. In that case
fon that piece stays a scalar op — cheap. If both base and set-value are full-sized arrays, lifting splits onefinto two full-sized ones with no benefit.Guard: apply when
all(base.type.broadcastable)ORall(set_value.type.broadcastable).Set mode only — for
incmode,f(x + inc_at_idx) ≠ f(x)[idx].set(f(inc))in general. As always inc on zeros is semantically the same as set if there are no repeats (shouldn't we canonicalize into set then?).Scope
SetSubtensor,AdvancedSetSubtensor,AdvancedSetSubtensor1Add(x[idx].set(c), 1)→Add(x, 1)[idx].set(Add(c, 1))). If another input has full-sized dimensions, lifting can't be done without splitting the other inputs, which isn't trivial