Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions PyTorchSimFrontend/mlir/graph_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
realize() (not a clone, which Inductor inlines) is what actually forces the buffer
boundary; see the PoC notes in docs.

Gated by TORCHSIM_GRAPH_COPY (install() is a no-op otherwise). Behavior-neutral
unless a genuine incompatible-radix conflict is detected.
Default-on; set TORCHSIM_GRAPH_COPY=0 to disable (install() is then a no-op).
Behavior-neutral unless a genuine incompatible-radix conflict is detected.
"""
import os
from torch._inductor import lowering as L
Expand Down Expand Up @@ -64,7 +64,14 @@ def _relayout_args(args):
# multi-var-view input) this is just that operand's shape -- still enough to
# detect a multi-var floor and copy_input it (case 7); the 2-operand radix
# conflict (case 5) naturally needs >=2 operands.
ranges = max((t.get_size() for t in tbs), key=len)
# Per-dim max extent over the max-rank operands (order-independent). Picking a
# single operand by rank alone (max key=len) would, for two equal-rank operands
# with different per-dim extents, take a broadcast-from operand's smaller shape
# and then miss the genuine conflict on the broadcast-to dim.
maxrank = max(len(t.get_size()) for t in tbs)
full = [t.get_size() for t in tbs if len(t.get_size()) == maxrank]
ranges = [max((s[d] for s in full), key=lambda v: (axis_split._as_int(v) or -1))
for d in range(maxrank)]
extents = [axis_split._as_int(s) for s in ranges]
dbg = os.environ.get("TORCHSIM_GRAPH_COPY_DEBUG")
if dbg:
Expand Down
11 changes: 11 additions & 0 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,17 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
index = index.subs({s: 0 for s in indirect_syms}, simultaneous=True)
indirect_dims = [f"{i}" for i in indirect_syms]

# axis-split + graph-copy linearize aligned floor/mod upstream. Anything that
# reaches here still carrying floor/mod (store-side ModularIndexing,
# reduction-axis floor/mod, incompatible-radix views) would be silently
# mis-strided in the dram_stride computation below, so fail loudly instead.
if index.has(FloorDiv) or index.has(ModularIndexing):
raise NotImplementedError(
f"Unlinearized floor/mod in DMA index: {index}. axis-split/graph-copy "
f"did not eliminate it; this view is unsupported "
f"(see docs/axis-split-scheduling.md)."
)

# Reduction can have two type of tile size
if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)):
local_dims = total_dims # Brodatcast tile shape
Expand Down
2 changes: 1 addition & 1 deletion PyTorchSimFrontend/mlir/mlir_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,6 @@ def get_order(n):


# Install the graph-copy (incompatible-radix relayout) lowering hook once at import.
# No-op unless TORCHSIM_GRAPH_COPY is set; see graph_copy.py.
# Default-on; set TORCHSIM_GRAPH_COPY=0 to disable. See graph_copy.py.
from . import graph_copy as _graph_copy
_graph_copy.install()
15 changes: 13 additions & 2 deletions PyTorchSimFrontend/mlir/passes/decompose_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_at
reassoc = ArrayAttr.get(
[ArrayAttr.get([IntegerAttr.get(i64, d) for d in g]) for g in groups])
collapsed_ty = MemRefType.get(target, elem, memory_space=space)
keep = [g[-1] for g in groups] # the non-unit dim in each group
# the non-unit dim in each group (g[-1] is wrong when trailing unit dims
# attach after it, e.g. [..,4,1,1] -> the kept dim must be the extent-4 one).
keep = [next((d for d in g if tile_shape[d] > 1), g[-1]) for g in groups]
dr_attr = ArrayAttr.get([IntegerAttr.get(i64, dram_stride[i]) for i in keep])
tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[i]) for i in keep])
st_attr = (ArrayAttr.get([IntegerAttr.get(i64, subtile[i]) for i in keep])
Expand Down Expand Up @@ -198,7 +200,16 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_at
st_attr = (ArrayAttr.get([IntegerAttr.get(i64, subtile[d]) for d in inner])
if subtile is not None else None)
# the vlane axis must survive into the inner descriptor (it is the lane dim).
new_vlane = inner.index(vlane_axis) if vlane_axis in inner else 0
if vlane_axis in inner:
new_vlane = inner.index(vlane_axis)
elif vlane_axis in peeled:
# lane dim peeled into the outer loop nest: it cannot be expressed in the
# <=4D descriptor, and _phys's lane-banking assumes it is a real axis.
raise NotImplementedError(
f"vlane split axis {vlane_axis} peeled into the outer loop nest; "
f">4D DMA peel cannot place the lane dim in the <=4D descriptor")
else:
new_vlane = 0 # vlane axis is a unit dim; lane on the first inner dim

# Lane-banked physical stride for split-outer dims (vlane_stride defaults to 1).
split_extent = tile_shape[vlane_axis]
Expand Down
48 changes: 44 additions & 4 deletions PyTorchSimFrontend/mlir/passes/lower_to_vcix.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@


def _sew(elt_ty):
if ir.F16Type.isinstance(elt_ty) or ir.BF16Type.isinstance(elt_ty):
return 16
# Mirror C++ legalizeVectorType: only F32/F64/integer/index get a sew. F16/BF16
# return 0 so transcendental math ops stay unlowered (-convert-math-to-llvm),
# matching the validated path -- do NOT emit VCIX for them here.
if ir.F32Type.isinstance(elt_ty):
return 32
if ir.F64Type.isinstance(elt_ty):
Expand All @@ -59,6 +60,8 @@ def _log2(x):

def _legalize_vector_type(vt, vlen):
"""Mirror legalizeVectorType: return (n, legal_vector_type) or (0, None)."""
if len(vt.shape) != 1: # C++ guards getRank() != 1
return 0, None
elt_ty = vt.element_type
sew = _sew(elt_ty)
if sew == 0:
Expand Down Expand Up @@ -113,9 +116,10 @@ def _make_sf_vc_v_iv(vec, op_vt, n, legal_ty, opcode, imm):
else:
for i in range(total // elt_count):
ext = vector.ExtractStridedSliceOp(
legal_ty, vec,
ir.ArrayAttr.get([_i64(i * elt_count)]),
ir.ArrayAttr.get([_i64(elt_count)]),
ir.ArrayAttr.get([_i64(1)]), vec).result
ir.ArrayAttr.get([_i64(1)])).result
v = _viv(ext, legal_ty, opcode, imm, rvl)
res = vector.InsertStridedSliceOp(
v, res, ir.ArrayAttr.get([_i64(i * elt_count)]),
Expand Down Expand Up @@ -337,6 +341,12 @@ def _lower_matmul(op, SS, vlen):
mtA, mtB = ir.MemRefType(A.type), ir.MemRefType(B.type)
elt = mtA.element_type
M, K, N = mtA.shape[0], mtA.shape[1], mtB.shape[1]
# Mirror the C++ guard: a dimension > SS must be an exact multiple, else the
# N//SS / K//SS loop trip counts below silently drop the tail tile.
for _dim, _name in ((M, "M"), (N, "N"), (K, "K")):
if _dim > SS and _dim % SS != 0:
raise NotImplementedError(
f"matmul {_name}={_dim} must be a multiple of systolic size {SS} when > {SS}")
elen = _elt_bits(elt)
nr_element = vlen // elen
i64 = ir.IntegerType.get_signless(64)
Expand Down Expand Up @@ -364,7 +374,29 @@ def a64(v): return ir.IntegerAttr.get(i64, v)
AAsync = BAsync = BiasAsync = 0
BiasIdx = None
subtileM, subtileN, subtileK = M, N, K
a_subk = b_subk = None
# Mirror the C++ isAInitialized / isBInitialized flags: an operand is
# "initialized" either by an MVIN dma_start (tag found below) or by a
# preceding affine.vector_store into its root memref (the fused case, e.g.
# SDPA scores.V where B is the softmax output produced in-place, not DMAed).
isAInit = isBInit = False

def _root(v):
owner = v.owner
if not isinstance(owner, ir.Block):
nm = owner.name
if nm in ("memref.reinterpret_cast", "memref.cast"):
return owner.operands[0]
return v
rootA, rootB = _root(A), _root(B)
for o in _iter_ops(outer[-1].regions[0].blocks[0]):
if o.operation.name == "affine.vector_store":
dest = _root(o.operation.operands[1])
if dest == rootA:
isAInit = True
elif dest == rootB:
isBInit = True
continue
if o.operation.name != "memref.dma_start":
continue
d = _DmaView(o.operation)
Expand All @@ -380,17 +412,25 @@ def a64(v): return ir.IntegerAttr.get(i64, v)
sub = d.subtile_size()
if argn == idxMap[0]:
ATag, AAsync = d.tag, d.is_async()
isAInit = True
if len(sub) >= 2:
subtileM, subtileK = sub[-2], sub[-1]
a_subk = sub[-1]
elif argn == idxMap[1]:
BTag, BAsync = d.tag, d.is_async()
isBInit = True
if len(sub) >= 2:
subtileK, subtileN = sub[-2], sub[-1]
b_subk = sub[-2]
elif argn == idxMap[2]:
BiasTag, BiasAsync = d.tag, d.is_async()
BiasIdx = d.tag_idx
if ATag is None or BTag is None:
if not isAInit or not isBInit:
return False
# A and B must agree on the K subtile (last-writer-wins would otherwise pick one silently).
if a_subk is not None and b_subk is not None and a_subk != b_subk:
raise NotImplementedError(
f"Mismatched subtile K between A ({a_subk}) and B ({b_subk}) matmul operands")

KStep = subtileK
push_length = min(subtileM, SS)
Expand Down