Add _vectorize_node dispatchers for sparse ops#2190
Open
jaanerik wants to merge 3 commits into
Open
Conversation
The default Blockwise-based fallback in pytensor/graph/replace.py wraps ops in Blockwise and rebuilds their make_node with dense dummy core inputs, which contradicts the sparse-input contract enforced by as_sparse_variable. As a result, vectorize_graph crashes with "Variable type field must be a SparseTensorType" the moment it encounters any sparse op — pmx-extras pathfinder hits this whenever a PyMC model uses a sparse projection (e.g. a sum-to-zero constraint encoded as pt.dot(flat, as_sparse_variable(csr))). This patch: - Registers an explicit dispatcher for StructuredDot that batches the dense (right) input via a moveaxis+reshape trick while keeping the sparse (left) input unbatched (scipy has no batched-sparse type). Raises NotImplementedError with a clear message if the caller tries to batch the sparse input. - Registers NotImplementedError stubs for the other sparse ops likely to appear in user graphs (TrueDot, AddSS, AddSSData, AddSD, SparseSparseMultiply, SparseDenseMultiply) so callers see a descriptive error instead of the cryptic as_sparse_variable TypeError from the Blockwise fallback.
Add TestVectorizeSparse covering the StructuredDot dispatcher (batched dense input, no-batch no-op, batched-sparse error) and the AddSD / SparseDenseMultiply NotImplementedError stubs. The structured_dot test reproduces the original "Variable type field must be a SparseTensorType" crash without the dispatcher (issue pymc-devs#2189). Drop the NotImplementedError stubs for the all-sparse-input ops (TrueDot, AddSS, AddSSData, SparseSparseMultiply): a sparse input can never become batched, so vectorize_graph never dispatches to them. Keep only the reachable AddSD / SparseDenseMultiply, and reword the error since AddSD's output is dense, not sparse. Move the _vectorize_node import to the top of the module (no circular import) to satisfy ruff E402. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1851ec5 to
0659b8f
Compare
ricardoV94
reviewed
May 29, 2026
| usmm = Usmm() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- |
Member
There was a problem hiding this comment.
please remove these global comments/separators
ricardoV94
reviewed
May 29, 2026
| # That contradicts the sparse-input contract enforced by as_sparse_variable, | ||
| # so every sparse op needs a custom dispatcher (or a clear NotImplementedError). | ||
| @_vectorize_node.register(StructuredDot) | ||
| def _vectorize_structured_dot(op, node, batch_a, batch_b): |
Member
|
looks good, just need style cleanup |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Related Issue
Checklist
Type of change