diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py index 4286ba19..449fc3a5 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py +++ b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py @@ -113,9 +113,10 @@ def _make_sf_vc_v_iv(vec, op_vt, n, legal_ty, opcode, imm): else: for i in range(total // elt_count): ext = vector.ExtractStridedSliceOp( + legal_ty, vec, ir.ArrayAttr.get([_i64(i * elt_count)]), ir.ArrayAttr.get([_i64(elt_count)]), - ir.ArrayAttr.get([_i64(1)]), vec).result + ir.ArrayAttr.get([_i64(1)])).result v = _viv(ext, legal_ty, opcode, imm, rvl) res = vector.InsertStridedSliceOp( v, res, ir.ArrayAttr.get([_i64(i * elt_count)]), @@ -364,7 +365,28 @@ def a64(v): return ir.IntegerAttr.get(i64, v) AAsync = BAsync = BiasAsync = 0 BiasIdx = None subtileM, subtileN, subtileK = M, N, K + # Mirror the C++ isAInitialized / isBInitialized flags: an operand is + # "initialized" either by an MVIN dma_start (tag found below) or by a + # preceding affine.vector_store into its root memref (the fused case, e.g. + # SDPA scores.V where B is the softmax output produced in-place, not DMAed). + isAInit = isBInit = False + + def _root(v): + owner = v.owner + if not isinstance(owner, ir.Block): + nm = owner.name + if nm in ("memref.reinterpret_cast", "memref.cast"): + return owner.operands[0] + return v + rootA, rootB = _root(A), _root(B) for o in _iter_ops(outer[-1].regions[0].blocks[0]): + if o.operation.name == "affine.vector_store": + dest = _root(o.operation.operands[1]) + if dest == rootA: + isAInit = True + elif dest == rootB: + isBInit = True + continue if o.operation.name != "memref.dma_start": continue d = _DmaView(o.operation) @@ -380,16 +402,18 @@ def a64(v): return ir.IntegerAttr.get(i64, v) sub = d.subtile_size() if argn == idxMap[0]: ATag, AAsync = d.tag, d.is_async() + isAInit = True if len(sub) >= 2: subtileM, subtileK = sub[-2], sub[-1] elif argn == idxMap[1]: BTag, BAsync = d.tag, d.is_async() + isBInit = True if len(sub) >= 2: subtileK, subtileN = sub[-2], sub[-1] elif argn == idxMap[2]: BiasTag, BiasAsync = d.tag, d.is_async() BiasIdx = d.tag_idx - if ATag is None or BTag is None: + if not isAInit or not isBInit: return False KStep = subtileK