Skip to content

Add _vectorize_node dispatchers for sparse ops#2190

Open
jaanerik wants to merge 3 commits into
pymc-devs:mainfrom
jaanerik:sparse-vectorize-dispatchers
Open

Add _vectorize_node dispatchers for sparse ops#2190
jaanerik wants to merge 3 commits into
pymc-devs:mainfrom
jaanerik:sparse-vectorize-dispatchers

Conversation

@jaanerik
Copy link
Copy Markdown
Contributor

Description

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix

jaanerik and others added 3 commits May 29, 2026 12:13
The default Blockwise-based fallback in pytensor/graph/replace.py wraps
ops in Blockwise and rebuilds their make_node with dense dummy core
inputs, which contradicts the sparse-input contract enforced by
as_sparse_variable. As a result, vectorize_graph crashes with
"Variable type field must be a SparseTensorType" the moment it
encounters any sparse op — pmx-extras pathfinder hits this whenever a
PyMC model uses a sparse projection (e.g. a sum-to-zero constraint
encoded as pt.dot(flat, as_sparse_variable(csr))).

This patch:

- Registers an explicit dispatcher for StructuredDot that batches the
  dense (right) input via a moveaxis+reshape trick while keeping the
  sparse (left) input unbatched (scipy has no batched-sparse type).
  Raises NotImplementedError with a clear message if the caller tries
  to batch the sparse input.
- Registers NotImplementedError stubs for the other sparse ops likely
  to appear in user graphs (TrueDot, AddSS, AddSSData, AddSD,
  SparseSparseMultiply, SparseDenseMultiply) so callers see a
  descriptive error instead of the cryptic as_sparse_variable TypeError
  from the Blockwise fallback.
Add TestVectorizeSparse covering the StructuredDot dispatcher (batched
dense input, no-batch no-op, batched-sparse error) and the AddSD /
SparseDenseMultiply NotImplementedError stubs. The structured_dot test
reproduces the original "Variable type field must be a SparseTensorType"
crash without the dispatcher (issue pymc-devs#2189).

Drop the NotImplementedError stubs for the all-sparse-input ops (TrueDot,
AddSS, AddSSData, SparseSparseMultiply): a sparse input can never become
batched, so vectorize_graph never dispatches to them. Keep only the
reachable AddSD / SparseDenseMultiply, and reword the error since AddSD's
output is dense, not sparse.

Move the _vectorize_node import to the top of the module (no circular
import) to satisfy ruff E402.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jaanerik jaanerik force-pushed the sparse-vectorize-dispatchers branch from 1851ec5 to 0659b8f Compare May 29, 2026 18:31
Comment thread pytensor/sparse/math.py
usmm = Usmm()


# ---------------------------------------------------------------------------
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove these global comments/separators

Comment thread pytensor/sparse/math.py
# That contradicts the sparse-input contract enforced by as_sparse_variable,
# so every sparse op needs a custom dispatcher (or a clear NotImplementedError).
@_vectorize_node.register(StructuredDot)
def _vectorize_structured_dot(op, node, batch_a, batch_b):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@ricardoV94 ricardoV94 added the enhancement New feature or request label May 29, 2026
@ricardoV94
Copy link
Copy Markdown
Member

looks good, just need style cleanup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: Sparse vectorize dispatchers

2 participants