From 10a374dbe9030d71a4220b8122339cf447e299df Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 18 Jun 2026 16:10:40 +0900 Subject: [PATCH] [Frontend] vcix: lower fused matmul whose operand is vector_store'd, fix exp chunk Two fixes to the C++->Python vcix port (lower_to_vcix.py) that SDPA exercises but the gemm/bmm/conv tests do not: - _lower_matmul bailed with 'if ATag is None or BTag is None: return False', gating on an MVIN dma_start tag for both operands. In SDPA's fused scores.V matmul, operand B is the softmax output produced in place by affine.vector_store, not DMAed, so BTag stayed None and the matmul was left un-lowered -> wrong attention output. Mirror the C++ MatmulOpLowering: an operand is initialized by either a dma_start OR a preceding affine.vector_store into its root memref; bail only when an operand is truly uninitialized. BTag/BAsync stay None/0 and are only read under 'if BAsync:', so the B dma_wait is correctly skipped (as in C++). - _make_sf_vc_v_iv n>1 transcendental chunking called vector.ExtractStridedSliceOp(offsets, sizes, strides, vec) -- wrong arg order, missing the result type and vector operand, raising TypeError under these MLIR bindings. Pass (result=legal_ty, vector=vec, offsets, sizes, strides). Only reached by large transcendentals (n>1), e.g. SDPA softmax exp, so CI's small-tile (n==1) tests never hit it. Validated end-to-end (Spike+TOGSim allclose): SDPA 56 cases pass (was crash/wrong); matmul/bmm/conv2d regress clean. Bisected: C++ vcix passes SDPA, Python vcix did not; exp chunking and fine-grained ruled out separately. Co-Authored-By: Claude Opus 4.8 --- .../mlir/passes/lower_to_vcix.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py index 4286ba19..449fc3a5 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py +++ b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py @@ -113,9 +113,10 @@ def _make_sf_vc_v_iv(vec, op_vt, n, legal_ty, opcode, imm): else: for i in range(total // elt_count): ext = vector.ExtractStridedSliceOp( + legal_ty, vec, ir.ArrayAttr.get([_i64(i * elt_count)]), ir.ArrayAttr.get([_i64(elt_count)]), - ir.ArrayAttr.get([_i64(1)]), vec).result + ir.ArrayAttr.get([_i64(1)])).result v = _viv(ext, legal_ty, opcode, imm, rvl) res = vector.InsertStridedSliceOp( v, res, ir.ArrayAttr.get([_i64(i * elt_count)]), @@ -364,7 +365,28 @@ def a64(v): return ir.IntegerAttr.get(i64, v) AAsync = BAsync = BiasAsync = 0 BiasIdx = None subtileM, subtileN, subtileK = M, N, K + # Mirror the C++ isAInitialized / isBInitialized flags: an operand is + # "initialized" either by an MVIN dma_start (tag found below) or by a + # preceding affine.vector_store into its root memref (the fused case, e.g. + # SDPA scores.V where B is the softmax output produced in-place, not DMAed). + isAInit = isBInit = False + + def _root(v): + owner = v.owner + if not isinstance(owner, ir.Block): + nm = owner.name + if nm in ("memref.reinterpret_cast", "memref.cast"): + return owner.operands[0] + return v + rootA, rootB = _root(A), _root(B) for o in _iter_ops(outer[-1].regions[0].blocks[0]): + if o.operation.name == "affine.vector_store": + dest = _root(o.operation.operands[1]) + if dest == rootA: + isAInit = True + elif dest == rootB: + isBInit = True + continue if o.operation.name != "memref.dma_start": continue d = _DmaView(o.operation) @@ -380,16 +402,18 @@ def a64(v): return ir.IntegerAttr.get(i64, v) sub = d.subtile_size() if argn == idxMap[0]: ATag, AAsync = d.tag, d.is_async() + isAInit = True if len(sub) >= 2: subtileM, subtileK = sub[-2], sub[-1] elif argn == idxMap[1]: BTag, BAsync = d.tag, d.is_async() + isBInit = True if len(sub) >= 2: subtileK, subtileN = sub[-2], sub[-1] elif argn == idxMap[2]: BiasTag, BiasAsync = d.tag, d.is_async() BiasIdx = d.tag_idx - if ATag is None or BTag is None: + if not isAInit or not isBInit: return False KStep = subtileK