Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions PyTorchSimFrontend/mlir/passes/lower_to_vcix.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ def _make_sf_vc_v_iv(vec, op_vt, n, legal_ty, opcode, imm):
else:
for i in range(total // elt_count):
ext = vector.ExtractStridedSliceOp(
legal_ty, vec,
ir.ArrayAttr.get([_i64(i * elt_count)]),
ir.ArrayAttr.get([_i64(elt_count)]),
ir.ArrayAttr.get([_i64(1)]), vec).result
ir.ArrayAttr.get([_i64(1)])).result
v = _viv(ext, legal_ty, opcode, imm, rvl)
res = vector.InsertStridedSliceOp(
v, res, ir.ArrayAttr.get([_i64(i * elt_count)]),
Expand Down Expand Up @@ -364,7 +365,28 @@ def a64(v): return ir.IntegerAttr.get(i64, v)
AAsync = BAsync = BiasAsync = 0
BiasIdx = None
subtileM, subtileN, subtileK = M, N, K
# Mirror the C++ isAInitialized / isBInitialized flags: an operand is
# "initialized" either by an MVIN dma_start (tag found below) or by a
# preceding affine.vector_store into its root memref (the fused case, e.g.
# SDPA scores.V where B is the softmax output produced in-place, not DMAed).
isAInit = isBInit = False

def _root(v):
owner = v.owner
if not isinstance(owner, ir.Block):
nm = owner.name
if nm in ("memref.reinterpret_cast", "memref.cast"):
return owner.operands[0]
return v
rootA, rootB = _root(A), _root(B)
for o in _iter_ops(outer[-1].regions[0].blocks[0]):
if o.operation.name == "affine.vector_store":
dest = _root(o.operation.operands[1])
if dest == rootA:
isAInit = True
elif dest == rootB:
isBInit = True
continue
if o.operation.name != "memref.dma_start":
continue
d = _DmaView(o.operation)
Expand All @@ -380,16 +402,18 @@ def a64(v): return ir.IntegerAttr.get(i64, v)
sub = d.subtile_size()
if argn == idxMap[0]:
ATag, AAsync = d.tag, d.is_async()
isAInit = True
if len(sub) >= 2:
subtileM, subtileK = sub[-2], sub[-1]
elif argn == idxMap[1]:
BTag, BAsync = d.tag, d.is_async()
isBInit = True
if len(sub) >= 2:
subtileK, subtileN = sub[-2], sub[-1]
elif argn == idxMap[2]:
BiasTag, BiasAsync = d.tag, d.is_async()
BiasIdx = d.tag_idx
if ATag is None or BTag is None:
if not isAInit or not isBInit:
return False

KStep = subtileK
Expand Down