diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index 98dbd791..54a2345b 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -172,6 +172,25 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/ops/view/test_cat.py + test_floormod_axis_split: + name: Run test_floormod_axis_split.py + runs-on: ubuntu-latest + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_floormod_axis_split.py + run: | + echo "Running test_floormod_axis_split.py" + docker run --rm \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/ops/view/test_floormod_axis_split.py + test_matmul: name: Run test_matmul.py runs-on: ubuntu-latest diff --git a/Dockerfile.base b/Dockerfile.base index 19bbbb2e..de023566 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -91,6 +91,11 @@ RUN curl -L -H "Accept: application/octet-stream" https://api.github.com/repos/P # Store RISC-V LLVM for TorchSim ENV TORCHSIM_LLVM_PATH=/riscv-llvm/bin +# MLIR Python bindings shipped inside the LLVM release artifact (built by the +# llvm-project CI with -DMLIR_ENABLE_BINDINGS_PYTHON=ON). Lets PyTorchSim load +# mlir.ir / dialects for Python-side MLIR passes. The artifact must be built +# against this image's Python (3.11) or `import mlir` fails on ABI mismatch. +ENV PYTHONPATH=/riscv-llvm/python_packages/mlir_core:$PYTHONPATH ENV TORCHSIM_DIR=/workspace/PyTorchSim # Download Spike simulator diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index efd4d4cb..492133a3 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -38,29 +38,16 @@ def dump_metadata(args, arg_attributes, path): return def mlir_compile_command(filename, vectorlane_size, vlen=256): + # The C++ -dma-fine-grained and -test-pytorchsim-to-vcix passes are ported to + # Python (passes/dma_fine_grained.py, lower_to_vcix.py), run in-process between + # loop-padding and the standard lowering. So mlir-opt now runs only loop-padding + # (-> _padded.mlir); the Python fine-grained + vcix passes produce _custom.mlir. return [re.sub(r"[ \n]+", " ", f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding \ - -dma-fine-grained='systolic-array-size={vectorlane_size}' \ - -global-idx='vlen={vlen}' \ - -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ - -test-memref-to-gemmini="vectorlane={vectorlane_size}" \ - -convert-linalg-to-loops \ - -convert-vector-to-scf='full-unroll' \ - -lower-affine \ - -finalize-memref-to-llvm \ - -lower-vector-multi-reduction \ - -convert-vector-to-llvm \ - -convert-arith-to-llvm \ - -convert-math-to-llvm \ - -convert-scf-to-cf \ - -convert-cf-to-llvm \ - -convert-func-to-llvm \ - -convert-index-to-llvm \ - -reconcile-unrealized-casts \ {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ - {filename}.mlir -o {filename}_llvm.mlir + {filename}.mlir -o {filename}_padded.mlir """, ).strip(), re.sub(r"[ \n]+", " ", @@ -88,30 +75,14 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): ).strip()] def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_size, vlen=256): + # See mlir_compile_command: -dma-fine-grained and -test-pytorchsim-to-vcix are + # Python passes run in-process; mlir-opt runs only loop-padding here. return [re.sub(r"[ \n]+", " ", f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding='timing_mode=1' \ - -dma-fine-grained='systolic-array-size={vectorlane_size}' \ - -global-idx='vlen={vlen}' \ - -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ - -test-tile-operation-graph='vectorlane={vectorlane_size} sample-mode={extension_config.CONFIG_TLS_MODE}' \ - -test-memref-to-gemmini="vectorlane={vectorlane_size} timing=1" \ - -convert-linalg-to-loops \ - -convert-vector-to-scf='full-unroll' \ - -lower-affine \ - -finalize-memref-to-llvm \ - -lower-vector-multi-reduction \ - -convert-vector-to-llvm \ - -convert-arith-to-llvm \ - -convert-math-to-llvm \ - -convert-scf-to-cf \ - -convert-cf-to-llvm \ - -convert-func-to-llvm \ - -convert-index-to-llvm \ - -reconcile-unrealized-casts \ {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ - {filename}.mlir -o {sample_filename}_llvm.mlir + {filename}.mlir -o {sample_filename}_padded.mlir """, ).strip(), re.sub(r"[ \n]+", " ", @@ -158,6 +129,14 @@ def load(cls, source_code, vlenb = vlen // 8 write_path = get_write_path(source_code) key, input_path = write(source_code, "mlir", specified_dir=write_path) + # Run the Python out-of-line MLIR passes (MLIR bindings) on the kernel + # .mlir in place, before mlir-opt. Currently lowers torchsim.vlane_idx + # (replaces the old C++ -global-idx pass); add more in passes/__init__.py. + from PyTorchSimFrontend.mlir.passes import ( + run_python_passes, run_module_passes, POST_OPT_PASSES, + run_standard_lowering, run_tog, + ) + run_python_passes(input_path, vectorlane=vectorlane_size) new_input_path = os.path.splitext(input_path)[0] raw_tog_path = new_input_path + "_tog.py" tog_path = os.path.join(write_path, "tile_graph.onnx") @@ -178,13 +157,21 @@ def load(cls, source_code, # Use custom malloc to avoid size error new_link_option = link_option + " -Wl,--wrap=malloc -Wl,--wrap=free" cmds = mlir_compile_command(new_input_path, vectorlane_size, vlen=vlen) - opt_cmd = shlex.split(cmds[0]) + opt_pad_cmd = shlex.split(cmds[0]) translate_cmd = shlex.split(cmds[1]) llc_cmd = shlex.split(cmds[2]) llc_asm_cmd = shlex.split(cmds[3]) with lock: try: - subprocess.check_call(opt_cmd) + # loop-padding (mlir-opt) -> Python fine-grained + vcix (one parse/print) + subprocess.check_call(opt_pad_cmd) + run_module_passes(new_input_path + "_padded.mlir", + new_input_path + "_custom.mlir", + POST_OPT_PASSES, vectorlane=vectorlane_size, vlen=vlen) + # Standard MLIR -> LLVM-dialect lowering (registered upstream + # passes) runs in-process via the bindings PassManager, picking + # up after the custom mlir-opt passes (memref-to-gemmini). + run_standard_lowering(new_input_path + "_custom.mlir", new_input_path + "_llvm.mlir") subprocess.check_call(translate_cmd) subprocess.check_call(llc_cmd) subprocess.check_call(llc_asm_cmd) @@ -213,16 +200,29 @@ def load(cls, source_code, return key # Launch tile graph generator - gem5_sample_cmd = shlex.split(gem5_cmds[0]) + gem5_pad_cmd = shlex.split(gem5_cmds[0]) gem5_translate_cmd = shlex.split(gem5_cmds[1]) gem5_llc_cmd = shlex.split(gem5_cmds[2]) lock = FileLock(get_lock_path(write_path), timeout=LOCK_TIMEOUT) with lock: try: - result = subprocess.check_output(gem5_sample_cmd) - with open(raw_tog_path, "wb") as file: - file.write(result) + # mlir-opt now runs only loop-padding/dma-fine-grained/pytorchsim-to-vcix + # and writes the post-vcix IR. The tile-operation-graph pass is ported + # to Python: run_tog reads that IR, writes the TOG (_tog.py) and the + # mutated IR (_custom.mlir: sample-mode step rewrite + compute markers), + # replacing the C++ -test-tile-operation-graph pass. + # loop-padding(timing, mlir-opt) -> Python fine-grained + vcix (one parse/print) + subprocess.check_call(gem5_pad_cmd) + run_module_passes(sample_mlir_path + "_padded.mlir", + sample_mlir_path + "_postvcix.mlir", + POST_OPT_PASSES, vectorlane=vectorlane_size, vlen=vlen) + run_tog(sample_mlir_path + "_postvcix.mlir", raw_tog_path, + sample_mlir_path + "_custom.mlir", + sample_mode=extension_config.CONFIG_TLS_MODE, + vectorlane=vectorlane_size) + # Standard MLIR -> LLVM-dialect lowering in-process (see functional path). + run_standard_lowering(sample_mlir_path + "_custom.mlir", sample_mlir_path + "_llvm.mlir", timing=True) subprocess.check_call(gem5_translate_cmd) subprocess.check_call(gem5_llc_cmd) except subprocess.CalledProcessError as e: diff --git a/PyTorchSimFrontend/mlir/axis_split.py b/PyTorchSimFrontend/mlir/axis_split.py new file mode 100644 index 00000000..71ec4809 --- /dev/null +++ b/PyTorchSimFrontend/mlir/axis_split.py @@ -0,0 +1,229 @@ +"""Aligned axis splitting at the Inductor scheduling layer. + +Goal: guarantee the MLIR codegen sees only per-axis affine index expressions +(no FloorDiv / ModularIndexing). When an index expr contains FloorDiv(v, k) or +ModularIndexing(v, k, m) where `v` is a single iteration variable of extent E +and the divisor (resp. k*m) divides E, the floor/mod is *aligned*: splitting the +loop axis v into (outer, inner) with v = outer*k + inner makes it collapse to a +plain affine term (outer), at zero data-movement cost. + +This is the cheap upstream tool of the affine-only contract. The misaligned case +(cat / non-factor reshape, divisor does not divide the extent) is NOT handled +here -- that needs graph-level copy insertion. + +The rebuild reuses Inductor's own LoopBody machinery, exactly like +MLIRScheduling.revert_group: feed a split var_ranges + iter_vars and re-trace the +node's store function so the index expressions are regenerated over the new +iteration domain. +""" +import sympy +from torch._inductor.ir import LoopBody +from torch._inductor.utils import sympy_index_symbol +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + + +def _as_int(x): + try: + return int(x) + except (TypeError, ValueError): + return None + + +def collect_boundaries(exprs, var_to_axis, var_ranges): + """{axis_index: set(boundary cut points)} for the given index expressions. + + A FloorDiv(v, k) contributes boundary k; ModularIndexing(v, k, m) contributes + k and k*m. Only aligned terms count (boundary divides the var extent). Shared + by find_split_plan (fused LoopBody) and graph_copy (operand loaders). + """ + import collections + bset = collections.defaultdict(set) + for expr in exprs: + for fd in expr.atoms(FloorDiv): + base, div = fd.args + k = _as_int(div) + if base in var_to_axis and k and k > 1: + E = _as_int(var_ranges.get(base)) + if E and E % k == 0: + bset[var_to_axis[base]].add(k) + for mi in expr.atoms(ModularIndexing): + base, div, mod = mi.args + k, m = _as_int(div), _as_int(mod) + if base in var_to_axis and k and m: + E = _as_int(var_ranges.get(base)) + if E and E % (k * m) == 0: + ax = var_to_axis[base] + if k > 1: + bset[ax].add(k) + if k * m < E: + bset[ax].add(k * m) + return bset + + +def _is_chain(boundaries, E): + """True iff [1, sorted(boundaries in (1,E)), E] is a divisibility chain.""" + chain = [1] + sorted(b for b in boundaries if 1 < b < E) + [E] + return all(chain[i + 1] % chain[i] == 0 for i in range(len(chain) - 1)) + + +def find_split_plan(nodes): + """Inspect a group of scheduler nodes and return {axis_index: boundaries}. + + `boundaries` is an ascending divisibility chain [1, b1, ..., E] of cut points + for that axis: splitting the axis at these boundaries (mixed radix, + `v = sum_i d_i * b_i`) makes every FloorDiv/ModularIndexing on it collapse to + an affine combination of the split sub-vars. The cut points are gathered from + the terms on the axis: + - FloorDiv(v, k) -> boundary k + - ModularIndexing(v, k, m) -> boundaries k and k*m (the digit lives in [k, k*m)) + Only aligned terms count (the boundary must divide the extent E). If the + collected boundaries for an axis do NOT form a divisibility chain (e.g. + floor-by-2 and mod-by-3 on extent 6), the radices are incompatible -> the axis + is left unsplit (its floor/mod stays for the misaligned/recompile path). + + axis_index is positional in the group's iteration space, so the same plan + applies to every fused node sharing that space. + """ + import collections + bset = collections.defaultdict(set) # axis -> set of boundary cut points + ext_of = {} # axis -> extent + for n in nodes: + body = getattr(n, "_body", None) + if body is None: + continue + var_to_axis = {v: i for i, v in enumerate(body.iter_vars)} + nb = collect_boundaries(body.indexing_exprs.values(), var_to_axis, body.var_ranges) + for ax, bs in nb.items(): + bset[ax] |= bs + ext_of[ax] = _as_int(body.var_ranges[body.iter_vars[ax]]) + + plan = {} + for ax, bs in bset.items(): + E = ext_of[ax] + # require a real, divisibility-chain split (incompatible radices -> skip). + if E and any(1 < b < E for b in bs) and _is_chain(bs, E): + plan[ax] = [1] + sorted(b for b in bs if 1 < b < E) + [E] + + # A split may push the per-axis index rank past 4. The resulting >4D logical tile + # is peeled into <=4D physical descriptors by the decompose-transfer pass (an + # affine.for nest carrying the lane-banked physical SRAM offset), so there is no + # rank cap here. + return plan + + +def build_split_body(node, plan, prefix="z"): + """Rebuild node._body / sizes for the given split plan. + + Returns (body, (index_size, reduce_size)). Reindexes the EXISTING (already + collapsed/reordered) node._body via LoopBody's copy path instead of re-tracing + from the raw store function: pass the body as `fn` so LoopBody.__init__ takes + _init_with_copy, which substitutes each original iter var with our expression + and runs simplify_with_ranges. For a split axis the substitution + v -> sum_i d_i * b_i (mixed radix over the boundary chain) makes every + FloorDiv/ModularIndexing on it collapse to an affine combination of the d_i, + and reindexing the collapsed body keeps already-merged dims merged (no rank + blow-up). indexing_from_args requires exactly one replacement expr per original + var (index dims then reduce dims), flattened to len(body.var_ranges). + """ + body = node._body + orig_index_vars = list(body.iter_vars) + orig_reduce_vars = list(body.reduce_vars) + + iter_vars = [] + index_args = [] # one expr per ORIGINAL index dim (substituted in) + var_ranges = {} + index_size = [] + ctr = 0 + + for ax, v in enumerate(orig_index_vars): + ext = body.var_ranges[v] + if ax in plan: + bounds = plan[ax] # ascending chain [1, b1, ..., E] + # one sub-var per segment: d_i has extent b_{i+1}/b_i, significance b_i. + subs = [] # (symbol, extent, significance) low->high + expr = sympy.Integer(0) + for i in range(len(bounds) - 1): + seg_ext = bounds[i + 1] // bounds[i] + nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1 + subs.append((nv, seg_ext, bounds[i])) + expr = expr + nv * bounds[i] + # iteration nest: most-significant (outermost) dim first. + for nv, seg_ext, _sig in reversed(subs): + iter_vars.append(nv) + var_ranges[nv] = sympy.Integer(seg_ext) + index_size.append(sympy.Integer(seg_ext)) + index_args.append(expr) + else: + nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1 + iter_vars.append(nv) + var_ranges[nv] = ext + index_size.append(ext) + index_args.append(nv) + + # Reduction dims pass through unchanged (a fresh symbol with the same range), + # using the "r" prefix and kept after the index dims so the reduction axis + # stays innermost (var_ranges is ordered iter-then-reduce; sizes splits on + # len(iter_vars)). We do not split reduction dims here. + reduce_vars = [] + reduce_size = [] + reduce_args = [] + for rctr, v in enumerate(orig_reduce_vars): + ext = body.var_ranges[v] + nv = sympy_index_symbol(f"r{rctr}") + reduce_vars.append(nv) + var_ranges[nv] = ext + reduce_size.append(ext) + reduce_args.append(nv) + + args = [index_args, reduce_args] if orig_reduce_vars else [index_args] + new_body = LoopBody(body, args, var_ranges, iter_vars, reduce_vars) + new_body.indexing_exprs = { + name: _fold_with_ranges(e, var_ranges) + for name, e in new_body.indexing_exprs.items() + } + return new_body, (index_size, reduce_size) + + +def _fold_with_ranges(expr, var_ranges): + """Fold residual FloorDiv/ModularIndexing that simplify_with_ranges missed. + + A mixed-radix split leaves terms like FloorDiv(z1 + 4*z2, 12); these are 0 by + construction (the lower digits sum below the boundary), but the Inductor + simplifier cannot prove a multi-term numerator < divisor. We prove it directly + from the split sub-var ranges via bound_sympy: + FloorDiv(num, d) -> 0 if 0 <= num < d + ModularIndexing(num, k, m) -> num // k if 0 <= num < k*m (mod is a no-op) + Iterated to a fixpoint (folding a mod can expose a foldable floor). + """ + from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + ranges = {} + for v, sz in var_ranges.items(): + e = _as_int(sz) + if e is not None and e >= 1: + ranges[v] = ValueRanges(0, e - 1) + if not ranges: + return expr + + def vr(num): + try: + return bound_sympy(num, ranges) + except Exception: + return None + + for _ in range(8): + changed = False + for fd in list(expr.atoms(FloorDiv)): + num, div = fd.args + d = _as_int(div) + b = vr(num) if d else None + if b is not None and b.lower >= 0 and b.upper < d: + expr = expr.subs(fd, sympy.Integer(0)); changed = True + for mi in list(expr.atoms(ModularIndexing)): + num, k, m = mi.args + ki, mi_ = _as_int(k), _as_int(m) + b = vr(num) if (ki and mi_) else None + if b is not None and b.lower >= 0 and b.upper < ki * mi_: + expr = expr.subs(mi, FloorDiv(num, k)); changed = True + if not changed: + break + return expr diff --git a/PyTorchSimFrontend/mlir/graph_copy.py b/PyTorchSimFrontend/mlir/graph_copy.py new file mode 100644 index 00000000..0c49b86f --- /dev/null +++ b/PyTorchSimFrontend/mlir/graph_copy.py @@ -0,0 +1,155 @@ +"""Graph-copy (relayout) for incompatible-radix operands. + +When an elementwise consumer reads two operands whose floor/mod groupings on a +shared axis are incompatible (the boundary cut points do not form a divisibility +chain, e.g. floor-by-2 and mod-by-3 on extent 6), axis-split cannot linearize the +fused index. We `realize()` the cheaper operand at the consumer's lowering, which +materializes it as a contiguous buffer; the consumer then reads it affine and only +the other (single, compatible) grouping remains for axis-split to handle. + +Detection reuses axis_split.collect_boundaries on each operand's loader index, so +it is the same precise radix analysis used at the scheduling layer -- not an FX +view-chain heuristic. The hook wraps the already-registered lowering entries (the +make_pointwise results), so it sees every elementwise consumer in one place. The +realize() (not a clone, which Inductor inlines) is what actually forces the buffer +boundary; see the PoC notes in docs. + +Behavior-neutral unless a genuine incompatible-radix conflict is detected. +""" +from torch._inductor import lowering as L +from torch._inductor import dependencies +from torch._inductor import ir +from torch._inductor.ir import TensorBox +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from . import axis_split + + +def _has_multivar_floormod(exprs): + """True if any FloorDiv/ModularIndexing argument spans >1 loop variable + (case 7: cross-axis floor/mod that axis-split cannot split).""" + for e in exprs: + for f in list(e.atoms(FloorDiv)) + list(e.atoms(ModularIndexing)): + if len(f.args[0].free_symbols) > 1: + return True + return False + + +def _numel(tb): + n = 1 + for s in tb.get_size(): + v = axis_split._as_int(s) + if v is None: + return float("inf") + n *= v + return n + + +def _relayout_args(args): + """Return a modified args list with one operand replaced by a forced copy when + it needs relayout, or None to leave args unchanged. The copy uses + ExternKernel.copy_input (a realized identity Pointwise) -- this materializes + *views* too, unlike StorageBox.realize() which is a no-op on a ReinterpretView. + The copy kernel iterates the operand's own (contiguous) shape, so its index + collapses to single-var and axis-split handles it; the consumer then reads the + copy affine.""" + pos = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + if not pos: + return None + tbs = [args[i] for i in pos] + # Output/iteration shape = the broadcast of all operands (the largest rank, + # max per dim). For a single-operand consumer (e.g. a reduction reading a + # multi-var-view input) this is just that operand's shape -- still enough to + # detect a multi-var floor and copy_input it (case 7); the 2-operand radix + # conflict (case 5) naturally needs >=2 operands. + # Per-dim max extent over the max-rank operands (order-independent). Picking a + # single operand by rank alone (max key=len) would, for two equal-rank operands + # with different per-dim extents, take a broadcast-from operand's smaller shape + # and then miss the genuine conflict on the broadcast-to dim. + maxrank = max(len(t.get_size()) for t in tbs) + full = [t.get_size() for t in tbs if len(t.get_size()) == maxrank] + ranges = [max((s[d] for s in full), key=lambda v: (axis_split._as_int(v) or -1)) + for d in range(maxrank)] + extents = [axis_split._as_int(s) for s in ranges] + if not extents or any(e is None for e in extents): + return None # scalar / dynamic -> skip + + # Only true elementwise consumers: each operand is broadcast-compatible with the + # output (same rank, every dim is 1 or == the output extent). This admits + # broadcasting operands (e.g. y[8,1] into [8,3]) while excluding mm/bmm/cat-style + # ops whose operands differ in a non-broadcast way. + for tb in tbs: + sz = [axis_split._as_int(s) for s in tb.get_size()] + if len(sz) != len(extents) or any( + d is not None and d != 1 and d != e for d, e in zip(sz, extents) + ): + return None + + # Trace each operand's loader to get its read indices (sympy) over the shared + # output iteration; make_loader returns a value, so extract_read_writes is what + # gives the index expressions. range_vars are positional per output axis, so the + # axis numbering is consistent across operands. + per_bnd = [] # [{axis: boundary set}] per operand + per_mv = [] # [bool] operand has multi-var floor/mod + for tb in tbs: + try: + rw = dependencies.extract_read_writes(tb.make_loader(), list(ranges)) + except Exception: + per_bnd.append({}) + per_mv.append(False) + continue + v2a = {v: i for i, v in enumerate(rw.range_vars)} + exprs = [r.index for r in rw.reads if hasattr(r, "index")] + b = axis_split.collect_boundaries(exprs, v2a, rw.var_ranges) + mv = _has_multivar_floormod(exprs) + per_bnd.append(b) + per_mv.append(mv) + + victim = None + + # Case 5 -- incompatible radices on a shared axis between two operands. + for axis, E in enumerate(extents): + contrib = [(i, per_bnd[i][axis]) for i in range(len(tbs)) if per_bnd[i].get(axis)] + if len(contrib) < 2: + continue # single grouping -> axis-split handles + union = {b for _, s in contrib for b in s} + if axis_split._is_chain(union, E): + continue # compatible -> axis-split handles + victim = min(contrib, key=lambda c: _numel(tbs[c[0]]))[0] + break + + # Case 7 -- an operand whose floor/mod argument spans multiple consumer axes + # (e.g. (3*p0+p1)//4 from a transpose+reshape feeding a broadcast/softmax that + # keeps the dims separate). axis-split cannot split a multi-var argument. + if victim is None: + mv_ops = [i for i in range(len(tbs)) if per_mv[i]] + if mv_ops: + victim = min(mv_ops, key=lambda i: _numel(tbs[i])) + + if victim is None: + return None + new = list(args) + p = pos[victim] + new[p] = ir.ExternKernel.copy_input(args[p]) + return new + + +def install(): + """Wrap registered lowering entries to insert relayout. Idempotent. Call once + at backend import (after torch._inductor.lowering is populated -- make_pointwise + runs at import to build the entries, so we wrap the entries, not the factory).""" + if getattr(L, "_torchsim_relayout_installed", False): + return + for key, fn in list(L.lowerings.items()): + def wrap(orig): + def wrapped(*a, **k): + try: + na = _relayout_args(a) + except Exception: + na = None # detection must never break lowering + if na is not None: + a = na + return orig(*a, **k) + return wrapped + L.lowerings[key] = wrap(fn) + L._torchsim_relayout_installed = True diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index b163ad1a..725e0dc6 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1,6 +1,5 @@ import contextlib import sympy -import sys import time import re import os @@ -536,10 +535,8 @@ def load(self, name: str, index: sympy.Expr): sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) - # MVIN Encoding - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, int(padding)) - code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, int(padding)) self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching if not comptute_depedency: @@ -608,9 +605,8 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) sram_index_var = self.spad_buffer_dict[str(value)][3] # Generate DMA instruction - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) self.dma_stores.writeline(common.DeferredLine(name, code)) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -737,9 +733,8 @@ def store_reduction(self, name, index, value): ops._store(value, sram_var, sram_index_var, tile_shape, buffer_name=name) # Generate DMA instruction - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) self.reductions_suffix.writeline(common.DeferredLine(name, code)) def indirect_indexing(self, index_var, size, check=True, wrap_neg=True): @@ -1170,7 +1165,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Note: index could contain symbols that represent dynamic axies # Extract dimension of index(e.g, index0, index1) local_dims = [int(str(i)[5:]) for i in index.free_symbols if "index" in str(i)] - implicit_local_dims = list(index.args) total_dims = [int(str(i)[5:]) for i in self.itervars] local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) local_dims.sort() # Assume that smaller index is placed in the outer loop @@ -1178,6 +1172,17 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe index = index.subs({s: 0 for s in indirect_syms}, simultaneous=True) indirect_dims = [f"{i}" for i in indirect_syms] + # axis-split + graph-copy linearize aligned floor/mod upstream. Anything that + # reaches here still carrying floor/mod (store-side ModularIndexing, + # reduction-axis floor/mod, incompatible-radix views) would be silently + # mis-strided in the dram_stride computation below, so fail loudly instead. + if index.has(FloorDiv) or index.has(ModularIndexing): + raise NotImplementedError( + f"Unlinearized floor/mod in DMA index: {index}. axis-split/graph-copy " + f"did not eliminate it; this view is unsupported " + f"(see docs/axis-split-scheduling.md)." + ) + # Reduction can have two type of tile size if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): local_dims = total_dims # Brodatcast tile shape @@ -1243,15 +1248,9 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride else: - raise NotImplementedError("Currently not implemented... ;)") - - if len(implicit_local_dims)!=0 and len(local_dims) != len(implicit_local_dims) and self.is_modular_indexing(index): - for axis_constraints in self.kernel_group.tile_desc.implicit_dim_size.values(): - if len(axis_constraints) <= 1: - continue - sorted_constraints = sorted(axis_constraints, key=lambda c: int(c.args[1])) - for constraint in sorted_constraints[1:]: - index = index.replace(constraint.original_expr, 0) + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) + local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride # Calculate dram stride in local tile-dim order. # This keeps dram/sram stride rank aligned with tile rank. @@ -1266,19 +1265,12 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe else: dram_dict = defaultdict(list) - implicit_dim_divisors = defaultdict(lambda: sys.maxsize) - # Assume that div will have high priority than mod for arg in index.as_ordered_terms(): coeff, dim = arg.as_coeff_mul() if len(dim) == 0: continue real_dim = list(dim[0].free_symbols)[0] - if dim[0].has(ModularIndexing): - if dim[0].args[1] < implicit_dim_divisors[str(real_dim)]: - implicit_dim_divisors[str(real_dim)] = dim[0].args[1] - dram_dict[str(real_dim)] = [coeff] - else: - dram_dict[str(real_dim)].append(coeff) + dram_dict[str(real_dim)].append(coeff) # Add missing dims if not added max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 @@ -1289,142 +1281,62 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe sorted_keys = sorted(dram_dict.keys()) dram_stride = sum((dram_dict[key] for key in sorted_keys), []) - # Support floordiv pattern - # FIXME. How to integrate implicit dims and floordiv? - # This was introduced to support GroupNorm - if index.has(FloorDiv) and not index.has(ModularIndexing): - dim_divisor = [1] * len(local_dims) - for sub in sympy.preorder_traversal(index): - if isinstance(sub, FloorDiv): - if not str(sub.args[0]).startswith("index"): - continue - dim_idx = int((str(sub.args[0])[5:])) - if dim_idx not in local_dim_to_axis: - continue - local_dim_idx = local_dim_to_axis[dim_idx] - if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: - # In this case, need to recompile - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - divisor = sub.args[1] * self.kernel_group.tile_desc.vmap.vlane_stride - new_size = ((original_size + divisor - 1) // divisor) * divisor - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - # Can't use dim_idx as vlane_split_axis - if dim_idx == self.kernel_group.tile_desc.vmap.vlane_split_axis: - self.kernel_group.tile_desc.vmap.vlane_split_axis = (dim_idx + 1) % len(original_tile) - - # Send recompile signal - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") - dim_divisor[local_dim_idx] = sub.args[1] - - # Update dram_stride, just insert 0 next to target dim - offset = 0 - for dim_idx, divisor in enumerate(dim_divisor): - if divisor == 1: - continue - dram_stride.insert(dim_idx+offset+1, 0) - local_tile_desc.apply_divisor(dim_idx+offset, divisor, "pad") - local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split") - offset = offset+1 - - # Support ModularIndexing pattern - # This pattern can be used to broadcast ex) torch.cat([a,a]) - # ModularIndexing(x, y, z) means (x // y) % z - # tile_size must be: multiple of y (floorDiv divisor) and divisor of z (modular divisor) - if index.has(ModularIndexing): - for sub in sympy.preorder_traversal(index): - if isinstance(sub, ModularIndexing): - if not str(sub.args[0]).startswith("index"): - continue - dim_idx = int((str(list(sub.args[0].free_symbols)[0])[5:])) - floor_divisor = sub.args[1] # y: floorDiv divisor - mod_divisor = sub.args[2] # z: modular divisor - current_tile_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] - - # Check if tile_size is multiple of floorDiv divisor - if int(current_tile_size % floor_divisor) != 0: - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - divisor = floor_divisor * self.kernel_group.tile_desc.vmap.vlane_stride - new_size = ((original_size + divisor - 1) // divisor) * divisor - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a multiple of floorDiv divisor {floor_divisor} in ModularIndexing") - - # Check if tile_size is a divisor of modular divisor - if int((mod_divisor * floor_divisor) % current_tile_size) != 0: - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - # Find the largest divisor of mod_divisor that is <= original_size - # and is a multiple of floor_divisor - new_size = original_size - while new_size > 0: - if mod_divisor % new_size == 0 and new_size % floor_divisor == 0: - break - new_size -= floor_divisor - - if new_size <= 0: - new_size = mod_divisor * floor_divisor - - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a divisor of modular divisor {mod_divisor} in ModularIndexing") - # FIXME. It will be nice to modify node instead of this exception handling... if len(self.itervars) == 1 and self.reduction_depth == 0: # In case of reduction loop only case, we will add dummy loop so shift it once dram_stride = [0] + dram_stride[:-1] return local_tile_desc, index_var, dram_stride - def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute): + def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, + dram_var, dram_index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, padding, + subtile_size=None, async_type=None): + """Emit a generic togsim.transfer op for a DMA whose access exceeds the + 4D Gemmini descriptor limit. Carries the full N-D access (dram/tile + strides + shapes) plus the SSA operands a memref.dma_start needs + (dma_type / vlane_split_axis / vlane_stride), so the decompose pass + (passes/decompose_transfer.py) is purely mechanical: it peels the excess + dims into a loop of <=4D memref.dma_start, reusing these operands. + + Operand prep uses the read/write cache+counter for the dma_type enum and + CSE'd vlane consts so the transfer is self-contained; togsim is an + unregistered dialect -> generic form. + """ dma_key = (vlane_split_axis, vlane_stride, mlir_dtype) if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: - dma_type, vlane_split_axis, vlane_stride = self.dma_read_cache[dma_key] + dma_type, vsa, vst = self.dma_read_cache[dma_key] elif dma_type_name == "MVOUT" and dma_key in self.dma_write_cache: - dma_type, vlane_split_axis, vlane_stride = self.dma_write_cache[dma_key] + dma_type, vsa, vst = self.dma_write_cache[dma_key] else: - vlane_split_axis = self.get_const_cse(vlane_split_axis) - vlane_stride = self.get_const_cse(vlane_stride) + vsa = self.get_const_cse(vlane_split_axis) + vst = self.get_const_cse(vlane_stride) if dma_type_name == "MVIN": dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_read_counter}"]) self.dma_read_counter += 1 - self.dma_read_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] + self.dma_read_cache[dma_key] = [dma_type, vsa, vst] else: dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"]) - self.dma_write_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] + self.dma_write_cache[dma_key] = [dma_type, vsa, vst] tag = self.get_tag_cse() zero_cse = self.get_const_cse(0) - - # Prepare opearnds and attributes - dram_operand = f"%{dram_var}[%{dram_index_var}]" - sram_operand = f"%{sram_var}[{sram_index_var}]" # Use string - tag_var = f"%{tag}[%{zero_cse}]" - dma_attribute = f"%{vlane_split_axis}, %{vlane_stride}" - sram_shape = tile_shape - tag_shape = "memref<1xi32>" - - if dma_type_name == "MVIN": - src_operand, dst_operand = dram_operand, sram_operand - src_shape, dst_shape = dram_shape, sram_shape - else: - src_operand, dst_operand = sram_operand, dram_operand - src_shape, dst_shape = sram_shape, dram_shape - - return f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape} {attribute}" + # vlane_split_axis is carried as a VALUE attr (not an SSA operand) because the + # decompose pass must remap it: collapsing unit tile dims renumbers the axes, + # so the descriptor's vlane axis index changes and the pass rebuilds the const. + attrs = ( + f'dma_kind = "{dma_type_name}", ' + f'vlane_split_axis = {int(vlane_split_axis)} : i64, ' + f'dram_stride = {dram_stride}, tile_stride = {tile_stride}, ' + f'padding = {int(padding)} : i64' + ) + if subtile_size: + av = int(async_type) if async_type is not None else 1 + attrs += f', subtile_size = {list(subtile_size)}, async = {av} : i64' + # operands: dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_stride + return ( + f'"togsim.transfer"(%{dram_var}, %{dram_index_var}, %{sram_var}, %{zero_cse}, ' + f'%{tag}, %{dma_type}, %{vst}) {{{attrs}}} : ' + f'({dram_shape}, index, {tile_shape}, index, memref<1xi32>, index, index) -> ()' + ) def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None): c_type = mlir_common.DTYPE_TO_C[dtype] diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 734ca967..a70d1c7d 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -1,5 +1,6 @@ import dataclasses import math +import os import contextvars from contextlib import contextmanager from dataclasses import dataclass @@ -14,9 +15,8 @@ from torch._inductor.codegen import cpp from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout -from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep from torch._inductor.codegen.wrapper import KernelDefinitionLine -from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity +from torch.utils._sympy.functions import Identity import sympy import contextlib @@ -119,27 +119,6 @@ def get_dtype_nbytes(dtype): } } -def format_dma_op_attributes( - dram_stride: Sequence, - sram_stride: Sequence, - padding: int = 0, - *, - subtile_size: Optional[Sequence] = None, - async_type: Optional[int] = None, -) -> str: - """Attribute dict for memref.dma_start; stride lists as bracketed integer lists.""" - parts = [ - f"dram_stride = {dram_stride}", - f"sram_stride = {sram_stride}", - f"padding = {int(padding)}", - ] - if subtile_size: - parts.append(f"subtile_size = {subtile_size}") - av = int(async_type) if async_type is not None else 1 - parts.append(f"async = {av} : i64") - return "{" + ", ".join(parts) + "}" - - class ParallelLoopBuffer(IndentedBuffer): def indent(self, offset=1, attribute="", suffix=""): @contextlib.contextmanager @@ -456,45 +435,27 @@ def pad_vlane_tile(self): padded_size = used_vlane * vlane_stride self._tile_size[vlane_split_axis] = math.ceil(self._tile_size[vlane_split_axis] / padded_size) * padded_size - def apply_constraints(self, constraints, ranges): - for idx, (axis_constraints, axis_size) in enumerate(zip(constraints.values(), ranges)): - for const in axis_constraints: - if const.args[1] == 1: - continue - divider = int(const.args[1]) - - if not self.tile_constraint[idx].fixed: - self.tile_constraint[idx].fixed = True - self._tile_size[idx] = divider - elif self.tile_constraint[idx].fixed and self._tile_size[idx] > divider: - self._tile_size[idx] = divider - self.update_tile_stride() - @staticmethod def init_tile_size(ranges, vlane_stride, vector_lane): + # Logical tile init for ANY rank. Only the innermost dims carry the + # vectorized tile; all further-outer dims stay 1. The physical Gemmini DMA + # descriptor is <=4D -- a higher-rank logical tile is mapped onto <=4D + # descriptors by togsim.transfer + the decompose pass (logical/physical + # tile split), so no rank cap here. nr_dim = len(ranges) + if nr_dim == 0: # scalar + return [1] tile_size = [1] * nr_dim - if len(tile_size) == 2: + if nr_dim == 1: + tile_size[0] = 1 if ranges[0] == 1 else 2 * vlane_stride * vector_lane + elif nr_dim == 2: tile_size[-1] = vlane_stride * vector_lane tile_size[-2] = 2 * vector_lane - elif len(tile_size) == 0: # Scalar - tile_size = [1] - ranges = [1] - elif len(tile_size) == 1 and ranges[0]==1: - tile_size[0] = 1 - elif len(tile_size) == 1: - tile_size[0] = 2 * vlane_stride * vector_lane - elif len(tile_size) == 3: - tile_size[-1] = vector_lane - tile_size[-2] = 4 * vector_lane - tile_size[-3] = 2 - elif len(tile_size) == 4: + else: # 3D and up (general) tile_size[-1] = vector_lane tile_size[-2] = 4 * vector_lane tile_size[-3] = 2 - tile_size[-4] = 1 - else: - raise NotImplementedError("dummy tile size fail!") + # tile_size[:-3] stay 1 (subsumes the old 4D [-4]=1 and any higher rank) return tile_size @staticmethod @@ -544,7 +505,6 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N vlane_stride=vlane_stride ) - self.implicit_dim_size = {} self.nr_rdim = 0 self.offset = sympy.Integer(0) # Dram offset @@ -710,59 +670,6 @@ def call_kernel(self, kernel_name): # generate the code to call this wrapper.generate_kernel_call(kernel_name, call_args, triton=False) - def is_modular_indexing(self, expr): - return "ModularIndexing" in str(expr) - - def implicit_dim_ops(self, nodes): - target_patterns = (ModularIndexing, FloorDiv, Mod) - target_operands = [] - for target_node in nodes: - for read_operand in target_node.read_writes.reads: - read_operand: MemoryDep - if isinstance(read_operand, StarDep) or isinstance(read_operand, WeakDep): - continue - read_index = read_operand.index - for arg_expr in read_index.args: - if arg_expr.atoms(*target_patterns): - target_operands.append(read_operand) - return target_operands - - def extract_dividers(self, implicit_ops): - # When a specific axis is processed, the key constraint to verify is the divider. - # The tile size must be forced to match the divider size. - dim_dividers = defaultdict(set) - for operand in implicit_ops: - subs_map = { - s: sympy.symbols(s.name.replace("c", "index", 1)) - for s in operand.index.free_symbols - } - rev_subs_map = { - sympy.symbols(s.name.replace("c", "index", 1)) : s - for s in operand.index.free_symbols - } - new_index = operand.index.subs(subs_map) - for arg in new_index.args: - if arg.is_number: - continue - if len(arg.free_symbols) > 1: - raise NotImplementedError("Not supporting this view operation...!") - if arg.is_Mul and arg.args[0].is_number: - arg = arg.args[1] - - if isinstance(arg, ModularIndexing): - modular_expr = ModularIndexing(arg.args[0], arg.args[1], arg.args[2]) - modular_expr.original_expr = arg - elif arg.is_symbol: - modular_expr = ModularIndexing(arg, 1, operand.ranges[rev_subs_map[arg]]) - modular_expr.original_expr = arg - elif "//" in str(arg): - modular_expr = ModularIndexing(arg.args[0], arg.args[1], operand.ranges[rev_subs_map[arg.args[0]]]//arg.args[1]) - modular_expr.original_expr = arg - else: - raise NotImplementedError("What is this case?") - dim_dividers[modular_expr.args[0]].add(modular_expr) - return dim_dividers - def compute_tile_size(self, nodes, vars, reduction_vars): vlane_split_axis = len(vars) - 1 vlane_stride = 2 # Set minimum vlane stride @@ -782,14 +689,6 @@ def compute_tile_size(self, nodes, vars, reduction_vars): self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0] - # Handle implict dims. Input operand could be high dimension tensor. - # Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173 - implicit_ops = self.implicit_dim_ops(nodes) - if implicit_ops: - tile_constraints = self.extract_dividers(implicit_ops) - self.kernel_group.tile_desc.apply_constraints(tile_constraints, self.ranges) - self.kernel_group.tile_desc.implicit_dim_size = tile_constraints - # Check recodegen reason if self.recodegen is not None: if self.recodegen == "spad_overflow": @@ -838,12 +737,11 @@ def codegen_nodes(self, nodes, kernel_name): with self as kernel: for node in nodes: node.run(vars, reduction_vars) - except RecompileSignal as e: + except RecompileSignal: recompile_try += 1 if recompile_try > max_retry_compile: raise RuntimeError("Failed to compile kernel after multiple attempts.") # Retry compile nodes - #print(f"Try recompile({recompile_try}/{max_retry_compile}). Reason: {e}") continue V.graph.removed_buffers |= self.removed_buffers # V.graph.inplaced_to_remove |= self.inplaced_to_remove diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index 217129e8..f1fb4186 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -1135,11 +1135,19 @@ def extract_strided_slice(operand, target_size, offsets=None, sizes=None, stride @staticmethod def vlane_offset(operand1, operand2, *args, **kwargs): + # Emit a dedicated torchsim.vlane_idx op (generic form; torchsim is an + # unregistered dialect) instead of overloading arith.addi with a + # vlane_offset attribute. A Python out-of-line pass lowers it to + # (vcix.v.i per-lane index * offset); see + # PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py. tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - opcode = f'arith.add{ret_type[0]}' - op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + offset = kwargs.get("attributes", {}).get("vlane_offset", 0) + op_str = '"torchsim.vlane_idx"()' + func_type = f'() -> {shape}' + return format_mlir_op(op_str, func_type, + attributes={"vlane_offset": f"{offset} : i64"}, + comment=kwargs.get("comment")), [tile_size, ret_type] @staticmethod def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_name, *args, **kwargs): diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 22d1011b..41ec61af 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -249,6 +249,19 @@ def codegen_node(self, _node): nodes, key=lambda x: int(x.is_reduction()) ).group + # axis-split: linearize compatible floor/mod radices at the scheduling layer. + from . import axis_split + plan = axis_split.find_split_plan(nodes) + if plan: + for _n in nodes: + if getattr(_n, "_body", None) is None: + continue + _body, _ranges = axis_split.build_split_body(_n, plan) + _n._sizes, _n._body, _n.group = _ranges, _body, (_n.get_device(), self.group_fn(_ranges)) + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + # Note: We assume that there is at least one loop in the nodes # But, inductor simplifies the group, there could be no loop # In that case, we add dummy loop(size=1) to the group @@ -353,3 +366,9 @@ def get_order(n): if origins: _, _, last = max(origins) V.graph.wrapper_code.enter_context(last) + + +# Install the graph-copy (incompatible-radix relayout) lowering hook once at import. +# See graph_copy.py. +from . import graph_copy as _graph_copy +_graph_copy.install() diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index c8fc036f..529a49b5 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -952,18 +952,9 @@ def generate_dma_code(): zero_cse = self.get_const_cse(0, "index") sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - if subtile_size: - attribute = mlir_common.format_dma_op_attributes( - _dram_stride, - sram_strides, - int(padding), - subtile_size=subtile_size, - async_type=int(async_type) if async_type is not None else None, - ) - else: - attribute = mlir_common.format_dma_op_attributes(_dram_stride, sram_strides, int(padding)) - code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, _dram_stride, sram_strides, int(padding), + subtile_size=subtile_size if subtile_size else None, async_type=async_type) local_code.writeline(code) return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() @@ -1030,9 +1021,8 @@ def load_epilogue(self, name: str, index: sympy.Expr): # Allocate sram buffer dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) self.cse.generate(self.dma_loads, code, assignment = False) self.buffer_names[name] = sram_var else: @@ -1098,9 +1088,8 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): ops._store(value, sram_var, compute_index_var, tile_shape, buffer_name=buffer_name) # Generate DMA instruction - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) self.dma_stores.writeline(DeferredLine(name, code)) def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): @@ -1249,9 +1238,8 @@ def store_reduction_epilogue(self, name, index, value): # MVOUT Encoding # Generate DMA instruction - attribute = mlir_common.format_dma_op_attributes(dram_stride, final_tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, final_tile_shape, attribute) + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, final_tile_shape, dram_stride, final_tile_stride, 0) self.reductions_suffix.writeline(DeferredLine(name, code)) def set_tile_size(self, template_fusion_info, prologue=False): diff --git a/PyTorchSimFrontend/mlir/passes/__init__.py b/PyTorchSimFrontend/mlir/passes/__init__.py new file mode 100644 index 00000000..82cadc2f --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/__init__.py @@ -0,0 +1,86 @@ +"""Python out-of-line MLIR passes run on each kernel .mlir before mlir-opt. + +MLIR's PassManager only schedules *registered C++ passes*, not arbitrary Python +functions, so imperative Python rewrites are orchestrated here instead. The flow +is Module-centric: parse the .mlir once, run each registered pass on the shared +Module, print once. A text marker check skips parsing entirely when no pass's +target op is present (the common case). + +To add a pass, create a module exposing MARKERS (tuple of op-name strings) and +run(module) (mutates the Module in place), and append it to PASSES below. +""" +def _ensure_mlir_bindings_on_path(): + """Make `import mlir` work even when PYTHONPATH is not set, by deriving the + bindings location from TORCHSIM_LLVM_PATH (e.g. /riscv-llvm/bin -> + /riscv-llvm/python_packages/mlir_core). The container sets PYTHONPATH, but + plain local runs may not.""" + try: + import mlir.ir # noqa: F401 + return + except ModuleNotFoundError: + pass + import os + import sys + from PyTorchSimFrontend import extension_config + llvm_path = (extension_config.CONFIG_TORCHSIM_LLVM_PATH or "").rstrip("/") + cand = os.path.join(os.path.dirname(llvm_path), "python_packages", "mlir_core") + if os.path.isdir(cand) and cand not in sys.path: + sys.path.insert(0, cand) + + +_ensure_mlir_bindings_on_path() + +from . import lower_vlane_idx +from . import decompose_transfer +from . import dma_fine_grained +from . import lower_to_vcix +from .lower_to_llvm import run_standard_lowering # noqa: F401 (re-exported) +from .build_tog import run_tog # noqa: F401 (re-exported; replaces C++ test-tile-operation-graph) +from .dma_fine_grained import run_fine_grained # noqa: F401 (re-exported; standalone/CLI) +from .lower_to_vcix import run_to_vcix # noqa: F401 (re-exported; standalone/CLI) + +# Module rewrite passes around the one remaining mlir-opt pass (-test-loop-padding). +# Each exposes MARKERS + run(module, **opts); run_module_passes parses once per phase. +# decompose_transfer first: togsim.transfer -> memref.dma_start (downstream expects it). +PRE_OPT_PASSES = [ + decompose_transfer, + lower_vlane_idx, +] +# fine-grained first: splits the matmul DMAs that the vcix lowering then reads. +POST_OPT_PASSES = [ + dma_fine_grained, + lower_to_vcix, +] + + +def run_module_passes(in_path, out_path, passes, **opts): + """Parse `in_path` once, run each marker-matched pass on the shared Module in + order, print once to `out_path` (in place if equal). `opts` forwarded to each + run(module, **opts). Returns True if any pass ran.""" + with open(in_path) as f: + text = f.read() + + active = [p for p in passes if any(mk in text for mk in p.MARKERS)] + if not active: + if out_path != in_path: + import shutil + shutil.copyfile(in_path, out_path) + return False + + from mlir.ir import Context, Module, Location + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + module = Module.parse(text) + for p in active: + p.run(module, **opts) + out = str(module) + + with open(out_path, "w") as f: + f.write(out) + return True + + +def run_python_passes(mlir_path, vectorlane=128): + """Run the pre-mlir-opt Module passes (PRE_OPT_PASSES) on `mlir_path`, in place.""" + return run_module_passes(mlir_path, mlir_path, PRE_OPT_PASSES, vectorlane=vectorlane) diff --git a/PyTorchSimFrontend/mlir/passes/build_tog.py b/PyTorchSimFrontend/mlir/passes/build_tog.py new file mode 100644 index 00000000..ae515010 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/build_tog.py @@ -0,0 +1,1139 @@ +"""Python port of the C++ `-test-tile-operation-graph` analysis pass. + +Reads a post-vcix MLIR file, walks `func.func @kernel`, and prints the Tile +Operation Graph (TOG) as a Python `graph = { id: {dict}, ... }` literal to +stdout. The output is byte-exact with the upstream C++ pass +(mlir/test/lib/Analysis/TestTileOperationGraph.cpp + the display() methods in +mlir/include/mlir/Analysis/TileOperationGraph.h). + +Usage: + python3 build_tog.py + +Requires the MLIR Python bindings on PYTHONPATH +(/riscv-llvm/python_packages/mlir_core). +""" + +import os +import sys + +# Allow running standalone without PYTHONPATH set. +_DEFAULT_BINDINGS = "/riscv-llvm/python_packages/mlir_core" +if os.path.isdir(_DEFAULT_BINDINGS) and _DEFAULT_BINDINGS not in sys.path: + sys.path.insert(0, _DEFAULT_BINDINGS) + +import mlir.ir as ir # noqa: E402 + + +# --------------------------------------------------------------------------- +# Node kinds / compute types (mirror TileOperationGraph.h enums). +# --------------------------------------------------------------------------- +BASE_NODE = 0 +COMPUTE_NODE = 1 +LOOP_NODE = 2 +DMA_NODE = 3 +DMA_WAIT_NODE = 4 + +VECTOR_COMPUTE = 0 +MATMUL_COMPUTE = 1 +MATMUL_PRELOAD = 2 + +# Printed compute_type string for each compute-node type (mirrors the C++ +# inline-asm marker emission in TestTileOperationGraph.cpp). +_COMPUTE_TYPE_NAME = { + VECTOR_COMPUTE: "VectorCompute", + MATMUL_COMPUTE: "MatmulCompute", + MATMUL_PRELOAD: "MatmulPreload", +} + + +# Unique-id counter (TOGNode::unique_id), incremented at construction. The C++ +# pass keeps this as a process global, but each kernel runs in its own mlir-opt +# PROCESS so the counters never collide. Here the pass runs in-process and +# extension_codecache compiles kernels CONCURRENTLY in a thread pool, so a single +# module global would race across threads (one kernel's reset/increments interleave +# with another's, leaving nodes -- including the root -- with wrong ids). Keep the +# counter thread-local so each compile thread has an isolated counter, mirroring +# the per-process isolation of the C++ pass. +import threading # noqa: E402 + +_ids = threading.local() + + +def _new_id(): + i = getattr(_ids, "counter", 0) + _ids.counter = i + 1 + return i + + +def _reset_ids(): + _ids.counter = 0 + + +# --------------------------------------------------------------------------- +# Node classes with display() mirroring the C++ header byte-for-byte. +# --------------------------------------------------------------------------- +class TOGNode: + kind = BASE_NODE + + def __init__(self, name, node_id=None): + self.node_id = _new_id() if node_id is None else node_id + self.node_name = name + self.parents = [] + self.children = [] + self.op = None + + def get_last_child(self): + return self.children[-1] if self.children else None + + def add_child(self, c): + self.children.append(c) + + def add_parent(self, p): + self.parents.append(p) + + @staticmethod + def _print_list(name, vec, out): + out.append('\t"%s": [' % name) + out.append(",".join(str(v) for v in vec)) + out.append("]") + + @staticmethod + def _print_str_list(name, vec, out): + out.append('\t"%s": [' % name) + out.append(",".join('"%s"' % v for v in vec)) + out.append("]") + + def _print_node_list(self, name, vec, out): + out.append('\t"%s": [' % name) + out.append(",".join(str(n.node_id) for n in vec)) + out.append("]") + + def display(self, out): + out.append("%d : {\n" % self.node_id) + out.append('\t"node_id": %d,\n' % self.node_id) + out.append('\t"node_name": "%s",\n' % self.node_name) + out.append('\t"node_type": %d,\n' % self.kind) + self._print_node_list("parents", self.parents, out) + out.append(",\n") + self._print_node_list("children", self.children, out) + if self.kind == BASE_NODE: + out.append("\n}\n") + + def bfs(self, out): + if self.node_name == "root": + out.append("graph = {\n") + self.display(out) + out.append(",") + for child in self.children: + child.bfs(out) + if self.node_name == "root": + out.append("}") + + +class TOGLoopNode(TOGNode): + kind = LOOP_NODE + + def __init__(self, name, idx, start, end, step, loop_type): + super().__init__(name) + self.loop_idx = idx + self.loop_start = start + self.loop_end = end + self.loop_step = step + self.loop_type = loop_type + + def display(self, out): + TOGNode.display(self, out) + out.append(",\n") + out.append('\t"loop_index": "%s",\n' % self.loop_idx) + out.append('\t"loop_start": %d,\n' % self.loop_start) + out.append('\t"loop_end": %d,\n' % self.loop_end) + out.append('\t"loop_step": %d,\n' % self.loop_step) + out.append('\t"loop_type": "%s"\n' % self.loop_type) + out.append("}\n") + + +class TOGDMANode(TOGNode): + kind = DMA_NODE + + def __init__(self, name, addr, tile_size, tile_stride, elem_size, is_write, + is_async, tag_idx_list, tag_stride_list, loop_idx_list, + loop_stride_list, indirect_mode): + super().__init__(name) + self.base_addr = addr + self.tile_size = tile_size + self.tile_stride = tile_stride + self.element_size = elem_size + self.is_write = is_write + self.is_async = is_async + self.tag_idx_list = tag_idx_list + self.tag_stride_list = tag_stride_list + self.loop_idx_list = loop_idx_list + self.loop_stride_list = loop_stride_list + self.indirect_mode = indirect_mode + + def display(self, out): + TOGNode.display(self, out) + out.append(",\n") + out.append('\t"is_write": %d,\n' % int(self.is_write)) + out.append('\t"is_async": %d,\n' % int(self.is_async)) + out.append('\t"base_address": "%s",\n' % self.base_addr) + out.append('\t"indirect_mode": %d,\n' % int(self.indirect_mode)) + self._print_list("tile_size", self.tile_size, out) + out.append(",\n") + self._print_list("tile_stride", self.tile_stride, out) + out.append(",\n") + self._print_str_list("tag_idx_list", self.tag_idx_list, out) + out.append(",\n") + self._print_list("tag_stride_list", self.tag_stride_list, out) + out.append(",\n") + self._print_str_list("loop_idx_list", self.loop_idx_list, out) + out.append(",\n") + self._print_list("loop_stride_list", self.loop_stride_list, out) + out.append(",\n") + out.append('\t"element_size": %d\n' % self.element_size) + out.append("}\n") + + +class TOGDMAWaitNode(TOGNode): + kind = DMA_WAIT_NODE + + def __init__(self, name, idx_list, tag_stride_list, tag_divider_list, addr): + super().__init__(name) + self.tag_idx_list = idx_list + self.tag_stride_list = tag_stride_list + self.tag_divider_list = tag_divider_list + self.base_addr = addr + + def display(self, out): + TOGNode.display(self, out) + out.append(",\n") + out.append('\t"base_address": "%s",\n' % self.base_addr) + self._print_str_list("tag_idx_list", self.tag_idx_list, out) + out.append(",\n") + self._print_list("tag_stride_list", self.tag_stride_list, out) + out.append(",\n") + self._print_list("tag_divider_list", self.tag_divider_list, out) + out.append("}\n") + + +class TOGComputeNode(TOGNode): + kind = COMPUTE_NODE + + def __init__(self, name, cycle, ctype): + super().__init__(name) + self.compute_cycle = cycle + self.compute_type = ctype + self.operations = [] + + def display(self, out): + TOGNode.display(self, out) + out.append(",\n") + out.append('\t"compute_cycle": %d,\n' % self.compute_cycle) + out.append('\t"compute_type": %d\n' % self.compute_type) + out.append("}\n") + + +# --------------------------------------------------------------------------- +# MLIR helpers. +# --------------------------------------------------------------------------- +def _op_name(value_or_op): + return value_or_op.operation.name + + +def _is_block_arg(value): + return ir.BlockArgument.isinstance(value) + + +def _defining_op_name(value): + """Return the op name that defines `value`, or None if it is a block arg.""" + if _is_block_arg(value): + return None + return value.owner.name + + +def _const_index_value(value): + """If `value` is an arith.constant of index type, return its int value.""" + if _is_block_arg(value): + return None + owner = value.owner + if owner.name != "arith.constant": + return None + attr = owner.attributes["value"] + try: + return ir.IntegerAttr(attr).value + except Exception: + return None + + +def _value_key(value): + """Identity key for a Value (induction vars map -> loop name).""" + if _is_block_arg(value): + ba = ir.BlockArgument(value) + return ("ba", ba.owner, ba.arg_number) + return ("res", value.owner, 0) + + +def _memref_space(memref_type): + mt = ir.MemRefType(memref_type) + sp = mt.memory_space + if sp is None: + return 0 + return ir.IntegerAttr(sp).value + + +def _int_array_attr(op, key): + if key not in op.attributes: + return None + arr = ir.ArrayAttr(op.attributes[key]) + return [ir.IntegerAttr(x).value for x in arr] + + +# ----- affine expr utilities (mirror the C++ free functions) --------------- +def _is_function_of_dim(expr, dim): + if ir.AffineDimExpr.isinstance(expr): + return ir.AffineDimExpr(expr).position == dim + if ir.AffineBinaryExpr.isinstance(expr): + b = ir.AffineBinaryExpr(expr) + return _is_function_of_dim(b.lhs, dim) or _is_function_of_dim(b.rhs, dim) + return False + + +def _get_coefficient_from_dim(expr, dim): + """Port of getCoefficientFromDim: coefficient of `dim` in expr, or -1.""" + if ir.AffineMulExpr.isinstance(expr): + b = ir.AffineMulExpr(expr) + lhs, rhs = b.lhs, b.rhs + if ir.AffineConstantExpr.isinstance(lhs) and ir.AffineDimExpr.isinstance(rhs): + if _is_function_of_dim(rhs, dim): + return ir.AffineConstantExpr(lhs).value + elif ir.AffineConstantExpr.isinstance(rhs) and ir.AffineDimExpr.isinstance(lhs): + if _is_function_of_dim(lhs, dim): + return ir.AffineConstantExpr(rhs).value + return -1 + if ir.AffineAddExpr.isinstance(expr): + b = ir.AffineAddExpr(expr) + r = _get_coefficient_from_dim(b.lhs, dim) + if r != -1: + return r + r = _get_coefficient_from_dim(b.rhs, dim) + if r != -1: + return r + return -1 + if ir.AffineDimExpr.isinstance(expr): + if _is_function_of_dim(expr, dim): + return 1 + return -1 + + +def _collect_coefficients(expr): + """Port of collectCoefficientsFromAffineExpr.""" + out = [] + + def rec(e): + if ir.AffineMulExpr.isinstance(e): + b = ir.AffineMulExpr(e) + lhs, rhs = b.lhs, b.rhs + if ir.AffineConstantExpr.isinstance(lhs): + if ir.AffineDimExpr.isinstance(rhs): + out.append(ir.AffineConstantExpr(lhs).value) + elif ir.AffineFloorDivExpr.isinstance(rhs): + out.append(ir.AffineConstantExpr(lhs).value) + elif ir.AffineConstantExpr.isinstance(rhs): + if ir.AffineDimExpr.isinstance(lhs): + out.append(ir.AffineConstantExpr(rhs).value) + elif ir.AffineFloorDivExpr.isinstance(lhs): + out.append(ir.AffineConstantExpr(rhs).value) + elif ir.AffineAddExpr.isinstance(e): + b = ir.AffineAddExpr(e) + rec(b.lhs) + rec(b.rhs) + elif ir.AffineFloorDivExpr.isinstance(e): + b = ir.AffineFloorDivExpr(e) + if ir.AffineConstantExpr.isinstance(b.rhs): + out.append(ir.AffineConstantExpr(b.rhs).value) + elif ir.AffineDimExpr.isinstance(e): + out.append(1) + + rec(expr) + return out + + +def _collect_dividers(expr): + """Port of collectDividersFromAffineExpr.""" + out = [] + + def rec(e): + if ir.AffineMulExpr.isinstance(e): + b = ir.AffineMulExpr(e) + lhs, rhs = b.lhs, b.rhs + if ir.AffineConstantExpr.isinstance(lhs): + if ir.AffineDimExpr.isinstance(rhs): + out.append(1) + elif ir.AffineFloorDivExpr.isinstance(rhs): + rec(rhs) + elif ir.AffineConstantExpr.isinstance(rhs): + if ir.AffineDimExpr.isinstance(lhs): + out.append(1) + elif ir.AffineFloorDivExpr.isinstance(lhs): + rec(lhs) + elif ir.AffineAddExpr.isinstance(e): + b = ir.AffineAddExpr(e) + rec(b.lhs) + rec(b.rhs) + elif ir.AffineFloorDivExpr.isinstance(e): + b = ir.AffineFloorDivExpr(e) + if ir.AffineConstantExpr.isinstance(b.rhs): + out.append(ir.AffineConstantExpr(b.rhs).value) + elif ir.AffineDimExpr.isinstance(e): + out.append(1) + + rec(expr) + return out + + +# --------------------------------------------------------------------------- +# The builder (mirrors TestTileOperationGraph members). +# --------------------------------------------------------------------------- +SKIP_OPS = { + "affine.yield", "affine.apply", "memref.get_global", "arith.constant", + "memref.alloc", "memref.reinterpret_cast", +} + + +class MatmulFsm: + Idle = 0 + Preload = 1 + MMPush = 2 + MMVpop = 3 + MMEpilogue = 4 + + +class TogBuilder: + def __init__(self): + self.nr_loop = 0 + self.loop_var_name = {} # value-identity-key -> loop name + self.compute_nodes = [] + self.loop_nodes = [] + self._reset_matmul_fsm() + + # ---- matmul FSM ---- + def _reset_matmul_fsm(self): + self.matmul_fsm = MatmulFsm.Idle + self.mm_iv_op0_count = 0 + self.mm_expected_vpop = 0 + self.mm_vpop_seen = 0 + self.mm_expected_transfer_writes = 0 + self.mm_transfer_writes_seen = 0 + self.current_preload_node = None + self.current_matmul_compute_node = None + + def _finish_matmul_block(self): + self.matmul_fsm = MatmulFsm.Idle + self.current_matmul_compute_node = None + self.mm_iv_op0_count = 0 + self.mm_expected_vpop = 0 + self.mm_vpop_seen = 0 + self.mm_expected_transfer_writes = 0 + self.mm_transfer_writes_seen = 0 + + def _enter_epilogue_or_finish(self): + if self.mm_expected_transfer_writes == 0: + self._finish_matmul_block() + else: + self.matmul_fsm = MatmulFsm.MMEpilogue + + def _append_mm_with_write_count(self, op): + cn = self.current_matmul_compute_node + if cn is None: + return + cn.operations.append(op) + if _op_name(op) != "vector.transfer_write": + return + if self.mm_expected_transfer_writes == 0: + return + self.mm_transfer_writes_seen += 1 + if self.mm_transfer_writes_seen >= self.mm_expected_transfer_writes: + self._finish_matmul_block() + + def _steal_leading_transfer_read(self, parent): + prepend = [] + if parent is None: + return prepend + last = parent.get_last_child() + if last is None or not isinstance(last, TOGComputeNode): + return prepend + if last.compute_type != VECTOR_COMPUTE: + return prepend + ops = last.operations + idx = None + for i, o in enumerate(ops): + if _op_name(o) == "vector.transfer_read": + idx = i + break + if idx is None: + return prepend + prepend.append(ops[idx]) + del ops[idx] + if not ops: + if parent.children and parent.children[-1] is last: + parent.children.pop() + else: + parent.children = [c for c in parent.children if c is not last] + if last in self.compute_nodes: + self.compute_nodes.remove(last) + # node deleted; its id stays consumed (matches C++ unique_id). + return prepend + + def _append_vector_compute(self, parent, op): + if parent is None: + return + last = parent.get_last_child() + if (isinstance(last, TOGComputeNode) + and last.compute_type == VECTOR_COMPUTE): + last.operations.append(op) + return + cn = TOGComputeNode("ComputeNode", 0, VECTOR_COMPUTE) + cn.op = op + cn.operations.append(op) + parent.add_child(cn) + cn.add_parent(parent) + self.compute_nodes.append(cn) + + # ---- vcix classification ---- + @staticmethod + def _vcix_iv_opcode(op): + if _op_name(op) != "vcix.iv": + return None + if "opcode" not in op.operation.attributes: + return None + return ir.IntegerAttr(op.operation.attributes["opcode"]).value + + @staticmethod + def _vcix_vi_opcode(op): + if _op_name(op) != "vcix.v.i": + return None + if "opcode" not in op.operation.attributes: + return None + return ir.IntegerAttr(op.operation.attributes["opcode"]).value + + # ---- affine.for bounds ---- + @staticmethod + def _affine_for_bounds(forop): + oper = forop.operation + step = ir.IntegerAttr(oper.attributes["step"]).value + lb_map = ir.AffineMapAttr(oper.attributes["lowerBoundMap"]).value + ub_map = ir.AffineMapAttr(oper.attributes["upperBoundMap"]).value + # operandSegmentSizes: [lb operands, ub operands, iter operands] + seg_attr = oper.attributes["operandSegmentSizes"] + seg = [seg_attr[i] for i in range(len(seg_attr))] + operands = list(oper.operands) + + def single_const(m): + return len(m.results) == 1 and ir.AffineConstantExpr.isinstance(m.results[0]) + + if single_const(lb_map): + start = ir.AffineConstantExpr(lb_map.results[0]).value + else: + start = _const_index_value(operands[0]) + n_lb = seg[0] if seg else 0 + if single_const(ub_map): + end = ir.AffineConstantExpr(ub_map.results[0]).value + else: + end = _const_index_value(operands[n_lb]) + return start, end, step + + # ---- DRAM index processing ---- + def _process_dram_indices(self, value, loop_index_list, indirect_box): + if not _is_block_arg(value) and value.owner.name == "affine.apply": + apply_op = value.owner + amap = ir.AffineMapAttr(apply_op.attributes["map"]).value + operands = list(apply_op.operands) + if "indirect_access" in apply_op.attributes: + indirect_box[0] = True + for op_v in operands: + if _is_block_arg(op_v): + expr = amap.results[0] + index_pos = operands.index(op_v) + coeff = _get_coefficient_from_dim(expr, index_pos) + key = self.loop_var_name[_value_key(op_v)] + loop_index_list.append((key, coeff)) + else: + self._process_dram_indices(op_v, loop_index_list, indirect_box) + elif _is_block_arg(value): + key = self.loop_var_name[_value_key(value)] + loop_index_list.append((key, 1)) + else: + c = _const_index_value(value) + if c is not None: + loop_index_list.append(("c" + str(c), c)) + + # ---- main recursion ---- + def print_operation(self, op, node): + name = _op_name(op) + if name in SKIP_OPS: + return + + if name == "affine.for": + oper = op.operation + attrs = oper.attributes + loop_type = "" + + def bool_true(k): + return k in attrs and ir.BoolAttr(attrs[k]).value + + if bool_true("outer_loop"): + loop_type = "outer_loop" + elif bool_true("accumulation_loop"): + loop_type = "accumulation_loop" + elif bool_true("inner_loop"): + loop_type = "inner_loop" + + if (bool_true("outer_loop") or bool_true("accumulation_loop") + or bool_true("inner_loop")): + start, end, step = self._affine_for_bounds(op) + loop_index = "loop_arg%03d" % self.nr_loop + self.nr_loop += 1 + iter_var = oper.regions[0].blocks[0].arguments[0] + self.loop_var_name[_value_key(iter_var)] = loop_index + loop_node = TOGLoopNode("loopNode", loop_index, start, end, step, + loop_type) + loop_node.op = op + self.loop_nodes.append(loop_node) + if node is not None: + node.add_child(loop_node) + loop_node.add_parent(node) + for region in oper.regions: + for block in region.blocks: + for inner in block.operations: + self.print_operation(inner, loop_node) + return + + if name == "memref.dma_start": + self._handle_dma_start(op, node) + return + + if name == "memref.dma_wait": + self._handle_dma_wait(op, node) + return + + if node is None: + return + + self._handle_compute(op, node) + + # ---- compute / matmul FSM dispatch ---- + def _handle_compute(self, op, node): + if _op_name(op) == "vcix.iv": + if (self.matmul_fsm == MatmulFsm.MMEpilogue + and self.current_matmul_compute_node): + self._finish_matmul_block() + opc = self._vcix_iv_opcode(op) + if opc is not None: + if opc == 1: + if self.matmul_fsm == MatmulFsm.Idle: + prepend = self._steal_leading_transfer_read(node) + cn = TOGComputeNode("ComputeNode", 0, MATMUL_PRELOAD) + for o in prepend: + cn.operations.append(o) + cn.operations.append(op) + cn.op = cn.operations[0] + node.add_child(cn) + cn.add_parent(node) + self.compute_nodes.append(cn) + self.current_preload_node = cn + self.matmul_fsm = MatmulFsm.Preload + elif (self.matmul_fsm == MatmulFsm.Preload + and self.current_preload_node): + self.current_preload_node.operations.append(op) + else: + raise RuntimeError("vcix.iv opcode=1 invalid state") + return + if opc == 0: + if (self.matmul_fsm == MatmulFsm.Preload + and self.current_preload_node): + cn = TOGComputeNode("ComputeNode", 0, MATMUL_COMPUTE) + cn.op = op + cn.operations.append(op) + node.add_child(cn) + cn.add_parent(node) + self.compute_nodes.append(cn) + self.current_preload_node = None + self.current_matmul_compute_node = cn + self.matmul_fsm = MatmulFsm.MMPush + self.mm_iv_op0_count = 1 + return + if (self.matmul_fsm == MatmulFsm.MMPush + and self.current_matmul_compute_node): + self.current_matmul_compute_node.operations.append(op) + self.mm_iv_op0_count += 1 + return + raise RuntimeError("vcix.iv opcode=0 invalid state") + if (self.matmul_fsm == MatmulFsm.MMEpilogue + and self.current_matmul_compute_node): + self._append_mm_with_write_count(op) + return + self._append_vector_compute(node, op) + return + + if _op_name(op) == "vcix.v.i": + viopc = self._vcix_vi_opcode(op) + if viopc is not None and viopc == 2: + if (self.matmul_fsm == MatmulFsm.MMPush + and self.current_matmul_compute_node): + self.mm_expected_vpop = self.mm_iv_op0_count + self.mm_expected_transfer_writes = self.mm_expected_vpop + self.mm_transfer_writes_seen = 0 + self.matmul_fsm = MatmulFsm.MMVpop + self.mm_vpop_seen = 1 + self.current_matmul_compute_node.operations.append(op) + if self.mm_vpop_seen >= self.mm_expected_vpop: + self._enter_epilogue_or_finish() + return + if (self.matmul_fsm == MatmulFsm.MMVpop + and self.current_matmul_compute_node): + self.current_matmul_compute_node.operations.append(op) + self.mm_vpop_seen += 1 + if self.mm_vpop_seen >= self.mm_expected_vpop: + self._enter_epilogue_or_finish() + return + if (self.matmul_fsm == MatmulFsm.MMEpilogue + and self.current_matmul_compute_node): + self._append_mm_with_write_count(op) + return + self._append_vector_compute(node, op) + return + + if self.matmul_fsm == MatmulFsm.Preload and self.current_preload_node: + self.current_preload_node.operations.append(op) + return + if (self.matmul_fsm == MatmulFsm.MMEpilogue + and self.current_matmul_compute_node): + self._append_mm_with_write_count(op) + return + if (self.matmul_fsm in (MatmulFsm.MMPush, MatmulFsm.MMVpop) + and self.current_matmul_compute_node): + self._append_mm_with_write_count(op) + return + self._append_vector_compute(node, op) + + # ---- dma_start ---- + def _dma_start_fields(self, op): + """Decode memref.dma_start operands by memref ranks. + + Layout: src[srcIdx], dst[dstIdx], numElements, tag[tagIdx], stride, + numElementsPerStride. + """ + operands = list(op.operands) + i = 0 + src = operands[i] + src_type = src.type + src_rank = len(ir.MemRefType(src_type).shape) + i += 1 + src_indices = operands[i:i + src_rank] + i += src_rank + dst = operands[i] + dst_type = dst.type + dst_rank = len(ir.MemRefType(dst_type).shape) + i += 1 + dst_indices = operands[i:i + dst_rank] + i += dst_rank + i += 1 # numElements + tag = operands[i] + tag_rank = len(ir.MemRefType(tag.type).shape) + i += 1 + tag_indices = operands[i:i + tag_rank] + return { + "src": src, "src_type": src_type, "src_indices": src_indices, + "dst": dst, "dst_type": dst_type, "dst_indices": dst_indices, + "tag": tag, "tag_indices": tag_indices, + } + + def _handle_dma_start(self, op, node): + oper = op.operation + f = self._dma_start_fields(op) + dma_async = bool(self._get_async(oper)) + + dst_space = _memref_space(f["dst_type"]) + src_space = _memref_space(f["src_type"]) + + subtile = _int_array_attr(oper, "subtile_size") + + if dst_space == 0 and src_space == 1: + is_write = True + tile_type = f["src_type"] + dram_memref = f["dst"] + dram_indices = f["dst_indices"] + elif dst_space == 1 and src_space == 0: + is_write = False + tile_type = f["dst_type"] + dram_memref = f["src"] + dram_indices = f["src_indices"] + else: + raise RuntimeError("Unexpected memory space") + + tile_mt = ir.MemRefType(tile_type) + tile_shape = subtile if subtile else list(tile_mt.shape) + tile_size = [int(x) for x in tile_shape] + + tile_stride = [] + ds = _int_array_attr(oper, "dram_stride") + if ds: + tile_stride = list(ds) + + loop_index_map = [] + indirect_box = [False] + self._process_dram_indices(dram_indices[0], loop_index_map, indirect_box) + + # std::map => dedup + lexicographic sort by key. + reordered = {} + for key, stride in loop_index_map: + if key not in reordered: + reordered[key] = stride + loop_idx_list = [] + loop_stride_list = [] + for key in sorted(reordered.keys()): + loop_idx_list.append(key) + loop_stride_list.append(reordered[key]) + + # base address + address = "arg" + if _is_block_arg(dram_memref): + address += str(ir.BlockArgument(dram_memref).arg_number) + + # element size + et = tile_mt.element_type + if ir.IntegerType.isinstance(et): + element_size = ir.IntegerType(et).width + elif ir.FloatType.isinstance(et): + element_size = ir.FloatType(et).width + else: + raise RuntimeError("Unsupported element type") + + # tag indices + tag_index_list = [] + tag_stride_list = [] + for tag_idx in f["tag_indices"]: + c = _const_index_value(tag_idx) + if c is not None: + tag_index_list.append(str(c)) + continue + if not _is_block_arg(tag_idx) and tag_idx.owner.name == "affine.apply": + apply_op = tag_idx.owner + for operand in apply_op.operands: + if _is_block_arg(operand): + tag_index_list.append( + self.loop_var_name[_value_key(operand)]) + elif (not _is_block_arg(operand) + and operand.owner.name == "affine.apply"): + nested = operand.owner + for nop in nested.operands: + if _is_block_arg(nop): + tag_index_list.append( + self.loop_var_name[_value_key(nop)]) + else: + nc = _const_index_value(nop) + if nc is not None: + tag_index_list.append(str(nc)) + else: + oc = _const_index_value(operand) + if oc is not None: + tag_index_list.append(str(oc)) + amap = ir.AffineMapAttr(apply_op.attributes["map"]).value + tag_stride_list = _collect_coefficients(amap.results[0]) + + if len(tag_index_list) == 0: + tag_index_list.append("0") + if len(tag_stride_list) == 0: + tag_stride_list.append(1) + + dma_node = TOGDMANode("DMANode", address, tile_size, tile_stride, + element_size, is_write, dma_async, tag_index_list, + tag_stride_list, loop_idx_list, loop_stride_list, + indirect_box[0]) + dma_node.op = op + node.add_child(dma_node) + dma_node.add_parent(node) + + # ---- dma_wait ---- + def _handle_dma_wait(self, op, node): + oper = op.operation + operands = list(oper.operands) + tag = operands[0] + tag_rank = len(ir.MemRefType(tag.type).shape) + tag_indices = operands[1:1 + tag_rank] + + tag_index_list = [] + tag_stride_list = [] + tag_divider_list = [] + for tag_idx in tag_indices: + if not _is_block_arg(tag_idx) and tag_idx.owner.name == "affine.apply": + apply_op = tag_idx.owner + for operand in apply_op.operands: + if _is_block_arg(operand): + tag_index_list.append( + self.loop_var_name[_value_key(operand)]) + else: + c = _const_index_value(operand) + if c is not None: + tag_index_list.append(str(c)) + amap = ir.AffineMapAttr(apply_op.attributes["map"]).value + tag_stride_list = _collect_coefficients(amap.results[0]) + tag_divider_list = _collect_dividers(amap.results[0]) + + # base address: scan users of tag memref for a dma_start. + address = "arg" + for use in tag.uses: + user = use.owner + if user.name == "memref.dma_start": + f = self._dma_start_fields(user) + dst_space = _memref_space(f["dst_type"]) + src_space = _memref_space(f["src_type"]) + dram_memref = None + if dst_space == 0 and src_space == 1: + dram_memref = f["dst"] + elif dst_space == 1 and src_space == 0: + dram_memref = f["src"] + if dram_memref is not None and _is_block_arg(dram_memref): + address += str(ir.BlockArgument(dram_memref).arg_number) + + if len(tag_stride_list) == 0: + tag_stride_list.append(1) + tag_divider_list.append(1) + + wait_node = TOGDMAWaitNode("DMAWaitNode", tag_index_list, tag_stride_list, + tag_divider_list, address) + wait_node.op = op + node.add_child(wait_node) + wait_node.add_parent(node) + + @staticmethod + def _get_async(oper): + if "async" not in oper.attributes: + return 0 + attr = oper.attributes["async"] + try: + return ir.IntegerAttr(attr).value + except Exception: + pass + try: + return 1 if ir.BoolAttr(attr).value else 0 + except Exception: + return 1 + + +# --------------------------------------------------------------------------- +# Empty affine.for erasure (fixed point), mirroring eraseEmptyAffineForLoops. +# --------------------------------------------------------------------------- +def _is_empty_affine_for(op): + if op.operation.name != "affine.for": + return False + body = op.operation.regions[0].blocks[0] + ops = list(body.operations) + # body holds exactly the terminator (affine.yield) and nothing else. + if len(ops) != 1: + return False + return ops[0].operation.name == "affine.yield" + + +def _erase_empty_affine_for(func_op): + changed = True + while changed: + changed = False + to_erase = [] + + def walk(block): + for op in block.operations: + for region in op.operation.regions: + for b in region.blocks: + walk(b) + if _is_empty_affine_for(op): + to_erase.append(op) + + walk(func_op.regions[0].blocks[0]) + for op in to_erase: + op.operation.erase() + changed = True + + +# --------------------------------------------------------------------------- +# IR mutation (mirrors the side effects of -test-tile-operation-graph). +# --------------------------------------------------------------------------- +def _make_inline_asm(ctx, ip, asm_string, compute_type=None): + """Build an llvm.inline_asm compute marker identical to the C++ pass.""" + i64 = ir.IntegerType.get_signless(64) + attrs = { + "has_side_effects": ir.UnitAttr.get(), + "asm_dialect": ir.IntegerAttr.get(i64, 0), + "asm_string": ir.StringAttr.get(asm_string), + "constraints": ir.StringAttr.get("~{a0},~{memory}"), + } + if compute_type is not None: + attrs["compute_type"] = ir.StringAttr.get(compute_type) + return ir.Operation.create( + "llvm.inline_asm", results=[], operands=[], attributes=attrs, + loc=ir.Location.unknown(ctx), ip=ip) + + +def _containing_block(op): + """Return the Block that directly contains `op`.""" + parent = op.operation.parent + for region in parent.regions: + for block in region.blocks: + for sib in block.operations: + if sib.operation == op.operation: + return block + return None + + +def _next_sibling(op): + """Return the op immediately following `op` in its block, or None.""" + block = _containing_block(op) + siblings = list(block.operations) + for i, sib in enumerate(siblings): + if sib.operation == op.operation: + return siblings[i + 1] if i + 1 < len(siblings) else None + return None + + +_ASM_START = ".insn r CUSTOM_3, 0, 0x40, x0, x0, x0" +_ASM_END = ".insn r CUSTOM_3, 0, 0x41, x0, x0, x0" + + +def _rewrite_loop_steps(builder): + """Sample mode: rewrite each loop node's affine.for step to its end. + + Done after the graph is built so the printed graph keeps the original step. + """ + index_t = ir.IndexType.get() + for node in builder.loop_nodes: + forop = node.op.operation + new_step = ir.IntegerAttr.get(index_t, node.loop_end) + forop.attributes["step"] = new_step + + +def _insert_compute_markers(builder): + """Insert inline-asm start/end markers around each compute node's ops.""" + for node in builder.compute_nodes: + ops = node.operations + if not ops: + continue + front = ops[0] + back = ops[-1] + ctx = front.operation.context + ctype_name = _COMPUTE_TYPE_NAME[node.compute_type] + # End marker immediately after `back`: insert before back's next sibling + # (or append to the block when back is the last op). The bindings have no + # "insert after" point, and move_after on a detached op crashes. + after = _next_sibling(back) + if after is not None: + end_ip = ir.InsertionPoint(after) + else: + end_ip = ir.InsertionPoint(_containing_block(back)) + _make_inline_asm(ctx, end_ip, _ASM_END) + # Start marker immediately before `front`. Done after the end marker so + # that the single-op case (front == back) still brackets the op. + _make_inline_asm(ctx, ir.InsertionPoint(front), _ASM_START, ctype_name) + + +# --------------------------------------------------------------------------- +# Driver. +# --------------------------------------------------------------------------- +def _find_kernel(module): + for op in module.body.operations: + if op.operation.name != "func.func": + continue + if ir.StringAttr(op.operation.attributes["sym_name"]).value == "kernel": + return op + return None + + +def _build(module, builder): + """Build the graph and return its display string, populating `builder`.""" + func_op = _find_kernel(module) + if func_op is None: + return "" + + _erase_empty_affine_for(func_op) + + block = func_op.regions[0].blocks[0] + out = [] + for op in block.operations: + if op.operation.name != "affine.for": + continue + root = TOGNode("root") + builder._reset_matmul_fsm() + builder.print_operation(op, root) + root.bfs(out) + return "".join(out) + + +def build_tog_string(module): + # The C++ unique_id is a process global; each mlir-opt invocation starts at + # 0. When this runs in-process per kernel, reset so node ids match byte-exact. + _reset_ids() + return _build(module, TogBuilder()) + + +def build_tog_and_mutate(module, sample_mode=True): + """Build the TOG and apply the C++ pass IR mutation in place. + + Returns the graph display string (byte-exact with build_tog_string). The + caller is responsible for serializing the mutated module (str(module)). + """ + _reset_ids() + builder = TogBuilder() + result = _build(module, builder) + if sample_mode: + _rewrite_loop_steps(builder) + _insert_compute_markers(builder) + return result + + +def run_tog(in_path, tog_out_path, custom_out_path, sample_mode=True, vectorlane=128): + """In-process replacement for the C++ -test-tile-operation-graph pass. + + Reads the post-vcix MLIR at `in_path`, builds the TOG (written as the + `graph = {...}` Python literal to `tog_out_path`) and applies the pass IR + mutation (sample-mode step rewrite + compute markers), writing the mutated + module to `custom_out_path`. `vectorlane` is accepted for parity with the + C++ option; it does not affect the output. Requires the MLIR bindings. + """ + ctx = ir.Context() + ctx.allow_unregistered_dialects = True + with ctx: + module = ir.Module.parse(open(in_path).read(), ctx) + graph = build_tog_and_mutate(module, sample_mode=bool(sample_mode)) + with open(custom_out_path, "w") as fh: + fh.write(str(module)) + with open(tog_out_path, "w") as fh: + fh.write(graph) + + +def main(argv): + import argparse + + parser = argparse.ArgumentParser(prog="build_tog.py") + parser.add_argument("input") + parser.add_argument("--sample-mode", type=int, default=1, choices=(0, 1)) + parser.add_argument("--vectorlane", type=int, default=128) + parser.add_argument("--mutate-out", default=None) + args = parser.parse_args(argv[1:]) + + with open(args.input) as fh: + src = fh.read() + ctx = ir.Context() + ctx.allow_unregistered_dialects = True + with ctx: + module = ir.Module.parse(src, ctx) + if args.mutate_out is not None: + result = build_tog_and_mutate(module, sample_mode=bool(args.sample_mode)) + with open(args.mutate_out, "w") as fh: + fh.write(str(module)) + else: + result = build_tog_string(module) + sys.stdout.write(result) + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py new file mode 100644 index 00000000..c0e82b66 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py @@ -0,0 +1,297 @@ +"""Python out-of-line MLIR pass: decompose togsim.transfer -> <=4D memref.dma_start. + +A togsim.transfer carries a per-axis affine DMA whose descriptor rank may exceed +the 4D Gemmini limit. This pass lowers it to <=4D customized memref.dma_start +(see docs/dma-transfer-lowering.md): + + - drop unit (extent-1) tile dims: they contribute no descriptor axis; + - if the remaining (effective) rank <= 4 -> emit one customized + memref.dma_start, reusing the transfer's operands (fast path); + - if effective rank > 4 -> wrap the outer dims in an affine.for nest and emit + one <=4D memref.dma_start in the body, mirroring the C++ -dma-fine-grained + subtile loop. The slice DRAM/SRAM offsets are affine.apply over the loop vars; + the SRAM offset is the lane-banked physical offset (split-outer dims rescaled + by the lane coeff) delivered as the last SRAM index operand. + +It does NO floor/mod linearization (aligned split happens upstream at the +scheduling layer) and NO relayout (misaligned access is copy-inserted at the +graph level). A transfer whose access is not per-axis affine is a contract +violation -- but by construction codegen only emits affine transfers. + +togsim.transfer operands (see emit_transfer): + (dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_split_axis, vlane_stride) +attrs: dma_kind ("MVIN"/"MVOUT"), dram_stride[], tile_stride[], padding. + +memref.dma_start (customized) operands: + src[idx], dst[idx], dma_type, tag[idx], vlane_split_axis, vlane_stride + : src_memref, dst_memref, memref<1xi32> {dram_stride, sram_stride, padding} + +Pass interface (passes/__init__.py): MARKERS + run(module). +""" + +OP_NAME = "togsim.transfer" +MARKERS = (OP_NAME,) + + +def _iter_ops(block): + for op in list(block.operations): + yield op + for region in op.operation.regions: + for b in region.blocks: + yield from _iter_ops(b) + + +def _int_array(attr): + from mlir.ir import ArrayAttr, IntegerAttr + return [IntegerAttr(a).value for a in ArrayAttr(attr)] + + +def _const_int(value, default=None): + """Read an arith.constant index/integer operand's value, else `default`.""" + from mlir.ir import IntegerAttr + try: + return IntegerAttr(value.owner.attributes["value"]).value + except Exception: + return default + + +def _squeeze_reassociation(shape): + """Group source dims so each group's product is one effective (non-unit) dim; + unit dims attach to a neighbor. Returns (groups, target_shape).""" + groups, cur = [], [] + for i, e in enumerate(shape): + cur.append(i) + if e > 1: + groups.append(cur) + cur = [] + if cur: # trailing unit dims + if groups: + groups[-1] += cur + else: + groups.append(cur) # all-ones -> single dim of size 1 + import math + target = [math.prod(shape[d] for d in g) for g in groups] + return groups, target + + +def run(module, vectorlane=128, **_): + """Lower every togsim.transfer in `module`, in place. Context must be active. + + vectorlane (= systolic-array size / number of vector lanes) feeds the lane-banked + physical SRAM offset in the >4D peel, matching -dma-fine-grained's + systolic-array-size option. + """ + from mlir.ir import (InsertionPoint, Operation, MemRefType, ArrayAttr, + IntegerAttr, IntegerType, IndexType, DenseI64ArrayAttr, + DenseI32ArrayAttr, StridedLayoutAttr, AffineMap, AffineMapAttr, + AffineExpr, BoolAttr) + from mlir.dialects import affine + i64 = IntegerType.get_signless(64) + idx_ty = IndexType.get() + + targets = [] + for region in module.operation.regions: + for b in region.blocks: + for op in _iter_ops(b): + if op.operation.name == OP_NAME: + targets.append(op.operation) + + for op in targets: + dram, dram_idx, sram, sram_idx, tag, dma_type, vst = op.operands + kind = op.attributes["dma_kind"].value # StringAttr -> "MVIN"/"MVOUT" + vlane_axis = IntegerAttr(op.attributes["vlane_split_axis"]).value + dram_stride = _int_array(op.attributes["dram_stride"]) + tile_stride = _int_array(op.attributes["tile_stride"]) + vlane_stride = _const_int(vst, 1) + padding = op.attributes["padding"] + try: + subtile = _int_array(op.attributes["subtile_size"]) + async_attr = op.attributes["async"] + except KeyError: + subtile, async_attr = None, None + + sram_ty = MemRefType(sram.type) + elem, space = sram_ty.element_type, sram_ty.memory_space + tile_shape = list(sram_ty.shape) + # effective (non-unit) dims carry the descriptor; unit dims drop out. + eff = [i for i, e in enumerate(tile_shape) if e > 1] + + def _const(v): + return Operation.create( + "arith.constant", results=[idx_ty], + attributes={"value": IntegerAttr.get(idx_ty, v)}).results[0] + + def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_attr=None): + vsa = _const(vsa_val) + if kind == "MVIN": + operands = [dram, dram_idx_val, sram_mem, *sram_indices, + dma_type, tag, sram_idx, vsa, vst] + else: + operands = [sram_mem, *sram_indices, dram, dram_idx_val, + dma_type, tag, sram_idx, vsa, vst] + attrs = {"dram_stride": dr_attr, "sram_stride": tl_attr, "padding": padding} + if st_attr is not None: + attrs["subtile_size"] = st_attr + attrs["async"] = async_attr + Operation.create( + "memref.dma_start", results=[], operands=operands, attributes=attrs) + + if len(tile_shape) <= 4: + # Already <=4D: emit the descriptor directly on the original SRAM, no + # collapse_shape. The C++ -dma-fine-grained subtile split walks the SRAM + # operand and chokes on a collapse_shape result, so keep it a direct buffer. + dr_attr = ArrayAttr.get([IntegerAttr.get(i64, s) for s in dram_stride]) + tl_attr = ArrayAttr.get([IntegerAttr.get(i64, s) for s in tile_stride]) + st_attr = (ArrayAttr.get([IntegerAttr.get(i64, s) for s in subtile]) + if subtile is not None else None) + with InsertionPoint(op): + _emit(sram, [sram_idx] * len(tile_shape), dram_idx, vlane_axis, + dr_attr, tl_attr, st_attr) + op.erase() + continue + + if len(eff) <= 4: + # Fast path: drop unit dims so the descriptor reaches <=4D. The customized + # dma_start convention requires SRAM rank == #indices == len(sram_stride), + # so collapse the unit tile dims away. DRAM stays flat rank-1 (its N-D + # structure is in dram_stride). + groups, target = _squeeze_reassociation(tile_shape) + reassoc = ArrayAttr.get( + [ArrayAttr.get([IntegerAttr.get(i64, d) for d in g]) for g in groups]) + collapsed_ty = MemRefType.get(target, elem, memory_space=space) + # the non-unit dim in each group (g[-1] is wrong when trailing unit dims + # attach after it, e.g. [..,4,1,1] -> the kept dim must be the extent-4 one). + keep = [next((d for d in g if tile_shape[d] > 1), g[-1]) for g in groups] + dr_attr = ArrayAttr.get([IntegerAttr.get(i64, dram_stride[i]) for i in keep]) + tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[i]) for i in keep]) + st_attr = (ArrayAttr.get([IntegerAttr.get(i64, subtile[i]) for i in keep]) + if subtile is not None else None) + # Remap vlane axis to the collapsed-dim index (the group containing it). + new_vlane = next(gi for gi, g in enumerate(groups) if vlane_axis in g) + with InsertionPoint(op): + sram_c = Operation.create( + "memref.collapse_shape", results=[collapsed_ty], operands=[sram], + attributes={"reassociation": reassoc}).results[0] + _emit(sram_c, [sram_idx] * len(target), dram_idx, new_vlane, + dr_attr, tl_attr, st_attr) + op.erase() + continue + + # Peel path: >4 effective dims. Wrap the outer (len-4) effective dims in an + # affine.for nest (one loop per outer dim, marked inner_loop so build_tog/TOG + # registers the induction var) and emit a single <=4D memref.dma_start in the + # innermost body -- mirroring the C++ -dma-fine-grained subtile loop. + # + # The slice SRAM offset is the PHYSICAL lane-banked offset: dims outer than the + # vlane axis are rescaled by the lane coeff (stride/old_size*new_size, the MVIN + # block_stride / buildSramAffineMap rule). It is delivered as the last SRAM index + # operand (row-major stride 1), NOT a subview offset -- the gemmini lowering reads + # the spad base via extract_aligned_pointer_as_index, which strips the subview + # offset, so the slice must be selected through the index. The DRAM offset is the + # flat contiguous offset, folded with the original dram_idx into one affine.apply + # (an arith.addi would be opaque to processDramIndices -- #258); the affine.for + # induction vars feed both maps so TOG reads the loop indices through them. + peeled, inner = eff[:-4], eff[-4:] + ndim = len(tile_shape) + inner_shape = [tile_shape[d] for d in inner] + inner_strides = [tile_stride[d] for d in inner] + dr_attr = ArrayAttr.get([IntegerAttr.get(i64, dram_stride[d]) for d in inner]) + tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[d]) for d in inner]) + st_attr = (ArrayAttr.get([IntegerAttr.get(i64, subtile[d]) for d in inner]) + if subtile is not None else None) + # the vlane axis must survive into the inner descriptor (it is the lane dim). + if vlane_axis in inner: + new_vlane = inner.index(vlane_axis) + elif vlane_axis in peeled: + # lane dim peeled into the outer loop nest: it cannot be expressed in the + # <=4D descriptor, and _phys's lane-banking assumes it is a real axis. + raise NotImplementedError( + f"vlane split axis {vlane_axis} peeled into the outer loop nest; " + f">4D DMA peel cannot place the lane dim in the <=4D descriptor") + else: + new_vlane = 0 # vlane axis is a unit dim; lane on the first inner dim + + # Lane-banked physical stride for split-outer dims (vlane_stride defaults to 1). + split_extent = tile_shape[vlane_axis] + nr_outerloop = max( + (split_extent + vectorlane * vlane_stride - 1) // (vectorlane * vlane_stride), 1) + new_size = nr_outerloop * vlane_stride + target_stride = tile_stride[vlane_axis] + + def _phys(d): + s = tile_stride[d] + return s // split_extent * new_size if s > target_stride else s + + # subview to the inner <=4D block at the buffer start (offset 0); slice selection + # is done through the SRAM index, so the StridedLayout offset stays 0. + static_sizes = [1] * ndim + for d in inner: + static_sizes[d] = tile_shape[d] + res_ty = MemRefType.get( + inner_shape, elem, + layout=StridedLayoutAttr.get(0, inner_strides), memory_space=space) + + # affine.for nest over the peeled (outer) dims. + cur_ip = InsertionPoint(op) + ivs = [] + for d in peeled: + floop = affine.AffineForOp(0, tile_shape[d], 1, ip=cur_ip) + floop.operation.attributes["inner_loop"] = BoolAttr.get(True) + ivs.append(floop.induction_variable) + with InsertionPoint(floop.body): + affine.AffineYieldOp([]) + cur_ip = InsertionPoint.at_block_terminator(floop.body) + + npeel = len(peeled) + with cur_ip: + sub = Operation.create( + "memref.subview", results=[res_ty], operands=[sram], + attributes={"static_offsets": DenseI64ArrayAttr.get([0] * ndim), + "static_sizes": DenseI64ArrayAttr.get(static_sizes), + "static_strides": DenseI64ArrayAttr.get([1] * ndim), + # i32 [source, offsets, sizes, strides] dynamic-operand counts; + # all static -> source only. i64 silently zeroes and fails verify. + "operandSegmentSizes": DenseI32ArrayAttr.get([1, 0, 0, 0])} + ).results[0] + # physical SRAM offset = sum_k iv_k * phys_stride(peeled[k]) + sram_expr = AffineExpr.get_dim(0) * _phys(peeled[0]) + for k in range(1, npeel): + sram_expr = sram_expr + AffineExpr.get_dim(k) * _phys(peeled[k]) + sram_off_val = Operation.create( + "affine.apply", results=[idx_ty], operands=list(ivs), + attributes={"map": AffineMapAttr.get(AffineMap.get(npeel, 0, [sram_expr]))} + ).results[0] + # DRAM index = orig dram_idx + sum_k iv_k * dram_stride(peeled[k]) + dram_expr = AffineExpr.get_dim(0) + for k in range(npeel): + dram_expr = dram_expr + AffineExpr.get_dim(k + 1) * dram_stride[peeled[k]] + dram_idx_val = Operation.create( + "affine.apply", results=[idx_ty], operands=[dram_idx, *ivs], + attributes={"map": AffineMapAttr.get(AffineMap.get(npeel + 1, 0, [dram_expr]))} + ).results[0] + zero = _const(0) + _emit(sub, [zero, zero, zero, sram_off_val], dram_idx_val, new_vlane, + dr_attr, tl_attr, st_attr) + op.erase() + + +def lower_text(text: str) -> str: + """Parse `text`, run this pass, return the printed module. CLI/testing helper.""" + if OP_NAME not in text: + return text + from mlir.ir import Context, Module, Location + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + m = Module.parse(text) + run(m) + return str(m) + + +if __name__ == "__main__": + import sys + out = lower_text(open(sys.argv[1]).read()) + if len(sys.argv) > 2: + open(sys.argv[2], "w").write(out) + else: + sys.stdout.write(out) diff --git a/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py b/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py new file mode 100644 index 00000000..3f583ef2 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py @@ -0,0 +1,406 @@ +"""Python port of the C++ `-dma-fine-grained` MLIR pass (TestDmaFineGrained.cpp). + +Splits the matmul MVIN DMAs (input / weight / optional bias) into subtile loops +(affine.for nests carrying per-subtile DRAM/SRAM offset affine.apply maps) and +fuses the input and weight loop nests, mirroring the C++ pass structurally. Runs +AFTER -test-loop-padding (it reads the padded tile shapes / loop bounds) and +BEFORE -test-pytorchsim-to-vcix, so extension_codecache splits the single mlir-opt +invocation around this pass. + +The C++ pass fuses the two subtile loop nests by cloning their bodies with an +IRMapping; the MLIR Python bindings expose no IRMapping, so this port builds the +fused nest directly and emits each DMA inside it using the fused induction vars +(equivalence target: same loop structure / counts, same offset maps, same +dma_start operands+attrs -- validated against mlir-opt -dma-fine-grained and the +end-to-end gemm/conv/model tests, not byte-exact SSA text). + +Operates on the customized memref.dma_start convention (see lower_dma_to_gemmini): +operands = src, *src_idx, dst, *dst_idx, num_elements(dma_type), tag, *tag_idx, +stride(=vlane_split_axis), num_elements_per_stride(=vlane_stride). MVIN dma_type in +{2,1,14}; tile shape = dst shape for MVIN. + +Pipeline entry point: run_fine_grained(in_path, out_path, vectorlane). +""" +import os +import sys + +_DEFAULT_BINDINGS = "/riscv-llvm/python_packages/mlir_core" +if os.path.isdir(_DEFAULT_BINDINGS) and _DEFAULT_BINDINGS not in sys.path: + sys.path.insert(0, _DEFAULT_BINDINGS) + +import mlir.ir as ir # noqa: E402 + +MARKERS = ("subtile_size",) # only subtile DMAs are split + +MVIN, MVIN2, MVIN3, MVOUT = 2, 1, 14, 3 + +# Per-rank subtile loop order and the fused-loop layout (mirror the C++ loopGroups). +# in_to_fused[d] / w_to_fused[d] give the fused-loop index that the input/weight +# DMA's dim d iterates; n_fused is the number of fused affine.for loops. +_FUSE = { + 2: dict(n_fused=3, in_to_fused=[0, 1], w_to_fused=[1, 2]), + 3: dict(n_fused=4, in_to_fused=[0, 1, 2], w_to_fused=[0, 2, 3]), + 4: dict(n_fused=7, in_to_fused=[0, 1, 4, 5], w_to_fused=[2, 3, 5, 6]), +} + + +# --------------------------------------------------------------------------- +# Small readers (mirror CustomDMAAttribute.h) +# --------------------------------------------------------------------------- +def _const_int(value, default=-1): + try: + return ir.IntegerAttr(value.owner.attributes["value"]).value + except Exception: + return default + + +def _int_array_attr(op, key): + if key not in op.attributes: + return [] + return [ir.IntegerAttr(a).value for a in ir.ArrayAttr(op.attributes[key])] + + +def _is_block_arg(v): + return isinstance(v, ir.BlockArgument) + + +class _Dma: + """Positional view of a customized memref.dma_start op.""" + + def __init__(self, op): + self.op = op + operands = list(op.operands) + src_rank = len(ir.MemRefType(operands[0].type).shape) + i = 0 + self.src = operands[i]; i += 1 + self.src_idx = operands[i:i + src_rank]; i += src_rank + self.dst = operands[i]; i += 1 + dst_rank = len(ir.MemRefType(self.dst.type).shape) + self.dst_idx = operands[i:i + dst_rank]; i += dst_rank + self.num_elements = operands[i]; i += 1 + self.tag = operands[i]; i += 1 + tag_rank = len(ir.MemRefType(self.tag.type).shape) + self.tag_idx = operands[i:i + tag_rank]; i += tag_rank + self.stride = operands[i]; i += 1 # = vlane_split_axis + self.num_elements_per_stride = operands[i] # = vlane_stride + self.src_rank, self.dst_rank, self.tag_rank = src_rank, dst_rank, tag_rank + + @property + def dma_type(self): + return _const_int(self.num_elements) + + @property + def is_mvin(self): + return self.dma_type in (MVIN, MVIN2, MVIN3) + + @property + def vlane_split_axis(self): + return _const_int(self.stride) + + @property + def vlane_stride(self): + return _const_int(self.num_elements_per_stride) & 0x7FFF + + def tile_shape(self): + mt = ir.MemRefType((self.dst if self.is_mvin else self.src).type) + return list(mt.shape) + + def subtile_size(self): + return _int_array_attr(self.op, "subtile_size") + + def sram_stride(self): + return _int_array_attr(self.op, "sram_stride") + + def dram_stride(self): + return _int_array_attr(self.op, "dram_stride") + + def is_async(self): + a = self.op.attributes + if "async" not in a: + return False + try: + return bool(ir.IntegerAttr(a["async"]).value) + except Exception: + return True + + +# --------------------------------------------------------------------------- +# Affine map builders (mirror buildDramAffineMap / buildSramAffineMap) +# --------------------------------------------------------------------------- +def _ceil_div(a, b): + return (a + b - 1) // b + + +def _build_dram_map(dma): + dram = dma.dram_stride() + sub = dma.subtile_size() + rank = len(dram) + expr = ir.AffineConstantExpr.get(0) + for i in range(rank): + expr = expr + ir.AffineDimExpr.get(i) * (dram[i] * sub[i]) + return ir.AffineMap.get(rank, 0, [expr]) + + +def _build_sram_map(dma, vectorlane): + tile_shape = dma.tile_shape() + tile_stride = dma.sram_stride() + sub = dma.subtile_size() + split = dma.vlane_split_axis + vstride = dma.vlane_stride + + target_stride = tile_stride[split] + old_size = tile_shape[split] + nr_outerloop = _ceil_div(old_size, vectorlane * vstride) + new_size = nr_outerloop * vstride + + expr = None + for i in range(len(tile_stride)): + subtilesize = sub[i] + stride = tile_stride[i] + if stride > target_stride: + stride = stride // old_size * new_size + d = ir.AffineDimExpr.get(i) + if i != split: + term = d * (subtilesize * stride) + else: + term = ir.AffineExpr.get_floor_div(d * subtilesize, vectorlane) * stride + expr = term if expr is None else expr + term + return ir.AffineMap.get(len(tile_stride), 0, [expr]) + + +def _build_tag_map(dma, loop_order): + """Mirror the tag stride map built inside buildSubtileLoop.""" + tile_sizes = dma.tile_shape() + sub = dma.subtile_size() + rank = len(tile_sizes) + strides = [1] * rank + for i in range(rank - 2, -1, -1): + cur, nxt = loop_order[i], loop_order[i + 1] + strides[cur] = strides[nxt] * _ceil_div(tile_sizes[nxt], sub[nxt]) + expr = ir.AffineConstantExpr.get(0) + for i in range(rank): + expr = expr + ir.AffineDimExpr.get(i) * strides[i] + return ir.AffineMap.get(rank, 0, [expr]) + + +def _loop_counts(dma, loop_order): + tile_sizes = dma.tile_shape() + sub = dma.subtile_size() + return [_ceil_div(tile_sizes[d], sub[d]) for d in range(len(tile_sizes))] + + +# --------------------------------------------------------------------------- +# DMA emission inside a body +# --------------------------------------------------------------------------- +def _sum_map(): + d0, d1 = ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1) + return ir.AffineMap.get(2, 0, [d0 + d1]) + + +def _apply(map_, operands, ip): + from mlir.dialects import affine + return affine.AffineApplyOp(map_, list(operands), ip=ip).result + + +def _dma_attrs(dma): + """Mirror getDmaAttrs: keep subtile/sram/dram strides, set async + fine_grained.""" + attrs = {} + op = dma.op + for k in ("subtile_size", "sram_stride", "dram_stride"): + if k in op.attributes: + attrs[k] = op.attributes[k] + attrs["async"] = ir.BoolAttr.get(dma.is_async()) + attrs["fine_grained"] = ir.BoolAttr.get(True) + return attrs + + +def _emit_dma(dma, ivs, vectorlane, ip): + """Emit one fine-grained memref.dma_start at `ip`, indexed by `ivs` (the fused + induction vars for this DMA's dims, in dim order).""" + idx_ty = ir.IndexType.get() + zero = _const_index(0, ip) + + dram_off = _apply(_build_dram_map(dma), ivs, ip) + src_idx0 = dma.src_idx[0] + dram_idx = _apply(_sum_map(), [dram_off, src_idx0], ip) + + sram_off = _apply(_build_sram_map(dma, vectorlane), ivs, ip) + tag_idx = _apply(_build_tag_map(dma, list(range(len(dma.tile_shape())))), ivs, ip) + + # SRAM indices: zeros except the last = sram offset (mirror sramIndices.back()). + sram_indices = [zero] * dma.dst_rank + sram_indices[-1] = sram_off + + operands = [dma.src, dram_idx, dma.dst, *sram_indices, + dma.num_elements, dma.tag, tag_idx, + dma.stride, dma.num_elements_per_stride] + ir.Operation.create("memref.dma_start", results=[], operands=operands, + attributes=_dma_attrs(dma), ip=ip) + + +def _const_index(v, ip): + from mlir.dialects import arith + return arith.ConstantOp(ir.IndexType.get(), + ir.IntegerAttr.get(ir.IndexType.get(), v), ip=ip).result + + +# --------------------------------------------------------------------------- +# Loop-nest construction +# --------------------------------------------------------------------------- +def _build_for_nest(bounds, ip): + """Create a nested affine.for over `bounds` (step 1, marked inner_loop). Returns + (induction_vars, innermost_body_ip_before_yield).""" + from mlir.dialects import affine + ivs = [] + cur_ip = ip + for b in bounds: + floop = affine.AffineForOp(0, b, 1, ip=cur_ip) + floop.operation.attributes["inner_loop"] = ir.BoolAttr.get(True) + ivs.append(floop.induction_variable) + with ir.InsertionPoint(floop.body): + affine.AffineYieldOp([]) + cur_ip = ir.InsertionPoint.at_block_terminator(floop.body) + return ivs, cur_ip + + +def _create_subtile_dma(dma, loop_order, ip, vectorlane): + """Standalone subtile loop for one DMA (used for bias). Mirrors createSubtileDMA.""" + counts = _loop_counts(dma, loop_order) + bounds = [counts[d] for d in loop_order] + ivs_in_order, body_ip = _build_for_nest(bounds, ip) + # map dim -> its induction var (loop_order[k] is the dim of the k-th loop) + iv_by_dim = [None] * len(counts) + for k, d in enumerate(loop_order): + iv_by_dim[d] = ivs_in_order[k] + _emit_dma(dma, iv_by_dim, vectorlane, body_ip) + + +# --------------------------------------------------------------------------- +# Operand reachability (mirror traverseOperands) +# --------------------------------------------------------------------------- +def _reaches(value, target): + if value == target: + return True + owner = value.owner + if isinstance(owner, ir.Block): # block argument: no defining op to walk + return False + for operand in owner.operands: + if _reaches(operand, target): + return True + return False + + +# --------------------------------------------------------------------------- +# Pass driver +# --------------------------------------------------------------------------- +def _iter_ops(block): + for op in list(block.operations): + yield op + for region in op.operation.regions: + for b in region.blocks: + yield from _iter_ops(b) + + +def _run_func(func, vectorlane): + from mlir.dialects import linalg + # First matmul only. + matmul = None + dmas = [] + for op in _iter_ops(func.regions[0].blocks[0]): + name = op.operation.name + if name == "linalg.matmul" and matmul is None: + matmul = op + elif name == "memref.dma_start": + dmas.append(op) + if matmul is None: + return + + m_in = matmul.operands[0] + m_w = matmul.operands[1] + m_res = list(matmul.operands)[-1] # output (init) operand + + mvin_input = mvin_weight = mvin_bias = None + for op in dmas: + d = _Dma(op) + if d.dma_type == MVOUT: + continue + if _reaches(m_in, d.dst): + mvin_input = d + elif _reaches(m_w, d.dst): + mvin_weight = d + elif _reaches(m_res, d.dst) and len(d.subtile_size()) > 1: + mvin_bias = d + + in_async = mvin_input is not None and mvin_input.is_async() + w_async = mvin_weight is not None and mvin_weight.is_async() + if not (in_async or w_async): + return + if mvin_input is None or mvin_weight is None: + return + + rank = len(mvin_input.tile_shape()) + if rank not in _FUSE: + return + fuse = _FUSE[rank] + loop_order = list(range(rank)) + + # Bias first (standalone), inserted before its own op. + if mvin_bias is not None: + brank = len(mvin_bias.tile_shape()) + border = {2: [0, 1], 4: [2, 3, 0, 1]}.get(brank) + if border is not None: + _create_subtile_dma(mvin_bias, border, + ir.InsertionPoint(mvin_bias.op), vectorlane) + mvin_bias.op.erase() + + # Fused input + weight nest. Fused loop bounds: take each fused loop's count + # from whichever DMA dim maps onto it. + in_counts = _loop_counts(mvin_input, loop_order) + w_counts = _loop_counts(mvin_weight, loop_order) + bounds = [None] * fuse["n_fused"] + for d, f in enumerate(fuse["in_to_fused"]): + bounds[f] = in_counts[d] + for d, f in enumerate(fuse["w_to_fused"]): + bounds[f] = w_counts[d] + + # Insert the fused nest at the weight DMA (the later of the two): both DMAs' + # original DRAM base indices (src_idx[0], computed in the enclosing loops) must + # dominate the nest. Codegen emits input before weight, matching the C++ pass + # which fuses after the weight subtile loop. + ip = ir.InsertionPoint(mvin_weight.op) + fused_ivs, body_ip = _build_for_nest(bounds, ip) + in_ivs = [fused_ivs[fuse["in_to_fused"][d]] for d in range(rank)] + w_ivs = [fused_ivs[fuse["w_to_fused"][d]] for d in range(rank)] + _emit_dma(mvin_input, in_ivs, vectorlane, body_ip) + _emit_dma(mvin_weight, w_ivs, vectorlane, body_ip) + mvin_input.op.erase() + mvin_weight.op.erase() + + +def run(module, vectorlane=128, **_): + """Apply fine-grained DMA subtiling to every func in `module`, in place.""" + from mlir.dialects import func as func_d # noqa: F401 (ensure dialect loaded) + for region in module.operation.regions: + for b in region.blocks: + for op in list(b.operations): + if op.operation.name == "func.func" and len(op.operation.regions[0].blocks): + _run_func(op, vectorlane) + + +def run_fine_grained(in_path, out_path, vectorlane=128): + """Parse `in_path`, run the pass, write `out_path`. Pipeline entry point.""" + with open(in_path) as f: + text = f.read() + ctx = ir.Context() + ctx.allow_unregistered_dialects = True + with ctx, ir.Location.unknown(): + module = ir.Module.parse(text) + run(module, vectorlane=vectorlane) + out = str(module) + with open(out_path, "w") as f: + f.write(out) + + +if __name__ == "__main__": + vl = int(sys.argv[3]) if len(sys.argv) > 3 else 128 + run_fine_grained(sys.argv[1], sys.argv[2], vl) diff --git a/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py b/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py new file mode 100644 index 00000000..f5b841bb --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py @@ -0,0 +1,227 @@ +"""Lower customized memref.dma_start ops to Gemmini RISC-V inline asm. + +Python port of the C++ test-memref-to-gemmini conversion. Each memref.dma_start +(carrying dram_stride / sram_stride / subtile_size attrs and vlane params encoded +in its stride / num_elements_per_stride / num_elements operands) becomes a +sequence of `llvm.inline_asm` ".insn r CUSTOM_1 ..." Gemmini instructions: +config_mvin/mvout, config2 (dram strides), config3 (spad strides), then the +mvin/mvout itself with the DRAM and scratchpad byte addresses. + +The conversion-framework coupling of the C++ pass (LLVMTypeConverter, +getStridedElementPtr, MemRefDescriptor) is avoided by working at the memref level: +addresses are computed with `memref.extract_aligned_pointer_as_index` + arith, +and the existing standard MLIR->LLVM lowering finalizes everything. Pass order: +this runs on memref-level IR (after test-pytorchsim-to-vcix), before +run_standard_lowering. + +NOTE: indirect-access (gather) dma_start is not yet handled (Phase 2); such ops +raise so they are caught rather than silently mishandled. +""" + +OP_NAME = "memref.dma_start" +WAIT_NAME = "memref.dma_wait" +MARKERS = (OP_NAME, WAIT_NAME) + +# func7 instruction codes (CustomDMAAttribute.h) +CONFIG, CONFIG2, CONFIG3, CONFIG4 = 0, 4, 5, 6 +MVIN, MVIN2, MVIN3, MVOUT = 2, 1, 14, 3 +CONFIG_TYPE = {MVIN: 0, MVIN2: 1, MVIN3: 2, MVOUT: 3} +MAX_TENSOR_DIM = 4 +CONSTRAINTS = "r,r,~{dirflag},~{fpsr},~{flags}" + + +def _asm(func7): + return f".insn r CUSTOM_1, 0x3, {func7}, x0, $0, $1" + + +def _i64_signed(v): + """Wrap an unsigned 64-bit packed value into signed int64 (matches C++ getI64IntegerAttr).""" + v &= 0xFFFFFFFFFFFFFFFF + return v - (1 << 64) if v >= (1 << 63) else v + + +def _row_major_strides(shape): + strides = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + return strides + + +def run(module, timing=False): + """Lower memref.dma_start / dma_wait to Gemmini instructions. + + timing=False (functional/Spike): dma_start -> gemmini config + mvin/mvout asm. + timing=True (gem5 cycle path): dma_start is erased (the TOG already carries + DMA timing; the cycle binary needs no asm). + memref.dma_wait is erased in both modes (matches C++ DmaWaitOpLowering). + """ + from mlir.ir import (InsertionPoint, Operation, IntegerType, IndexType, + IntegerAttr, MemRefType) + from mlir.dialects import llvm, arith, memref + + i64 = IntegerType.get_signless(64) + idx = IndexType.get() + + def const_int(val): + return IntegerAttr(val.owner.attributes["value"]).value + + def i64_const(value): + return arith.ConstantOp(i64, IntegerAttr.get(i64, _i64_signed(value))).result + + def asm(func7, rs1, rs2): + llvm.InlineAsmOp(None, [rs1, rs2], _asm(func7), CONSTRAINTS, + has_side_effects=True, asm_dialect=0) + + def elem_addr_i64(memref_val, indices, mtype, elem_bytes): + """i64 byte address of memref_val[indices] (aligned ptr + linear elem offset).""" + base = memref.ExtractAlignedPointerAsIndexOp(memref_val).result # index = byte addr + strides = _row_major_strides(list(mtype.shape)) + off = None # element offset (index) + for k, ival in enumerate(indices): + if strides[k] == 0: + continue + term = ival + if strides[k] != 1: + term = arith.MulIOp(ival, arith.ConstantOp(idx, IntegerAttr.get(idx, strides[k])).result).result + off = term if off is None else arith.AddIOp(off, term).result + if off is not None: + byte = arith.MulIOp(off, arith.ConstantOp(idx, IntegerAttr.get(idx, elem_bytes)).result).result + base = arith.AddIOp(base, byte).result + return arith.IndexCastOp(i64, base).result + + starts, waits = [], [] + for region in module.operation.regions: + for b in region.blocks: + _collect(b, starts, waits) + + for op in waits: # dma_wait: erase in both modes + op.erase() + + for op in starts: + if timing: # gem5 cycle path: drop the dma_start (TOG has timing) + op.erase() + continue + operands = list(op.operands) + src, dst = operands[0], None + src_ty = MemRefType(src.type) + src_rank = len(src_ty.shape) + dst = operands[1 + src_rank] + dst_ty = MemRefType(dst.type) + dst_rank = len(dst_ty.shape) + src_idx = operands[1:1 + src_rank] + dst_idx = operands[1 + src_rank + 1:1 + src_rank + 1 + dst_rank] + + dma_type = const_int(operands[1 + src_rank + 1 + dst_rank]) # num_elements + vlane_split_axis = const_int(operands[-2]) # stride (always 2nd-to-last) + vlane_stride = const_int(operands[-1]) & 0x7FFF # num_elements_per_stride (last) + is_mvin = dma_type in (MVIN, MVIN2, MVIN3) + + elem_bytes = _elem_bytes(src_ty.element_type) + # Indirect (gather): the gather-side indices are src for mvin, dst for mvout. + gather_idx = src_idx if is_mvin else dst_idx + indirect, indirect_memref = _find_indirect(gather_idx) + + tile_shape = _subtile(op) + if tile_shape is None: + tile_shape = list(dst_ty.shape) if is_mvin else list(src_ty.shape) + dram_strides = _int_array(op, "dram_stride") + spad_strides = _int_array(op, "sram_stride") + assert len(tile_shape) == len(dram_strides) == len(spad_strides), \ + f"shape/stride rank mismatch: {tile_shape} {dram_strides} {spad_strides}" + + expand = MAX_TENSOR_DIM - len(tile_shape) + shape4 = [1] * expand + tile_shape + dram4 = [0] * expand + dram_strides + spad4 = [0] * expand + spad_strides + vlane_split_axis += expand + config_type = CONFIG_TYPE[dma_type] + + with InsertionPoint(op): + addrA = elem_addr_i64(src, src_idx, src_ty, elem_bytes) + addrB = elem_addr_i64(dst, dst_idx, dst_ty, elem_bytes) + dram_addr, spad_addr = (addrA, addrB) if is_mvin else (addrB, addrA) + + cfg_rs1 = i64_const(((shape4[0] & 0xFFFF) << 48) | ((shape4[1] & 0xFFFF) << 32) + | ((shape4[2] & 0xFFFF) << 16) | (shape4[3] & 0xFFFF)) + cfg_rs2 = i64_const((vlane_stride << 32) | ((config_type & 0x3) << 17) + | ((1 if indirect else 0) << 16) + | ((vlane_split_axis & 0x3) << 14) | elem_bytes) + asm(CONFIG, cfg_rs1, cfg_rs2) + asm(CONFIG2, i64_const((dram4[0] << 32) | (dram4[1] & 0xFFFFFFFF)), + i64_const((dram4[2] << 32) | (dram4[3] & 0xFFFFFFFF))) + asm(CONFIG3, i64_const((spad4[0] << 32) | (spad4[1] & 0xFFFFFFFF)), + i64_const((spad4[2] << 32) | (spad4[3] & 0xFFFFFFFF))) + if indirect: + # CONFIG4: rs1 = indirect index-spad base address, rs2 = (elem_size<<16)|stride(1) + ind_base = memref.ExtractAlignedPointerAsIndexOp(indirect_memref).result + ind_addr = arith.IndexCastOp(i64, ind_base).result + ind_esize = _elem_bytes(MemRefType(indirect_memref.type).element_type) + asm(CONFIG4, ind_addr, i64_const(((ind_esize & 0xFF) << 16) | (1 & 0xFFFF))) + asm(dma_type, dram_addr, spad_addr) + op.erase() + + +def _collect(block, starts, waits): + for op in list(block.operations): + name = op.operation.name + if name == OP_NAME: + starts.append(op.operation) + elif name == WAIT_NAME: + waits.append(op.operation) + for region in op.operation.regions: + for b in region.blocks: + _collect(b, starts, waits) + + +def _subtile(op): + from mlir.ir import ArrayAttr, IntegerAttr + if "subtile_size" not in op.attributes: + return None + return [IntegerAttr(a).value for a in ArrayAttr(op.attributes["subtile_size"])] + + +def _int_array(op, name): + from mlir.ir import ArrayAttr, IntegerAttr + return [IntegerAttr(a).value for a in ArrayAttr(op.attributes[name])] + + +def _elem_bytes(elem_type): + from mlir.ir import IntegerType, FloatType + bits = (IntegerType(elem_type).width if IntegerType.isinstance(elem_type) + else FloatType(elem_type).width) + return max(bits, 8) // 8 + + +def _find_indirect(indices): + """If a gather index is an affine.apply{indirect_access} whose operands include + index_cast(affine.load(%spad)), return (True, %spad memref); else (False, None).""" + for idx in indices: + ap = idx.owner + if getattr(ap, "name", None) != "affine.apply" or "indirect_access" not in ap.attributes: + continue + for operand in ap.operands: + ic = operand.owner + if getattr(ic, "name", None) != "arith.index_cast": + continue + ld = ic.operands[0].owner + if getattr(ld, "name", None) == "affine.load": + return True, ld.operands[0] # affine.load operand 0 == the index spad memref + return False, None + + +def lower_text(text): + if OP_NAME not in text: + return text + from mlir.ir import Context, Module, Location + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + m = Module.parse(text) + run(m) + return str(m) + + +if __name__ == "__main__": + import sys + out = lower_text(open(sys.argv[1]).read()) + (open(sys.argv[2], "w").write(out) if len(sys.argv) > 2 else sys.stdout.write(out)) diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py b/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py new file mode 100644 index 00000000..ad287499 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py @@ -0,0 +1,69 @@ +"""Standard MLIR -> LLVM-dialect lowering via the bindings PassManager. + +Runs the upstream *registered* lowering passes (convert-*-to-llvm, lower-affine, +reconcile-unrealized-casts, ...) in-process on the post-custom-pass IR, replacing +the tail of the mlir-opt pipeline. Most custom passes are now Python out-of-line +passes: dma-fine-grained -> dma_fine_grained, test-pytorchsim-to-vcix -> +lower_to_vcix, test-tile-operation-graph -> build_tog (gem5 path), memref-to-gemmini +-> lower_dma_to_gemmini (run inside this lowering), global-idx -> lower_vlane_idx. +Only test-loop-padding still runs in mlir-opt; once it migrates, mlir-opt drops out +entirely and the flow is fully in-process. + +Validated to produce byte-identical LLVM IR to running the same passes inside +mlir-opt. Note: only lower-vector-multi-reduction is func.func-scoped (the +bindings pass-pipeline parser does not auto-nest like the mlir-opt CLI, so it is +wrapped explicitly); order is preserved to match the original pipeline. +""" + +STANDARD_PIPELINE = ( + "builtin.module(" + "convert-linalg-to-loops," + "convert-vector-to-scf{full-unroll=true}," + "lower-affine," + "expand-strided-metadata," # decompose memref.collapse_shape/subview before LLVM + "finalize-memref-to-llvm," + "func.func(lower-vector-multi-reduction)," + "convert-vector-to-llvm," + "convert-arith-to-llvm," + "convert-math-to-llvm," + "convert-scf-to-cf," + "convert-cf-to-llvm," + "convert-func-to-llvm," + "convert-index-to-llvm," + "reconcile-unrealized-casts)" +) + + +def run_standard_lowering(in_path, out_path=None, timing=False): + """Lower the post-custom-pass MLIR at `in_path` to the LLVM dialect. + + Runs the imperative Gemmini lowering (memref.dma_start/dma_wait) then the + registered standard MLIR->LLVM passes. `timing` selects the Gemmini behavior: + False for the functional/Spike path (emit gemmini asm), True for the gem5 + cycle path (erase dma_start; the TOG already carries DMA timing) -- this + preserves the old test-memref-to-gemmini `timing=1` semantics. + + Writes the result to `out_path` (defaults to `in_path`, i.e. in place). + Requires the MLIR Python bindings on PYTHONPATH. + """ + if out_path is None: + out_path = in_path + from mlir.ir import Context, Module, Location + from mlir.passmanager import PassManager + from . import lower_dma_to_gemmini + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + with open(in_path) as f: + module = Module.parse(f.read()) + # Imperative Python pass: memref.dma_start/dma_wait -> Gemmini asm (replaces + # the C++ test-memref-to-gemmini), then the registered standard lowering. + lower_dma_to_gemmini.run(module, timing=timing) + PassManager.parse(STANDARD_PIPELINE, ctx).run(module.operation) + with open(out_path, "w") as f: + f.write(str(module)) + + +if __name__ == "__main__": + import sys + run_standard_lowering(sys.argv[1], sys.argv[2] if len(sys.argv) > 2 else None) diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py new file mode 100644 index 00000000..ac93ebc8 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py @@ -0,0 +1,661 @@ +"""Python port of the C++ `-test-pytorchsim-to-vcix` conversion pass +(TestPyTorchSimToVCIXConversion.cpp). + +Lowers `linalg.matmul` and the transcendental math ops (exp/erf/tanh/sin/cos) to +VCIX dialect ops (RISC-V vector custom instructions). The C++ pass is a +dialect-conversion (`applyPartialConversion`); the MLIR Python bindings expose no +conversion framework, so each matchAndRewrite is reimplemented as imperative IR +rewriting (walk + build replacement + replace uses + erase). + +The VCIX dialect is NOT registered in the Python bindings, so vcix ops are created +as unregistered generic ops. This round-trips: mlir-opt / mlir-translate (which do +have vcix registered) re-parse the `{}`-attr generic form fine, and the existing +`run_standard_lowering` already consumes the C++ vcix output via +`allow_unregistered_dialects` -- so emitting generic vcix ops here is consistent +with the current pipeline. + +Covers all 6 C++ patterns: linalg.matmul (gemm + conv2d) and exp/erf/tanh/sin/cos. +Wired into extension_codecache (run_to_vcix) after fine-grained, before the standard +lowering; mlir-opt then runs only -test-loop-padding. Validated structurally against +`mlir-opt -test-pytorchsim-to-vcix` (non-constant ops byte-identical incl. dma_wait tag +maps) and numerically end-to-end (gemm/bmm/conv2d/transcendental, Spike+gem5 allclose). +""" +import os +import sys + +_DEFAULT_BINDINGS = "/riscv-llvm/python_packages/mlir_core" +if os.path.isdir(_DEFAULT_BINDINGS) and _DEFAULT_BINDINGS not in sys.path: + sys.path.insert(0, _DEFAULT_BINDINGS) + +import mlir.ir as ir # noqa: E402 + +MARKERS = ("linalg.matmul", "math.exp", "math.erf", "math.tanh", "math.sin", "math.cos") + +# math op name -> (opcode, imm) for the vcix.v.iv lowering (mirror Math*ToVCIX). +_MATH_VIV = { + "math.exp": (0b000011, 0), + "math.erf": (0b000000, 0), + "math.tanh": (0b000001, 0), + "math.sin": (0b000010, 0), + "math.cos": (0b000010, 1), +} + + +def _sew(elt_ty): + # Mirror C++ legalizeVectorType: only F32/F64/integer/index get a sew. F16/BF16 + # return 0 so transcendental math ops stay unlowered (-convert-math-to-llvm), + # matching the validated path -- do NOT emit VCIX for them here. + if ir.F32Type.isinstance(elt_ty): + return 32 + if ir.F64Type.isinstance(elt_ty): + return 64 + if ir.IntegerType.isinstance(elt_ty): + return ir.IntegerType(elt_ty).width + if ir.IndexType.isinstance(elt_ty): + return 64 + return 0 + + +def _log2(x): + return x.bit_length() - 1 + + +def _legalize_vector_type(vt, vlen): + """Mirror legalizeVectorType: return (n, legal_vector_type) or (0, None).""" + if len(vt.shape) != 1: # C++ guards getRank() != 1 + return 0, None + elt_ty = vt.element_type + sew = _sew(elt_ty) + if sew == 0: + return 0, None + elt_count = vt.shape[0] + lmul = elt_count * sew // 64 + scalable = vt.scalable + if not scalable: + n = (_log2(lmul) - 2) if lmul > 32 else 1 + if n == 1: + return 1, vt + return n, ir.VectorType.get([vlen // (sew // 8)], elt_ty) + n = (_log2(lmul) - 2) if lmul > 8 else 1 + return n, ir.VectorType.get([elt_count >> (n - 1)], elt_ty, scalable=[True]) + + +def _i64(v): + return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), v) + + +def _i32(v): + return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), v) + + +def _viv(operand, result_ty, opcode, imm, rvl=None): + """Create an unregistered vcix.v.iv (vcix::BinaryImmOp) op at the current IP.""" + operands = [operand] if rvl is None else [operand, rvl] + return ir.Operation.create( + "vcix.v.iv", results=[result_ty], operands=operands, + attributes={"opcode": _i64(opcode), "imm": _i32(imm)}).results[0] + + +def _make_sf_vc_v_iv(vec, op_vt, n, legal_ty, opcode, imm): + """Mirror make_sf_vc_v_iv: chunk `vec` (type op_vt) into legal-width vcix.v.iv.""" + from mlir.dialects import arith, vector + total = op_vt.shape[0] + elt_count = legal_ty.shape[0] + scalable = legal_ty.scalable + rvl = None + if scalable: + rvl = arith.ConstantOp(ir.IntegerType.get_signless(64), _i64(9)).result + if n == 1: + return _viv(vec, legal_ty, opcode, imm, rvl) + elt_ty = legal_ty.element_type + zero = ir.DenseElementsAttr.get_splat(op_vt, ir.FloatAttr.get(elt_ty, 0.0)) + res = arith.ConstantOp(op_vt, zero).result + if scalable: + for i in range(n): + ext = vector.ScalableExtractOp(legal_ty, vec, i * elt_count).result + v = _viv(ext, legal_ty, opcode, imm, rvl) + res = vector.ScalableInsertOp(v, res, i * elt_count).result + else: + for i in range(total // elt_count): + ext = vector.ExtractStridedSliceOp( + legal_ty, vec, + ir.ArrayAttr.get([_i64(i * elt_count)]), + ir.ArrayAttr.get([_i64(elt_count)]), + ir.ArrayAttr.get([_i64(1)])).result + v = _viv(ext, legal_ty, opcode, imm, rvl) + res = vector.InsertStridedSliceOp( + v, res, ir.ArrayAttr.get([_i64(i * elt_count)]), + ir.ArrayAttr.get([_i64(1)])).result + return res + + +def _iter_ops(block): + for op in list(block.operations): + yield op + for region in op.operation.regions: + for b in region.blocks: + yield from _iter_ops(b) + + +# --------------------------------------------------------------------------- +# matmul lowering helpers (mirror MatmulOpLowering) +# --------------------------------------------------------------------------- +def _elt_bits(elt_ty): + if ir.IntegerType.isinstance(elt_ty): + return ir.IntegerType(elt_ty).width + return ir.FloatType(elt_ty).width + + +def _bool_attr_true(op, key): + a = op.attributes + return key in a and ir.BoolAttr(a[key]).value + + +def _enclosing_loops(op): + """Walk ancestor ops; return (accumulation, outer, inner) affine.for lists, + outermost-first (mirror the C++ insert-at-begin).""" + acc, outer, inner = [], [], [] + parent = op.operation.parent + while parent is not None: + if parent.name == "affine.for": + if _bool_attr_true(parent, "accumulation_loop"): + acc.insert(0, parent) + if _bool_attr_true(parent, "outer_loop"): + outer.insert(0, parent) + if _bool_attr_true(parent, "inner_loop"): + inner.insert(0, parent) + parent = parent.parent + return acc, outer, inner + + +def _loop_iv(forop): + return forop.regions[0].blocks[0].arguments[0] + + +def _loop_ub(forop): + # single constant upper bound + m = ir.AffineMapAttr(forop.attributes["upperBoundMap"]).value + return ir.AffineConstantExpr(m.results[0]).value + + +def _block_terminator(forop): + blk = forop.regions[0].blocks[0] + ops = list(blk.operations) + return ops[-1] + + +def _affine_consts(expr): + """All AffineConstantExpr values reachable in `expr` (recursive).""" + out = [] + if ir.AffineConstantExpr.isinstance(expr): + out.append(ir.AffineConstantExpr(expr).value) + elif ir.AffineBinaryExpr.isinstance(expr): + be = ir.AffineBinaryExpr(expr) + out += _affine_consts(be.lhs) + out += _affine_consts(be.rhs) + return out + + +def _scan_conv_offsets(ow_loop, o_h, k_h, o_w, k_w): + """Mirror the heuristic offset scan: find affine.apply(o_h,k_h)/(o_w,k_w) in the + o_w loop and read the constant in its map (default 1).""" + offset_h = offset_w = 1 + for o in _iter_ops(ow_loop.regions[0].blocks[0]): + if o.operation.name != "affine.apply": + continue + ops = list(o.operation.operands) + if len(ops) < 2: + continue + m = ir.AffineMapAttr(o.operation.attributes["map"]).value + consts = _affine_consts(m.results[0]) + if ops[0] == o_h and ops[1] == k_h and consts: + offset_h = consts[-1] + if ops[0] == o_w and ops[1] == k_w and consts: + offset_w = consts[-1] + return offset_h, offset_w + + +def _mem_space(v): + mt = ir.MemRefType(v.type) + ms = mt.memory_space + return ir.IntegerAttr(ms).value if ms is not None else 0 + + +def _dram_is_write(src, dst): + """(dram_memref, is_write) by memory space, mirror getDramMemRef.""" + ss, ds = _mem_space(src), _mem_space(dst) + if ds == 0 and ss == 1: + return dst, True + if ds == 1 and ss == 0: + return src, False + return None, False + + +def _idx(v): + return ir.IntegerAttr.get(ir.IndexType.get(), v) + + +def _const_index(v): + from mlir.dialects import arith + return arith.ConstantOp(ir.IndexType.get(), _idx(v)).result + + +def _apply(map_, operands): + from mlir.dialects import affine + return affine.AffineApplyOp(map_, list(operands)).result + + +def _spad_maps(): + d0, d1, d2 = (ir.AffineDimExpr.get(i) for i in range(3)) + s0, s1 = (ir.AffineSymbolExpr.get(i) for i in range(2)) + spad = ir.AffineMap.get(3, 2, [d0 * s0 + d1 * s1 + d2]) + x = ir.AffineMap.get(1, 1, [ir.AffineExpr.get_floor_div(ir.AffineDimExpr.get(0), + ir.AffineSymbolExpr.get(0))]) + y = ir.AffineMap.get(1, 1, [ir.AffineExpr.get_mod(ir.AffineDimExpr.get(0), + ir.AffineSymbolExpr.get(0))]) + return spad, x, y + + +def _transfer_read(vec_ty, source, indices, padding): + from mlir.dialects import vector + src_rank = len(ir.MemRefType(source.type).shape) + vec_rank = len(ir.VectorType(vec_ty).shape) + perm = ir.AffineMap.get_minor_identity(src_rank, vec_rank) + return vector.TransferReadOp(vec_ty, source, list(indices), perm, padding, + [False] * vec_rank).result + + +def _transfer_write(value, dest, indices): + from mlir.dialects import vector + dst_rank = len(ir.MemRefType(dest.type).shape) + vec_rank = len(ir.VectorType(value.type).shape) + perm = ir.AffineMap.get_minor_identity(dst_rank, vec_rank) + vector.TransferWriteOp(None, value, dest, list(indices), perm, [False] * vec_rank) + + +def _dma_wait(tag, idx, num_elements): + from mlir.dialects import memref + memref.DmaWaitOp(tag, [idx], num_elements) + + +def _vcix(name, operands, result_tys, attrs): + return ir.Operation.create(name, results=result_tys, operands=list(operands), + attributes=attrs) + + +def _reaches(value, target): + if value == target: + return True + owner = value.owner + if isinstance(owner, ir.Block): + return False + for operand in owner.operands: + if _reaches(operand, target): + return True + return False + + +class _DmaView: + """Positional view of a customized memref.dma_start (see lower_dma_to_gemmini).""" + + def __init__(self, op): + self.op = op + operands = list(op.operands) + src_rank = len(ir.MemRefType(operands[0].type).shape) + i = 0 + self.src = operands[i]; i += 1 + i += src_rank + self.dst = operands[i]; i += 1 + dst_rank = len(ir.MemRefType(self.dst.type).shape) + i += dst_rank + i += 1 # num_elements + self.tag = operands[i]; i += 1 + tag_rank = len(ir.MemRefType(self.tag.type).shape) + self.tag_idx = operands[i:i + tag_rank] + + def subtile_size(self): + a = self.op.attributes + if "subtile_size" not in a: + return [] + return [ir.IntegerAttr(x).value for x in ir.ArrayAttr(a["subtile_size"])] + + def is_async(self): + a = self.op.attributes + if "async" not in a: + return False + try: + return bool(ir.IntegerAttr(a["async"]).value) + except Exception: + return True + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +def _lower_matmul(op, SS, vlen): + """Lower one linalg.matmul (gemm path) to the vcix push/compute/pop sequence. + Returns True if lowered, False if skipped (conv2d / unexpected nesting -> left + for the C++ pass / a later port). Mirrors MatmulOpLowering (gemm branch).""" + from mlir.dialects import arith + + A, B, C = op.operands[0], op.operands[1], op.operands[2] + mtA, mtB = ir.MemRefType(A.type), ir.MemRefType(B.type) + elt = mtA.element_type + M, K, N = mtA.shape[0], mtA.shape[1], mtB.shape[1] + # Mirror the C++ guard: a dimension > SS must be an exact multiple, else the + # N//SS / K//SS loop trip counts below silently drop the tail tile. + for _dim, _name in ((M, "M"), (N, "N"), (K, "K")): + if _dim > SS and _dim % SS != 0: + raise NotImplementedError( + f"matmul {_name}={_dim} must be a multiple of systolic size {SS} when > {SS}") + elen = _elt_bits(elt) + nr_element = vlen // elen + i64 = ir.IntegerType.get_signless(64) + def a64(v): return ir.IntegerAttr.get(i64, v) + + acc, outer, inner = _enclosing_loops(op) + is_conv2d = len(inner) == 4 + if not acc or len(outer) < 2: + return False + tile_kw = tile_oh = tile_ow = None + if is_conv2d: # inner = [k_h, k_w, o_h, o_w] + tile_kw, tile_oh, tile_ow = inner[1], inner[2], inner[3] + + vectorType = ir.VectorType.get([nr_element], elt) + nr_m = max(min(M, nr_element), 2) + vectorMType = ir.VectorType.get([nr_m], elt) + spad_map, spadX, spadY = _spad_maps() + + idxMap = [0, 1, 2] + if "idx_map" in op.attributes: + idxMap = list(ir.DenseI32ArrayAttr(op.attributes["idx_map"])) + + # Scan the outermost loop for the A/B/Bias load DMAs (tags + subtile). + ATag = BTag = BiasTag = None + AAsync = BAsync = BiasAsync = 0 + BiasIdx = None + subtileM, subtileN, subtileK = M, N, K + a_subk = b_subk = None + # Mirror the C++ isAInitialized / isBInitialized flags: an operand is + # "initialized" either by an MVIN dma_start (tag found below) or by a + # preceding affine.vector_store into its root memref (the fused case, e.g. + # SDPA scores.V where B is the softmax output produced in-place, not DMAed). + isAInit = isBInit = False + + def _root(v): + owner = v.owner + if not isinstance(owner, ir.Block): + nm = owner.name + if nm in ("memref.reinterpret_cast", "memref.cast"): + return owner.operands[0] + return v + rootA, rootB = _root(A), _root(B) + for o in _iter_ops(outer[-1].regions[0].blocks[0]): + if o.operation.name == "affine.vector_store": + dest = _root(o.operation.operands[1]) + if dest == rootA: + isAInit = True + elif dest == rootB: + isBInit = True + continue + if o.operation.name != "memref.dma_start": + continue + d = _DmaView(o.operation) + dram, is_write = _dram_is_write(d.src, d.dst) + if dram is None or is_write: + continue + sram = d.dst # MVIN: dst is the spad + if not any(_reaches(opnd, sram) for opnd in op.operands): + continue + if not isinstance(dram.owner, ir.Block): # must be a block argument + continue + argn = ir.BlockArgument(dram).arg_number + sub = d.subtile_size() + if argn == idxMap[0]: + ATag, AAsync = d.tag, d.is_async() + isAInit = True + if len(sub) >= 2: + subtileM, subtileK = sub[-2], sub[-1] + a_subk = sub[-1] + elif argn == idxMap[1]: + BTag, BAsync = d.tag, d.is_async() + isBInit = True + if len(sub) >= 2: + subtileK, subtileN = sub[-2], sub[-1] + b_subk = sub[-2] + elif argn == idxMap[2]: + BiasTag, BiasAsync = d.tag, d.is_async() + BiasIdx = d.tag_idx + if not isAInit or not isBInit: + return False + # A and B must agree on the K subtile (last-writer-wins would otherwise pick one silently). + if a_subk is not None and b_subk is not None and a_subk != b_subk: + raise NotImplementedError( + f"Mismatched subtile K between A ({a_subk}) and B ({b_subk}) matmul operands") + + KStep = subtileK + push_length = min(subtileM, SS) + MStep = min(M, push_length) + NStep = min(subtileN, SS) + M_LOOP = min(M, push_length) + + # conv2d builds inside the existing k_w loop; gemm builds at the matmul site. + ip = (ir.InsertionPoint.at_block_terminator(tile_kw.regions[0].blocks[0]) + if is_conv2d else ir.InsertionPoint(op)) + with ip: + c0 = _const_index(0) + rvl = arith.ConstantOp(i64, a64(nr_element)).result + K_val, N_val, M_val = _const_index(K), _const_index(N), _const_index(M) + push_val = _const_index(push_length) + num1 = _const_index(1) + zero_pad = arith.ConstantOp(elt, ir.FloatAttr.get(elt, 0.0)).result + + # --- inner N / K loops --- + from mlir.dialects import affine + body_ip = ip + n_idx = c0 + k_idx = c0 + nk_inner = None # innermost n/k loop created (conv2d hosts o_h/o_w here) + if N > SS: + with body_ip: + nl = affine.AffineForOp(0, N // SS, 1) + nl.operation.attributes["inner_loop"] = ir.BoolAttr.get(True) + n_idx = nl.induction_variable + with ir.InsertionPoint(nl.body): + affine.AffineYieldOp([]) + body_ip = ir.InsertionPoint.at_block_terminator(nl.body) + nk_inner = nl + zero_vector = None + if K > SS: + with body_ip: + kl = affine.AffineForOp(0, K // SS, 1) + kl.operation.attributes["inner_loop"] = ir.BoolAttr.get(True) + k_idx = kl.induction_variable + with ir.InsertionPoint(kl.body): + affine.AffineYieldOp([]) + body_ip = ir.InsertionPoint.at_block_terminator(kl.body) + nk_inner = kl + else: + with body_ip: + zv = ir.DenseElementsAttr.get_splat(vectorType, ir.FloatAttr.get(elt, 0.0)) + zero_vector = arith.ConstantOp(vectorType, zv).result + + n_tag = c0 if N == subtileN else n_idx + k_tag = c0 if K == subtileK else k_idx + + with body_ip: + # --- B dma_wait --- + nacc = len(acc) + acc_ivs = [_loop_iv(l) for l in acc] + bexpr = ir.AffineDimExpr.get(0) * -1 + for i in range(1, nacc): + bexpr = bexpr + ir.AffineDimExpr.get(i) * -1 + b_extra = [] + bdo = nacc + if is_conv2d: + kW = _loop_ub(tile_kw) + bdo = nacc + 2 + bexpr = (bexpr + + ir.AffineDimExpr.get(bdo - 2) * ((N // subtileN) * (K // subtileK) * kW) + + ir.AffineDimExpr.get(bdo - 1) * ((N // subtileN) * (K // subtileK))) + b_extra = [_loop_iv(inner[0]), _loop_iv(inner[1])] # k_h, k_w + bexpr = (bexpr + + ir.AffineExpr.get_floor_div(ir.AffineDimExpr.get(bdo), _ceil_div(NStep, SS)) * (K // KStep) + + ir.AffineExpr.get_floor_div(ir.AffineDimExpr.get(bdo + 1), _ceil_div(KStep, SS)) * 1) + bmap = ir.AffineMap.get(bdo + 2, 0, [bexpr]) + btag_idx = _apply(bmap, acc_ivs + b_extra + [n_tag, k_tag]) + if BAsync: + _dma_wait(BTag, btag_idx, num1) + + # --- weight push loop (K x N) --- + for i in range(0, SS, nr_element): + if i < K: + sp = _apply(spad_map, [n_idx, k_idx, _const_index(i), K_val, _const_index(SS)]) + wx = _apply(spadX, [sp, N_val]) + wy = _apply(spadY, [sp, N_val]) + wv = _transfer_read(vectorType, B, [wx, wy], zero_pad) + else: + wv = zero_vector + _vcix("vcix.iv", [wv, rvl], [], + {"opcode": a64(1), "imm": a64(0), "rd": a64(0)}) + + # conv2d: move the o_h/o_w spatial loops after the weight push and continue the + # input-push/compute/pop inside the o_w loop (heuristic, mirrors the C++ branch + # for the no-extra-inner-loop case). + if is_conv2d: + # host the o_h/o_w spatial loops inside the innermost n/k loop (so n_idx/k_idx + # stay in scope) or directly in the k_w loop when no n/k loop was created. + host = nk_inner if nk_inner is not None else tile_kw + tile_oh.operation.move_before(_block_terminator(host)) + body_ip = ir.InsertionPoint.at_block_terminator(tile_ow.regions[0].blocks[0]) + + # --- M loop --- + m_idx = c0 + if M > push_length: + with body_ip: + ml = affine.AffineForOp(0, M // push_length, 1) + ml.operation.attributes["inner_loop"] = ir.BoolAttr.get(True) + m_idx = ml.induction_variable + with ir.InsertionPoint(ml.body): + affine.AffineYieldOp([]) + body_ip = ir.InsertionPoint.at_block_terminator(ml.body) + m_tag = c0 if M == subtileM else m_idx + + with body_ip: + # --- A dma_wait --- + aexpr = ir.AffineDimExpr.get(0) * -1 + for i in range(1, nacc): + aexpr = aexpr + ir.AffineDimExpr.get(i) * -1 + a_extra = [] + ado = nacc + if is_conv2d: + k_h, k_w, o_h, o_w = (_loop_iv(inner[j]) for j in range(4)) + kW, oW = _loop_ub(tile_kw), _loop_ub(tile_ow) + offset_h, offset_w = _scan_conv_offsets(tile_ow, o_h, k_h, o_w, k_w) + coeff_h = 1 + (oW - 1) * offset_w + (kW - 1) + ado = nacc + 2 + aexpr = (aexpr + + ir.AffineDimExpr.get(ado - 2) * ((K // subtileK) * (M // subtileM) * offset_h * coeff_h) + + ir.AffineDimExpr.get(ado - 1) * ((K // subtileK) * (M // subtileM) * offset_w)) + a_extra = [o_h, o_w] + aexpr = (aexpr + + ir.AffineDimExpr.get(ado) * (M // MStep) + + ir.AffineExpr.get_floor_div(ir.AffineDimExpr.get(ado + 1), _ceil_div(MStep, SS))) + amap = ir.AffineMap.get(ado + 2, 0, [aexpr]) + atag_idx = _apply(amap, acc_ivs + a_extra + [k_tag, m_tag]) + if AAsync: + _dma_wait(ATag, atag_idx, num1) + + # --- Bias dma_wait --- + if BiasTag is not None: + bias_is_const = BiasIdx and BiasIdx[0].owner.name == "arith.constant" + first_i = c0 if bias_is_const else n_tag + third_i = c0 if bias_is_const else m_tag + d0, d1 = ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1) + bias_expr = (ir.AffineExpr.get_floor_div(d0, _ceil_div(NStep, SS)) * (M // MStep) + + ir.AffineExpr.get_floor_div(d1, _ceil_div(MStep, SS))) + bias_map = ir.AffineMap.get(2, 0, [bias_expr]) + bias_tag_idx = _apply(bias_map, [first_i, third_i]) + if BiasAsync: + _dma_wait(BiasTag, bias_tag_idx, num1) + + # --- input push loop (M x K) --- + for i in range(0, M_LOOP, nr_element): + sp = _apply(spad_map, [k_idx, m_idx, _const_index(i), M_val, push_val]) + x = _apply(spadX, [sp, K_val]) + y = _apply(spadY, [sp, K_val]) + iv = _transfer_read(vectorMType, A, [x, y], zero_pad) + _vcix("vcix.iv", [iv, rvl], [], + {"opcode": a64(0), "imm": a64(0), "rd": a64(0)}) + + # --- compute --- + _vcix("vcix.i", [rvl], [], + {"opcode": a64(1), "imm": a64(4), "rd": a64(0), "rs2": a64(0), + "sew": a64(elen), "lmul": a64(0)}) + + # --- pop loop (M x N) --- + for i in range(0, M_LOOP, nr_element): + sp = _apply(spad_map, [n_idx, m_idx, _const_index(i), M_val, push_val]) + vpop = _vcix("vcix.v.i", [rvl], [vectorMType], + {"opcode": a64(2), "imm": a64(0), "rs2": a64(0)}).results[0] + x = _apply(spadX, [sp, N_val]) + y = _apply(spadY, [sp, N_val]) + prev = _transfer_read(vectorMType, C, [x, y], zero_pad) + if ir.IntegerType.isinstance(elt): + out = arith.AddIOp(prev, vpop).result + else: + out = arith.AddFOp(prev, vpop).result + _transfer_write(out, C, [x, y]) + op.erase() + return True + + +def run(module, vectorlane=128, vlen=128, **_): + """Lower linalg.matmul (gemm) + transcendental math ops to VCIX ops, in place.""" + # matmul first (uses surrounding loop structure before any rewrites) + mms = [] + for region in module.operation.regions: + for b in region.blocks: + for o in _iter_ops(b): + if o.operation.name == "linalg.matmul": + mms.append(o.operation) + for o in mms: + _lower_matmul(o, vectorlane, vlen) + targets = [] + for region in module.operation.regions: + for b in region.blocks: + for op in _iter_ops(b): + if op.operation.name in _MATH_VIV: + targets.append(op.operation) + for op in targets: + opcode, imm = _MATH_VIV[op.name] + vec = op.operands[0] + res_ty = op.results[0].type + vt = ir.VectorType(res_ty) + n, legal_ty = _legalize_vector_type(vt, vlen) + if legal_ty is None: + continue + with ir.InsertionPoint(op): + new = _make_sf_vc_v_iv(vec, vt, n, legal_ty, opcode, imm) + op.results[0].replace_all_uses_with(new) + op.erase() + + +def run_to_vcix(in_path, out_path, vectorlane=128, vlen=128): + with open(in_path) as f: + text = f.read() + ctx = ir.Context() + ctx.allow_unregistered_dialects = True + with ctx, ir.Location.unknown(): + module = ir.Module.parse(text) + run(module, vectorlane=vectorlane, vlen=vlen) + out = str(module) + with open(out_path, "w") as f: + f.write(out) + + +if __name__ == "__main__": + vl = int(sys.argv[3]) if len(sys.argv) > 3 else 128 + vlen_ = int(sys.argv[4]) if len(sys.argv) > 4 else 128 + run_to_vcix(sys.argv[1], sys.argv[2], vl, vlen_) diff --git a/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py b/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py new file mode 100644 index 00000000..76e30cb3 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py @@ -0,0 +1,92 @@ +"""Python out-of-line MLIR pass: lower torchsim.vlane_idx -> per-lane index * offset. + +Codegen emits a dedicated `torchsim.vlane_idx` op (generic form, unregistered +dialect) carrying a `vlane_offset` integer attribute. This pass rewrites each +such op to: + + %v = "vcix.v.i"(%K) {opcode = 0, rs2 = 0, imm = 0} : (i64) -> vector // per-lane index + %n = arith.constant dense : vector + %r = arith.muli %v, %n : vector + +and replaces uses of the original op with %r. Replaces the former C++ +`-global-idx` pass (which overloaded arith.addi with a vlane_offset attribute). + +Pass interface (see passes/__init__.py): MARKERS + run(module). Also runnable +standalone as a CLI: + python PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py in.mlir [out.mlir] + +Requires the MLIR Python bindings on PYTHONPATH +(/riscv-llvm/python_packages/mlir_core). The `vcix` dialect must be registered +in the consuming mlir-opt for the result to round-trip (see +registerVCIXDialectTranslation in mlir-opt.cpp). +""" + +OP_NAME = "torchsim.vlane_idx" +MARKERS = (OP_NAME,) + + +def _iter_ops(block): + for op in list(block.operations): + yield op + for region in op.operation.regions: + for b in region.blocks: + yield from _iter_ops(b) + + +def run(module, **_): + """Lower every torchsim.vlane_idx op in `module`, in place. + + Must be called with the module's Context active (the orchestrator provides it). + """ + from mlir.ir import (InsertionPoint, Operation, IntegerType, IntegerAttr, + DenseElementsAttr, VectorType) + i64 = IntegerType.get_signless(64) + i32 = IntegerType.get_signless(32) + + targets = [] + for region in module.operation.regions: + for b in region.blocks: + for op in _iter_ops(b): + if op.operation.name == OP_NAME: + targets.append(op.operation) + + for op in targets: + res = op.results[0] + vt = VectorType(res.type) + k, et = vt.shape[0], vt.element_type + offset = IntegerAttr(op.attributes["vlane_offset"]).value + with InsertionPoint(op): + rvl = Operation.create("arith.constant", results=[i64], + attributes={"value": IntegerAttr.get(i64, k)}).results[0] + lane = Operation.create("vcix.v.i", results=[vt], operands=[rvl], + attributes={"opcode": IntegerAttr.get(i64, 0), + "rs2": IntegerAttr.get(i32, 0), + "imm": IntegerAttr.get(i32, 0)}).results[0] + ovec = Operation.create("arith.constant", results=[vt], + attributes={"value": DenseElementsAttr.get_splat( + vt, IntegerAttr.get(et, offset))}).results[0] + mul = Operation.create("arith.muli", results=[vt], operands=[lane, ovec]).results[0] + res.replace_all_uses_with(mul) + op.erase() + + +def lower_text(text: str) -> str: + """Parse `text`, run this pass, return the printed module. CLI/testing helper.""" + if OP_NAME not in text: + return text + from mlir.ir import Context, Module, Location + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + m = Module.parse(text) + run(m) + return str(m) + + +if __name__ == "__main__": + import sys + out = lower_text(open(sys.argv[1]).read()) + if len(sys.argv) > 2: + open(sys.argv[2], "w").write(out) + else: + sys.stdout.write(out) diff --git a/docs/axis-split-scheduling.md b/docs/axis-split-scheduling.md new file mode 100644 index 00000000..48e7db2f --- /dev/null +++ b/docs/axis-split-scheduling.md @@ -0,0 +1,217 @@ +# Aligned axis splitting at the Inductor scheduling layer + +Status: **prototype / proposed**. Companion to `dma-transfer-lowering.md`. This doc +covers the *upstream* half of the affine-only contract: removing aligned +`FloorDiv` / `ModularIndexing` from index expressions before they reach MLIR +codegen, by splitting loop axes at the Inductor scheduling layer. + +## Goal: the affine-only contract + +We want the MLIR codegen (`get_dma_info` in `mlir_codegen_backend.py`) to receive +only per-axis affine index expressions: + + off(i,j,k,...) = base + Sum_k stride_k * loop_var_k (stride_k constant int) + +with **zero** `FloorDiv` / `ModularIndexing`. If that invariant holds, codegen no +longer fights non-affine indices: the recompile dance (RecompileSignal, forced +tile sizes, max_retry_compile), the heuristic `TestLoopPadding` pass, and the +hard-fail-on-conflict path all become unnecessary. Codegen's only remaining job +is the *mechanical* rank<=4 peel for the Gemmini descriptor (orthogonal; see +`dma-transfer-lowering.md`), which operates on already-affine input. + +Two tools produce this invariant, matching the alignment theory: + +- **aligned floor/mod -> axis split** (this doc): loop transformation, free, no + data movement. +- **misaligned floor/mod -> graph copy insertion** (XLA-style): genuine data + movement; out of scope here. + +"Aligned" means the floor/mod argument is a *single* iteration variable `v` of +extent `E` and the divisor `k` (resp. `k*m` for ModularIndexing) divides `E`, so +splitting `v = outer*k + inner` lands the wrap point on a fixed axis boundary. + +## Where: the scheduling layer already rebuilds LoopBody + +`mlir_scheduling.py` already does loop-IR surgery at the scheduling layer: + +- `revert_group` (line ~219) rebuilds a `LoopBody` from `get_store_function()` + with a chosen `var_ranges` -- it undoes Inductor's `simplify_and_reorder`. +- `codegen_node` (line ~246) injects dummy size-1 loops when Inductor + over-simplified the group. + +Axis splitting is the same operation with a different `var_ranges`: split the +axes carrying aligned floor/mod, then rebuild. No new infrastructure -- reuse +`LoopBody`. This is "upstream" of MLIR codegen and native to Inductor's IR (sympy +ranges + index exprs), so we are not reverse-engineering MLIR text. + +## How: detect / rebuild / hook + +Implemented in `PyTorchSimFrontend/mlir/axis_split.py`, wired into +`codegen_node` behind `TORCHSIM_AXIS_SPLIT=1` (dump with +`TORCHSIM_DEBUG_AXIS_SPLIT=1`). + +1. **Detect -- `find_split_plan(nodes)`**: scan each node's + `_body.indexing_exprs` for `FloorDiv(v, k)` / `ModularIndexing(v, k, m)` where + `v` is a single iter var and the divisor divides `v`'s extent. Return + `{axis_index: divisor}`, keyed positionally so it applies to every fused node + sharing the iteration space. +2. **Rebuild -- `build_split_body(node, plan)`**: rebuild `node._body` / + `_sizes` with the split var_ranges; feed the store function the index + expression `outer*k + inner` at the split dim so the floor/mod collapses. +3. **Hook -- `codegen_node`**: apply the plan to every node + (`_sizes, _body, group = ...`), then recompute the group. + +## Empirical validation (group norm) + +`group_norm(x[2,6,4,4], num_groups=3)` normalize kernel, before vs after split: + + before var_ranges={p0:2, p1:6, p2:16} + idx0 = 96*p0 + 16*p1 + p2 # x input, affine + idx1 = 3*p0 + (p1//2) # mean/rstd <- FloorDiv(p1,2), 2|6 aligned + idx2 = p1 # weight/bias, affine + + after plan={1: 2}, var_ranges={s0:2, s1:3, s2:2, ...} + idx1 = 3*s0 + (s1//1) # FloorDiv collapsed to identity -> s1 + ... # mean now affine; s2/spatial broadcast (stride 0) + +The FloorDiv is eliminated. group `(2,6,16) -> (2,3,2,...)`. + +## Coverage (what this framework can and cannot do) + +| Case | Example | Status | +|---|---|---| +| aligned FloorDiv, single var | group norm `c//2` (2\|6) | DONE (prototype) | +| aligned ModularIndexing | `(v//k)%m`, k*m\|E | needs mixed-radix multi-split | +| multiple radices on one axis | `//2` + `%3`, E=6 | needs nested split (now: first divisor only) | +| reduction-axis floor/mod | `r//k` inside reduce | needs reduction-var splitting | +| divisor does not divide extent | C=8 groups of 3; uneven cat | IMPOSSIBLE by split -> graph copy | +| multi-axis argument | `(4p+q)//6` non-factor reshape | IMPOSSIBLE by split -> graph copy | +| dynamic / symbolic | `v//s`, symbolic extent | separate symbolic/guard path | + +The aligned class is the framework's domain (currently only single-split +FloorDiv); the misaligned class is structurally a graph-copy problem. + +## Resolved + +- **5D blow-up (fixed).** `build_split_body` now reindexes the already-collapsed + `node._body` via `LoopBody`'s copy path (pass the body as `fn` -> + `_init_with_copy`), instead of re-tracing the raw store function over + `inode.data.get_size()`. This keeps merged dims merged (spatial stays `16`), + so group_norm goes `(2,6,16) -> (2,3,2,16)` (4D, no cap hit), and + `_init_with_copy`'s `simplify_with_ranges` folds the split floor. +- **`floor//1` residue (fixed).** The fold only happened once the new symbols + carried integer/non-negative assumptions: build them with + `torch._inductor.utils.sympy_index_symbol` (not bare `sympy.Symbol`), which is + also why the index prefix must not be `s` (reserved for shape symbols). With + this, `idx1 = 3*p0 + (p1//2)` becomes `3*z0 + z1` -- the channel FloorDiv is + gone, not left as `z1//1`. +- **Symbol conventions.** Index dims use the `z` prefix; reduction dims use the + `r` prefix and are kept after the index dims so the reduction axis stays + innermost (`var_ranges` is ordered iter-then-reduce; `LoopBody.sizes` splits on + `len(iter_vars)`). LoopBody var names are remapped to `index` during MLIR + codegen, so the prefix is internal -- but it must not collide with the original + body's names (those are `p`/`q`, so `z`/`r` are safe). + +## Resolved (cont.) + +- **`floor//1` / residual floor on multi-level split (fixed).** `simplify_with_ranges` + cannot prove a *multi-term* numerator is below the divisor (e.g. + `FloorDiv(z1 + 4*z2, 12)` with `z1<4, z2<3`), so a 3-level mixed-radix split left + a residual floor that codegen rejected ("Not supporting this view operation"). + `_fold_with_ranges` now proves it directly from the split sub-var ranges via + `bound_sympy`: `FloorDiv(num,d)->0` when `0<=numnum//k` + when `0<=num 5D). This now lowers through the decompose-transfer + >4D peel (affine.for nest, one <=4D dma_start per iteration), so the earlier + `find_split_plan` rank guard was removed (a6b7ebb9): plans are no longer dropped to + baseline for exceeding rank 4. pixel_shuffle splits to >4D and passes end-to-end + (Gem5+Spike+TOGSim); 3D group_norm still splits (rank 4). + +## Known issues / open + +- None open here. The decompose-transfer peel <-> TOG incompatibility (>4D peel + unreadable by TOG) was resolved by rewriting the peel as an affine.for nest -- see + Done below. + +## Done + +- **>4D peel via affine.for (fixed)** -- a6b7ebb9. The earlier peel emitted + `memref.subview` + unrolled constant-offset `dma_start` that the TOG pass could not + read (empty `loop_idx_list`) and that aliased one spad slot + (extract_aligned_pointer_as_index strips the subview offset -> pixel_shuffle + MISMATCH). Rewritten to wrap the outer dims in an `affine.for` nest (marked + inner_loop so build_tog registers the induction var), with the lane-banked physical + SRAM offset carried as the last SRAM index operand and the DRAM offset folded into + one affine.apply (#258). The axis-split rank guard was removed; pixel_shuffle passes + end-to-end. +- **Mixed-radix (ModularIndexing + multi-radix)**: `find_split_plan` returns a + per-axis divisibility-chain of boundaries; `build_split_body` splits into one + sub-var per segment (`v = sum_i d_i*b_i`). Validated allclose=True on group_norm + (FloorDiv, `[1,2,6]`) and `x.repeat(1,2)` (single-axis ModularIndexing, + `[1,8,16]`); pixel_shuffle (floor+mod on two axes) linearizes correctly. +- **Reduction pass-through**: reduction dims keep the `r` prefix and stay innermost + (after the index dims). Exercised via the `TORCHSIM_AXIS_SPLIT_FORCE` validation + gate (force-split a reduction kernel's index axis even without floor -- an + identity transform, so allclose must hold): layernorm `(512)->(256,2)` and + reduce `(68)->(34,2)` keep their reduction groups and pass. +- **Graph-copy for incompatible radices (case 5)** -- `graph_copy.py`, + `TORCHSIM_GRAPH_COPY`. When two operands of an elementwise consumer carry + incompatible-radix groupings on a shared axis (e.g. `a[c//2] + b[c%3]`, floor-by-2 + vs mod-by-3 on extent 6 -- not a divisibility chain), neither axis-split nor the + recompile-dance can express it. We wrap the registered lowering entries (the + make_pointwise results = every elementwise consumer, one place), trace each + operand's loader with `extract_read_writes` to get its read indices, run the same + `collect_boundaries` analysis, and if the union is not a chain, `realize()` the + cheaper operand. realize() (not clone -- Inductor inlines clone, confirmed) forces + a buffer: the consumer then reads it affine and the remaining single grouping is + handled by axis-split. Validated: `incompat` (`a.repeat_interleave(2)+b.repeat(2)`) + goes ERR -> allclose=True with `GRAPH_COPY+AXIS_SPLIT` (still ERR on default, + confirming graph-copy is the fix); no regression on the pattern battery, + test_add, resnet (compile overhead negligible). +- **Graph-copy for cross-axis floor/mod (case 7)** -- same hook. A transpose+reshape + feeding a consumer that keeps the output dims separate (broadcast / softmax / + layernorm / reduce-one-dim) produces a floor/mod whose argument spans *two* loop + vars, e.g. `(3*p0+p1)//4`; axis-split cannot split a multi-var argument. We detect + an operand whose read index has a floor/mod argument with >1 free symbol and + replace it with `ExternKernel.copy_input` (a realized identity Pointwise). This is + why copy_input and not `realize()`: `StorageBox.realize()` is a no-op on a + ReinterpretView (a reshape), so it does not materialize view operands; copy_input + forces the copy. The copy kernel iterates the operand's own contiguous shape, so + its index collapses to single-var for axis-split, and the consumer reads the copy + affine. Also covers single-operand consumers (a reduction reading a multi-var + view). Validated allclose=True: reshape+broadcast, softmax(reshape), + layernorm(reshape) (all ERR on default). NOTE the empirical correction: case 7 is + NOT rare -- it is the common attention/norm "reshape then reduce/broadcast" + shape; Inductor only avoids it when it can collapse the output to 1D (then the + floor is single-var). + +## Default-on + recompile-dance status + +axis-split and graph-copy are **ON by default** (disable with `TORCHSIM_AXIS_SPLIT=0` +/ `TORCHSIM_GRAPH_COPY=0`). With them on, the codegen recompile-dance (tile-forcing +for floor/mod divisibility) is demoted from primary mechanism to a rarely-hit +fallback. + +Measured under default-on (`TORCHSIM_RECOMPILE_LOG=1`), 33 tests, all pass: +- 16 core (elementwise/gemm/reduce/conv/view/fusion + mlp/resnet/transformer/vit): 0 recompiles. +- 7 broader families (cnn/pool/group_conv/sort/indirect_access/exponent/conv_fusion): 0 recompiles. +- 10 floor/mod patterns: 1 recompile total (an unrelated tile-divisibility in the + 3-level mixed-radix case). + +**Full retirement of the dance is deferred** (it is still a real dependency, not +just a safety net): removing the floor/mod recompile branches would break the +3-level mixed-radix case (1 recompile) and any case axis-split/graph-copy do not +yet cover (case 6: non-dividing divisor / uneven cat; reduction-axis floor/mod). +attention/sdpa families were not run here (too slow locally) and need CI validation +before retirement. + +## Next steps + +1. Eliminate the last recompile dependency (the 3-level mixed-radix sub-kernel) so + the dance reaches 0/all -> then retire the floor/mod recompile branches (keep the + non-floor/mod ones: non-power-of-2 vec size, indirect). +2. Graph-copy coverage: case 6 (non-dividing divisor / uneven cat -> pad or gather), + and conflicts internal to templates (gemm/conv/sdpa). +3. Reduction-axis floor/mod (`r//k` inside a reduce): needs reduction-var splitting. +4. Dynamic shapes -> symbolic divisibility / guards. diff --git a/docs/dma-transfer-lowering.md b/docs/dma-transfer-lowering.md new file mode 100644 index 00000000..1ab9be45 --- /dev/null +++ b/docs/dma-transfer-lowering.md @@ -0,0 +1,484 @@ +# DMA transfer op + decomposition lowering + +Status: **design / proposed**. Captures the plan to fix the recompile-dance +fragility by representing DMA as a high-level declarative transfer op and +decomposing it into affine descriptors in a lowering pass. + +Companion docs: `linalg-codegen-migration.md` (Plan B, the full structured-ops +rewrite this is a narrow tactical slice of). The near-term graph-level padding +work is referred to here as Plan A. + +## TL;DR + +The MLIR codegen forces tile sizes so that non-affine index expressions +(`FloorDiv` / `ModularIndexing`, produced by view/reshape/cat) collapse into the +DMA's strictly-affine 4D integer-stride address model. That forcing is a lazy, +greedy, monotonic, restart-based search (the "recompile dance") capped at 5 +retries; when operand constraints conflict or exceed 4D it hard-fails and the +model does not compile. This blocks model coverage, which is the primary goal. + +Proposed fix: stop forcing one affine descriptor. Introduce a **high-level +`togsim.transfer` op** that carries an iteration domain plus `iter->src` / +`iter->dst` affine maps (which may legally contain floordiv/mod), and a +**decomposition pass** that lowers it to a loop of the **existing customized +`memref.dma_start` descriptors** (kept unchanged as the leaf). The non-affine / +high-rank part is peeled into a base-pointer loop instead of being crammed into +one descriptor. This inverts the tile<->DMA dependency (the DMA adapts to the +tile, not the reverse), removes the rank cap, and removes the recompile dance. + +## Problem + +### Root cause: an impedance mismatch + +The DMA address model is `base + sum_i stride_i * idx_i`, with **integer strides, +4D**, i.e. strictly affine/linear. Inductor's index expressions are not: views, +reshapes, `cat`, and broadcasts introduce `FloorDiv` / `ModularIndexing`, which +are non-affine. The codegen copes by searching for a tiling under which the +floor/mod collapses to a linear stride within a tile -- that is exactly what the +ModularIndexing tile constraints ("tile must be a multiple of the floordiv +divisor and a divisor of the modular divisor") encode. + +### The recompile dance (where it breaks) + +`codegen_nodes` (`mlir_common.py`) is a `while True` loop, `max_retry_compile = 5`. +During emission, `get_dma_info` (`mlir_codegen_backend.py`) inspects the index: + +- the **split path** (good): `apply_divisor(axis, divisor, "split")` peels an axis + into two affine dims to represent floor/mod, inserting a `0` into `dram_stride`; +- the **pad path** (fragile): when the tile is not divisible it mutates the tile + (`set_tile_size`, `tile_constraint.fixed = True`) and raises `RecompileSignal` + to restart emission. + +It breaks because: + +1. **One global tile must satisfy every operand's divisibility** on a shared axis. + Fused ops with conflicting constraints (common with reshape/modular indexing) + cannot be satisfied at once -- this is the loop<->tensor mismatch. +2. **Greedy + monotonic + no backtracking + 5-retry cap.** `tile_constraint.fixed` + persists across retries (the tile descriptor lives on `kernel_group`, survives + `reset`), so the search ratchets one way; conflicting fixes oscillate and hit + the cap -> `RuntimeError("Failed to compile kernel after multiple attempts")`. +3. **4D rank cap.** A reshape needing more than 4 affine dims after splitting + raises `NotImplementedError`. +4. **vlane / LMUL entanglement.** Pad-forcing moves `vlane_split_axis` / relaxes + `vlane_stride`, and `compute_vec` must be a power of two; these can be mutually + unsatisfiable with the divisibility constraints. + +Padding logic is currently spread across three places: the Python recompile/tile +-adjust dance (1), the Python `get_mask` vector-tail handling (2), and the MLIR +`TestLoopPadding` pass (3). Removing (3) alone does not fix the fragility; (1) is +the larger source. + +### We already have a de-facto custom DMA op + +`get_dma_code` emits `memref.dma_start` overloaded via string formatting with +extra operands (`dma_type` MVIN/MVOUT, tag, `vlane_split_axis`, `vlane_stride`) +and extra attributes (`dram_stride`, `tile_stride`, padding type). This is a +custom descriptor in all but name -- and it is what Spike / gem5 / TOGSim already +consume. + +## Proposed design + +Two op levels, with a pass bridging them. + +``` +[high] togsim.transfer iteration domain + iter->src / iter->dst affine maps + (maps MAY contain floordiv/mod; rank unbounded) + | decompose-transfer pass (cost-aware peel) + v +[low] scf.for { customized memref.dma_start } <- existing leaf, UNCHANGED +``` + +### Low-level descriptor (keep as-is) + +The existing customized `memref.dma_start` is the lowering target / leaf: affine, +4D, integer stride, simulator-understood. **Do not add maps or floor/mod to it** -- +that would re-create the representational limit and blur the boundary. Optionally +formalize it into a real op (`togsim.dma_descriptor`) with a verifier so the pass +rewrites real ops instead of strings; not required to start. + +### High-level transfer op (new) + +Strawman: + +```mlir +togsim.transfer + ins(%src : memref) // DRAM + outs(%dst : memref<...xf16, 1>) // scratchpad + iter_bounds = [%M, %N, %K] // iteration domain (dynamic via SSA operands) + attributes { + src_map = affine_map<(m,n,k)[s0] -> (m, (n floordiv s0), (n mod s0), k)>, // non-affine lives here + dst_map = affine_map<(m,n,k) -> (m, n, k)>, + vlane_split_axis = 1, vlane_stride = 4, + dma_kind = "MVIN", tag_policy = "async", + peel_plan = [0] // optional: which iter dims to peel (decided in Python; see below) + } +``` + +Design choices: + +- **Iteration domain + two maps, not a single src->dst map.** A direct src->dst + relation only works for bijections; broadcast (`cat([a, a])`) and non-bijective + access need the loop-mediated form. This is the `linalg.generic` model, and peel + becomes "tile the iteration domain." +- **floor/mod ride in the `AffineMap`.** MLIR `AffineMap` supports constant-divisor + `floordiv`/`mod`/`ceildiv` natively; the codegen already produces these as + strings in `convert_index`. Symbolic divisors are semi-affine -- representable, + handled by our pass. +- **`memref`, not raw pointers.** The memref carries base + shape + layout so the + pass can reason about strides; `src_ptr`/`dst_ptr` are inside it. Buffer shape is + the memref type; the transfer region is `iter_bounds` + maps. +- **Closest existing op is `linalg.generic` / `linalg.copy`** (same shape + maps + + body structure) but it lacks DMA/vlane/tag/scratchpad semantics and its tiling + lowers to subview+scf, not our descriptors -- so a custom op modeled on linalg's + design, reusing AffineMap utilities. + +### Decomposition pass (contract): aligned-only mechanical peel + +> **Scope decision (narrowed).** This pass is a **pure mechanical rank peel** of an +> already-affine access. It does **not** linearize floor/mod and does **not** do +> relayout. Those two responsibilities moved upstream (see "Division of labor" +> below): aligned floor/mod is removed by **axis splitting at the Inductor +> scheduling layer** (`axis-split-scheduling.md`), and misaligned access is +> resolved by **graph-level copy insertion**. So every `togsim.transfer` that +> reaches this pass is guaranteed per-axis affine; the only thing left is that its +> rank may exceed the 4D Gemmini descriptor. + +The DMA descriptor is an **affine map of rank <= 4 with integer strides** +(`base + sum_i stride_i * idx_i`). The pass sees affine input (rank `D`) and: + +1. **`D <= 4`** -> emit **one** customized `memref.dma_start`; the dims become the + descriptor's <=4D shape/strides. Identical to today's output (fast path). +2. **`D > 4`** -> peel `D - 4` dims into an outer `affine.for` (marked `inner_loop` + so the TOG pass reads the induction var); each iteration computes the DRAM base + with one `affine.apply` (the peeled dims' linear contribution folded with the + original index) and the **lane-banked physical** SRAM offset (dims outer than the + vlane axis rescaled by the lane coeff -- the MVIN `block_stride` / + `-dma-fine-grained` `buildSramAffineMap` rule, which needs the vector-lane count), + delivered as the **last SRAM index operand**. The offset must go through the index, + not a subview offset: the gemmini lowering reads the spad base via + `extract_aligned_pointer_as_index`, which strips a subview offset. + +That is the whole pass. There is **no linearization step** (upstream guarantees +affine) and **no relayout fallback** (upstream graph copy handles misalignment). + +**Fail loud, not silent.** If the pass encounters floor/mod that does not reduce to +per-axis affine (misaligned), or a genuinely non-affine / indirect / gather index, +that is a **contract violation** -- upstream did not normalize it. The pass +**asserts/errors** rather than silently inserting a relayout. A silent in-pass copy +would be a hidden performance cliff and would duplicate, at the wrong layer, a +global layout decision only the graph can make correctly. + +The decision point maps onto existing code: `get_dma_info` already raises at >4D. +That exact site becomes "emit `togsim.transfer`" (done, Phase 1), and this pass +consumes it. The recompile/tile-forcing dance is unnecessary because (a) aligned +floor/mod is gone before codegen and (b) the outer peel loop's `ceil` bound absorbs +non-divisible remainders. + +### Division of labor (the affine-only contract) + +| floor/mod source | handled by | cost | layer | +|---|---|---|---| +| aligned (single axis, divisor \| extent; group norm, broadcast) | axis split | free | Inductor scheduling | +| misaligned (uneven cat, non-factor reshape, multi-axis arg) | copy insertion | copy | FX graph | +| affine but rank > 4 (e.g. 5D permute) | mechanical peel | free | **this pass** | +| data-dependent / indirect / gather | indirect-indexing path | -- | out of scope | + +Only the third row is this pass. The first two produce the affine-only invariant +this pass relies on. + +### Relationship to memref-to-gemmini (ISA lowering) -- keep separate + +`memref.dma_start` is the boundary, not the endpoint. The layering is: + + togsim.transfer --[Python decompose]--> memref.dma_start --[Python lower_dma_to_gemmini]--> Gemmini ISA + +decompose-transfer stops at `memref.dma_start` and must **not** emit Gemmini +instructions directly; the ISA encoding is a separate pass +(`passes/lower_dma_to_gemmini.py`, which replaced the C++ test-memref-to-gemmini). +Rationale: + +- **Separation of concerns**: decompose does descriptor decomposition (affine + algebra: rank / peel); gemmini does instruction encoding (hardware). Different + axes; merging couples affine logic with ISA detail. They stay distinct passes. +- **`memref.dma_start` is a shared contract** with multiple consumers + (`lower_dma_to_gemmini`, `dma_fine_grained`, `build_tog` -- all now Python + out-of-line passes; the C++ `-dma-fine-grained` / `-test-tile-operation-graph` + are ported). Keeping it as the interface lets all of them stay unchanged. +- **gemmini is now a Python out-of-line pass too** -- the conversion-framework + coupling (LLVMTypeConverter / getStridedElementPtr) was avoided by working at + the memref level (`memref.extract_aligned_pointer_as_index` + arith for + addresses, `llvm.inline_asm` for instructions; the existing standard lowering + finalizes to LLVM). So both decompose and gemmini live in Python; mlir-opt keeps + only the remaining custom passes. + +One constraint flows the other way: gemmini's ISA limits (max dims / size per MVIN) +set decompose's target inner-descriptor shape (the "<=4D" and max-extent bounds). +decompose must *respect* those limits when it picks what stays inner vs gets peeled +-- but respecting a constraint is not doing the lowering. + +### Cost-aware peeling (this is a cycle-accurate simulator) + +Descriptor count is a **modeled cost** (issue overhead + DRAM burst efficiency in +Ramulator). Rules: + +1. Peel the **outermost, lowest-trip-count** dims (descriptor count = product of + peeled extents). +2. Keep the inner descriptor **as large and contiguous as possible** (maximize + bytes per descriptor). + +(A pathological peel is not this pass's problem to fix: it means the operand's +layout is bad, which is a graph-level layout/copy decision, not an in-pass +relayout.) + +### Placement: hybrid (least burden) + +Keep the decision in Python (where shape/sympy info is available and iteration is +fast); keep the C++ pass purely mechanical. + +| Step | Where | +|---|---| +| peel-plan decision (which dims to peel, count estimate) | Python | +| encode plan as op attributes | Python -> MLIR | +| emit `scf.for { customized dma_start }` per the plan | C++ pass | + +The cost model can migrate into C++ later if desired. + +## Expected effects + +- **Removes recompile-dance hard-fails.** The `max_retry` `RuntimeError` path + disappears: access that does not linearize is peeled, not retried-then-killed. + This directly increases model coverage (the primary goal). +- **Removes the 4D rank cap.** Arbitrary-rank reshapes become expressible via the + base-pointer loop; the `NotImplementedError` for >4D goes away. +- **Inverts the tile<->DMA dependency.** Tile size is chosen for compute / vlane + efficiency only; the DMA conforms to whatever access results. No divisibility + forcing, no oscillation. Tile selection simplifies. +- **Shrinks the codegen.** The `FloorDiv` / `ModularIndexing` recompile branches in + `get_dma_info`, and the in-emission `RecompileSignal` paths, leave Python; the + codegen emits one declarative op instead of procedurally forcing tiles. +- **Collapses two of the three padding sites for the DMA case.** Once divisibility + is no longer required to represent access, the Python tile-adjust dance (1) is + unnecessary for DMA, and `get_mask` (2) shrinks. `TestLoopPadding` (3) is + addressed by Plan A. (Compute-side vectorization remainder is separate; see + Plan A.) +- **Behavior-preserving for the common case.** Access without floor/mod still emits + a single `dma_start` identical to today -> low-risk, incremental rollout. +- **Preserves the simulator contract.** The leaf is the existing customized + `dma_start`; Spike / gem5 / TOGSim see the same descriptor kind, just more of + them in a loop. +- **A clean tactical slice toward Plan B.** This factors out exactly the one piece + that is actually broken (DMA decomposition) into a lowering pass, without the + full linalg rewrite. +- **Cost-aware, so modeled performance is protected.** Peel small/outer, keep inner + contiguous. Pathological layouts are fixed upstream (graph copy), not by an + in-pass relayout. + +## Migration strategy + +1. Define `togsim.transfer` (op + verifier) above the existing descriptor. + Optionally formalize the descriptor as `togsim.dma_descriptor`. +2. Make the codegen emit `togsim.transfer` for loads/stores, carrying the access + maps and vlane attributes it already computes. +3. Implement `decompose-transfer` with the fast path first (<=4D affine -> one + `dma_start`), proving **bit-identical output** to today on a smoke test. +4. Add the **affine** peel path for >4D; validate end-to-end through all three + simulators (the loop-of-descriptors must satisfy the TOG / Spike / gem5 + contract). Make the pass **assert** on any non-affine residue (contract guard). +5. Land the upstream producers of the affine-only invariant: aligned axis split at + scheduling (`axis-split-scheduling.md`) and misaligned graph copy insertion. +6. Remove the `get_dma_info` recompile branches once the pass + upstream cover their + cases; use the failure ledger + assert-only `TestLoopPadding` to confirm nothing + regresses before deleting. + +## Relationship to Plan A and Plan B + +- **Plan A (graph-level padding)** reduces how often peeling/relayout is needed by + making dims granule-aligned, and retires `TestLoopPadding`. Complementary: this + op makes representation robust; Plan A reduces constraint frequency. +- **Plan B (linalg)** is the full structured-ops rewrite; `expand_shape` / + `collapse_shape` are the principled home for reshape, and the framework would + generate the same peel/relayout under the hood. This transfer op is the narrow, + now-achievable slice of that idea. + +## Risks / open questions + +- **C++ pass in the `PSAL-POSTECH/llvm-project` fork**: heavier iteration + (rebuild), logic split across two repos. Mitigated by the hybrid split (smarts in + Python, pass is mechanical). +- **TOG / Spike / gem5 contract on a loop of descriptors.** If TOG generation + assumes "one DMA = one node," the loop form needs handling. Validate at step 4. +- **Cost model accuracy** for the peel plan; start with a simple descriptor-count + threshold and refine against measured cycles. +- **Dynamic shapes**: `iter_bounds` as SSA operands; affine peel must handle + symbolic outer extents. (Symbolic-divisor floor/mod normalization is an upstream + concern, not this pass.) +- **Upstream completeness.** The pass's fail-loud contract is only safe if the + upstream producers (axis split + graph copy) actually normalize every misaligned + case. Until they do, the assert may fire on real models -- track which ops trip it + as the work-list for the upstream passes. +- **Async / tag management across the peel loop**: double-buffering / compute + overlap must survive decomposition (e.g. keep the inner large DMA async, sequence + the outer peel). + +## Appendix: alignment theory (when floor/mod is statically decomposable) + +This section records the math that decides, for a given DMA access, whether the +non-affine `FloorDiv`/`ModularIndexing` terms can be peeled into a *static* loop +of affine descriptors (free) or require a data movement (relayout / copy). + +### Setup + +A Gemmini-style descriptor addresses an element as + + addr(idx) = base + Σ_k stride_k · idx_k (integer strides, rank <= 4) + +i.e. each loop index `idx_k` contributes a **constant** stride. A DMA is +statically decomposable iff every index term it reads has constant stride over +the rectangular tile domain. Inductor index expressions, after fusion/view, carry +`FloorDiv(x, y)` and `ModularIndexing(x, y, z)` of the *flattened* loop variable +`x`. The question is when those reduce to constant-stride axes. + +### Mixed-radix decomposition + +Write the flattened index `x` (extent `E`) in mixed radix. For a `ModularIndexing` +with inner period `y` and modulus `z`, decompose uniquely as + + x = o·(y·z) + m·y + r, with 0 <= r < y, 0 <= m < z, o >= 0 + +Then `FloorDiv(x, y) = o·z + m`, and `ModularIndexing(x, y, z) = m`. Each of +`o, m, r` is a separate **implicit axis** with a constant per-axis stride — +*provided the axis boundaries do not move across the tile*. That holds iff the +period divides the extent it partitions: + +- `ModularIndexing(x, y, z)` is a valid rectangular axis **iff y·z | E**. +- `FloorDiv(x, y)` is a valid rectangular axis **iff y | E**. + +**Aligned** = the divisor (and modular period `y·z`) divides the extent, so the +wrap point lands on a fixed axis boundary -> constant stride -> peelable for free. +**Misaligned** = the wrap point falls at a loop-value-dependent position inside the +descriptor (e.g. uneven `cat`, ragged split) -> the stride is not constant -> +**not** statically decomposable; only a relayout (physical copy) fixes it. + +### One loop axis -> several implicit axes (complex fusion) + +When fusion merges many dims into one flattened loop variable, a *single* loop +axis can expand into **several** implicit axes through nested floor/mod, e.g. + + x in [0, D0·D1·D2): + a = FloorDiv(x, D1·D2) # outer + b = ModularIndexing(x, D2, D1) # middle + c = ModularIndexing(x, 1, D2) # inner + +That is three implicit descriptor axes coming from one loop axis. This is the +general case the un-flatten must handle: it is **not** limited to splitting one +axis into two. Key consequences: + +1. **The loop's own factorization is always aligned.** When the implicit axes + come from re-reading the loop's *own* contiguous factorization (the common + fusion case -- Inductor flattens contiguous dims then a consumer reads them + back via floor/mod), every period divides by construction (`D1·D2 | D0·D1·D2`, + etc.). So these un-flatten splits are **free** -- they just add descriptor + axes, never a copy. +2. **Rank blows past 4 fast.** k implicit axes per loop axis, across multiple + operands, means the descriptor rank exceeds the 4D Gemmini limit very quickly. + This is exactly why `togsim.transfer` + the peel pass matters *more* under + complex fusion, independent of any misalignment. The >4D branch in + `get_dma_info` already routes these to `togsim.transfer`. +3. **Misalignment is still only from non-factor views.** An implicit axis is + misaligned only when its period does not divide the extent -- i.e. the view + does not factor along the loop's factorization (uneven `cat`, ragged split, + group sizes that don't divide the channel count). Those, and only those, need + relayout. + +### Case-handling summary + +| Source of floor/mod | Aligned? | Handling | Cost | +|------------------------------------------------|----------|-----------------------------------|------| +| Broadcast / dim-merge (`[N,1]->[N,M]`, `i//M`) | always | un-merge (split loop axis back) | free | +| Reshape along the loop's own factorization | yes (`y·z\|E`) | un-flatten split, then peel for rank | free | +| >4D logical tile from complex fusion | yes | `togsim.transfer` -> peel into <=4D loop | free (extra DMA nodes) | +| Uneven `cat`, ragged split, non-dividing group | no | graph copy insertion (relayout, upstream) | copy = TPU `concatenate` | + +The TPU/XLA model is the reference: express only aligned views as +descriptor/bitcast (free reshape); never put a misaligned access in the +descriptor -- insert a copy (relayout) instead. Plan A (graph-level +force-contiguous / pad-to-granule, like XLA copy-insertion) is the upstream lever +that *reduces how often* the misaligned branch fires, keeping codegen affine-only. + +## Implementation status (Phase 1: codegen emission) + +Landed on branch `dma-transfer/codegen` (worktree), emission only -- the +decompose pass is deferred until explicitly signalled. A >4D access now emits a +`togsim.transfer` instead of hard-failing; without the pass it does not yet run +end-to-end (expected). + +- **`mlir_common.py` `init_tile_size`** generalized to any rank. Logical tile is + separated from the physical (<=4D) descriptor: only the innermost dims carry the + vectorized tile, all further-outer dims stay 1, and there is no rank cap. The + `nr_dim >= 3` formula reproduces the old 3D/4D values exactly (the old `[-4]=1` + is subsumed by "outer dims stay 1"); scalar/1D/2D keep their special cases. This + removes the old `raise NotImplementedError("dummy tile size fail!")` that + conflated logical and physical tile rank. +- **`mlir_codegen_backend.py`**: + - `__init__` adds `self._dma_needs_transfer = False`. + - `get_dma_info` >4D `else` branch (was + `raise NotImplementedError("Currently not implemented... ;)")`) now builds the + full N-D tile (`set_tile_size`, vlane split/stride) and sets + `self._dma_needs_transfer = True`. + - `emit_transfer(...)` emits the generic-form `"togsim.transfer"(...)` op + carrying `dma_kind`, `vlane_split_axis`, `vlane_stride`, `dram_stride`, + `tile_stride`, `padding`, with operands `(dram, dram_idx, sram, 0, tag)`. + `togsim` is an unregistered dialect, hence generic form. + - `load()` (MVIN) and `store()` (MVOUT) check the flag: if set, reset it and + call `emit_transfer`; otherwise the existing `get_dma_code` path is unchanged. + So aligned <=4D DMAs are **bit-identical** to before; only >4D accesses change. + +Validated: the 5D permute smoke test (`x.permute(4,3,2,1,0).contiguous() + 1.0`) +now emits MVIN/MVOUT `togsim.transfer` with 5D `dram_stride [1,6,30,120,360]` and a +`memref<1x1x2x4x2xf32,1>` tile, instead of crashing in `init_tile_size` or the +`get_dma_info` >4D branch. + +### Phase 2: aligned-only peel pass (landed: unit-collapse path) + +`passes/decompose_transfer.py` (registered in `passes/__init__.py`, runs before +`lower_vlane_idx`) lowers each `togsim.transfer` to a customized `memref.dma_start`: + +- **Unit-dim collapse (done, validated).** Drop extent-1 tile dims so the + descriptor reaches <=4D. The SRAM (spad) memref is collapsed to the effective + rank via `memref.collapse_shape` (the customized `dma_start` convention requires + SRAM rank == #indices == len(sram_stride)); DRAM stays flat rank-1 with its N-D + structure in `dram_stride`. The `vlane_split_axis` is **remapped** from the + original tile-dim index to the collapsed-dim index and rematerialized as a const + (carried as a value attr precisely so the pass can remap it). +- Supporting changes: `emit_transfer` now carries the SSA operands a `dma_start` + needs (`dma_type`, `vlane_stride`) + the `vlane_split_axis` value attr, so the + pass is mechanical. `lower_to_llvm.py` gains `expand-strided-metadata` to lower + `collapse_shape`. + +Validated end-to-end (Gem5 + Spike + TOGSim, `allclose=True`) on the 5D permute +`x.permute(4,3,2,1,0).contiguous() + 1.0`; no regression on 2D/3D/elementwise. + +- **Genuine >4 effective rank (affine.for peel; #258 resolved).** When >4 *non-unit* + dims survive, the pass keeps the inner 4 as the <=4D descriptor and peels the outer + dims into an `affine.for` nest (marked `inner_loop`), emitting one inner descriptor + per iteration -- mirroring the `-dma-fine-grained` subtile loop. The DRAM base is + `affine.apply(dram_idx + sum_k iv_k * dram_stride_k)` (one apply, not `arith.addi`, + so the TOG pass walks the loop index through it). The SRAM slice offset is the + **lane-banked physical** offset (split-outer dims rescaled by the lane coeff) + delivered as the **last SRAM index operand**, *not* a `memref.subview` offset -- + `extract_aligned_pointer_as_index` in the gemmini lowering strips a subview offset, + which is why the earlier full-unroll + subview attempt produced wrong data and the + C++ TOG read an empty `loop_idx_list` (#258). + + The earlier full-unroll + subview form was isolation-only and INCOMPATIBLE with the + TOG; the `affine.for` rework (this is exactly the #258 TODO) fixed both the TOG + read and the numerics, so the axis-split rank guard was removed. Validated + end-to-end (Gem5 + Spike + TOGSim, `allclose=True`) on `pixel_shuffle(x, 2) + 1.0` + (5D tile) plus the gemm/bmm/conv/model suite. + +The input stays per-axis affine by upstream guarantee. A non-affine residue is a +contract violation (aligned floor/mod +removal lives in `axis-split-scheduling.md`, misaligned relayout in graph copy +insertion -- see "Division of labor"); a genuinely non-affine / indirect index +would surface as a build failure here rather than being silently relaid out. diff --git a/docs/linalg-codegen-migration.md b/docs/linalg-codegen-migration.md new file mode 100644 index 00000000..9a60ba31 --- /dev/null +++ b/docs/linalg-codegen-migration.md @@ -0,0 +1,224 @@ +# Linalg-based codegen migration (Plan B) + +Status: **deferred / design only**. This is not scheduled work. It records *why* +a linalg-based rewrite of the MLIR codegen will eventually be worth doing, and a +rough plan, so the decision does not have to be re-derived from scratch later. + +For the near-term padding/dynamic-shape work, see Plan A (graph-level padding) in +the "Relationship to Plan A" section below. Plan A is the one to do first; it does +not depend on this document. + +## TL;DR + +The current MLIR codegen (`PyTorchSimFrontend/mlir/`) does not just emit loops — +it hand-implements the entire hardware mapping (tiling, vectorization, DMA, +scratchpad allocation, vector-lane distribution) as Python string emission. That +works, but it entangles three concerns that should be separable: + +1. **what to compute** (the op math), +2. **how to map it onto the NPU** (tile sizes, vlane layout, DMA/scratchpad), +3. **how to make shapes fit the hardware** (padding / divisibility). + +Concern 3 is currently spread across three places and is the source of the +"padding is heuristic and fragile" pain. Plan B factors concern 1 up to the +`linalg` dialect and rebuilds concern 2 as a set of MLIR lowering passes, so that +concern 3 falls out of the structured representation instead of being patched. + +Plan B is a multi-month, higher-risk effort because the hardware mapping (concern +2) is bespoke and has no upstream equivalent. Do it when the payoff (separation, +reuse of upstream tiling/vectorization/fusion, easier autotuning, lower codegen +maintenance) is worth that cost — not as a means to fix padding alone. + +## Where we are today + +Entry point: PyTorchSim is an **Inductor backend**. Inductor handles capture, +decomposition, lowering to Inductor IR, scheduling, and **fusion**. Our code runs +at the codegen (step-5) stage and turns scheduled `SchedulerNode`s into MLIR. + +`MLIRKernel` (`mlir/mlir_codegen_backend.py`) and friends emit, by hand: + +- explicit DMA: `memref.dma_start` with MVIN/MVOUT encoding, `vlane_split_axis`, + `vlane_stride` (`load`, `store`, `get_dma_code`); +- explicit scratchpad: `.spad` sections and `memref.global @bufN_spad` + (`allocate_sram_buffer`, `get_scratchpad_buffer`); +- explicit vector-lane (vlane) distribution: `vmap.vlane_split_axis` / + `vlane_stride`, `vlane_offset`, `get_used_vlane` (`_index_expr`, + `get_dma_info`); +- explicit tile descriptors: `MLIRMultiDimTile` with per-axis tile sizes; +- explicit reduction loops: manual accumulator/iterator vars + `affine_yield` + (`reduction`, `codegen_loops`); +- explicit vector-tail masking: `get_mask`. + +This is roughly 5,500 lines across `mlir_codegen_backend.py`, `mlir_common.py`, +`mlir_template.py`, and `mlir_ops.py`, plus per-op templates +(`mlir_gemm_template.py`, `mlir_conv_*`, `mlir_bmm_template.py`, +`mlir_sdpa_template.py`, `mlir_sort_template.py`, `mlir_cat_template.py`, +`mlir_maxpool_template.py`). Most of it encodes how this specific NPU's memory +hierarchy and vector unit work. **That accumulated hardware knowledge is the asset +and the cost center for any rewrite.** + +### Padding lives in three places + +This is the key observation motivating the separation. Divisibility / padding is +handled by: + +1. **Python recompile + tile adjustment** — `get_dma_info` FloorDiv / + ModularIndexing handling, `_index_expr`, `convert_indirect_indexing`: "if the + tile size does not divide the dim, bump the tile and raise `RecompileSignal` to + recompile." This is the actual heuristic-padding body, and it is in Python, not + MLIR. +2. **Python `get_mask`** — vector-tail masking for the innermost compute loop. +3. **MLIR `TestLoopPadding`** pass (in the `PSAL-POSTECH/llvm-project` fork) — + rounds affine loop bounds up to a multiple of the step and resizes buffers, + reverse-engineering the loop<->tensor mapping from affine maps. + +Removing only (3) does not fix the fragility; (1) is arguably the larger source of +"the loop and the tensor do not line up." Any real fix has to address all three. + +## Target architecture (Plan B) + +Split the codegen into two layers with a clean boundary: + +``` +L0 ATen / FX logical graph (dynamic dims as SymInt) <- Inductor, unchanged +L1 math layer: Inductor IR -> linalg.generic / named ops (untiled, unvectorized) + = iteration domain + affine indexing maps + scalar body. + Decides nothing about tiles / lanes / DMA. +L2 mapping layer: tiling(+pad) -> vectorize(+vl) -> bufferize + -> DMA / scratchpad / vlane lowering -> leaf replacement + = hardware mapping, parameterized by a target description. +L3 LLVM / RVV / systolic microkernel +``` + +Why this helps: + +- **The loop<->tensor mapping stops being reverse-engineered.** `linalg` ops carry + `indexing_maps` + `iterator_types`, i.e. exactly the information `TestLoopPadding` + tries to recover. Padding becomes a *parameter of the tiling transform* + (`tensor.pad` generated with full knowledge of the maps), not a separate + analysis pass. +- **The padding strategies we want become per-axis policy** in L2: systolic + operand axes -> pad-to-uniform-tile (keeps a single 128x128 microkernel and a + single gem5 latency entry); VPU / vector axes -> RVV `vl` (no padding); reduction + axes -> `affine.min` clamp. One mechanism, selected per axis, instead of three + scattered implementations. +- **Fusion policy stays in Inductor** (its scheduler decides what is one kernel), + while the *mechanism* is upstream `linalg` tile-and-fuse. Pointwise epilogue + fusion is essentially free because Inductor already composes `inner_fn`s into a + single fused body -> one `linalg.generic`. +- **Dynamic shapes** are carried as `?` dims + symbolic affine, uniformly handled + by L2 rather than by the recompile dance. +- **The cost model gets cleaner, not harder**: tile shape becomes an explicit + attribute, so the gem5 latency table / TOG key on it directly instead of on + inferred loop shapes. + +### What is reusable vs bespoke + +- **Reusable from upstream MLIR**: `linalg` ops, tiling + `tensor.pad`, + vectorization, bufferization, the TilingInterface. The L1 translation + (Inductor IR -> `linalg.generic`) is a *generic* translator (one path covers all + regular pointwise/reduction), not per-op work — Inductor IR is already in + structured iteration-domain + scalar-body form. +- **Bespoke, must be (re)written as MLIR passes**: MVIN/MVOUT DMA encoding, + `.spad` scratchpad assignment, and the vlane_split mapping. **These have no + upstream equivalent.** This is the bulk of the effort and the main risk: the + knowledge currently in ~5,500 lines of Python emission must be re-expressed as + custom bufferization-to-DMA / scratchpad / vlane-vectorization lowerings. + +## Expressibility boundary + +A regular (linalg.generic) op needs: a fixed rectangular iteration space; every +operand index an **affine** function of loop vars (no data-dependent indexing); +each axis purely `parallel` or a simple `reduction` (no scan/recurrence); a +statically-determined output shape (dynamic `?` ok, data-dependent shape not); and +a data-independent body (`arith.select` ok, data-dependent branching not). + +Expressible: elementwise, broadcast, transpose, reductions (incl. multiple +reduction axes), matmul / bmm / contractions, direct conv, pooling, and fused +chains of these (matmul+bias+activation, prologue cast/dequant/transpose, +pointwise->reduce). `slice` / `pad` / `cat` are structured `tensor` ops (not +`linalg.generic`) but are supported by the same pipeline. + +Not expressible -> stay as hand-written custom kernels: data-dependent indexing +(gather/scatter/embedding), sort/topk, data-dependent output shape +(nonzero/unique/masked_select), scan/recurrence (cumsum), and online/streaming +algorithms (flash-attention). In our op set this means **`sdpa` (online softmax) +and `sort` remain custom**; gemm/conv/bmm/maxpool are regular; `cat` is a +structured tensor op. + +## Migration strategy (when Plan B is scheduled) + +Incremental, op-by-op, with a numeric and a structural safety net. Do **not** +big-bang. + +1. **Stand up the L2 pipeline for one op (matmul first).** Emit `linalg.matmul` + from the matmul path; wire tiling(+pad, pad_value=0) -> bufferize -> a custom + pass that lowers the 128^3 leaf tile to the existing systolic intrinsic -> + LLVM. Milestone 1 is end-to-end correctness through all three simulators + (Spike functional, gem5 latency, TOGSim cycle) for a single matmul. +2. **Demote `TestLoopPadding` to assert-only** (check, do not modify; fail/log if a + loop bound is not a multiple of its step). Run the full test suite; anything it + flags is a case L1/L2 has not covered yet. +3. **Migrate the remaining regular ops** (conv, bmm, pointwise, reductions, + maxpool). Pointwise/reduction go through the generic L1 translator; VPU + remainder via `vl`. +4. **Delete `TestLoopPadding`** once the assert-only version is silent across the + suite, and retire the Python recompile/tile-adjust dance and most of `get_mask`. +5. Leave `sdpa` and `sort` as custom kernels that bypass L1/L2. + +### Risks + +- **Simulator-facing contract.** The current emission is tuned to produce exactly + the LLVM / TOG shape the three simulators expect. `linalg`'s standard lowering + emits different IR; re-validating the lowered artifact end-to-end (especially TOG + generation, which may assume specific loop/memory patterns) is the real + integration risk. This is why milestone 1 is "one matmul, end-to-end," not "all + ops, emission only." +- **Re-encoding the hardware mapping.** DMA/scratchpad/vlane lowerings are new code + with no upstream reference; budget for them dominating the schedule. +- **Inductor index expressions.** Inductor often collapses dims into one flat index + with `FloorDiv` / `ModularIndexing`, which are not affine; `linalg` indexing_maps + must be affine. Either keep dims uncollapsed or normalize div/mod back to + multi-dim affine. (We already convert these to affine strings today in + `_convert_sympy_to_mlir_expr`, but that path will need to be revisited for the + map-carrying representation.) +- **Fusion seams.** Not everything fuses cleanly (reductions with mismatched axes, + transpose/layout mismatches); expect some barriers, same as any framework. + +## Relationship to Plan A (graph-level padding) — do this first + +Plan A inserts padding at the FX/graph level (via Inductor's +`post_grad_custom_pass`) so that tiled dims arrive at codegen already aligned +(`tile_granule * symbol`). Under the hard constraint that we **keep the Inductor +spine**, Plan A is the high-ROI move: + +- It collapses all three padding sites at once: the recompile/tile-adjust dance (1) + becomes unnecessary (tiles always divide), `get_mask` (2) becomes trivial (no + tail), and `TestLoopPadding` (3) becomes unnecessary. +- It does **not** touch the bespoke DMA/scratchpad/vlane mapper. + +Key correctness facts that make Plan A tractable: + +- Matmul contraction (K) padding with zeros is *exact* (additive identity); weights + are constants, so they can be zero-padded once, offline, at no runtime cost. +- Padding only ever corrupts results when a *non-contraction* padded axis is later + reduced (softmax over keys; layernorm if hidden is padded). Those points need + masking; everything else is pad-transparent. +- Safety rule: default any op to slice-back-to-real-shape; only opt an op into + "propagate padded shape" once it is proven pad-transparent or given a mask + handler. Correct-by-construction; unknown ops cannot silently corrupt. + +Plan A and Plan B are compatible: Plan A's graph-level alignment makes the eventual +Plan B simpler (L2 tiling rarely needs to pad, because dims already divide). + +## Open questions + +- Does the current toolchain (the `PSAL-POSTECH/llvm-project` fork) already ship the + `linalg` + transform/tiling passes, or were they stripped? (Almost certainly + present if it tracks upstream — verify before committing.) +- Can the systolic leaf be expressed cleanly as a match-and-replace on a fixed-size + `linalg.matmul`, or does weight-stationary loading order force a more custom + representation? +- How much of `mlir_ops.py` (the scalar `OpsHandler`) survives? It currently emits + *vectorized* ops (compute_vec_size, broadcast) and is therefore entangled with + vlane; the linalg body should be scalar, with vectorization done in L2. diff --git a/docs/loop-padding-elimination.md b/docs/loop-padding-elimination.md new file mode 100644 index 00000000..cfb9ce3b --- /dev/null +++ b/docs/loop-padding-elimination.md @@ -0,0 +1,344 @@ +# Padding model + retiring test-loop-padding (two-layer: alignment vs compute-tile) + +Status: **decided**. Earlier drafts argued for *eliminating* padding via variable-extent +DMA (rejected), then for *porting* the pass as-is (also wrong -- it over-materializes). +The settled answer (grounded in `docs/tpu_layout_padding_report.md`): padding is **two +layers** -- (A) lane/sublane alignment is materialized traffic, (B) compute-block tail is +masked compute-util -- and `test-loop-padding`'s post-codegen heuristic is replaced by +informed emission at the scheduling/codegen layer. See "RESOLVED MODEL" below for the +authoritative conclusion; the earlier sections are the analysis trail. + +## DECISION: the modeled NPU has no partial-extent DMA -> padding is fundamental + +We model an NPU whose DMA **always moves full tiles** (TPU-class dense movement); it +does **not** do partial / variable-extent transfers. Therefore: + +- Padding (full-tile DMA over padded buffers) is a **real architectural cost**, not a + simulator convenience. Moving the padding bytes is what the hardware actually does. +- "Handle the tail instead of padding" (variable-extent DMA, boundary clamp) would model + a *different machine* (one with partial-transfer DMA) and would under-count traffic / + cycles. **Rejected.** +- You cannot have "logical DRAM + full-tile boundary traffic" -- moving a full tile + requires the padded bytes to exist in DRAM. So eliminating the padded buffer is + incompatible with the full-tile-DMA model. The two are linked. + +Consequence: **keep padding, but do NOT port the current mechanism** -- reimplement it +at the layer that has the information. Why the current C++ pass is fundamentally wrong +(not just buggy): + +- It runs **after codegen** and **reverse-engineers** the padding need from the emitted + IR: it walks `affine.for` step sizes, `dram_stride`, and `affine.apply` maps to *guess* + which memrefs to grow, by how much, and how to rewrite addressing. The info it needs + (tile size, tensor shape, which dims are reduction vs parallel, the access map) was + **known at codegen time and thrown away**, then heuristically reconstructed. +- That reconstruction is inherently partial: hardcoded conv geometry (k_h/k_w/o_h/o_w), + "find the stride", coefficient-from-dim guessing, etc. New op patterns / multi-dim / + edge cases break it. **It cannot be shown to cover all cases** -- it is a heuristic + retrofit, not a derivation. + +Correct plan: **decide padding at the scheduling / codegen layer**, where the tile `T` +and extent `E` are *known*. When `T` does not divide `E`, the codegen knows +`E' = ceil(E/T)*T` directly and emits the padded buffer + full-tile loops/DMA **by +construction** -- no post-hoc IR analysis, no guessing. This eliminates the heuristic +pass (`test-loop-padding` -> gone, mlir-opt drops out) while **preserving padding** +(padded buffers + full-tile DMA traffic) and being robust by derivation rather than +inference. + +This is "scheduling-level padding" -- but to *produce* padding correctly, NOT to +eliminate it (the earlier tail-handling framing). Padding stays; only the mechanism +moves from a fragile post-codegen heuristic to direct emission from the tiling decision. + +Also resolved for free: the CI robustness bug (current pass uses `emitError` without +`signalPassFailure` -> error paths exit 0 and silently drop `@wrapper_kernel` -> +cryptic `undefined reference to wrapper_kernel` at link under autotune + unseeded +Poisson). Direct emission has no such silent-guess failure mode. + +## RESOLVED MODEL: padding is TWO layers (see tpu_layout_padding_report.md) + +The TPU layout/padding investigation (`docs/tpu_layout_padding_report.md`) settles the +debate: padding is **not one thing**. There are two layers with *different cost +semantics*, and conflating them was the source of the back-and-forth above. + +| layer | what | on real TPU | PyTorchSim cost | +|---|---|---|---| +| **(A) lane/sublane alignment** (8x128; T(2,128)/T(4,128) for small 2nd-minor; bf16 T(8,128)(2,1)) | tile must be address-aligned -> tensor stored padded in HBM | **materialized** (`nofold`); tensor physically bigger, unavoidable | **footprint + DMA traffic** (padded bytes are stored AND moved) | +| **(B) compute-block (MXU tile) boundary** (>8x128) | the contraction/output block tail when a dim isn't a multiple of the MXU block | **masking / peeling** (mostly NOT materialized); MXU computes zeros then masks output | **compute utilization only** (wasted MXU cycles) -- **NOT traffic** | + +This corrects both earlier extremes: +- "eliminate all padding" (tail-handling) -- WRONG: (A) is materialized real traffic. +- "materialize all padding" (current loop-padding) -- WRONG: it buffer-grows (B) too, + so it **double-counts (B) as traffic**; TPU masks (B). loop-padding over-materializes. + +### The two-cost-function rule (report 7.1 -- the key modeling constraint) +- **footprint / HBM traffic** function: count ONLY (A) lane/sublane alignment padding as + physical size (e.g. extent 100 -> stored/moved as 128). Reflect bf16 packing and the + small-2nd-minor T(2,128)/T(4,128) variants. +- **compute-utilization** function: (B) compute-block tail lowers MXU utilization via + masking; **do not add it as traffic** (would over-estimate bandwidth). Only the rare + alignment-forced `tensor.pad` materialization adds copy traffic. +- Pipeline ordering (report 1): the layout **decision** (which axis is lane, how much + alignment padding) is early/metadata; **materialization** is late. Matches "decide at + scheduling, materialize at codegen." + +### Corrected plan for test-loop-padding +Reimplement at the scheduling/codegen layer (informed by `tile_desc`, which already has +the vlane axis + tile sizes), splitting the two layers: +- **(A)** materialize lane/sublane alignment padding -> the padded staging buffer + + full-tile DMA (this is the `wrapper_kernel`-style staging; structurally necessary + because the real DRAM tensor is logical -- you cannot pad it in place). Counts as + footprint + traffic. +- **(B)** handle the compute-block tail by **masking** (`get_mask` already exists) -> + compute-util only, NOT a buffer grow, NOT extra traffic. +Then the post-codegen heuristic `test-loop-padding` is gone, padding is faithful per +layer, and the modeled hardware is unchanged. (Open: whether to force (A) to fixed +(8,128)/T(packing,128) granularity -- robust, matches TPU -- vs minimal `ceil`.) + +Validation (report 2.3 / 7.4): dump real XLA layouts with +`XLA_FLAGS="--xla_dump_to=... --xla_dump_hlo_as_text=true"` and read the `:T(...)` +annotations to ground the lane-axis + alignment-padding model against the compiler, +rather than guessing. + +--- +(Below: the earlier elimination analysis, retained as the record of WHY tail-handling +was rejected. The "fundamental vs not" framing still correctly explains the *mechanics*; +the conclusion -- that padding is eliminable -- is overturned by the DECISION above, +because on a full-tile-DMA NPU the padding traffic is a real cost that must be modeled.) + +## Original goal (SUPERSEDED -- kept for the analysis trail) + +`-test-loop-padding` is the only C++ MLIR pass still invoked in `mlir-opt` (after +build_tog, dma_fine_grained, lower_to_vcix, lower_dma_to_gemmini, lower_vlane_idx were +ported to Python). The original goal was to **eliminate it** by handling +tile-vs-extent misalignment at the scheduling layer. (Superseded: see DECISION -- we +port, not eliminate.) + +## What test-loop-padding does today (fact, from TestLoopPadding.cpp) + +Runs on the post-codegen MLIR of `@kernel`: + +1. For each `affine.for`, round the upper bound **up to a multiple of its step** + (= tile size): `paddedUpperBound = roundUpToMultiple(upperBound, stepSize)`. +2. For every DRAM `memref` indexed by a padded loop: **resize the memref** to the + padded extent (`modifyMemrefWithPadding`), **update the func signature** + (`updateFunctionSignatureWithMemRef`), and **rewrite `dram_stride` + the + `affine.apply` addressing maps** so the addressing matches the larger buffer. + Has a conv2d-specific path (nested `affine.apply` over k_h/k_w/o_h/o_w). +3. `timing_mode=1`: skip copying the padding region (cycles only, no real data). + +Net: loop trip counts and the **DRAM-side buffers** are grown to aligned sizes after +codegen, with addressing rewritten to match. + +## Where padding lives today (the split = the problem) + +| layer | mechanism | what it pads | +|---|---|---| +| Python tile selection (`mlir_common`: `apply_divisor("pad")`, `pad_vlane_tile`, `roundup_vectorlane`) + recompile-dance (`RecompileSignal`) | pads the **tile size** to vlane/divisor multiples; forces tiles via restart | the tile shape | +| Python `get_mask` (vector tail) | masks the unaligned tail **within a tile** for vector ops | partial-tile compute | +| MLIR `test-loop-padding` | rounds **loop trip count** + grows **DRAM buffer** + rewrites strides/maps | the iteration domain + DRAM side | + +Three mechanisms, three layers, one underlying concern (tile does not divide extent). + +## The mismatch model + +A dim has logical extent `E` and a chosen tile `T`. If `T | E`, everything is +aligned and loop-padding is a no-op. If `T does not divide E`, the last tile is +partial (`E mod T` elements). Two ways to make the hardware see full tiles: + +- **Pad**: treat the extent as `E' = ceil(E/T)*T`; the loop runs `E'/T` full tiles; + the tail tile covers padding (garbage / zero) that must not corrupt results. +- **Mask**: keep `E` and mask the partial tile so padding lanes/rows are inert. + +Today: tiles are padded in Python, the within-tile tail is masked (`get_mask`), and +the loop+DRAM are padded in MLIR. We want one coherent story. + +## Why padding exists -- what is fundamental vs not + +Three *distinct* things happen at a tile boundary; only one is layout-fundamental, and +it is not what loop-padding does. + +1. **Parallel-dim tail** (M, N, output spatial): the last tile just produces fewer + outputs. Process fewer -- no value-fill, no padding; only "don't read/write past the + real extent." +2. **Reduction-dim tail** (matmul K, reduce axis): the inactive elements must contribute + the **reduction identity** (0 for sum, -inf for max) or the result is corrupted. This + is real -- but it is satisfied either by buffer value-fill (loop-padding) OR by + masked/identity-fill at compute/push granularity. Evidence the latter already exists: + the matmul vcix lowering pushes `zero_vector` for the K-tail (`i >= K`), and + `get_mask` masks vector-reduction tails. So no *buffer* padding is required for this. +3. **DMA boundary**: a full-tile transfer would run past the real tensor -> **clamp the + transfer extent** (Q1=(c)). No buffer growth. + +loop-padding bundles #2 (value-fill) + #3 (buffer grow) into one post-codegen step to +avoid per-tile tail logic. **None of that is fundamental** -- (c) replaces it with tail +handling (extent clamp + mask/push-zero). So: "handle the tail well and you don't need +separate padding" is correct, for everything loop-padding does. + +### The one genuinely-fundamental padding: lane / sublane granularity + +TPU pads to **(8, 128)**. These are two *different* kinds of padding: +- **128** = the MXU systolic dimension. Full tiles are a *throughput* choice, and the + contraction tail is identity-fillable (see #2). Tail-handleable; not fundamental as + buffer padding. +- **8** = the VMEM **sublane** -- memory is physically tiled in (8, 128) banks. This is a + **physical layout granularity**: a dimension laid across lanes/sublanes must be a whole + number of lane-tiles. You cannot have a "partial lane" in a lane-banked memory, and + **masking does not help** -- the issue is the data's physical placement, not which + compute lanes are active. + +PyTorchSim mirrors this: the SRAM scratchpad is lane-banked (`vlane_idx * vu_sram_byte`; +see the MVIN layout). `pad_vlane_tile` / `roundup_vectorlane` pad the vlane-split tile dim +to a multiple of the lane count. **This padding is layout-fundamental and is RETAINED in +(c)** -- but it is *internal SRAM* padding (the spad is over-allocated per lane), cheap, +and never touches DRAM. + +Conclusion: padding splits into +- (a) **layout-fundamental** lane/sublane padding (the TPU "8") -> KEEP, internal SRAM, + cheap; and +- (b) **compute/bound** padding (reduction identity + DMA bounds; the TPU "128" + DRAM + buffer grow) -> NOT fundamental, replaced by tail handling. + +`test-loop-padding` is entirely in category (b) -> eliminable. So TPU pads to 128/8 not +because tail handling is *impossible*, but because (i) for 128 it is a dense-throughput +design choice (TPU avoids runtime masking), and (ii) for 8 it is a real lane-banked +*memory layout* constraint -- which we already satisfy with cheap internal SRAM padding, +independent of loop-padding. + +## Proposed design (sketch -- to refine together) + +At the scheduling / codegen-prep layer, when a dim's extent is not a multiple of its +chosen tile: + +1. Compute `padded_extent = roundup(extent, tile)` at tile-selection time (the layer + already knows the tile). +2. Emit the loop nest with **padded bounds** and size the spad/DRAM tile descriptors + to `padded_extent` -- i.e., produce what loop-padding produces, but at emit time, + so the addressing maps / `dram_stride` are correct by construction. +3. Reuse `get_mask` for boundary correctness (no real data movement / compute on the + padding region). +4. `test-loop-padding` then has nothing to do -> delete it; drop `-test-loop-padding` + from `extension_codecache`; `mlir-opt` is gone. + +## Design decisions + +**Q1 (crux): DRAM buffer resize vs the real tensor. RESOLVED -> (c) hybrid.** +loop-padding grows the DRAM function-arg `memref`; the real `npu` tensor is +logical-sized. Decision: **pad the SRAM (spad) tile fully** (cheap, internal -- the +spad is already over-allocated per-lane), **keep DRAM logical**, and **clamp the +boundary tile's DMA** to the real tail so no OOB DRAM access happens. No device +buffer growth, no func-signature/stride rewrite -> genuinely scheduling-level. +(Rejected: (a) over-allocating the device DRAM buffer -- leaks into PyTorchSimDevice +allocation and needs a logical/padded shape bridge; (b) was the same as (c) minus the +explicit SRAM-full-padding framing.) + +Consequence: the loop still iterates `ceil(E/T)` tiles (padded trip count) and the +spad tile is full `T`, but for the **last tile** the DMA moves only `E - (ceil(E/T)-1)*T` +real rows; the spad tail rows are garbage, kept inert in compute by `get_mask`. + +### PRECONDITION: index/data-dependent tail handling + +(c) means the **last (tail) tile must behave differently** from the rest: move only +`E - (ceil(E/T)-1)*T` real rows from logical DRAM, leave the spad tail inert. That is +an **index-dependent operation** -- behavior varies with the loop induction variable. + +- **Compute side already supports this.** `get_mask` builds a per-iteration predicate + `step_vec < (upper_bound - compute_idx)` (depends on the loop index), masking the + tail lanes. Index-dependent masking already exists for vector compute. +- **DMA side does NOT.** The DMA transfer length today = the spad tile shape, a + **compile-time constant**. loop-padding exists precisely to avoid a variable-length + DMA: it grows DRAM so even the boundary reads a full `T`. Remove loop-padding and the + boundary DMA must transfer a **variable (index-dependent) length** -- a capability the + customized `memref.dma_start` / Spike MVIN does not have today. + +**So the gating precondition is: the DMA must support an index-dependent transfer +extent** (move `min(T, E - i*T)` rows for tile `i`). Establishing that is the real +foundation of this work; without it, scheduling-level (c) padding cannot be expressed. + +### Q1a: how to satisfy the precondition + + - **Variable-extent DMA (data-dependent).** Extend the customized `memref.dma_start` + + lowering + Spike MVIN to accept a runtime transfer length, and emit + `affine.min(T, E - i*T)` per tile. General (scales to multi-dim, any extent), + uniform loop body. Cost: a real hardware-model + descriptor extension. This is the + capability the precondition names. + - **Static tail-peel.** `floor(E/T)` full-tile iterations + a separate compile-time + partial DMA. No variable-length DMA needed, but **combinatorial in the number of + unaligned dims** (2^k corner DMAs) -- the same blow-up that made the decompose + unroll-peel a dead end (#258). Does not scale; rejected as the general path. + Leaning: the variable-extent DMA is the principled answer -- it is the missing + capability, and it generalizes. + +### Finding: variable extent is a codegen change, not a Spike change + +The transfer dim sizes are packed into the CONFIG instruction's **rs1 register** +(`lower_dma_to_gemmini.py:144`): today `cfg_rs1 = i64_const((shape4[0]&0xFFFF)<<48 | +... )` -- a compile-time constant. But `asm(CONFIG, cfg_rs1, cfg_rs2)` passes rs1 as a +**register operand**, and Spike reads the dim sizes from it into `P.VU.dma_dim_size`. +So the hardware model already takes the extent at runtime; today's codegen merely feeds +a constant. + +Implication: the variable-extent precondition is satisfiable **at the codegen / +lower_dma_to_gemmini layer, with no Spike change** -- emit `cfg_rs1` as a *computed* +value (pack a runtime dim size `min(T, E - i*T)` into the 16-bit field via arith) +instead of `i64_const`. The boundary tile then moves only the real tail; the spad stays +fully padded; DRAM stays logical. This is exactly (c). + +Evidence (strong): the CONFIG instruction is `CUSTOM_1` funct7=0 with rs1/rs2 as +**register operands** (`asm(func7, rs1, rs2)` -> `.insn r CUSTOM_1, 0x3, func7, x0, +$0, $1`); the MVIN reads the dims from `P.VU.dma_dim_size`, which is runtime VU state +set by that CONFIG insn. A config insn reading rs1 register bits into dma_dim_size is +the only sensible implementation -- so runtime-variable extent is already supported by +the model; today's codegen just feeds a constant rs1. (One file unread: the exact +torchsim config insn in riscv-isa-sim -- verify it stores all four 16-bit dim fields +from rs1.) + +Remaining work to thread it: (1) carry a dynamic per-axis transfer extent on the +customized `memref.dma_start` (a length operand, like the existing dynamic indices); +(2) `lower_dma_to_gemmini` builds `cfg_rs1` from those (constant when static, arith when +dynamic); (3) the scheduling layer computes `affine.min(T, E - i*T)` for unaligned dims +and passes it. + +**Q2: where is the transfer shape threaded? RESOLVED -> pass shape as an operand, +computed separately.** Compute the real (boundary-clamped) per-axis transfer extent in +a separate step (the scheduling layer, via `affine.min(T, E - i*T)` / arith) and pass +it to the customized `memref.dma_start` as an explicit **shape operand** (alongside the +existing dynamic index operands). `lower_dma_to_gemmini` then packs `cfg_rs1` from that +operand (constant-folds when static). Keeps the extent an explicit, separately-computed +value rather than implicit in the descriptor's static memref shape. + +**Q3: timing-mode skip-copy equivalent.** Padding iterations must cost cycles but move +no real data: loop = padded, DMA size = real-tail (the Q2 shape operand). Confirm +`get_mask` covers the compute side and the shape operand covers the DMA side. + +**Q4: conv2d -- NOT a special case under (c).** loop-padding has a conv-specific branch +(it walks the nested `affine.apply` over k_h/k_w/o_h/o_w and rewrites the maps + +`dram_stride`). That complexity exists because loop-padding **grows the DRAM buffer and +rewrites addressing** -- and conv's input address is a nested affine +(`input_row = o_h*stride + k_h - pad`), so growing the buffer forces rewriting those +nested maps. + +Under (c) we do neither: we keep DRAM logical and only **clamp the boundary tile's DMA +extent**, leaving the address computation untouched. A DMA is a rectangular DRAM<->SRAM +block transfer (base address + per-axis extents); "don't read past the DRAM end" is a +per-axis `min(tile, remaining)` regardless of gemm vs conv. The conv composite indexing +affects *where* the block starts (address) and *how* compute consumes it (handled by +`get_mask`), not *how many* elements move (the per-axis extent clamp). So conv reduces to +the same flat per-axis clamp as gemm; loop-padding's conv branch has no counterpart here. +(Residual check, not a design fork: conv tiles overlap (halo from stride/kernel) so input +tiles are not a clean DRAM partition -- confirm the per-axis "remaining" is still just +`E - base_on_that_axis`, which it is, since each tile's address is independent.) + +**Q5: incremental path.** Start by deleting loop-padding for kernels where no dim +needs padding (it is already a no-op there) and confirm zero diff; then handle the +real padding cases, possibly behind a flag, validating each. + +## Validation + +e2e suite (gemm/bmm/conv2d with deliberately non-multiple extents, + models), +`allclose` through Gem5+Spike+TOGSim. Where feasible, structurally compare the emitted +MLIR against the current `-test-loop-padding` output on the same kernels. + +## Relation to prior work + +Same move as floor/mod -> axis-split + graph-copy: handle the misalignment upstream so +the MLIR layer sees only the clean (aligned) case and the pass disappears. This is the +padding instance of that principle (Plan A in `dma-transfer-lowering.md`). diff --git a/docs/mlir-python-bindings.md b/docs/mlir-python-bindings.md new file mode 100644 index 00000000..6bb03339 --- /dev/null +++ b/docs/mlir-python-bindings.md @@ -0,0 +1,102 @@ +# Enabling MLIR Python bindings + +Goal: ship the MLIR Python bindings (`import mlir`, `mlir.ir`, `mlir.dialects`) +so we can write MLIR passes in Python (imperative IR rewriting via the bindings) +instead of only C++ passes in the `PSAL-POSTECH/llvm-project` fork. See +`dma-transfer-lowering.md` for the first intended use (a Python decompose pass). + +## How LLVM reaches the runtime (why this touches 3 places) + +``` +PSAL-POSTECH/llvm-project (fork, tag vX.Y.Z) + .github/workflows/build-torchsim.yaml -- CI builds + releases riscv-llvm-release.tar.gz + | (release asset) + v +thirdparty/github-releases.json -- pins llvm_project.release_tag + asset + | + v +Dockerfile.base -- downloads asset, extracts to /riscv-llvm, + sets TORCHSIM_LLVM_PATH (+ now PYTHONPATH) +``` + +`scripts/build_from_source.sh` is the alternative source-build path (not the +normal flow, but kept consistent). + +## The one real blocker: Python ABI must match + +The bindings are a native CPython extension (`_mlir.cpython-3XX-*.so`). They only +import under the **same Python minor version** they were built against. The +runtime base image uses **conda Python 3.11**. So the artifact must be built with +**Python 3.11**. Building with the build container's default (ubuntu-22.04 -> +3.10) produces bindings that fail to import at runtime with a confusing error +much later -- hence the fail-fast guard in the CI step. + +Patch version (3.11.x) does not matter; minor version (3.11 vs 3.10) does. + +## What was changed + +- **`scripts/build_from_source.sh`**: cmake gets + `-DMLIR_ENABLE_BINDINGS_PYTHON=ON -DPython3_EXECUTABLE=$(command -v python3)`; + build deps (nanobind/pybind11/numpy/PyYAML) pip-installed; after `make install` + the build-tree `tools/mlir/python_packages` is copied into `/riscv-llvm` + (install does not place it there). PYTHONPATH exported for the current shell. +- **`Dockerfile.base`**: `ENV PYTHONPATH=/riscv-llvm/python_packages/mlir_core:$PYTHONPATH` + after the LLVM artifact is extracted. +- **`llvm-project/.github/workflows/build-torchsim.yaml`** (fork): same cmake + flags + deps; copies `python_packages` into the `riscv-llvm` tree so the + existing `tar` includes it; fail-fast guard requiring `python3.11`. + +## Rollout sequence (must be done in order) + +1. **python3.11 in the build container: done, non-root.** The CI step keeps the + original `-u $(id -u):$(id -g)` (no root assumed) and fetches a standalone + CPython 3.11 with `uv` (`uv venv --python 3.11`), then points + `Python3_EXECUTABLE` at that venv. No apt / no root needed. ubuntu-22.04's + default 3.10 is not used for the bindings. + - ABI note: extensions built against a uv/python-build-standalone CPython 3.11 + are expected to import under the runtime conda CPython 3.11 (same minor + version, standard builds are C-ABI compatible). The verify step below is the + check; if it ever fails, build instead in the runtime image (`python:3.11` or + the pytorch base) so build Python == runtime Python by construction. +2. **Push the fork changes** to `PSAL-POSTECH/llvm-project` and cut a new tag + (e.g. `v1.0.9`). CI builds `riscv-llvm-release.tar.gz` now containing + `python_packages/`. +3. **Bump `thirdparty/github-releases.json`** -> `llvm_project.release_tag` to the + new tag (and `asset_name` unchanged). This triggers a new base image build. +4. **Rebuild the base image** (the fork CI already dispatches `build_base`; or run + the PyTorchSim docker-image workflow) so `Dockerfile.base` produces an image + with the bindings + PYTHONPATH. + +## Verify + +Inside the rebuilt container (or after `build_from_source.sh`): + +```bash +python -c "import mlir; print(mlir.__file__)" # -> /riscv-llvm/python_packages/mlir_core/mlir/__init__.py +python -c "from mlir.ir import Context; c=Context(); c.allow_unregistered_dialects=True; print('ok')" +python -c "from mlir.dialects import scf, affine, arith; print('dialects ok')" +``` + +`allow_unregistered_dialects=True` is what lets us read/write the custom ops +(`togsim.transfer`, the customized `memref.dma_start`) generically without +registering a dialect in the bindings. + +## Notes / gotchas + +- Keep the bindings statically linked (default, i.e. do NOT add + `-DBUILD_SHARED_LIBS=ON` / `-DLLVM_BUILD_LLVM_DYLIB=ON`); otherwise the `.so` + needs libMLIR/libLLVM at runtime and the artifact + LD_LIBRARY_PATH grow. +- Worktrees: add the same `PYTHONPATH` line to the worktree `.envrc` (see + `docs/worktrees.md`) if a worktree overrides paths. +- The bindings are an additive, optional dependency: text emission + C++ passes + keep working unchanged. Only new Python passes require the bindings present. +- This LLVM fork's MLIR bindings use **pybind11** (not nanobind) and require + **pybind11 <= 2.10.3**: newer pybind11 (3.x) fails to compile `IRCore.cpp` with + `def_property family does not currently support keep_alive`. Pin it + (`pybind11>=2.9.0,<=2.10.3`). See `mlir/python/requirements.txt` for the fork's + pins. pybind11 is build-time only; the runtime needs just the built `.so` + numpy. +- numpy: the fork's requirements pin `<=1.26`, but a local build against numpy 2.x + compiled and imported fine, so we keep numpy at the runtime version (2.x) to + avoid a numpy-1-built / numpy-2-runtime ABI mismatch. (Validated locally: + conda 3.11 + pybind11 2.10.3 + numpy 2.x -> `import mlir` and parsing a custom + `togsim.transfer` op with floordiv/mod affine maps both work.) diff --git a/docs/tpu_layout_padding_report.md b/docs/tpu_layout_padding_report.md new file mode 100644 index 00000000..68cbacbb --- /dev/null +++ b/docs/tpu_layout_padding_report.md @@ -0,0 +1,182 @@ +# TPU Layout Assignment & Padding 메커니즘 조사 보고서 + +> **목적**: PyTorchSim에서 TPU 워크로드의 메모리 footprint / compute utilization을 정확히 모델링하기 위해, XLA/Mosaic 컴파일 파이프라인에서 (1) 어느 축이 lane/sublane으로 선택되는지, (2) 패딩이 언제·어떻게 일어나는지, (3) 그 패딩이 물리적으로 물질화되는지 마스킹으로 처리되는지를 정리한 핸드오프 문서. +> +> **수신자**: PyTorchSim 모델링/구현 담당 agent +> **작성 기준일**: 2026-06-18 +> **신뢰도 표기**: [확정]=공개 문서/소스로 검증됨, [추론]=문서 기반 합리적 추론, [미확인]=공개 자료로 닿지 못함 + +--- + +## 0. 한 줄 요약 + +TPU에서 **lane(128) 축 선택과 패딩은 XLA의 layout assignment pass에서 동시에 결정**되며, 패딩은 두 층위로 나뉜다: **(A) 8×128 레이아웃 정렬 패딩은 주소 정렬상 강제 물질화**(실제 텐서가 HBM에서 커짐), **(B) 그보다 큰 연산 블록 크기의 경계(tail)는 masking/peeling 등으로 처리**(대체로 비물질화). 모델링 시 (A)는 footprint+traffic, (B)는 compute utilization만 반영해야 한다. + +--- + +## 1. 컴파일 파이프라인 순서 [확정] + +``` +프론트엔드(JAX/PyTorch) + → JAXPR/FX + → HLO (DotGeneral 등, layout 미확정) + → [layout assignment] ← lane/sublane 축 + 패딩 결정 + → [fusion] ← op들을 커널로 묶음 + → LLO (TPU-specific IR) + → VLIW bundles +``` + +- 출처: JAX→VLIW 컴파일러 추적 (patricktoulme.substack.com), OpenXLA 공식 문서. +- 핵심 함의: **layout 결정이 fusion보다 먼저**다. 따라서 "어느 축이 lane인가 / 얼마나 패딩되는가"는 fusion을 몰라도 결정 가능하지만, **실제 pad/relayout op의 삽입(물질화)은 fusion이 보이는 단계(LLO)에서** 해야 minimal하게 된다. +- PyTorchSim 관점: 이 분리를 그대로 따를 것. FX/상위 단계에서는 layout **결정**(메타데이터)만, 실제 padding/relayout **물질화**는 하위 단계에서. + +--- + +## 2. lane/sublane 축 선택 메커니즘 [확정] + +### 2.1 결정 시점과 표현 +layout assignment pass에서 HLO 텐서에 `{minor_to_major : T(tile)}` 어노테이션이 부착되는 순간 확정. + +관측된 예시 (layout assignment 전후): +``` +// Before +%dot.10 = f32[16,64]{1,0} dot(...) +%reduce.16 = f32[16]{0} reduce(...) +// After +%dot.10 = f32[16,64]{1,0:T(8,128)} dot(...) +%reduce.16 = f32[16]{0:T(128)} reduce(...) +``` +- `{1,0}` = minor_to_major 순서 (row-major: 마지막 차원이 메모리 연속). +- `:T(8,128)` = TPU 타일링, VPU의 **8 sublane × 128 lane**에 대응. + +### 2.2 규칙 +- **minor_to_major 리스트의 첫 원소(가장 minor한 차원) = lane(128) 방향**, 그 다음 = sublane(8) 방향. +- 타일링은 **항상 most-minor 두 축에만** 적용. 나머지 major 차원은 타일링 없이 그대로 (rank 무관). +- 누가 결정하나: **matmul/dot이 anchor로 layout 강제** → elementwise는 통과 → reduce는 cross-lane 비용 때문에 lane에서 빠지려는 압력 → graph 위로 전파(propagation). + +### 2.3 검증 방법 (PyTorchSim ground truth) +```bash +XLA_FLAGS="--xla_dump_to=/path --xla_dump_hlo_as_text=true" python model.py +``` +덤프된 HLO의 `:T(...)` 어노테이션으로 실제 lane 축/패딩을 추측 아닌 컴파일러 출력으로 확인 가능. +(주의: `--xla_enable_hlo_passes_only=layout-assignment` 단독은 후속 buffer assignment에서 에러 가능 → 전체 덤프 권장. 출처: openxla/xla issue #12850) + +--- + +## 3. 타일 크기 규칙 [확정] + +| 조건 | 타일 | 비고 | +|---|---|---| +| f32, 일반 | `T(8,128)` | 32-bit 8×128 벡터 레지스터에 대응 | +| bf16 | `T(8,128)(2,1)` | 2단계 타일링 = BF16 packing. 짝/홀수 행 16-bit 둘을 묶어 32-bit 하나로 | +| 2nd-minor 차원 = 1 or 2 | `T(2,128)` | "Compact 2nd Minor Layout" — 메모리 절약 | +| 2nd-minor 차원 = 3 or 4 | `T(4,128)` | 동일 목적 | + +- **중요 정정**: sublane 패딩이 항상 8은 아니다. 작은 2nd-minor 차원이면 2 또는 4로 줄어든다. + - 함의: **LLM 디코딩 token=1의 sublane 패딩은 8배가 아니라 2배** (`T(2,128)`). +- bf16 packing 이유: TPU는 32-bit 네이티브. most-minor보다 2nd-minor 가로지르는 데이터 이동이 효율적이라 같은 column에서 16-bit 둘을 모음. +- 출처: OpenXLA tiled_layout 문서, gdymind 블로그. + +### 3.1 MXU 크기 (세대별, 연산 단위 — 레이아웃 타일과 별개) [확정] +- v6e, TPU7x(Ironwood): **256×256** +- v6e 이전: **128×128** +- peak FLOPs 위해선 matmul 차원이 해당 세대 MXU 크기보다 커야 함. +- ⚠️ 이건 *연산* 단위. *메모리 레이아웃* 타일은 세대 무관 8×128 유지. 둘을 섞지 말 것. +- 출처: Google Cloud TPU performance guide. + +--- + +## 4. 패딩 처리: 두 층위 [핵심 — 확정] + +### 4.1 (A) 8×128 레이아웃 정렬 패딩 = 강제 물질화 +- **실제 텐서가 HBM에서 패딩된 크기로 저장됨.** 회피 불가. +- 이유: 타일이 HBM에 연속으로 깔리려면 각 타일이 꽉 찬 8×128이어야 함 → 주소 정렬(address alignment) 필수. +- MLIR 코드생성 일반론에서 이는 `nofold` 패딩에 해당: "address alignment가 필수인 경우 강제로 패딩". value padding이 불필요해 보여도 정렬 때문에 fold되지 않고 물질화됨. +- Google 공식 확인: 128×8 청크를 못 채우면 XLA가 텐서를 패딩하고, 이는 "on-chip 메모리 저장량을 늘리고 OOM 유발 가능" = 물리적 공간 점유. +- 패딩량 = `⌈d/tile⌉ × tile − d` (lane/sublane 각각). +- 출처: OpenXLA, Google Cloud performance guide, MLIR codegen 논문(arxiv 2202.03293). + +### 4.2 (B) 연산 블록 크기(>8×128) 경계 = tail 처리 +컴파일러가 비용 보고 세 전략 중 선택 (MLIR codegen 논문 §3.2): + +1. **Loop peeling / versioning** (비물질화): main loop는 정적 상수 부분, 경계는 cleanup loop(타일=1로 축소). 텐서 안 키움. +2. **실제 패딩** (물질화): 동적 타일을 정적 크기로 패딩, 값은 소비 연산의 **neutral(항등원)** — matmul이면 0. `tensor.pad` op으로 물질화, 크기 = 정적 타일 − 동적 타일. 추가 복사 비용 발생. +3. **명시적 masking** (비물질화): MXU가 0 포함 전체 계산 후 출력에서 마스킹. "MXU엔 '하지 마라' 신호가 없어 패딩 0을 실데이터와 곱하고 출력에서 마스킹 — 정확하나 느림(버려지는 work에 MXU 비용 지불)". + +- 기본은 비물질화(masking/peeling)가 흔함. pad 물질화는 정렬 필수 케이스. +- 출처: MLIR codegen 논문, Pallas matmul 튜토리얼(neuropurrfectai). + +### 4.3 두 층위의 관계 [추론 — 합리적] +- 연산 블록 tail이 8×128 정렬과도 안 맞으면: 이미 물질화된 레이아웃 패딩(A)을 그대로 읽고, 블록 크기와 정렬 사이의 간격(B의 순수분)만 masking/peeling. +- 즉 "레이아웃 패딩은 물리적으로 이미 존재, 그걸 넘어서는 블록 경계분만 tail 전략"으로 두 층이 포개짐. + +--- + +## 5. TPU 특유 제약: tiled 축 경계가 비싼 이유 [확정] + +- Ragged Paged Attention 논문: 논리/물리 레이아웃 불일치 + narrow dtype packing 때문에 **tiled 차원(특히 lane)에서 임의 메모리 슬라이스가 근본적으로 어렵다** (VREG blending 없이 ragged 입력을 메모리에 직접 쓰는 경우 특히). +- 실전 해법: **ragged/동적 차원을 non-tiled(major) 축에 배치**, packing을 2nd-minor에 삽입해 XLA가 최소 타일 `T(packing,128)`을 쓰도록 강제 → 임의 동적 슬라이스 가능. +- Pallas `pl.BoundedSlice` / `pl.ds`로 동적 크기 청크 처리 가능 (패딩 대신 정확 크기). +- 함의: 컴파일러/커널이 작은·동적 축을 lane에서 빼려 애쓰는 이유가 정량적으로 설명됨. +- 출처: Ragged Paged Attention (arxiv 2604.15464), JAX Pallas pipelining 문서. + +--- + +## 6. LLM 디코딩 특이사항 [확정 + 추론] + +- 디코딩 GEMV(`[1,hidden]×[hidden,out]`)에서 token=1 차원: + - 보통 hidden이 lane(128)에, token=1은 sublane으로 → **sublane 2배 패딩**(`T(2,128)`), lane까지 가지 않음. [추론, §3 규칙 기반] +- **activation 패딩 traffic은 디코딩에서 무시 가능**: activation(`[1,hidden]`, 수 KB)은 weight/KV cache(수십~수백 GB)보다 3~5 자릿수 작음. 8배든 2배든 전체 traffic에서 미미. [확정 — 디코딩 memory-bound 특성] +- 디코딩 traffic 모델은 **weight 전체 재읽기 + KV cache 읽기에 집중**할 것. activation 패딩 오차의 영향은 작음. +- 패딩의 실질 페널티는 traffic이 아니라 MXU utilization 저하인데, memory-bound regime에선 wall-clock 비결정적. +- → self-spec decoding 연구 동기(한 번 읽은 weight당 토큰 더 뽑기)와 직결. + +--- + +## 7. PyTorchSim 모델링 권고 [실행 항목] + +### 7.1 두 비용 함수를 분리하라 (가장 중요) +- **footprint / HBM traffic 함수**: (A) 8×128 레이아웃 정렬 패딩만 물리 크기로 계산. 예: 길이 100 → 128로 저장·전송. bf16 packing, small-tile(2/4×128) 변형 반영. +- **compute utilization 함수**: (B) 연산 블록 경계 패딩 처리. + - 기본 masking: MXU 사이클 낭비로 utilization ↓, **traffic 중복 계상 금지**. + - pad 물질화 케이스(정렬 필수)만 추가 복사 traffic. + - peeling 케이스는 작은 cleanup 커널로 별도. +- **함정**: 연산 경계 패딩을 traffic으로 또 더하면 대역폭 과대평가. 물리 패딩(A)만 traffic+compute, 연산 경계(B)는 compute만. + +### 7.2 layout 결정 로직 +- minor_to_major 축 선택 + 타일 크기 선택을 §2~§3 규칙으로 모델링. +- matmul anchor → 전파 → reduce는 lane에서 빼기 선호. +- 세대별 MXU 크기(128 vs 256)를 파라미터화. + +### 7.3 비대칭 반영 +- tiled 축(lane) 경계 처리 비용 > non-tiled 축. 이 비대칭을 넣으면 실제 커널이 동적 축을 major로 미는 동작 재현됨(§5). + +### 7.4 검증 +- §2.3의 `XLA_FLAGS` 덤프로 실제 `:T(...)` 어노테이션 떠서 모델 출력과 대조. +- 가능하면 LLO 덤프까지 떠서 경계 타일이 `tensor.pad`(물질화) vs masking/peeling 중 무엇으로 처리되는지 확인. + +--- + +## 8. 미확인 / 후속 조사 필요 [미확인] + +1. **Mosaic의 tail 전략 선택 휴리스틱**: peeling vs pad vs masking을 언제 고르는지의 내부 규칙. Mosaic이 상당 부분 비공개(Google 내부 컴파일러)라 공개 자료로 닿지 못함. → LLO 덤프 실측이 유일한 확실한 길. +2. **연산 블록 경계 0의 정확한 주입 위치**: VREG / VMEM / MXU 입력 중 어디서 0이 주입되고 VMEM 점유에 잡히는지. → Mosaic 소스 또는 LLO 덤프 필요. +3. **layout_assignment.cc의 minor_to_major 선택 휴리스틱 코드 레벨**: matmul이 정확히 어떤 layout을 강제하고 어떻게 전파하는지. OpenXLA 공개 소스(layout_assignment.cc, instruction_fusion.cc)에서 추적 가능 — 아직 코드 레벨로 파지 않음. +4. **DMA의 strided/partial transfer가 레이아웃 패딩을 정확히 어떻게 처리하는지** (세대별): v4부터 512B granularity striding 지원은 확인됨. 패딩 영역 전송 회피 가능 여부의 정밀 동작은 세대·XLA 버전 의존. + +--- + +## 9. 출처 목록 + +- OpenXLA — Tiled layout: https://openxla.org/xla/tiled_layout +- OpenXLA — Shapes and layout: https://openxla.org/xla/shapes +- Google Cloud — TPU performance guide: https://docs.cloud.google.com/tpu/docs/performance-guide +- Google Cloud — Intro to Cloud TPU: https://docs.cloud.google.com/tpu/docs/intro-to-tpu +- Google Cloud — TPU v4: https://docs.cloud.google.com/tpu/docs/v4 +- From JAX to VLIW (컴파일러 추적, layout assignment 전후 HLO): https://patricktoulme.substack.com/p/from-jax-to-vliw-tracing-a-computation +- Pallas matmul 튜토리얼 (MXU 마스킹, tail 패딩): https://neuropurrfectai.substack.com/p/part-2-your-first-pallas-kernel-tiled +- Composable/Modular Code Generation in MLIR (경계 타일 3전략, nofold): https://arxiv.org/pdf/2202.03293 +- Ragged Paged Attention (tiled 축 슬라이스 제약): https://arxiv.org/html/2604.15464 +- JAX — TPU pipelining (동적 슬라이스): https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html +- openxla/xla issue #12850 (pass 덤프 방법): https://github.com/openxla/xla/issues/12850 +- gdymind 블로그 (small-tile, packing 정리): https://gdymind.com/2026/02/26/XLA02-shapes-layout-tiling/ diff --git a/scripts/build_from_source.sh b/scripts/build_from_source.sh index 4e7ff604..f23eab82 100644 --- a/scripts/build_from_source.sh +++ b/scripts/build_from_source.sh @@ -45,12 +45,23 @@ export GEM5_PATH="$home/gem5/build/RISCV/gem5.opt" cd "$home" # LLVM + MLIR (RISCV target) +# MLIR Python bindings are enabled so Python-side MLIR passes can run. The +# bindings are a native extension: they MUST be built against the same Python +# that runs PyTorchSim at runtime (the conda 3.11 here) or `import mlir` will +# fail with an ABI mismatch. nanobind/pybind11/numpy/PyYAML are build-time deps. +python3 -m pip install --user "pybind11>=2.9.0,<=2.10.3" numpy PyYAML git clone --depth 1 --branch "$LLVM_TAG" "https://github.com/${LLVM_REPO}.git" cd llvm-project && mkdir -p build && cd build && \ cmake -DLLVM_ENABLE_PROJECTS=mlir -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/riscv-llvm -DLLVM_TARGETS_TO_BUILD=RISCV \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE="$(command -v python3)" \ -G "Unix Makefiles" ../llvm && \ - make -j && make install + make -j && make install && \ + rm -rf /riscv-llvm/python_packages && \ + cp -r tools/mlir/python_packages /riscv-llvm/python_packages +# Make the bindings importable in this shell (also set in .envrc / Dockerfile.base) +export PYTHONPATH="/riscv-llvm/python_packages/mlir_core:$PYTHONPATH" cd "$home" # Spike Simulator diff --git a/scripts/op_coverage.py b/scripts/op_coverage.py new file mode 100644 index 00000000..1f4567b6 --- /dev/null +++ b/scripts/op_coverage.py @@ -0,0 +1,540 @@ +"""Op-coverage diagnostic for new LLM models on PyTorchSim. + +Runs each model in two phases: + Phase 1 (enumerate): custom torch.compile backend captures the FX graph and + lists every aten op that appears, without touching NPU. + Phase 2 (run): torch.compile(model) on npu:0, real forward. On crash, + parses the traceback to identify the failing op. + +Usage: + python scripts/op_coverage.py # all models + python scripts/op_coverage.py --models qwen2 # subset + python scripts/op_coverage.py --enumerate-only # skip NPU compile (fast) +""" + +import argparse +import datetime as _dt +import os +import re +import sys +import traceback +from contextlib import contextmanager + +import torch + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + + +# --------------------------------------------------------------------------- +# Model registry: each entry returns (model, kwargs_for_forward) on CPU. +# Sizes follow "small but realistic" variants (1-layer) so a forward is cheap +# enough to actually drive through TOGSim. +# --------------------------------------------------------------------------- + +def _causal_mask(batch, seq_len, dtype): + min_v = torch.finfo(dtype).min + m = torch.full((seq_len, seq_len), min_v, dtype=dtype) + if seq_len > 1: + m = torch.triu(m, diagonal=1) + return m[None, None, :, :].expand(batch, 1, -1, -1).contiguous() + + +def build_qwen2(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + cfg = Qwen2Config( + vocab_size=4096, + hidden_size=1536, + num_attention_heads=12, + num_key_value_heads=2, + intermediate_size=8960, + num_hidden_layers=2, + max_position_embeddings=4096, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + torch_dtype=dtype, + use_cache=False, + _attn_implementation="eager", + ) + model = Qwen2Model(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def build_gemma(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.gemma.configuration_gemma import GemmaConfig + from transformers.models.gemma.modeling_gemma import GemmaModel + cfg = GemmaConfig( + vocab_size=4096, + hidden_size=2048, + num_attention_heads=8, + num_key_value_heads=1, + intermediate_size=16384, + num_hidden_layers=2, + head_dim=256, + max_position_embeddings=4096, + rms_norm_eps=1e-6, + rope_theta=10000.0, + torch_dtype=dtype, + use_cache=False, + _attn_implementation="eager", + ) + model = GemmaModel(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def build_gemma2(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.gemma2.configuration_gemma2 import Gemma2Config + from transformers.models.gemma2.modeling_gemma2 import Gemma2Model + cfg = Gemma2Config( + vocab_size=4096, + hidden_size=2304, + num_attention_heads=8, + num_key_value_heads=4, + intermediate_size=9216, + num_hidden_layers=2, + head_dim=256, + max_position_embeddings=4096, + rms_norm_eps=1e-6, + rope_theta=10000.0, + torch_dtype=dtype, + use_cache=False, + attn_logit_softcapping=50.0, + final_logit_softcapping=30.0, + sliding_window=16, + _attn_implementation="eager", + ) + model = Gemma2Model(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def build_phi3(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.phi3.configuration_phi3 import Phi3Config + from transformers.models.phi3.modeling_phi3 import Phi3Model + cfg = Phi3Config( + vocab_size=4096, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=32, + intermediate_size=8192, + num_hidden_layers=2, + max_position_embeddings=4096, + rms_norm_eps=1e-5, + rope_theta=10000.0, + torch_dtype=dtype, + use_cache=False, + _attn_implementation="eager", + ) + model = Phi3Model(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def _build_lm(cfg, ModelCls, batch, seq_len, dtype): + """Shared helper: build a causal-LM-style model and matching token+mask inputs.""" + model = ModelCls(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def build_qwen3(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.qwen3.configuration_qwen3 import Qwen3Config + from transformers.models.qwen3.modeling_qwen3 import Qwen3Model + cfg = Qwen3Config( + vocab_size=4096, hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-6, rope_theta=1000000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, Qwen3Model, batch, seq_len, dtype) + + +def build_qwen3_moe(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel + cfg = Qwen3MoeConfig( + vocab_size=4096, hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, moe_intermediate_size=768, num_experts=4, num_experts_per_tok=2, + decoder_sparse_step=1, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-6, rope_theta=1000000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, Qwen3MoeModel, batch, seq_len, dtype) + + +def build_gemma3(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel + cfg = Gemma3TextConfig( + vocab_size=4096, hidden_size=2048, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=8192, head_dim=256, num_hidden_layers=2, + sliding_window=16, sliding_window_pattern=2, + max_position_embeddings=4096, rms_norm_eps=1e-6, rope_theta=10000.0, + torch_dtype=dtype, use_cache=False, _attn_implementation="eager", + ) + return _build_lm(cfg, Gemma3TextModel, batch, seq_len, dtype) + + +def build_deepseek_v3(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3Model + cfg = DeepseekV3Config( + vocab_size=4096, hidden_size=1024, num_attention_heads=16, num_key_value_heads=16, + intermediate_size=4096, moe_intermediate_size=512, + n_routed_experts=8, num_experts_per_tok=2, n_shared_experts=1, + n_group=2, topk_group=1, + q_lora_rank=512, kv_lora_rank=128, qk_rope_head_dim=32, qk_nope_head_dim=32, v_head_dim=64, + num_hidden_layers=2, first_k_dense_replace=1, + max_position_embeddings=4096, rms_norm_eps=1e-6, rope_theta=10000.0, + torch_dtype=dtype, use_cache=False, _attn_implementation="eager", + ) + return _build_lm(cfg, DeepseekV3Model, batch, seq_len, dtype) + + +def build_llama4(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.llama4.configuration_llama4 import Llama4TextConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextModel + cfg = Llama4TextConfig( + vocab_size=4096, hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, intermediate_size_mlp=3072, + num_local_experts=4, num_experts_per_tok=1, num_hidden_layers=2, interleave_moe_layer_step=2, + max_position_embeddings=4096, rms_norm_eps=1e-6, rope_theta=10000.0, + torch_dtype=dtype, use_cache=False, _attn_implementation="eager", + ) + return _build_lm(cfg, Llama4TextModel, batch, seq_len, dtype) + + +def build_glm4(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.glm4.configuration_glm4 import Glm4Config + from transformers.models.glm4.modeling_glm4 import Glm4Model + cfg = Glm4Config( + vocab_size=4096, pad_token_id=0, bos_token_id=1, eos_token_id=2, + hidden_size=1536, num_attention_heads=12, num_key_value_heads=2, + intermediate_size=4096, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-5, rope_theta=10000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, Glm4Model, batch, seq_len, dtype) + + +def build_olmo2(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.olmo2.configuration_olmo2 import Olmo2Config + from transformers.models.olmo2.modeling_olmo2 import Olmo2Model + cfg = Olmo2Config( + vocab_size=4096, hidden_size=2048, num_attention_heads=16, num_key_value_heads=16, + intermediate_size=8192, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-6, rope_theta=10000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, Olmo2Model, batch, seq_len, dtype) + + +def build_granite(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.granite.configuration_granite import GraniteConfig + from transformers.models.granite.modeling_granite import GraniteModel + cfg = GraniteConfig( + vocab_size=4096, hidden_size=2048, num_attention_heads=16, num_key_value_heads=8, + intermediate_size=8192, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-5, rope_theta=10000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, GraniteModel, batch, seq_len, dtype) + + +def build_phimoe(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.phimoe.configuration_phimoe import PhimoeConfig + from transformers.models.phimoe.modeling_phimoe import PhimoeModel + cfg = PhimoeConfig( + vocab_size=4096, hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, num_local_experts=4, num_experts_per_tok=2, + num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-5, rope_theta=10000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, PhimoeModel, batch, seq_len, dtype) + + +def build_mamba2(batch=1, seq_len=32, dtype=torch.float32): + # State-space model: no attention, no RoPE -- completely different op profile. + # Invariant: num_heads * head_dim == intermediate_size == expand * hidden_size + # (modeling_mamba2.py:171 + the view(B, num_heads*head_dim) at line 365). + from transformers.models.mamba2.configuration_mamba2 import Mamba2Config + from transformers.models.mamba2.modeling_mamba2 import Mamba2Model + cfg = Mamba2Config( + vocab_size=4096, hidden_size=512, + num_heads=16, head_dim=64, + state_size=16, chunk_size=16, + expand=2, n_groups=1, + num_hidden_layers=2, torch_dtype=dtype, use_cache=False, + ) + model = Mamba2Model(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + # Mamba has no attention mask; pass none. + return model, {"input_ids": input_ids} + + +def build_mllama(batch=1, seq_len=32, dtype=torch.float32): + # Llama 3.2 Vision -- text branch only (text-only call path). + # MllamaRotaryEmbedding requires config.rope_scaling["rope_type"]; pass default. + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaTextModel + cfg = MllamaTextConfig( + vocab_size=4096, pad_token_id=0, bos_token_id=1, eos_token_id=2, + hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, num_hidden_layers=2, + cross_attention_layers=[], + max_position_embeddings=4096, rms_norm_eps=1e-5, rope_theta=10000.0, + rope_scaling={"rope_type": "default"}, + torch_dtype=dtype, use_cache=False, _attn_implementation="eager", + ) + return _build_lm(cfg, MllamaTextModel, batch, seq_len, dtype) + + +BUILDERS = { + "qwen2": build_qwen2, + "gemma": build_gemma, + "gemma2": build_gemma2, + "phi3": build_phi3, + # Models newly available with transformers 4.51.3 + "qwen3": build_qwen3, + "qwen3_moe": build_qwen3_moe, + "gemma3": build_gemma3, + "deepseek_v3": build_deepseek_v3, + "llama4": build_llama4, + "glm4": build_glm4, + "olmo2": build_olmo2, + "granite": build_granite, + "phimoe": build_phimoe, + "mamba2": build_mamba2, + "mllama": build_mllama, +} + + +# --------------------------------------------------------------------------- +# Phase 1: enumerate aten ops by intercepting the FX graph from torch.compile. +# --------------------------------------------------------------------------- + +def _node_op_name(target): + # OpOverload / OpOverloadPacket: has a .name() method returning "aten::mm.default" etc. + if hasattr(target, "name") and callable(target.name): + try: + return target.name() + except Exception: + pass + if hasattr(target, "_schema"): + try: + return str(target._schema.name) + ( + "." + target._schema.overload_name if target._schema.overload_name else "" + ) + except Exception: + pass + # torch.* python builtins: use their __module__/__qualname__ + mod = getattr(target, "__module__", "") + qn = getattr(target, "__qualname__", None) or getattr(target, "__name__", "") + if mod and qn: + return f"{mod}.{qn}" + return str(target) + + +@torch.no_grad() +def enumerate_ops(model, inputs): + """Capture the post-AOTAutograd aten graph(s) via aot_module_simplified. + + This is the same level of IR TOGSim/Inductor consumes, so the op set + matches what the NPU backend actually has to lower. + """ + from functorch.compile import aot_module_simplified + + seen = set() + graph_sizes = [] + + def fw_compiler(gm, example_inputs): + graph_sizes.append(sum(1 for _ in gm.graph.nodes)) + for node in gm.graph.nodes: + if node.op == "call_function": + seen.add(_node_op_name(node.target)) + return gm.forward + + def dynamo_backend(gm, example_inputs): + return aot_module_simplified(gm, example_inputs, fw_compiler=fw_compiler) + + torch._dynamo.reset() + compiled = torch.compile(model, backend=dynamo_backend, dynamic=False) + compiled(**inputs) + return sorted(seen), graph_sizes + + +# --------------------------------------------------------------------------- +# Phase 2: real NPU compile + run. Capture and parse failure tracebacks. +# --------------------------------------------------------------------------- + +ATEN_RE = re.compile(r"aten[.:][a-zA-Z_][a-zA-Z0-9_.]*") +NOTIMPL_RE = re.compile(r"NotImplementedError[: ]+(.*)") + + +def parse_failure(tb_text): + aten_hits = [] + for m in ATEN_RE.finditer(tb_text): + op = m.group(0).replace("aten:", "aten.").lstrip(".") + if op not in aten_hits: + aten_hits.append(op) + msg = "" + nm = NOTIMPL_RE.search(tb_text) + if nm: + msg = nm.group(1).strip().splitlines()[0] + return aten_hits, msg + + +@torch.no_grad() +def run_on_npu(model, inputs): + device = torch.device("npu:0") + model = model.to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + torch._dynamo.reset() + compiled = torch.compile(model, dynamic=False) + out = compiled(**inputs) + # touch the output to force completion + if hasattr(out, "last_hidden_state"): + out.last_hidden_state.cpu() + elif isinstance(out, torch.Tensor): + out.cpu() + return "OK", None, None + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + +def run_model(name, args, out_dir): + builder = BUILDERS[name] + log_path = os.path.join(out_dir, f"{name}.log") + with open(log_path, "w") as fh: + def w(s=""): + print(s) + fh.write(s + "\n") + + w(f"=== {name} ===") + w(f"batch={args.batch} seq_len={args.seq_len} dtype={args.dtype}") + + try: + model, inputs = builder(args.batch, args.seq_len, _DTYPE_MAP[args.dtype]) + except Exception as e: + w(f"[BUILD FAIL] {type(e).__name__}: {e}") + return {"name": name, "status": "BUILD_FAIL", "ops": [], "fail_op": str(e)} + + # Phase 1 + w("\n[Phase 1] FX op enumeration (eager backend, no NPU)") + try: + ops, graph_sizes = enumerate_ops(model, inputs) + w(f" graphs: {len(graph_sizes)} total_nodes_per_graph: {graph_sizes}") + w(f" unique aten ops: {len(ops)}") + for op in ops: + w(f" {op}") + except Exception: + tb = traceback.format_exc() + w("[Phase 1 FAIL]\n" + tb) + ops = [] + + if args.enumerate_only: + return {"name": name, "status": "ENUM_ONLY", "ops": ops, "fail_op": None} + + # Phase 2 + w("\n[Phase 2] torch.compile on npu:0 + forward") + try: + status, fail_op, msg = run_on_npu(model, inputs) + w(f" status: {status}") + return {"name": name, "status": status, "ops": ops, "fail_op": None} + except Exception: + tb = traceback.format_exc() + hits, msg = parse_failure(tb) + w(" status: FAIL") + if msg: + w(f" NotImplemented message: {msg}") + if hits: + w(f" aten ops in traceback (first = most likely culprit):") + for h in hits[:10]: + w(f" {h}") + w("\n----- traceback -----\n" + tb) + return { + "name": name, + "status": "FAIL", + "ops": ops, + "fail_op": hits[0] if hits else "?", + "msg": msg, + } + + +_DTYPE_MAP = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--models", nargs="+", default=list(BUILDERS.keys()), + choices=list(BUILDERS.keys())) + p.add_argument("--batch", type=int, default=1) + p.add_argument("--seq-len", type=int, default=32) + p.add_argument("--dtype", default="float32", choices=list(_DTYPE_MAP.keys())) + p.add_argument("--enumerate-only", action="store_true", + help="Skip NPU compile; just list aten ops per model (fast).") + p.add_argument("--out-dir", default=None) + args = p.parse_args() + + ts = _dt.datetime.now().strftime("%Y%m%d_%H%M%S") + out_dir = args.out_dir or os.path.join( + os.environ.get("TORCHSIM_LOG_PATH", os.path.join(REPO_ROOT, "togsim_results")), + "op_coverage", ts, + ) + os.makedirs(out_dir, exist_ok=True) + print(f"Output dir: {out_dir}") + + results = [] + for name in args.models: + try: + results.append(run_model(name, args, out_dir)) + except KeyboardInterrupt: + print(f"[interrupt] aborted during {name}") + break + except Exception: + traceback.print_exc() + results.append({"name": name, "status": "DRIVER_ERR", "ops": [], "fail_op": None}) + + # Summary + summary_path = os.path.join(out_dir, "summary.txt") + with open(summary_path, "w") as fh: + def w(s=""): + print(s) + fh.write(s + "\n") + w("\n========== SUMMARY ==========") + w(f"{'model':10s} {'ops':>5s} {'status':10s} first_fail") + for r in results: + w(f"{r['name']:10s} {len(r['ops']):>5d} {r['status']:10s} {r.get('fail_op') or '-'}") + # Union & overlap across models + all_ops = set() + for r in results: + all_ops.update(r["ops"]) + w(f"\nUnion of aten ops across all models: {len(all_ops)}") + w("Per-model op set diff (ops unique to this model):") + for r in results: + others = set().union(*(set(r2["ops"]) for r2 in results if r2 is not r)) + unique = sorted(set(r["ops"]) - others) + w(f" {r['name']}: {len(unique)} unique") + for op in unique: + w(f" {op}") + + print(f"\nWrote: {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/models/DeepSeek/test_deepseek_v3_base.py b/tests/models/DeepSeek/test_deepseek_v3_base.py index 2941e776..79160127 100644 --- a/tests/models/DeepSeek/test_deepseek_v3_base.py +++ b/tests/models/DeepSeek/test_deepseek_v3_base.py @@ -230,6 +230,11 @@ def run_deepseek_v3_base( config.quantization_config = None config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) + # Seed the global RNG so config-random weight init is deterministic. Without + # this every run builds a different network, so the worst-element NPU-vs-CPU + # error randomly crosses the (loose) allclose threshold and the test is flaky. + torch.manual_seed(0) + if init_mode == "config-random": model = AutoModelForCausalLM.from_config( config=config, diff --git a/tests/ops/view/test_floormod_axis_split.py b/tests/ops/view/test_floormod_axis_split.py new file mode 100644 index 00000000..19e32e69 --- /dev/null +++ b/tests/ops/view/test_floormod_axis_split.py @@ -0,0 +1,118 @@ +"""Floor/mod index handling: axis-split (aligned) + graph-copy (incompatible). + +Covers the index-expression shapes that view/reshape/tile/group ops produce and +how the frontend handles them: + + - aligned floor/mod (single iter var, divisor divides extent): removed by + axis-split at the scheduling layer. group_norm, repeat, repeat_interleave, + permute+reshape (mixed-radix). + - incompatible radices on a shared axis (case 5, e.g. a[c//2] + b[c%3]): the + conflicting operand is realized by graph-copy so the consumer reads it affine + and the remainder is axis-split's. + - cross-axis / multi-variable floor/mod argument (case 7, e.g. (3*p0+p1)//4 from + a transpose+reshape feeding a broadcast/softmax/layernorm that keeps the dims + separate): graph-copy materializes the multi-var operand with copy_input (which + forces a copy of a view, unlike realize()); the copy kernel iterates the + operand's own shape so its index collapses to single-var for axis-split. + +Both features are always on; graph-copy installs its lowering hook at import. + +Not in the CI allowlist (pytorchsim_test.yml) -- local feature/regression test. +""" +import os +import sys + +import torch +import torch.nn.functional as F + +sys.path.insert(0, os.path.join(os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim"), "tests")) +from _pytorchsim_utils import test_result + +from PyTorchSimFrontend.mlir import graph_copy +graph_copy.install() + + +def _run(device, name, fn, *inputs): + torch.manual_seed(0) + opt = torch.compile(dynamic=False)(fn) + res = opt(*[t.to(device=device) for t in inputs]) + ref = fn(*[t.cpu() for t in inputs]) + test_result(name, res, ref, rtol=1e-3, atol=1e-3) + + +# --- aligned floor/mod: handled by axis-split --------------------------------- +def test_group_norm(device): + _run(device, "group_norm c//(C/G)", lambda x: F.group_norm(x, 3), torch.randn(2, 6, 4, 4)) + + +def test_repeat(device): + # tile -> ModularIndexing(c, 1, n) + _run(device, "repeat (mod)", lambda x: x.repeat(1, 2) + 1.0, torch.randn(4, 8)) + + +def test_repeat_interleave(device): + # -> FloorDiv(c, k) + _run(device, "repeat_interleave (floor)", + lambda x: torch.repeat_interleave(x, 2, dim=1) + 1.0, torch.randn(2, 4, 8)) + + +def test_permute_reshape(device): + # permute+reshape -> single-var mixed-radix floor/mod + _run(device, "permute+reshape (mixed-radix)", + lambda x: x.permute(0, 2, 1).reshape(2, 12) + 1.0, torch.randn(2, 3, 4)) + + +def test_three_level_mixed_radix(device): + # reshape+permute+reshape -> chain [1,4,12,24]; the 3-level split leaves a + # residual FloorDiv that simplify_with_ranges cannot fold -> _fold_with_ranges. + _run(device, "3-level mixed-radix", + lambda x: x.reshape(2, 3, 2, 4).permute(0, 2, 1, 3).reshape(2, 24) + 1.0, + torch.randn(2, 6, 4)) + + +def test_pixel_shuffle(device): + # splits two spatial axes -> 5D logical tile; the decompose-transfer pass peels + # the outer dims into an affine.for nest with the lane-banked physical SRAM offset. + _run(device, "pixel_shuffle (>4D peel)", + lambda x: F.pixel_shuffle(x, 2) + 1.0, torch.randn(1, 8, 4, 4)) + + +# --- incompatible radices (case 5): handled by graph-copy --------------------- +def test_incompatible_radix(device): + # a[c//2] + b[c%3] on axis c=6 : floor-by-2 vs mod-by-3 (not a chain) + _run(device, "incompat a[c//2]+b[c%3]", + lambda a, b: torch.repeat_interleave(a, 2, dim=1) + b.repeat(1, 2), + torch.randn(2, 3), torch.randn(2, 3)) + + +# --- cross-axis multi-var floor/mod (case 7): handled by graph-copy copy_input - +def test_case7_reshape_broadcast(device): + # (3*p0+p1)//4 from transpose+reshape feeding an elementwise broadcast consumer + _run(device, "case7 reshape+broadcast", + lambda x, y: x.t().reshape(8, 3) + y, torch.randn(4, 6), torch.randn(8, 1)) + + +def test_case7_softmax_reshape(device): + # same multi-var floor feeding a reduction (softmax over the kept-separate dim) + _run(device, "case7 softmax(reshape)", + lambda x: F.softmax(x.t().reshape(8, 3), dim=1), torch.randn(4, 6)) + + +def test_case7_layernorm_reshape(device): + _run(device, "case7 layernorm(reshape)", + lambda x: F.layer_norm(x.t().reshape(8, 3), (3,)), torch.randn(4, 6)) + + +if __name__ == "__main__": + device = torch.device("npu:0") + with torch.no_grad(): + test_group_norm(device) + test_repeat(device) + test_repeat_interleave(device) + test_permute_reshape(device) + test_three_level_mixed_radix(device) + test_pixel_shuffle(device) + test_incompatible_radix(device) + test_case7_reshape_broadcast(device) + test_case7_softmax_reshape(device) + test_case7_layernorm_reshape(device) diff --git a/tests/test_mlir_bindings.py b/tests/test_mlir_bindings.py new file mode 100644 index 00000000..a0e5055d --- /dev/null +++ b/tests/test_mlir_bindings.py @@ -0,0 +1,56 @@ +"""Exercise the MLIR Python bindings the way a decompose-transfer pass would: +parse a custom op, read its AffineMap attr, build an scf.for loop with +affine.apply + an inner (unregistered) DMA op, erase the original, re-verify. +""" +from mlir.ir import (Context, Module, Location, InsertionPoint, Operation, + IndexType, IntegerAttr, AffineMap) +from mlir.dialects import scf, affine, arith, func, memref + +ctx = Context() +ctx.allow_unregistered_dialects = True + +with ctx, Location.unknown(): + src = ''' + func.func @kernel(%dram: memref<256x256xf16>, %sram: memref<128x128xf16, 1>) { + "togsim.transfer"(%dram, %sram) { + dma_kind = "MVIN", + src_map = affine_map<(d0, d1) -> (d0, d1 floordiv 16, d1 mod 16)> + } : (memref<256x256xf16>, memref<128x128xf16, 1>) -> () + return + } + ''' + m = Module.parse(src) + print("[1] parsed module ok") + + fn = m.body.operations[0] + blk = fn.regions[0].blocks[0] + transfer = next(op.operation for op in blk.operations + if op.operation.name == "togsim.transfer") + print("[2] found op:", transfer.name) + + src_map = transfer.attributes["src_map"] + print("[3] src_map attr:", src_map) + + idx = IndexType.get() + def cst(v): + return Operation.create("arith.constant", results=[idx], + attributes={"value": IntegerAttr.get(idx, v)}).result + + with InsertionPoint(transfer): + lb, ub, step = cst(0), cst(2), cst(1) + loop = scf.ForOp(lb, ub, step) + with InsertionPoint(loop.body): + iv = loop.induction_variable + base = affine.AffineApplyOp(AffineMap.get_identity(1), [iv]) + Operation.create("togsim.dma_descriptor", + operands=[base.result], results=[]) + scf.YieldOp([]) + print("[4] built scf.for + affine.apply + inner op") + + transfer.erase() + print("[5] erased original transfer") + + print("[6] verify:", m.operation.verify()) + print("----- rewritten IR -----") + print(str(m)) +print("ALL GOOD") diff --git a/thirdparty/github-releases.json b/thirdparty/github-releases.json index 3e836f2e..b641fd9a 100644 --- a/thirdparty/github-releases.json +++ b/thirdparty/github-releases.json @@ -8,7 +8,7 @@ }, "llvm_project": { "repository": "PSAL-POSTECH/llvm-project", - "release_tag": "v1.0.8", + "release_tag": "v1.0.10", "asset_name": "riscv-llvm-release.tar.gz" }, "spike": {