Feature or Model Request
Problem
The current MoE sparse_matmul implementation (layers/moe.py) and the underlying megablox/ops.py conflate two orthogonal concerns:
- Quantization orchestration — obtaining QWIX rules, quantizing lhs/rhs tensors
- Kernel backend dispatch — choosing between megablox Pallas, tokamax Pallas, or
jax.lax.ragged_dot
This leads to confusing control flow where mblx.gmm() acts as both a "quantization-aware GMM wrapper" and a "backend router" via the
use_tokamax_backend flag. The megablox module shouldn't need to know about tokamax at all.
Current call paths in sparse_matmul (simplified)
if use_tokamax_gmm:
if quantization:
mblx.gmm(..., use_tokamax_backend=True) # megablox wraps tokamax?
else:
tokamax.ragged_dot(...) # direct tokamax call
elif megablox:
mblx.gmm(..., use_tokamax_backend=False) # megablox native
else:
jax.lax.ragged_dot(...) # JAX fallback
Issues:
- mblx.gmm imports and calls tokamax_backend internally — megablox shouldn't depend on tokamax
- Quantization logic is buried inside ops.py fwd/bwd functions (qpl.get_current_rule, qpl.quantize) rather than being a composable layer
- 4 code paths in sparse_matmul with subtle differences in tiling, quantization, and dtype handling
- Adding a new backend (e.g., Mosaic, Triton) requires modifying megablox internals
Proposed Direction
Separate the three concerns into composable layers:
- Backend interface — a common GmmBackend protocol with gmm() / tgmm() methods, implemented by megablox, tokamax, and ragged_dot backends
independently
- Quantization wrapper — a standalone layer that handles qpl.get_current_rule / qpl.quantize and delegates the actual matmul to any backend
- sparse_matmul — selects backend based on config, wraps it with quantization if needed, calls it
Desired flow (conceptual)
backend = get_backend(config) # tokamax / megablox / ragged_dot
if quantization:
backend = QuantizedGmm(backend) # wraps any backend with FP8 quantization
output = backend.gmm(inputs, kernel, group_sizes, tiling)
This would:
- Remove tokamax imports from megablox
- Make quantization composable with any backend
- Simplify sparse_matmul from 4 branches to ~2 lines
- Make adding new backends trivial
Additional Context
No response
Feature or Model Request
Problem
The current MoE
sparse_matmulimplementation (layers/moe.py) and the underlyingmegablox/ops.pyconflate two orthogonal concerns:jax.lax.ragged_dotThis leads to confusing control flow where
mblx.gmm()acts as both a "quantization-aware GMM wrapper" and a "backend router" via theuse_tokamax_backendflag. The megablox module shouldn't need to know about tokamax at all.Current call paths in
sparse_matmul(simplified)Issues:
Proposed Direction
Separate the three concerns into composable layers:
independently
Desired flow (conceptual)
backend = get_backend(config) # tokamax / megablox / ragged_dot
if quantization:
backend = QuantizedGmm(backend) # wraps any backend with FP8 quantization
output = backend.gmm(inputs, kernel, group_sizes, tiling)
This would:
Additional Context
No response