From 2d1f4303f1acde9238b542e8fafba0b45ce8ea39 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 17:33:52 +0900 Subject: [PATCH 01/20] [Frontend] MLIR Python bindings: in-process passes and lowering Run the MLIR -> LLVM pipeline in-process via the bindings PassManager. Add Python out-of-line passes: lower_to_llvm, lower_dma_to_gemmini, lower_vlane_idx. Auto-resolve the bindings path from TORCHSIM_LLVM_PATH and ship the bindings in the LLVM artifact. Add op_coverage tooling and the bindings smoke test. Bump the LLVM pin and rebuild the thirdparty base image. --- Dockerfile.base | 5 + PyTorchSimFrontend/extension_codecache.py | 59 +- .../mlir/mlir_codegen_backend.py | 81 ++- PyTorchSimFrontend/mlir/mlir_ops.py | 14 +- PyTorchSimFrontend/mlir/passes/__init__.py | 71 +++ .../mlir/passes/lower_dma_to_gemmini.py | 227 ++++++++ .../mlir/passes/lower_to_llvm.py | 68 +++ .../mlir/passes/lower_vlane_idx.py | 92 +++ docs/linalg-codegen-migration.md | 224 ++++++++ docs/mlir-python-bindings.md | 102 ++++ scripts/build_from_source.sh | 13 +- scripts/op_coverage.py | 540 ++++++++++++++++++ tests/test_mlir_bindings.py | 56 ++ thirdparty/github-releases.json | 2 +- 14 files changed, 1506 insertions(+), 48 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/passes/__init__.py create mode 100644 PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py create mode 100644 PyTorchSimFrontend/mlir/passes/lower_to_llvm.py create mode 100644 PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py create mode 100644 docs/linalg-codegen-migration.md create mode 100644 docs/mlir-python-bindings.md create mode 100644 scripts/op_coverage.py create mode 100644 tests/test_mlir_bindings.py diff --git a/Dockerfile.base b/Dockerfile.base index 19bbbb2e..de023566 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -91,6 +91,11 @@ RUN curl -L -H "Accept: application/octet-stream" https://api.github.com/repos/P # Store RISC-V LLVM for TorchSim ENV TORCHSIM_LLVM_PATH=/riscv-llvm/bin +# MLIR Python bindings shipped inside the LLVM release artifact (built by the +# llvm-project CI with -DMLIR_ENABLE_BINDINGS_PYTHON=ON). Lets PyTorchSim load +# mlir.ir / dialects for Python-side MLIR passes. The artifact must be built +# against this image's Python (3.11) or `import mlir` fails on ABI mismatch. +ENV PYTHONPATH=/riscv-llvm/python_packages/mlir_core:$PYTHONPATH ENV TORCHSIM_DIR=/workspace/PyTorchSim # Download Spike simulator diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index efd4d4cb..7bd43079 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -43,24 +43,9 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding \ -dma-fine-grained='systolic-array-size={vectorlane_size}' \ - -global-idx='vlen={vlen}' \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ - -test-memref-to-gemmini="vectorlane={vectorlane_size}" \ - -convert-linalg-to-loops \ - -convert-vector-to-scf='full-unroll' \ - -lower-affine \ - -finalize-memref-to-llvm \ - -lower-vector-multi-reduction \ - -convert-vector-to-llvm \ - -convert-arith-to-llvm \ - -convert-math-to-llvm \ - -convert-scf-to-cf \ - -convert-cf-to-llvm \ - -convert-func-to-llvm \ - -convert-index-to-llvm \ - -reconcile-unrealized-casts \ {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ - {filename}.mlir -o {filename}_llvm.mlir + {filename}.mlir -o {filename}_custom.mlir """, ).strip(), re.sub(r"[ \n]+", " ", @@ -93,25 +78,9 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding='timing_mode=1' \ -dma-fine-grained='systolic-array-size={vectorlane_size}' \ - -global-idx='vlen={vlen}' \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ - -test-tile-operation-graph='vectorlane={vectorlane_size} sample-mode={extension_config.CONFIG_TLS_MODE}' \ - -test-memref-to-gemmini="vectorlane={vectorlane_size} timing=1" \ - -convert-linalg-to-loops \ - -convert-vector-to-scf='full-unroll' \ - -lower-affine \ - -finalize-memref-to-llvm \ - -lower-vector-multi-reduction \ - -convert-vector-to-llvm \ - -convert-arith-to-llvm \ - -convert-math-to-llvm \ - -convert-scf-to-cf \ - -convert-cf-to-llvm \ - -convert-func-to-llvm \ - -convert-index-to-llvm \ - -reconcile-unrealized-casts \ {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ - {filename}.mlir -o {sample_filename}_llvm.mlir + {filename}.mlir -o {sample_filename}_postvcix.mlir """, ).strip(), re.sub(r"[ \n]+", " ", @@ -158,6 +127,11 @@ def load(cls, source_code, vlenb = vlen // 8 write_path = get_write_path(source_code) key, input_path = write(source_code, "mlir", specified_dir=write_path) + # Run the Python out-of-line MLIR passes (MLIR bindings) on the kernel + # .mlir in place, before mlir-opt. Currently lowers torchsim.vlane_idx + # (replaces the old C++ -global-idx pass); add more in passes/__init__.py. + from PyTorchSimFrontend.mlir.passes import run_python_passes, run_standard_lowering, run_tog + run_python_passes(input_path) new_input_path = os.path.splitext(input_path)[0] raw_tog_path = new_input_path + "_tog.py" tog_path = os.path.join(write_path, "tile_graph.onnx") @@ -185,6 +159,10 @@ def load(cls, source_code, with lock: try: subprocess.check_call(opt_cmd) + # Standard MLIR -> LLVM-dialect lowering (registered upstream + # passes) runs in-process via the bindings PassManager, picking + # up after the custom mlir-opt passes (memref-to-gemmini). + run_standard_lowering(new_input_path + "_custom.mlir", new_input_path + "_llvm.mlir") subprocess.check_call(translate_cmd) subprocess.check_call(llc_cmd) subprocess.check_call(llc_asm_cmd) @@ -220,9 +198,18 @@ def load(cls, source_code, lock = FileLock(get_lock_path(write_path), timeout=LOCK_TIMEOUT) with lock: try: - result = subprocess.check_output(gem5_sample_cmd) - with open(raw_tog_path, "wb") as file: - file.write(result) + # mlir-opt now runs only loop-padding/dma-fine-grained/pytorchsim-to-vcix + # and writes the post-vcix IR. The tile-operation-graph pass is ported + # to Python: run_tog reads that IR, writes the TOG (_tog.py) and the + # mutated IR (_custom.mlir: sample-mode step rewrite + compute markers), + # replacing the C++ -test-tile-operation-graph pass. + subprocess.check_call(gem5_sample_cmd) + run_tog(sample_mlir_path + "_postvcix.mlir", raw_tog_path, + sample_mlir_path + "_custom.mlir", + sample_mode=extension_config.CONFIG_TLS_MODE, + vectorlane=vectorlane_size) + # Standard MLIR -> LLVM-dialect lowering in-process (see functional path). + run_standard_lowering(sample_mlir_path + "_custom.mlir", sample_mlir_path + "_llvm.mlir", timing=True) subprocess.check_call(gem5_translate_cmd) subprocess.check_call(gem5_llc_cmd) except subprocess.CalledProcessError as e: diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index b163ad1a..19ae3af5 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -322,6 +322,10 @@ def __init__(self, kernel_group, reason=None): self.spad_buffer_dict = dict() self.base_vector_initialized = False self.loop_size = None + # Set by get_dma_info when a DMA access cannot fit one <=4D Gemmini + # descriptor; load()/store() then emit a togsim.transfer for the + # decompose pass to peel into a loop of <=4D dma_start. + self._dma_needs_transfer = False def reset(self, reason): save = self.exit_stack, self._nested_context_depth @@ -537,9 +541,14 @@ def load(self, name: str, index: sympy.Expr): compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # MVIN Encoding - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, int(padding)) - code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + if self._dma_needs_transfer: + self._dma_needs_transfer = False + code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, int(padding)) + else: + attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, int(padding)) + code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching if not comptute_depedency: @@ -608,9 +617,14 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) sram_index_var = self.spad_buffer_dict[str(value)][3] # Generate DMA instruction - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + if self._dma_needs_transfer: + self._dma_needs_transfer = False + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) + else: + attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, attribute) self.dma_stores.writeline(common.DeferredLine(name, code)) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -1243,7 +1257,13 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride else: - raise NotImplementedError("Currently not implemented... ;)") + # >4D access: one Gemmini DMA descriptor (<=4D) cannot represent this. + # Build the full N-D tile and flag it for togsim.transfer; the decompose + # pass peels the excess dims into a loop of <=4D memref.dma_start. + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) + local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride + self._dma_needs_transfer = True if len(implicit_local_dims)!=0 and len(local_dims) != len(implicit_local_dims) and self.is_modular_indexing(index): for axis_constraints in self.kernel_group.tile_desc.implicit_dim_size.values(): @@ -1426,6 +1446,53 @@ def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype return f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape} {attribute}" + def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, + dram_var, dram_index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, padding): + """Emit a generic togsim.transfer op for a DMA whose access exceeds the + 4D Gemmini descriptor limit. Carries the full N-D access (dram/tile + strides + shapes) plus the SSA operands a memref.dma_start needs + (dma_type / vlane_split_axis / vlane_stride), so the decompose pass + (passes/decompose_transfer.py) is purely mechanical: it peels the excess + dims into a loop of <=4D memref.dma_start, reusing these operands. + + The operand prep mirrors get_dma_code (dma_type enum via the read/write + cache+counter, vlane consts via CSE) so the transfer is self-contained; + togsim is an unregistered dialect -> generic form. + """ + dma_key = (vlane_split_axis, vlane_stride, mlir_dtype) + if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: + dma_type, vsa, vst = self.dma_read_cache[dma_key] + elif dma_type_name == "MVOUT" and dma_key in self.dma_write_cache: + dma_type, vsa, vst = self.dma_write_cache[dma_key] + else: + vsa = self.get_const_cse(vlane_split_axis) + vst = self.get_const_cse(vlane_stride) + if dma_type_name == "MVIN": + dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_read_counter}"]) + self.dma_read_counter += 1 + self.dma_read_cache[dma_key] = [dma_type, vsa, vst] + else: + dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"]) + self.dma_write_cache[dma_key] = [dma_type, vsa, vst] + tag = self.get_tag_cse() + zero_cse = self.get_const_cse(0) + # vlane_split_axis is carried as a VALUE attr (not an SSA operand) because the + # decompose pass must remap it: collapsing unit tile dims renumbers the axes, + # so the descriptor's vlane axis index changes and the pass rebuilds the const. + attrs = ( + f'dma_kind = "{dma_type_name}", ' + f'vlane_split_axis = {int(vlane_split_axis)} : i64, ' + f'dram_stride = {dram_stride}, tile_stride = {tile_stride}, ' + f'padding = {int(padding)} : i64' + ) + # operands: dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_stride + return ( + f'"togsim.transfer"(%{dram_var}, %{dram_index_var}, %{sram_var}, %{zero_cse}, ' + f'%{tag}, %{dma_type}, %{vst}) {{{attrs}}} : ' + f'({dram_shape}, index, {tile_shape}, index, memref<1xi32>, index, index) -> ()' + ) + def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None): c_type = mlir_common.DTYPE_TO_C[dtype] mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index 217129e8..f1fb4186 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -1135,11 +1135,19 @@ def extract_strided_slice(operand, target_size, offsets=None, sizes=None, stride @staticmethod def vlane_offset(operand1, operand2, *args, **kwargs): + # Emit a dedicated torchsim.vlane_idx op (generic form; torchsim is an + # unregistered dialect) instead of overloading arith.addi with a + # vlane_offset attribute. A Python out-of-line pass lowers it to + # (vcix.v.i per-lane index * offset); see + # PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py. tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - opcode = f'arith.add{ret_type[0]}' - op_str = f'{opcode} %{operand1}, %{operand2}' - return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] + offset = kwargs.get("attributes", {}).get("vlane_offset", 0) + op_str = '"torchsim.vlane_idx"()' + func_type = f'() -> {shape}' + return format_mlir_op(op_str, func_type, + attributes={"vlane_offset": f"{offset} : i64"}, + comment=kwargs.get("comment")), [tile_size, ret_type] @staticmethod def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_name, *args, **kwargs): diff --git a/PyTorchSimFrontend/mlir/passes/__init__.py b/PyTorchSimFrontend/mlir/passes/__init__.py new file mode 100644 index 00000000..a5b81e91 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/__init__.py @@ -0,0 +1,71 @@ +"""Python out-of-line MLIR passes run on each kernel .mlir before mlir-opt. + +MLIR's PassManager only schedules *registered C++ passes*, not arbitrary Python +functions, so imperative Python rewrites are orchestrated here instead. The flow +is Module-centric: parse the .mlir once, run each registered pass on the shared +Module, print once. A text marker check skips parsing entirely when no pass's +target op is present (the common case). + +To add a pass, create a module exposing MARKERS (tuple of op-name strings) and +run(module) (mutates the Module in place), and append it to PASSES below. +""" +def _ensure_mlir_bindings_on_path(): + """Make `import mlir` work even when PYTHONPATH is not set, by deriving the + bindings location from TORCHSIM_LLVM_PATH (e.g. /riscv-llvm/bin -> + /riscv-llvm/python_packages/mlir_core). The container sets PYTHONPATH, but + plain local runs may not.""" + try: + import mlir.ir # noqa: F401 + return + except ModuleNotFoundError: + pass + import os + import sys + from PyTorchSimFrontend import extension_config + llvm_path = (extension_config.CONFIG_TORCHSIM_LLVM_PATH or "").rstrip("/") + cand = os.path.join(os.path.dirname(llvm_path), "python_packages", "mlir_core") + if os.path.isdir(cand) and cand not in sys.path: + sys.path.insert(0, cand) + + +_ensure_mlir_bindings_on_path() + +from . import lower_vlane_idx +from . import decompose_transfer +from .lower_to_llvm import run_standard_lowering # noqa: F401 (re-exported) +from .build_tog import run_tog # noqa: F401 (re-exported; replaces C++ test-tile-operation-graph) + +# Ordered passes applied to each kernel .mlir before mlir-opt. +# decompose_transfer first: it lowers togsim.transfer -> memref.dma_start, which +# downstream passes (and the gemmini lowering) expect. +PASSES = [ + decompose_transfer, + lower_vlane_idx, +] + + +def run_python_passes(mlir_path): + """Apply all registered Python MLIR passes to the .mlir at `mlir_path`, in place. + + Returns True if the file was modified, False otherwise. + """ + with open(mlir_path) as f: + text = f.read() + + # Fast path: nothing to do if no pass's target op appears in the text. + active = [p for p in PASSES if any(mk in text for mk in p.MARKERS)] + if not active: + return False + + from mlir.ir import Context, Module, Location + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + module = Module.parse(text) + for p in active: + p.run(module) + out = str(module) + + with open(mlir_path, "w") as f: + f.write(out) + return True diff --git a/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py b/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py new file mode 100644 index 00000000..f5b841bb --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py @@ -0,0 +1,227 @@ +"""Lower customized memref.dma_start ops to Gemmini RISC-V inline asm. + +Python port of the C++ test-memref-to-gemmini conversion. Each memref.dma_start +(carrying dram_stride / sram_stride / subtile_size attrs and vlane params encoded +in its stride / num_elements_per_stride / num_elements operands) becomes a +sequence of `llvm.inline_asm` ".insn r CUSTOM_1 ..." Gemmini instructions: +config_mvin/mvout, config2 (dram strides), config3 (spad strides), then the +mvin/mvout itself with the DRAM and scratchpad byte addresses. + +The conversion-framework coupling of the C++ pass (LLVMTypeConverter, +getStridedElementPtr, MemRefDescriptor) is avoided by working at the memref level: +addresses are computed with `memref.extract_aligned_pointer_as_index` + arith, +and the existing standard MLIR->LLVM lowering finalizes everything. Pass order: +this runs on memref-level IR (after test-pytorchsim-to-vcix), before +run_standard_lowering. + +NOTE: indirect-access (gather) dma_start is not yet handled (Phase 2); such ops +raise so they are caught rather than silently mishandled. +""" + +OP_NAME = "memref.dma_start" +WAIT_NAME = "memref.dma_wait" +MARKERS = (OP_NAME, WAIT_NAME) + +# func7 instruction codes (CustomDMAAttribute.h) +CONFIG, CONFIG2, CONFIG3, CONFIG4 = 0, 4, 5, 6 +MVIN, MVIN2, MVIN3, MVOUT = 2, 1, 14, 3 +CONFIG_TYPE = {MVIN: 0, MVIN2: 1, MVIN3: 2, MVOUT: 3} +MAX_TENSOR_DIM = 4 +CONSTRAINTS = "r,r,~{dirflag},~{fpsr},~{flags}" + + +def _asm(func7): + return f".insn r CUSTOM_1, 0x3, {func7}, x0, $0, $1" + + +def _i64_signed(v): + """Wrap an unsigned 64-bit packed value into signed int64 (matches C++ getI64IntegerAttr).""" + v &= 0xFFFFFFFFFFFFFFFF + return v - (1 << 64) if v >= (1 << 63) else v + + +def _row_major_strides(shape): + strides = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + return strides + + +def run(module, timing=False): + """Lower memref.dma_start / dma_wait to Gemmini instructions. + + timing=False (functional/Spike): dma_start -> gemmini config + mvin/mvout asm. + timing=True (gem5 cycle path): dma_start is erased (the TOG already carries + DMA timing; the cycle binary needs no asm). + memref.dma_wait is erased in both modes (matches C++ DmaWaitOpLowering). + """ + from mlir.ir import (InsertionPoint, Operation, IntegerType, IndexType, + IntegerAttr, MemRefType) + from mlir.dialects import llvm, arith, memref + + i64 = IntegerType.get_signless(64) + idx = IndexType.get() + + def const_int(val): + return IntegerAttr(val.owner.attributes["value"]).value + + def i64_const(value): + return arith.ConstantOp(i64, IntegerAttr.get(i64, _i64_signed(value))).result + + def asm(func7, rs1, rs2): + llvm.InlineAsmOp(None, [rs1, rs2], _asm(func7), CONSTRAINTS, + has_side_effects=True, asm_dialect=0) + + def elem_addr_i64(memref_val, indices, mtype, elem_bytes): + """i64 byte address of memref_val[indices] (aligned ptr + linear elem offset).""" + base = memref.ExtractAlignedPointerAsIndexOp(memref_val).result # index = byte addr + strides = _row_major_strides(list(mtype.shape)) + off = None # element offset (index) + for k, ival in enumerate(indices): + if strides[k] == 0: + continue + term = ival + if strides[k] != 1: + term = arith.MulIOp(ival, arith.ConstantOp(idx, IntegerAttr.get(idx, strides[k])).result).result + off = term if off is None else arith.AddIOp(off, term).result + if off is not None: + byte = arith.MulIOp(off, arith.ConstantOp(idx, IntegerAttr.get(idx, elem_bytes)).result).result + base = arith.AddIOp(base, byte).result + return arith.IndexCastOp(i64, base).result + + starts, waits = [], [] + for region in module.operation.regions: + for b in region.blocks: + _collect(b, starts, waits) + + for op in waits: # dma_wait: erase in both modes + op.erase() + + for op in starts: + if timing: # gem5 cycle path: drop the dma_start (TOG has timing) + op.erase() + continue + operands = list(op.operands) + src, dst = operands[0], None + src_ty = MemRefType(src.type) + src_rank = len(src_ty.shape) + dst = operands[1 + src_rank] + dst_ty = MemRefType(dst.type) + dst_rank = len(dst_ty.shape) + src_idx = operands[1:1 + src_rank] + dst_idx = operands[1 + src_rank + 1:1 + src_rank + 1 + dst_rank] + + dma_type = const_int(operands[1 + src_rank + 1 + dst_rank]) # num_elements + vlane_split_axis = const_int(operands[-2]) # stride (always 2nd-to-last) + vlane_stride = const_int(operands[-1]) & 0x7FFF # num_elements_per_stride (last) + is_mvin = dma_type in (MVIN, MVIN2, MVIN3) + + elem_bytes = _elem_bytes(src_ty.element_type) + # Indirect (gather): the gather-side indices are src for mvin, dst for mvout. + gather_idx = src_idx if is_mvin else dst_idx + indirect, indirect_memref = _find_indirect(gather_idx) + + tile_shape = _subtile(op) + if tile_shape is None: + tile_shape = list(dst_ty.shape) if is_mvin else list(src_ty.shape) + dram_strides = _int_array(op, "dram_stride") + spad_strides = _int_array(op, "sram_stride") + assert len(tile_shape) == len(dram_strides) == len(spad_strides), \ + f"shape/stride rank mismatch: {tile_shape} {dram_strides} {spad_strides}" + + expand = MAX_TENSOR_DIM - len(tile_shape) + shape4 = [1] * expand + tile_shape + dram4 = [0] * expand + dram_strides + spad4 = [0] * expand + spad_strides + vlane_split_axis += expand + config_type = CONFIG_TYPE[dma_type] + + with InsertionPoint(op): + addrA = elem_addr_i64(src, src_idx, src_ty, elem_bytes) + addrB = elem_addr_i64(dst, dst_idx, dst_ty, elem_bytes) + dram_addr, spad_addr = (addrA, addrB) if is_mvin else (addrB, addrA) + + cfg_rs1 = i64_const(((shape4[0] & 0xFFFF) << 48) | ((shape4[1] & 0xFFFF) << 32) + | ((shape4[2] & 0xFFFF) << 16) | (shape4[3] & 0xFFFF)) + cfg_rs2 = i64_const((vlane_stride << 32) | ((config_type & 0x3) << 17) + | ((1 if indirect else 0) << 16) + | ((vlane_split_axis & 0x3) << 14) | elem_bytes) + asm(CONFIG, cfg_rs1, cfg_rs2) + asm(CONFIG2, i64_const((dram4[0] << 32) | (dram4[1] & 0xFFFFFFFF)), + i64_const((dram4[2] << 32) | (dram4[3] & 0xFFFFFFFF))) + asm(CONFIG3, i64_const((spad4[0] << 32) | (spad4[1] & 0xFFFFFFFF)), + i64_const((spad4[2] << 32) | (spad4[3] & 0xFFFFFFFF))) + if indirect: + # CONFIG4: rs1 = indirect index-spad base address, rs2 = (elem_size<<16)|stride(1) + ind_base = memref.ExtractAlignedPointerAsIndexOp(indirect_memref).result + ind_addr = arith.IndexCastOp(i64, ind_base).result + ind_esize = _elem_bytes(MemRefType(indirect_memref.type).element_type) + asm(CONFIG4, ind_addr, i64_const(((ind_esize & 0xFF) << 16) | (1 & 0xFFFF))) + asm(dma_type, dram_addr, spad_addr) + op.erase() + + +def _collect(block, starts, waits): + for op in list(block.operations): + name = op.operation.name + if name == OP_NAME: + starts.append(op.operation) + elif name == WAIT_NAME: + waits.append(op.operation) + for region in op.operation.regions: + for b in region.blocks: + _collect(b, starts, waits) + + +def _subtile(op): + from mlir.ir import ArrayAttr, IntegerAttr + if "subtile_size" not in op.attributes: + return None + return [IntegerAttr(a).value for a in ArrayAttr(op.attributes["subtile_size"])] + + +def _int_array(op, name): + from mlir.ir import ArrayAttr, IntegerAttr + return [IntegerAttr(a).value for a in ArrayAttr(op.attributes[name])] + + +def _elem_bytes(elem_type): + from mlir.ir import IntegerType, FloatType + bits = (IntegerType(elem_type).width if IntegerType.isinstance(elem_type) + else FloatType(elem_type).width) + return max(bits, 8) // 8 + + +def _find_indirect(indices): + """If a gather index is an affine.apply{indirect_access} whose operands include + index_cast(affine.load(%spad)), return (True, %spad memref); else (False, None).""" + for idx in indices: + ap = idx.owner + if getattr(ap, "name", None) != "affine.apply" or "indirect_access" not in ap.attributes: + continue + for operand in ap.operands: + ic = operand.owner + if getattr(ic, "name", None) != "arith.index_cast": + continue + ld = ic.operands[0].owner + if getattr(ld, "name", None) == "affine.load": + return True, ld.operands[0] # affine.load operand 0 == the index spad memref + return False, None + + +def lower_text(text): + if OP_NAME not in text: + return text + from mlir.ir import Context, Module, Location + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + m = Module.parse(text) + run(m) + return str(m) + + +if __name__ == "__main__": + import sys + out = lower_text(open(sys.argv[1]).read()) + (open(sys.argv[2], "w").write(out) if len(sys.argv) > 2 else sys.stdout.write(out)) diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py b/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py new file mode 100644 index 00000000..ce6e081a --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py @@ -0,0 +1,68 @@ +"""Standard MLIR -> LLVM-dialect lowering via the bindings PassManager. + +Runs the upstream *registered* lowering passes (convert-*-to-llvm, lower-affine, +reconcile-unrealized-casts, ...) in-process on the post-custom-pass IR, replacing +the tail of the mlir-opt pipeline. The remaining custom passes (test-loop-padding, +dma-fine-grained, test-pytorchsim-to-vcix) still run in mlir-opt; the gem5 path's +test-tile-operation-graph is now the Python build_tog pass, and memref-to-gemmini +is the Python lower_dma_to_gemmini pass (run inside this lowering). As the custom +passes migrate to Python, mlir-opt shrinks toward an all-in-process flow. + +Validated to produce byte-identical LLVM IR to running the same passes inside +mlir-opt. Note: only lower-vector-multi-reduction is func.func-scoped (the +bindings pass-pipeline parser does not auto-nest like the mlir-opt CLI, so it is +wrapped explicitly); order is preserved to match the original pipeline. +""" + +STANDARD_PIPELINE = ( + "builtin.module(" + "convert-linalg-to-loops," + "convert-vector-to-scf{full-unroll=true}," + "lower-affine," + "expand-strided-metadata," # decompose memref.collapse_shape/subview before LLVM + "finalize-memref-to-llvm," + "func.func(lower-vector-multi-reduction)," + "convert-vector-to-llvm," + "convert-arith-to-llvm," + "convert-math-to-llvm," + "convert-scf-to-cf," + "convert-cf-to-llvm," + "convert-func-to-llvm," + "convert-index-to-llvm," + "reconcile-unrealized-casts)" +) + + +def run_standard_lowering(in_path, out_path=None, timing=False): + """Lower the post-custom-pass MLIR at `in_path` to the LLVM dialect. + + Runs the imperative Gemmini lowering (memref.dma_start/dma_wait) then the + registered standard MLIR->LLVM passes. `timing` selects the Gemmini behavior: + False for the functional/Spike path (emit gemmini asm), True for the gem5 + cycle path (erase dma_start; the TOG already carries DMA timing) -- this + preserves the old test-memref-to-gemmini `timing=1` semantics. + + Writes the result to `out_path` (defaults to `in_path`, i.e. in place). + Requires the MLIR Python bindings on PYTHONPATH. + """ + if out_path is None: + out_path = in_path + from mlir.ir import Context, Module, Location + from mlir.passmanager import PassManager + from . import lower_dma_to_gemmini + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + with open(in_path) as f: + module = Module.parse(f.read()) + # Imperative Python pass: memref.dma_start/dma_wait -> Gemmini asm (replaces + # the C++ test-memref-to-gemmini), then the registered standard lowering. + lower_dma_to_gemmini.run(module, timing=timing) + PassManager.parse(STANDARD_PIPELINE, ctx).run(module.operation) + with open(out_path, "w") as f: + f.write(str(module)) + + +if __name__ == "__main__": + import sys + run_standard_lowering(sys.argv[1], sys.argv[2] if len(sys.argv) > 2 else None) diff --git a/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py b/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py new file mode 100644 index 00000000..c9898f4b --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py @@ -0,0 +1,92 @@ +"""Python out-of-line MLIR pass: lower torchsim.vlane_idx -> per-lane index * offset. + +Codegen emits a dedicated `torchsim.vlane_idx` op (generic form, unregistered +dialect) carrying a `vlane_offset` integer attribute. This pass rewrites each +such op to: + + %v = "vcix.v.i"(%K) {opcode = 0, rs2 = 0, imm = 0} : (i64) -> vector // per-lane index + %n = arith.constant dense : vector + %r = arith.muli %v, %n : vector + +and replaces uses of the original op with %r. Replaces the former C++ +`-global-idx` pass (which overloaded arith.addi with a vlane_offset attribute). + +Pass interface (see passes/__init__.py): MARKERS + run(module). Also runnable +standalone as a CLI: + python PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py in.mlir [out.mlir] + +Requires the MLIR Python bindings on PYTHONPATH +(/riscv-llvm/python_packages/mlir_core). The `vcix` dialect must be registered +in the consuming mlir-opt for the result to round-trip (see +registerVCIXDialectTranslation in mlir-opt.cpp). +""" + +OP_NAME = "torchsim.vlane_idx" +MARKERS = (OP_NAME,) + + +def _iter_ops(block): + for op in list(block.operations): + yield op + for region in op.operation.regions: + for b in region.blocks: + yield from _iter_ops(b) + + +def run(module): + """Lower every torchsim.vlane_idx op in `module`, in place. + + Must be called with the module's Context active (the orchestrator provides it). + """ + from mlir.ir import (InsertionPoint, Operation, IntegerType, IntegerAttr, + DenseElementsAttr, VectorType) + i64 = IntegerType.get_signless(64) + i32 = IntegerType.get_signless(32) + + targets = [] + for region in module.operation.regions: + for b in region.blocks: + for op in _iter_ops(b): + if op.operation.name == OP_NAME: + targets.append(op.operation) + + for op in targets: + res = op.results[0] + vt = VectorType(res.type) + k, et = vt.shape[0], vt.element_type + offset = IntegerAttr(op.attributes["vlane_offset"]).value + with InsertionPoint(op): + rvl = Operation.create("arith.constant", results=[i64], + attributes={"value": IntegerAttr.get(i64, k)}).results[0] + lane = Operation.create("vcix.v.i", results=[vt], operands=[rvl], + attributes={"opcode": IntegerAttr.get(i64, 0), + "rs2": IntegerAttr.get(i32, 0), + "imm": IntegerAttr.get(i32, 0)}).results[0] + ovec = Operation.create("arith.constant", results=[vt], + attributes={"value": DenseElementsAttr.get_splat( + vt, IntegerAttr.get(et, offset))}).results[0] + mul = Operation.create("arith.muli", results=[vt], operands=[lane, ovec]).results[0] + res.replace_all_uses_with(mul) + op.erase() + + +def lower_text(text: str) -> str: + """Parse `text`, run this pass, return the printed module. CLI/testing helper.""" + if OP_NAME not in text: + return text + from mlir.ir import Context, Module, Location + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + m = Module.parse(text) + run(m) + return str(m) + + +if __name__ == "__main__": + import sys + out = lower_text(open(sys.argv[1]).read()) + if len(sys.argv) > 2: + open(sys.argv[2], "w").write(out) + else: + sys.stdout.write(out) diff --git a/docs/linalg-codegen-migration.md b/docs/linalg-codegen-migration.md new file mode 100644 index 00000000..9a60ba31 --- /dev/null +++ b/docs/linalg-codegen-migration.md @@ -0,0 +1,224 @@ +# Linalg-based codegen migration (Plan B) + +Status: **deferred / design only**. This is not scheduled work. It records *why* +a linalg-based rewrite of the MLIR codegen will eventually be worth doing, and a +rough plan, so the decision does not have to be re-derived from scratch later. + +For the near-term padding/dynamic-shape work, see Plan A (graph-level padding) in +the "Relationship to Plan A" section below. Plan A is the one to do first; it does +not depend on this document. + +## TL;DR + +The current MLIR codegen (`PyTorchSimFrontend/mlir/`) does not just emit loops — +it hand-implements the entire hardware mapping (tiling, vectorization, DMA, +scratchpad allocation, vector-lane distribution) as Python string emission. That +works, but it entangles three concerns that should be separable: + +1. **what to compute** (the op math), +2. **how to map it onto the NPU** (tile sizes, vlane layout, DMA/scratchpad), +3. **how to make shapes fit the hardware** (padding / divisibility). + +Concern 3 is currently spread across three places and is the source of the +"padding is heuristic and fragile" pain. Plan B factors concern 1 up to the +`linalg` dialect and rebuilds concern 2 as a set of MLIR lowering passes, so that +concern 3 falls out of the structured representation instead of being patched. + +Plan B is a multi-month, higher-risk effort because the hardware mapping (concern +2) is bespoke and has no upstream equivalent. Do it when the payoff (separation, +reuse of upstream tiling/vectorization/fusion, easier autotuning, lower codegen +maintenance) is worth that cost — not as a means to fix padding alone. + +## Where we are today + +Entry point: PyTorchSim is an **Inductor backend**. Inductor handles capture, +decomposition, lowering to Inductor IR, scheduling, and **fusion**. Our code runs +at the codegen (step-5) stage and turns scheduled `SchedulerNode`s into MLIR. + +`MLIRKernel` (`mlir/mlir_codegen_backend.py`) and friends emit, by hand: + +- explicit DMA: `memref.dma_start` with MVIN/MVOUT encoding, `vlane_split_axis`, + `vlane_stride` (`load`, `store`, `get_dma_code`); +- explicit scratchpad: `.spad` sections and `memref.global @bufN_spad` + (`allocate_sram_buffer`, `get_scratchpad_buffer`); +- explicit vector-lane (vlane) distribution: `vmap.vlane_split_axis` / + `vlane_stride`, `vlane_offset`, `get_used_vlane` (`_index_expr`, + `get_dma_info`); +- explicit tile descriptors: `MLIRMultiDimTile` with per-axis tile sizes; +- explicit reduction loops: manual accumulator/iterator vars + `affine_yield` + (`reduction`, `codegen_loops`); +- explicit vector-tail masking: `get_mask`. + +This is roughly 5,500 lines across `mlir_codegen_backend.py`, `mlir_common.py`, +`mlir_template.py`, and `mlir_ops.py`, plus per-op templates +(`mlir_gemm_template.py`, `mlir_conv_*`, `mlir_bmm_template.py`, +`mlir_sdpa_template.py`, `mlir_sort_template.py`, `mlir_cat_template.py`, +`mlir_maxpool_template.py`). Most of it encodes how this specific NPU's memory +hierarchy and vector unit work. **That accumulated hardware knowledge is the asset +and the cost center for any rewrite.** + +### Padding lives in three places + +This is the key observation motivating the separation. Divisibility / padding is +handled by: + +1. **Python recompile + tile adjustment** — `get_dma_info` FloorDiv / + ModularIndexing handling, `_index_expr`, `convert_indirect_indexing`: "if the + tile size does not divide the dim, bump the tile and raise `RecompileSignal` to + recompile." This is the actual heuristic-padding body, and it is in Python, not + MLIR. +2. **Python `get_mask`** — vector-tail masking for the innermost compute loop. +3. **MLIR `TestLoopPadding`** pass (in the `PSAL-POSTECH/llvm-project` fork) — + rounds affine loop bounds up to a multiple of the step and resizes buffers, + reverse-engineering the loop<->tensor mapping from affine maps. + +Removing only (3) does not fix the fragility; (1) is arguably the larger source of +"the loop and the tensor do not line up." Any real fix has to address all three. + +## Target architecture (Plan B) + +Split the codegen into two layers with a clean boundary: + +``` +L0 ATen / FX logical graph (dynamic dims as SymInt) <- Inductor, unchanged +L1 math layer: Inductor IR -> linalg.generic / named ops (untiled, unvectorized) + = iteration domain + affine indexing maps + scalar body. + Decides nothing about tiles / lanes / DMA. +L2 mapping layer: tiling(+pad) -> vectorize(+vl) -> bufferize + -> DMA / scratchpad / vlane lowering -> leaf replacement + = hardware mapping, parameterized by a target description. +L3 LLVM / RVV / systolic microkernel +``` + +Why this helps: + +- **The loop<->tensor mapping stops being reverse-engineered.** `linalg` ops carry + `indexing_maps` + `iterator_types`, i.e. exactly the information `TestLoopPadding` + tries to recover. Padding becomes a *parameter of the tiling transform* + (`tensor.pad` generated with full knowledge of the maps), not a separate + analysis pass. +- **The padding strategies we want become per-axis policy** in L2: systolic + operand axes -> pad-to-uniform-tile (keeps a single 128x128 microkernel and a + single gem5 latency entry); VPU / vector axes -> RVV `vl` (no padding); reduction + axes -> `affine.min` clamp. One mechanism, selected per axis, instead of three + scattered implementations. +- **Fusion policy stays in Inductor** (its scheduler decides what is one kernel), + while the *mechanism* is upstream `linalg` tile-and-fuse. Pointwise epilogue + fusion is essentially free because Inductor already composes `inner_fn`s into a + single fused body -> one `linalg.generic`. +- **Dynamic shapes** are carried as `?` dims + symbolic affine, uniformly handled + by L2 rather than by the recompile dance. +- **The cost model gets cleaner, not harder**: tile shape becomes an explicit + attribute, so the gem5 latency table / TOG key on it directly instead of on + inferred loop shapes. + +### What is reusable vs bespoke + +- **Reusable from upstream MLIR**: `linalg` ops, tiling + `tensor.pad`, + vectorization, bufferization, the TilingInterface. The L1 translation + (Inductor IR -> `linalg.generic`) is a *generic* translator (one path covers all + regular pointwise/reduction), not per-op work — Inductor IR is already in + structured iteration-domain + scalar-body form. +- **Bespoke, must be (re)written as MLIR passes**: MVIN/MVOUT DMA encoding, + `.spad` scratchpad assignment, and the vlane_split mapping. **These have no + upstream equivalent.** This is the bulk of the effort and the main risk: the + knowledge currently in ~5,500 lines of Python emission must be re-expressed as + custom bufferization-to-DMA / scratchpad / vlane-vectorization lowerings. + +## Expressibility boundary + +A regular (linalg.generic) op needs: a fixed rectangular iteration space; every +operand index an **affine** function of loop vars (no data-dependent indexing); +each axis purely `parallel` or a simple `reduction` (no scan/recurrence); a +statically-determined output shape (dynamic `?` ok, data-dependent shape not); and +a data-independent body (`arith.select` ok, data-dependent branching not). + +Expressible: elementwise, broadcast, transpose, reductions (incl. multiple +reduction axes), matmul / bmm / contractions, direct conv, pooling, and fused +chains of these (matmul+bias+activation, prologue cast/dequant/transpose, +pointwise->reduce). `slice` / `pad` / `cat` are structured `tensor` ops (not +`linalg.generic`) but are supported by the same pipeline. + +Not expressible -> stay as hand-written custom kernels: data-dependent indexing +(gather/scatter/embedding), sort/topk, data-dependent output shape +(nonzero/unique/masked_select), scan/recurrence (cumsum), and online/streaming +algorithms (flash-attention). In our op set this means **`sdpa` (online softmax) +and `sort` remain custom**; gemm/conv/bmm/maxpool are regular; `cat` is a +structured tensor op. + +## Migration strategy (when Plan B is scheduled) + +Incremental, op-by-op, with a numeric and a structural safety net. Do **not** +big-bang. + +1. **Stand up the L2 pipeline for one op (matmul first).** Emit `linalg.matmul` + from the matmul path; wire tiling(+pad, pad_value=0) -> bufferize -> a custom + pass that lowers the 128^3 leaf tile to the existing systolic intrinsic -> + LLVM. Milestone 1 is end-to-end correctness through all three simulators + (Spike functional, gem5 latency, TOGSim cycle) for a single matmul. +2. **Demote `TestLoopPadding` to assert-only** (check, do not modify; fail/log if a + loop bound is not a multiple of its step). Run the full test suite; anything it + flags is a case L1/L2 has not covered yet. +3. **Migrate the remaining regular ops** (conv, bmm, pointwise, reductions, + maxpool). Pointwise/reduction go through the generic L1 translator; VPU + remainder via `vl`. +4. **Delete `TestLoopPadding`** once the assert-only version is silent across the + suite, and retire the Python recompile/tile-adjust dance and most of `get_mask`. +5. Leave `sdpa` and `sort` as custom kernels that bypass L1/L2. + +### Risks + +- **Simulator-facing contract.** The current emission is tuned to produce exactly + the LLVM / TOG shape the three simulators expect. `linalg`'s standard lowering + emits different IR; re-validating the lowered artifact end-to-end (especially TOG + generation, which may assume specific loop/memory patterns) is the real + integration risk. This is why milestone 1 is "one matmul, end-to-end," not "all + ops, emission only." +- **Re-encoding the hardware mapping.** DMA/scratchpad/vlane lowerings are new code + with no upstream reference; budget for them dominating the schedule. +- **Inductor index expressions.** Inductor often collapses dims into one flat index + with `FloorDiv` / `ModularIndexing`, which are not affine; `linalg` indexing_maps + must be affine. Either keep dims uncollapsed or normalize div/mod back to + multi-dim affine. (We already convert these to affine strings today in + `_convert_sympy_to_mlir_expr`, but that path will need to be revisited for the + map-carrying representation.) +- **Fusion seams.** Not everything fuses cleanly (reductions with mismatched axes, + transpose/layout mismatches); expect some barriers, same as any framework. + +## Relationship to Plan A (graph-level padding) — do this first + +Plan A inserts padding at the FX/graph level (via Inductor's +`post_grad_custom_pass`) so that tiled dims arrive at codegen already aligned +(`tile_granule * symbol`). Under the hard constraint that we **keep the Inductor +spine**, Plan A is the high-ROI move: + +- It collapses all three padding sites at once: the recompile/tile-adjust dance (1) + becomes unnecessary (tiles always divide), `get_mask` (2) becomes trivial (no + tail), and `TestLoopPadding` (3) becomes unnecessary. +- It does **not** touch the bespoke DMA/scratchpad/vlane mapper. + +Key correctness facts that make Plan A tractable: + +- Matmul contraction (K) padding with zeros is *exact* (additive identity); weights + are constants, so they can be zero-padded once, offline, at no runtime cost. +- Padding only ever corrupts results when a *non-contraction* padded axis is later + reduced (softmax over keys; layernorm if hidden is padded). Those points need + masking; everything else is pad-transparent. +- Safety rule: default any op to slice-back-to-real-shape; only opt an op into + "propagate padded shape" once it is proven pad-transparent or given a mask + handler. Correct-by-construction; unknown ops cannot silently corrupt. + +Plan A and Plan B are compatible: Plan A's graph-level alignment makes the eventual +Plan B simpler (L2 tiling rarely needs to pad, because dims already divide). + +## Open questions + +- Does the current toolchain (the `PSAL-POSTECH/llvm-project` fork) already ship the + `linalg` + transform/tiling passes, or were they stripped? (Almost certainly + present if it tracks upstream — verify before committing.) +- Can the systolic leaf be expressed cleanly as a match-and-replace on a fixed-size + `linalg.matmul`, or does weight-stationary loading order force a more custom + representation? +- How much of `mlir_ops.py` (the scalar `OpsHandler`) survives? It currently emits + *vectorized* ops (compute_vec_size, broadcast) and is therefore entangled with + vlane; the linalg body should be scalar, with vectorization done in L2. diff --git a/docs/mlir-python-bindings.md b/docs/mlir-python-bindings.md new file mode 100644 index 00000000..6bb03339 --- /dev/null +++ b/docs/mlir-python-bindings.md @@ -0,0 +1,102 @@ +# Enabling MLIR Python bindings + +Goal: ship the MLIR Python bindings (`import mlir`, `mlir.ir`, `mlir.dialects`) +so we can write MLIR passes in Python (imperative IR rewriting via the bindings) +instead of only C++ passes in the `PSAL-POSTECH/llvm-project` fork. See +`dma-transfer-lowering.md` for the first intended use (a Python decompose pass). + +## How LLVM reaches the runtime (why this touches 3 places) + +``` +PSAL-POSTECH/llvm-project (fork, tag vX.Y.Z) + .github/workflows/build-torchsim.yaml -- CI builds + releases riscv-llvm-release.tar.gz + | (release asset) + v +thirdparty/github-releases.json -- pins llvm_project.release_tag + asset + | + v +Dockerfile.base -- downloads asset, extracts to /riscv-llvm, + sets TORCHSIM_LLVM_PATH (+ now PYTHONPATH) +``` + +`scripts/build_from_source.sh` is the alternative source-build path (not the +normal flow, but kept consistent). + +## The one real blocker: Python ABI must match + +The bindings are a native CPython extension (`_mlir.cpython-3XX-*.so`). They only +import under the **same Python minor version** they were built against. The +runtime base image uses **conda Python 3.11**. So the artifact must be built with +**Python 3.11**. Building with the build container's default (ubuntu-22.04 -> +3.10) produces bindings that fail to import at runtime with a confusing error +much later -- hence the fail-fast guard in the CI step. + +Patch version (3.11.x) does not matter; minor version (3.11 vs 3.10) does. + +## What was changed + +- **`scripts/build_from_source.sh`**: cmake gets + `-DMLIR_ENABLE_BINDINGS_PYTHON=ON -DPython3_EXECUTABLE=$(command -v python3)`; + build deps (nanobind/pybind11/numpy/PyYAML) pip-installed; after `make install` + the build-tree `tools/mlir/python_packages` is copied into `/riscv-llvm` + (install does not place it there). PYTHONPATH exported for the current shell. +- **`Dockerfile.base`**: `ENV PYTHONPATH=/riscv-llvm/python_packages/mlir_core:$PYTHONPATH` + after the LLVM artifact is extracted. +- **`llvm-project/.github/workflows/build-torchsim.yaml`** (fork): same cmake + flags + deps; copies `python_packages` into the `riscv-llvm` tree so the + existing `tar` includes it; fail-fast guard requiring `python3.11`. + +## Rollout sequence (must be done in order) + +1. **python3.11 in the build container: done, non-root.** The CI step keeps the + original `-u $(id -u):$(id -g)` (no root assumed) and fetches a standalone + CPython 3.11 with `uv` (`uv venv --python 3.11`), then points + `Python3_EXECUTABLE` at that venv. No apt / no root needed. ubuntu-22.04's + default 3.10 is not used for the bindings. + - ABI note: extensions built against a uv/python-build-standalone CPython 3.11 + are expected to import under the runtime conda CPython 3.11 (same minor + version, standard builds are C-ABI compatible). The verify step below is the + check; if it ever fails, build instead in the runtime image (`python:3.11` or + the pytorch base) so build Python == runtime Python by construction. +2. **Push the fork changes** to `PSAL-POSTECH/llvm-project` and cut a new tag + (e.g. `v1.0.9`). CI builds `riscv-llvm-release.tar.gz` now containing + `python_packages/`. +3. **Bump `thirdparty/github-releases.json`** -> `llvm_project.release_tag` to the + new tag (and `asset_name` unchanged). This triggers a new base image build. +4. **Rebuild the base image** (the fork CI already dispatches `build_base`; or run + the PyTorchSim docker-image workflow) so `Dockerfile.base` produces an image + with the bindings + PYTHONPATH. + +## Verify + +Inside the rebuilt container (or after `build_from_source.sh`): + +```bash +python -c "import mlir; print(mlir.__file__)" # -> /riscv-llvm/python_packages/mlir_core/mlir/__init__.py +python -c "from mlir.ir import Context; c=Context(); c.allow_unregistered_dialects=True; print('ok')" +python -c "from mlir.dialects import scf, affine, arith; print('dialects ok')" +``` + +`allow_unregistered_dialects=True` is what lets us read/write the custom ops +(`togsim.transfer`, the customized `memref.dma_start`) generically without +registering a dialect in the bindings. + +## Notes / gotchas + +- Keep the bindings statically linked (default, i.e. do NOT add + `-DBUILD_SHARED_LIBS=ON` / `-DLLVM_BUILD_LLVM_DYLIB=ON`); otherwise the `.so` + needs libMLIR/libLLVM at runtime and the artifact + LD_LIBRARY_PATH grow. +- Worktrees: add the same `PYTHONPATH` line to the worktree `.envrc` (see + `docs/worktrees.md`) if a worktree overrides paths. +- The bindings are an additive, optional dependency: text emission + C++ passes + keep working unchanged. Only new Python passes require the bindings present. +- This LLVM fork's MLIR bindings use **pybind11** (not nanobind) and require + **pybind11 <= 2.10.3**: newer pybind11 (3.x) fails to compile `IRCore.cpp` with + `def_property family does not currently support keep_alive`. Pin it + (`pybind11>=2.9.0,<=2.10.3`). See `mlir/python/requirements.txt` for the fork's + pins. pybind11 is build-time only; the runtime needs just the built `.so` + numpy. +- numpy: the fork's requirements pin `<=1.26`, but a local build against numpy 2.x + compiled and imported fine, so we keep numpy at the runtime version (2.x) to + avoid a numpy-1-built / numpy-2-runtime ABI mismatch. (Validated locally: + conda 3.11 + pybind11 2.10.3 + numpy 2.x -> `import mlir` and parsing a custom + `togsim.transfer` op with floordiv/mod affine maps both work.) diff --git a/scripts/build_from_source.sh b/scripts/build_from_source.sh index 4e7ff604..f23eab82 100644 --- a/scripts/build_from_source.sh +++ b/scripts/build_from_source.sh @@ -45,12 +45,23 @@ export GEM5_PATH="$home/gem5/build/RISCV/gem5.opt" cd "$home" # LLVM + MLIR (RISCV target) +# MLIR Python bindings are enabled so Python-side MLIR passes can run. The +# bindings are a native extension: they MUST be built against the same Python +# that runs PyTorchSim at runtime (the conda 3.11 here) or `import mlir` will +# fail with an ABI mismatch. nanobind/pybind11/numpy/PyYAML are build-time deps. +python3 -m pip install --user "pybind11>=2.9.0,<=2.10.3" numpy PyYAML git clone --depth 1 --branch "$LLVM_TAG" "https://github.com/${LLVM_REPO}.git" cd llvm-project && mkdir -p build && cd build && \ cmake -DLLVM_ENABLE_PROJECTS=mlir -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/riscv-llvm -DLLVM_TARGETS_TO_BUILD=RISCV \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE="$(command -v python3)" \ -G "Unix Makefiles" ../llvm && \ - make -j && make install + make -j && make install && \ + rm -rf /riscv-llvm/python_packages && \ + cp -r tools/mlir/python_packages /riscv-llvm/python_packages +# Make the bindings importable in this shell (also set in .envrc / Dockerfile.base) +export PYTHONPATH="/riscv-llvm/python_packages/mlir_core:$PYTHONPATH" cd "$home" # Spike Simulator diff --git a/scripts/op_coverage.py b/scripts/op_coverage.py new file mode 100644 index 00000000..1f4567b6 --- /dev/null +++ b/scripts/op_coverage.py @@ -0,0 +1,540 @@ +"""Op-coverage diagnostic for new LLM models on PyTorchSim. + +Runs each model in two phases: + Phase 1 (enumerate): custom torch.compile backend captures the FX graph and + lists every aten op that appears, without touching NPU. + Phase 2 (run): torch.compile(model) on npu:0, real forward. On crash, + parses the traceback to identify the failing op. + +Usage: + python scripts/op_coverage.py # all models + python scripts/op_coverage.py --models qwen2 # subset + python scripts/op_coverage.py --enumerate-only # skip NPU compile (fast) +""" + +import argparse +import datetime as _dt +import os +import re +import sys +import traceback +from contextlib import contextmanager + +import torch + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + + +# --------------------------------------------------------------------------- +# Model registry: each entry returns (model, kwargs_for_forward) on CPU. +# Sizes follow "small but realistic" variants (1-layer) so a forward is cheap +# enough to actually drive through TOGSim. +# --------------------------------------------------------------------------- + +def _causal_mask(batch, seq_len, dtype): + min_v = torch.finfo(dtype).min + m = torch.full((seq_len, seq_len), min_v, dtype=dtype) + if seq_len > 1: + m = torch.triu(m, diagonal=1) + return m[None, None, :, :].expand(batch, 1, -1, -1).contiguous() + + +def build_qwen2(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + cfg = Qwen2Config( + vocab_size=4096, + hidden_size=1536, + num_attention_heads=12, + num_key_value_heads=2, + intermediate_size=8960, + num_hidden_layers=2, + max_position_embeddings=4096, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + torch_dtype=dtype, + use_cache=False, + _attn_implementation="eager", + ) + model = Qwen2Model(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def build_gemma(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.gemma.configuration_gemma import GemmaConfig + from transformers.models.gemma.modeling_gemma import GemmaModel + cfg = GemmaConfig( + vocab_size=4096, + hidden_size=2048, + num_attention_heads=8, + num_key_value_heads=1, + intermediate_size=16384, + num_hidden_layers=2, + head_dim=256, + max_position_embeddings=4096, + rms_norm_eps=1e-6, + rope_theta=10000.0, + torch_dtype=dtype, + use_cache=False, + _attn_implementation="eager", + ) + model = GemmaModel(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def build_gemma2(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.gemma2.configuration_gemma2 import Gemma2Config + from transformers.models.gemma2.modeling_gemma2 import Gemma2Model + cfg = Gemma2Config( + vocab_size=4096, + hidden_size=2304, + num_attention_heads=8, + num_key_value_heads=4, + intermediate_size=9216, + num_hidden_layers=2, + head_dim=256, + max_position_embeddings=4096, + rms_norm_eps=1e-6, + rope_theta=10000.0, + torch_dtype=dtype, + use_cache=False, + attn_logit_softcapping=50.0, + final_logit_softcapping=30.0, + sliding_window=16, + _attn_implementation="eager", + ) + model = Gemma2Model(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def build_phi3(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.phi3.configuration_phi3 import Phi3Config + from transformers.models.phi3.modeling_phi3 import Phi3Model + cfg = Phi3Config( + vocab_size=4096, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + hidden_size=3072, + num_attention_heads=32, + num_key_value_heads=32, + intermediate_size=8192, + num_hidden_layers=2, + max_position_embeddings=4096, + rms_norm_eps=1e-5, + rope_theta=10000.0, + torch_dtype=dtype, + use_cache=False, + _attn_implementation="eager", + ) + model = Phi3Model(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def _build_lm(cfg, ModelCls, batch, seq_len, dtype): + """Shared helper: build a causal-LM-style model and matching token+mask inputs.""" + model = ModelCls(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + attn_mask = _causal_mask(batch, seq_len, dtype) + return model, {"input_ids": input_ids, "attention_mask": attn_mask} + + +def build_qwen3(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.qwen3.configuration_qwen3 import Qwen3Config + from transformers.models.qwen3.modeling_qwen3 import Qwen3Model + cfg = Qwen3Config( + vocab_size=4096, hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-6, rope_theta=1000000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, Qwen3Model, batch, seq_len, dtype) + + +def build_qwen3_moe(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel + cfg = Qwen3MoeConfig( + vocab_size=4096, hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, moe_intermediate_size=768, num_experts=4, num_experts_per_tok=2, + decoder_sparse_step=1, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-6, rope_theta=1000000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, Qwen3MoeModel, batch, seq_len, dtype) + + +def build_gemma3(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel + cfg = Gemma3TextConfig( + vocab_size=4096, hidden_size=2048, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=8192, head_dim=256, num_hidden_layers=2, + sliding_window=16, sliding_window_pattern=2, + max_position_embeddings=4096, rms_norm_eps=1e-6, rope_theta=10000.0, + torch_dtype=dtype, use_cache=False, _attn_implementation="eager", + ) + return _build_lm(cfg, Gemma3TextModel, batch, seq_len, dtype) + + +def build_deepseek_v3(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3Model + cfg = DeepseekV3Config( + vocab_size=4096, hidden_size=1024, num_attention_heads=16, num_key_value_heads=16, + intermediate_size=4096, moe_intermediate_size=512, + n_routed_experts=8, num_experts_per_tok=2, n_shared_experts=1, + n_group=2, topk_group=1, + q_lora_rank=512, kv_lora_rank=128, qk_rope_head_dim=32, qk_nope_head_dim=32, v_head_dim=64, + num_hidden_layers=2, first_k_dense_replace=1, + max_position_embeddings=4096, rms_norm_eps=1e-6, rope_theta=10000.0, + torch_dtype=dtype, use_cache=False, _attn_implementation="eager", + ) + return _build_lm(cfg, DeepseekV3Model, batch, seq_len, dtype) + + +def build_llama4(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.llama4.configuration_llama4 import Llama4TextConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextModel + cfg = Llama4TextConfig( + vocab_size=4096, hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, intermediate_size_mlp=3072, + num_local_experts=4, num_experts_per_tok=1, num_hidden_layers=2, interleave_moe_layer_step=2, + max_position_embeddings=4096, rms_norm_eps=1e-6, rope_theta=10000.0, + torch_dtype=dtype, use_cache=False, _attn_implementation="eager", + ) + return _build_lm(cfg, Llama4TextModel, batch, seq_len, dtype) + + +def build_glm4(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.glm4.configuration_glm4 import Glm4Config + from transformers.models.glm4.modeling_glm4 import Glm4Model + cfg = Glm4Config( + vocab_size=4096, pad_token_id=0, bos_token_id=1, eos_token_id=2, + hidden_size=1536, num_attention_heads=12, num_key_value_heads=2, + intermediate_size=4096, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-5, rope_theta=10000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, Glm4Model, batch, seq_len, dtype) + + +def build_olmo2(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.olmo2.configuration_olmo2 import Olmo2Config + from transformers.models.olmo2.modeling_olmo2 import Olmo2Model + cfg = Olmo2Config( + vocab_size=4096, hidden_size=2048, num_attention_heads=16, num_key_value_heads=16, + intermediate_size=8192, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-6, rope_theta=10000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, Olmo2Model, batch, seq_len, dtype) + + +def build_granite(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.granite.configuration_granite import GraniteConfig + from transformers.models.granite.modeling_granite import GraniteModel + cfg = GraniteConfig( + vocab_size=4096, hidden_size=2048, num_attention_heads=16, num_key_value_heads=8, + intermediate_size=8192, num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-5, rope_theta=10000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, GraniteModel, batch, seq_len, dtype) + + +def build_phimoe(batch=1, seq_len=32, dtype=torch.float32): + from transformers.models.phimoe.configuration_phimoe import PhimoeConfig + from transformers.models.phimoe.modeling_phimoe import PhimoeModel + cfg = PhimoeConfig( + vocab_size=4096, hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, num_local_experts=4, num_experts_per_tok=2, + num_hidden_layers=2, max_position_embeddings=4096, + rms_norm_eps=1e-5, rope_theta=10000.0, torch_dtype=dtype, use_cache=False, + _attn_implementation="eager", + ) + return _build_lm(cfg, PhimoeModel, batch, seq_len, dtype) + + +def build_mamba2(batch=1, seq_len=32, dtype=torch.float32): + # State-space model: no attention, no RoPE -- completely different op profile. + # Invariant: num_heads * head_dim == intermediate_size == expand * hidden_size + # (modeling_mamba2.py:171 + the view(B, num_heads*head_dim) at line 365). + from transformers.models.mamba2.configuration_mamba2 import Mamba2Config + from transformers.models.mamba2.modeling_mamba2 import Mamba2Model + cfg = Mamba2Config( + vocab_size=4096, hidden_size=512, + num_heads=16, head_dim=64, + state_size=16, chunk_size=16, + expand=2, n_groups=1, + num_hidden_layers=2, torch_dtype=dtype, use_cache=False, + ) + model = Mamba2Model(cfg).eval().to(dtype=dtype) + input_ids = torch.randint(0, cfg.vocab_size, (batch, seq_len)) + # Mamba has no attention mask; pass none. + return model, {"input_ids": input_ids} + + +def build_mllama(batch=1, seq_len=32, dtype=torch.float32): + # Llama 3.2 Vision -- text branch only (text-only call path). + # MllamaRotaryEmbedding requires config.rope_scaling["rope_type"]; pass default. + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaTextModel + cfg = MllamaTextConfig( + vocab_size=4096, pad_token_id=0, bos_token_id=1, eos_token_id=2, + hidden_size=1024, num_attention_heads=8, num_key_value_heads=4, + intermediate_size=3072, num_hidden_layers=2, + cross_attention_layers=[], + max_position_embeddings=4096, rms_norm_eps=1e-5, rope_theta=10000.0, + rope_scaling={"rope_type": "default"}, + torch_dtype=dtype, use_cache=False, _attn_implementation="eager", + ) + return _build_lm(cfg, MllamaTextModel, batch, seq_len, dtype) + + +BUILDERS = { + "qwen2": build_qwen2, + "gemma": build_gemma, + "gemma2": build_gemma2, + "phi3": build_phi3, + # Models newly available with transformers 4.51.3 + "qwen3": build_qwen3, + "qwen3_moe": build_qwen3_moe, + "gemma3": build_gemma3, + "deepseek_v3": build_deepseek_v3, + "llama4": build_llama4, + "glm4": build_glm4, + "olmo2": build_olmo2, + "granite": build_granite, + "phimoe": build_phimoe, + "mamba2": build_mamba2, + "mllama": build_mllama, +} + + +# --------------------------------------------------------------------------- +# Phase 1: enumerate aten ops by intercepting the FX graph from torch.compile. +# --------------------------------------------------------------------------- + +def _node_op_name(target): + # OpOverload / OpOverloadPacket: has a .name() method returning "aten::mm.default" etc. + if hasattr(target, "name") and callable(target.name): + try: + return target.name() + except Exception: + pass + if hasattr(target, "_schema"): + try: + return str(target._schema.name) + ( + "." + target._schema.overload_name if target._schema.overload_name else "" + ) + except Exception: + pass + # torch.* python builtins: use their __module__/__qualname__ + mod = getattr(target, "__module__", "") + qn = getattr(target, "__qualname__", None) or getattr(target, "__name__", "") + if mod and qn: + return f"{mod}.{qn}" + return str(target) + + +@torch.no_grad() +def enumerate_ops(model, inputs): + """Capture the post-AOTAutograd aten graph(s) via aot_module_simplified. + + This is the same level of IR TOGSim/Inductor consumes, so the op set + matches what the NPU backend actually has to lower. + """ + from functorch.compile import aot_module_simplified + + seen = set() + graph_sizes = [] + + def fw_compiler(gm, example_inputs): + graph_sizes.append(sum(1 for _ in gm.graph.nodes)) + for node in gm.graph.nodes: + if node.op == "call_function": + seen.add(_node_op_name(node.target)) + return gm.forward + + def dynamo_backend(gm, example_inputs): + return aot_module_simplified(gm, example_inputs, fw_compiler=fw_compiler) + + torch._dynamo.reset() + compiled = torch.compile(model, backend=dynamo_backend, dynamic=False) + compiled(**inputs) + return sorted(seen), graph_sizes + + +# --------------------------------------------------------------------------- +# Phase 2: real NPU compile + run. Capture and parse failure tracebacks. +# --------------------------------------------------------------------------- + +ATEN_RE = re.compile(r"aten[.:][a-zA-Z_][a-zA-Z0-9_.]*") +NOTIMPL_RE = re.compile(r"NotImplementedError[: ]+(.*)") + + +def parse_failure(tb_text): + aten_hits = [] + for m in ATEN_RE.finditer(tb_text): + op = m.group(0).replace("aten:", "aten.").lstrip(".") + if op not in aten_hits: + aten_hits.append(op) + msg = "" + nm = NOTIMPL_RE.search(tb_text) + if nm: + msg = nm.group(1).strip().splitlines()[0] + return aten_hits, msg + + +@torch.no_grad() +def run_on_npu(model, inputs): + device = torch.device("npu:0") + model = model.to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + torch._dynamo.reset() + compiled = torch.compile(model, dynamic=False) + out = compiled(**inputs) + # touch the output to force completion + if hasattr(out, "last_hidden_state"): + out.last_hidden_state.cpu() + elif isinstance(out, torch.Tensor): + out.cpu() + return "OK", None, None + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + +def run_model(name, args, out_dir): + builder = BUILDERS[name] + log_path = os.path.join(out_dir, f"{name}.log") + with open(log_path, "w") as fh: + def w(s=""): + print(s) + fh.write(s + "\n") + + w(f"=== {name} ===") + w(f"batch={args.batch} seq_len={args.seq_len} dtype={args.dtype}") + + try: + model, inputs = builder(args.batch, args.seq_len, _DTYPE_MAP[args.dtype]) + except Exception as e: + w(f"[BUILD FAIL] {type(e).__name__}: {e}") + return {"name": name, "status": "BUILD_FAIL", "ops": [], "fail_op": str(e)} + + # Phase 1 + w("\n[Phase 1] FX op enumeration (eager backend, no NPU)") + try: + ops, graph_sizes = enumerate_ops(model, inputs) + w(f" graphs: {len(graph_sizes)} total_nodes_per_graph: {graph_sizes}") + w(f" unique aten ops: {len(ops)}") + for op in ops: + w(f" {op}") + except Exception: + tb = traceback.format_exc() + w("[Phase 1 FAIL]\n" + tb) + ops = [] + + if args.enumerate_only: + return {"name": name, "status": "ENUM_ONLY", "ops": ops, "fail_op": None} + + # Phase 2 + w("\n[Phase 2] torch.compile on npu:0 + forward") + try: + status, fail_op, msg = run_on_npu(model, inputs) + w(f" status: {status}") + return {"name": name, "status": status, "ops": ops, "fail_op": None} + except Exception: + tb = traceback.format_exc() + hits, msg = parse_failure(tb) + w(" status: FAIL") + if msg: + w(f" NotImplemented message: {msg}") + if hits: + w(f" aten ops in traceback (first = most likely culprit):") + for h in hits[:10]: + w(f" {h}") + w("\n----- traceback -----\n" + tb) + return { + "name": name, + "status": "FAIL", + "ops": ops, + "fail_op": hits[0] if hits else "?", + "msg": msg, + } + + +_DTYPE_MAP = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--models", nargs="+", default=list(BUILDERS.keys()), + choices=list(BUILDERS.keys())) + p.add_argument("--batch", type=int, default=1) + p.add_argument("--seq-len", type=int, default=32) + p.add_argument("--dtype", default="float32", choices=list(_DTYPE_MAP.keys())) + p.add_argument("--enumerate-only", action="store_true", + help="Skip NPU compile; just list aten ops per model (fast).") + p.add_argument("--out-dir", default=None) + args = p.parse_args() + + ts = _dt.datetime.now().strftime("%Y%m%d_%H%M%S") + out_dir = args.out_dir or os.path.join( + os.environ.get("TORCHSIM_LOG_PATH", os.path.join(REPO_ROOT, "togsim_results")), + "op_coverage", ts, + ) + os.makedirs(out_dir, exist_ok=True) + print(f"Output dir: {out_dir}") + + results = [] + for name in args.models: + try: + results.append(run_model(name, args, out_dir)) + except KeyboardInterrupt: + print(f"[interrupt] aborted during {name}") + break + except Exception: + traceback.print_exc() + results.append({"name": name, "status": "DRIVER_ERR", "ops": [], "fail_op": None}) + + # Summary + summary_path = os.path.join(out_dir, "summary.txt") + with open(summary_path, "w") as fh: + def w(s=""): + print(s) + fh.write(s + "\n") + w("\n========== SUMMARY ==========") + w(f"{'model':10s} {'ops':>5s} {'status':10s} first_fail") + for r in results: + w(f"{r['name']:10s} {len(r['ops']):>5d} {r['status']:10s} {r.get('fail_op') or '-'}") + # Union & overlap across models + all_ops = set() + for r in results: + all_ops.update(r["ops"]) + w(f"\nUnion of aten ops across all models: {len(all_ops)}") + w("Per-model op set diff (ops unique to this model):") + for r in results: + others = set().union(*(set(r2["ops"]) for r2 in results if r2 is not r)) + unique = sorted(set(r["ops"]) - others) + w(f" {r['name']}: {len(unique)} unique") + for op in unique: + w(f" {op}") + + print(f"\nWrote: {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_mlir_bindings.py b/tests/test_mlir_bindings.py new file mode 100644 index 00000000..a0e5055d --- /dev/null +++ b/tests/test_mlir_bindings.py @@ -0,0 +1,56 @@ +"""Exercise the MLIR Python bindings the way a decompose-transfer pass would: +parse a custom op, read its AffineMap attr, build an scf.for loop with +affine.apply + an inner (unregistered) DMA op, erase the original, re-verify. +""" +from mlir.ir import (Context, Module, Location, InsertionPoint, Operation, + IndexType, IntegerAttr, AffineMap) +from mlir.dialects import scf, affine, arith, func, memref + +ctx = Context() +ctx.allow_unregistered_dialects = True + +with ctx, Location.unknown(): + src = ''' + func.func @kernel(%dram: memref<256x256xf16>, %sram: memref<128x128xf16, 1>) { + "togsim.transfer"(%dram, %sram) { + dma_kind = "MVIN", + src_map = affine_map<(d0, d1) -> (d0, d1 floordiv 16, d1 mod 16)> + } : (memref<256x256xf16>, memref<128x128xf16, 1>) -> () + return + } + ''' + m = Module.parse(src) + print("[1] parsed module ok") + + fn = m.body.operations[0] + blk = fn.regions[0].blocks[0] + transfer = next(op.operation for op in blk.operations + if op.operation.name == "togsim.transfer") + print("[2] found op:", transfer.name) + + src_map = transfer.attributes["src_map"] + print("[3] src_map attr:", src_map) + + idx = IndexType.get() + def cst(v): + return Operation.create("arith.constant", results=[idx], + attributes={"value": IntegerAttr.get(idx, v)}).result + + with InsertionPoint(transfer): + lb, ub, step = cst(0), cst(2), cst(1) + loop = scf.ForOp(lb, ub, step) + with InsertionPoint(loop.body): + iv = loop.induction_variable + base = affine.AffineApplyOp(AffineMap.get_identity(1), [iv]) + Operation.create("togsim.dma_descriptor", + operands=[base.result], results=[]) + scf.YieldOp([]) + print("[4] built scf.for + affine.apply + inner op") + + transfer.erase() + print("[5] erased original transfer") + + print("[6] verify:", m.operation.verify()) + print("----- rewritten IR -----") + print(str(m)) +print("ALL GOOD") diff --git a/thirdparty/github-releases.json b/thirdparty/github-releases.json index 3e836f2e..b641fd9a 100644 --- a/thirdparty/github-releases.json +++ b/thirdparty/github-releases.json @@ -8,7 +8,7 @@ }, "llvm_project": { "repository": "PSAL-POSTECH/llvm-project", - "release_tag": "v1.0.8", + "release_tag": "v1.0.10", "asset_name": "riscv-llvm-release.tar.gz" }, "spike": { From 3da3a7cf38de73cc7c54b3d60cd3718e3d4bce8d Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 17:33:52 +0900 Subject: [PATCH 02/20] [Frontend] decompose togsim.transfer to <=4D dma_start Emit togsim.transfer for >4D DMA and decompose it to a <=4D customized memref.dma_start: unit-collapse fast path, unrolled-subview peel for >4 effective dims. Fix #258 by emitting affine.apply (not arith.addi) for the peeled DRAM offset so the TOG pass can walk the loop index through it. --- .../mlir/passes/decompose_transfer.py | 216 ++++++++ docs/dma-transfer-lowering.md | 478 ++++++++++++++++++ 2 files changed, 694 insertions(+) create mode 100644 PyTorchSimFrontend/mlir/passes/decompose_transfer.py create mode 100644 docs/dma-transfer-lowering.md diff --git a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py new file mode 100644 index 00000000..87b8aadf --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py @@ -0,0 +1,216 @@ +"""Python out-of-line MLIR pass: decompose togsim.transfer -> <=4D memref.dma_start. + +A togsim.transfer carries a per-axis affine DMA whose descriptor rank may exceed +the 4D Gemmini limit. This pass is a **pure mechanical rank peel** of that +already-affine access (see docs/dma-transfer-lowering.md, "aligned-only peel"): + + - drop unit (extent-1) tile dims: they contribute no descriptor axis; + - if the remaining (effective) rank <= 4 -> emit one customized + memref.dma_start, reusing the transfer's operands (fast path); + - if effective rank > 4 -> peel the outer dims into a loop, adjusting the + base index by stride*iv per iteration, inner descriptor <=4D. + +It does NO floor/mod linearization (aligned split happens upstream at the +scheduling layer) and NO relayout (misaligned access is copy-inserted at the +graph level). A transfer whose access is not per-axis affine is a contract +violation -- but by construction codegen only emits affine transfers. + +togsim.transfer operands (see emit_transfer): + (dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_split_axis, vlane_stride) +attrs: dma_kind ("MVIN"/"MVOUT"), dram_stride[], tile_stride[], padding. + +memref.dma_start (customized) operands: + src[idx], dst[idx], dma_type, tag[idx], vlane_split_axis, vlane_stride + : src_memref, dst_memref, memref<1xi32> {dram_stride, sram_stride, padding} + +Pass interface (passes/__init__.py): MARKERS + run(module). +""" + +OP_NAME = "togsim.transfer" +MARKERS = (OP_NAME,) + + +def _iter_ops(block): + for op in list(block.operations): + yield op + for region in op.operation.regions: + for b in region.blocks: + yield from _iter_ops(b) + + +def _int_array(attr): + from mlir.ir import ArrayAttr, IntegerAttr + return [IntegerAttr(a).value for a in ArrayAttr(attr)] + + +def _squeeze_reassociation(shape): + """Group source dims so each group's product is one effective (non-unit) dim; + unit dims attach to a neighbor. Returns (groups, target_shape).""" + groups, cur = [], [] + for i, e in enumerate(shape): + cur.append(i) + if e > 1: + groups.append(cur) + cur = [] + if cur: # trailing unit dims + if groups: + groups[-1] += cur + else: + groups.append(cur) # all-ones -> single dim of size 1 + import math + target = [math.prod(shape[d] for d in g) for g in groups] + return groups, target + + +def run(module): + """Lower every togsim.transfer in `module`, in place. Context must be active.""" + import itertools + from mlir.ir import (InsertionPoint, Operation, MemRefType, ArrayAttr, + IntegerAttr, IntegerType, IndexType, DenseI64ArrayAttr, + DenseI32ArrayAttr, StridedLayoutAttr, AffineMap, AffineMapAttr, + AffineExpr) + i64 = IntegerType.get_signless(64) + idx_ty = IndexType.get() + + targets = [] + for region in module.operation.regions: + for b in region.blocks: + for op in _iter_ops(b): + if op.operation.name == OP_NAME: + targets.append(op.operation) + + for op in targets: + dram, dram_idx, sram, sram_idx, tag, dma_type, vst = op.operands + kind = op.attributes["dma_kind"].value # StringAttr -> "MVIN"/"MVOUT" + vlane_axis = IntegerAttr(op.attributes["vlane_split_axis"]).value + dram_stride = _int_array(op.attributes["dram_stride"]) + tile_stride = _int_array(op.attributes["tile_stride"]) + padding = op.attributes["padding"] + + sram_ty = MemRefType(sram.type) + elem, space = sram_ty.element_type, sram_ty.memory_space + tile_shape = list(sram_ty.shape) + # effective (non-unit) dims carry the descriptor; unit dims drop out. + eff = [i for i, e in enumerate(tile_shape) if e > 1] + + def _const(v): + return Operation.create( + "arith.constant", results=[idx_ty], + attributes={"value": IntegerAttr.get(idx_ty, v)}).results[0] + + def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr): + vsa = _const(vsa_val) + if kind == "MVIN": + operands = [dram, dram_idx_val, sram_mem, *sram_indices, + dma_type, tag, sram_idx, vsa, vst] + else: + operands = [sram_mem, *sram_indices, dram, dram_idx_val, + dma_type, tag, sram_idx, vsa, vst] + Operation.create( + "memref.dma_start", results=[], operands=operands, + attributes={"dram_stride": dr_attr, "sram_stride": tl_attr, + "padding": padding}) + + if len(eff) <= 4: + # Fast path: drop unit dims so the descriptor reaches <=4D. The customized + # dma_start convention requires SRAM rank == #indices == len(sram_stride), + # so collapse the unit tile dims away. DRAM stays flat rank-1 (its N-D + # structure is in dram_stride). + groups, target = _squeeze_reassociation(tile_shape) + reassoc = ArrayAttr.get( + [ArrayAttr.get([IntegerAttr.get(i64, d) for d in g]) for g in groups]) + collapsed_ty = MemRefType.get(target, elem, memory_space=space) + keep = [g[-1] for g in groups] # the non-unit dim in each group + dr_attr = ArrayAttr.get([IntegerAttr.get(i64, dram_stride[i]) for i in keep]) + tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[i]) for i in keep]) + # Remap vlane axis to the collapsed-dim index (the group containing it). + new_vlane = next(gi for gi, g in enumerate(groups) if vlane_axis in g) + with InsertionPoint(op): + sram_c = Operation.create( + "memref.collapse_shape", results=[collapsed_ty], operands=[sram], + attributes={"reassociation": reassoc}).results[0] + _emit(sram_c, [sram_idx] * len(target), dram_idx, new_vlane, + dr_attr, tl_attr) + op.erase() + continue + + # Peel path: >4 effective dims. Keep the inner 4 as the <=4D descriptor and + # peel the outer (len-4) effective dims into a fully-unrolled set of slices + # (one descriptor per outer index combo; base advances by stride*idx). The + # SRAM slice is a rank-reduced memref.subview at the slice offset; the DRAM + # base advances by a *constant* per slice. + # + # The constant DRAM offset must be folded into an affine.apply over the + # original dram_idx (NOT arith.addi): the TOG pass reads loop_idx_list by + # walking the DRAM index via processDramIndices, which understands + # affine.apply / block-arg / constant but NOT arith.addi -- an addi yields an + # empty loop_idx_list and the kernel fails ONNX serialization (#258). The + # peeled dim itself is a fixed constant in each unrolled slice (this DMA does + # not iterate it), so it correctly contributes no loop var; the surviving + # loop vars come from the original dram_idx affine.apply, into which + # processDramIndices recurses. + peeled, inner = eff[:-4], eff[-4:] + ndim = len(tile_shape) + inner_shape = [tile_shape[d] for d in inner] + inner_strides = [tile_stride[d] for d in inner] + dr_attr = ArrayAttr.get([IntegerAttr.get(i64, dram_stride[d]) for d in inner]) + tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[d]) for d in inner]) + # the vlane axis must survive into the inner descriptor (it is the lane dim). + new_vlane = inner.index(vlane_axis) if vlane_axis in inner else 0 + for combo in itertools.product(*[range(tile_shape[d]) for d in peeled]): + static_offsets = [0] * ndim + static_sizes = [1] * ndim + for k, d in enumerate(peeled): + static_offsets[d] = combo[k] + for d in inner: + static_sizes[d] = tile_shape[d] + sram_off = sum(combo[k] * tile_stride[peeled[k]] for k in range(len(peeled))) + dram_off = sum(combo[k] * dram_stride[peeled[k]] for k in range(len(peeled))) + res_ty = MemRefType.get( + inner_shape, elem, + layout=StridedLayoutAttr.get(sram_off, inner_strides), memory_space=space) + with InsertionPoint(op): + sub = Operation.create( + "memref.subview", results=[res_ty], operands=[sram], + attributes={"static_offsets": DenseI64ArrayAttr.get(static_offsets), + "static_sizes": DenseI64ArrayAttr.get(static_sizes), + "static_strides": DenseI64ArrayAttr.get([1] * ndim), + # operandSegmentSizes is an i32 property: [source, offsets, + # sizes, strides] dynamic-operand counts. All static here -> + # only the source operand. Must be i32, not i64 (i64 silently + # zeroes to [0,0,0,0] and fails verification). + "operandSegmentSizes": DenseI32ArrayAttr.get([1, 0, 0, 0])} + ).results[0] + if dram_off == 0: + dram_idx_val = dram_idx + else: + # affine.apply (d0) -> (d0 + dram_off) so TOG's processDramIndices + # recurses through it into the original dram_idx's loop vars. + amap = AffineMap.get(1, 0, [AffineExpr.get_dim(0) + dram_off]) + dram_idx_val = Operation.create( + "affine.apply", results=[idx_ty], operands=[dram_idx], + attributes={"map": AffineMapAttr.get(amap)}).results[0] + _emit(sub, [sram_idx] * 4, dram_idx_val, new_vlane, dr_attr, tl_attr) + op.erase() + + +def lower_text(text: str) -> str: + """Parse `text`, run this pass, return the printed module. CLI/testing helper.""" + if OP_NAME not in text: + return text + from mlir.ir import Context, Module, Location + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + m = Module.parse(text) + run(m) + return str(m) + + +if __name__ == "__main__": + import sys + out = lower_text(open(sys.argv[1]).read()) + if len(sys.argv) > 2: + open(sys.argv[2], "w").write(out) + else: + sys.stdout.write(out) diff --git a/docs/dma-transfer-lowering.md b/docs/dma-transfer-lowering.md new file mode 100644 index 00000000..cbf875c0 --- /dev/null +++ b/docs/dma-transfer-lowering.md @@ -0,0 +1,478 @@ +# DMA transfer op + decomposition lowering + +Status: **design / proposed**. Captures the plan to fix the recompile-dance +fragility by representing DMA as a high-level declarative transfer op and +decomposing it into affine descriptors in a lowering pass. + +Companion docs: `linalg-codegen-migration.md` (Plan B, the full structured-ops +rewrite this is a narrow tactical slice of). The near-term graph-level padding +work is referred to here as Plan A. + +## TL;DR + +The MLIR codegen forces tile sizes so that non-affine index expressions +(`FloorDiv` / `ModularIndexing`, produced by view/reshape/cat) collapse into the +DMA's strictly-affine 4D integer-stride address model. That forcing is a lazy, +greedy, monotonic, restart-based search (the "recompile dance") capped at 5 +retries; when operand constraints conflict or exceed 4D it hard-fails and the +model does not compile. This blocks model coverage, which is the primary goal. + +Proposed fix: stop forcing one affine descriptor. Introduce a **high-level +`togsim.transfer` op** that carries an iteration domain plus `iter->src` / +`iter->dst` affine maps (which may legally contain floordiv/mod), and a +**decomposition pass** that lowers it to a loop of the **existing customized +`memref.dma_start` descriptors** (kept unchanged as the leaf). The non-affine / +high-rank part is peeled into a base-pointer loop instead of being crammed into +one descriptor. This inverts the tile<->DMA dependency (the DMA adapts to the +tile, not the reverse), removes the rank cap, and removes the recompile dance. + +## Problem + +### Root cause: an impedance mismatch + +The DMA address model is `base + sum_i stride_i * idx_i`, with **integer strides, +4D**, i.e. strictly affine/linear. Inductor's index expressions are not: views, +reshapes, `cat`, and broadcasts introduce `FloorDiv` / `ModularIndexing`, which +are non-affine. The codegen copes by searching for a tiling under which the +floor/mod collapses to a linear stride within a tile -- that is exactly what the +ModularIndexing tile constraints ("tile must be a multiple of the floordiv +divisor and a divisor of the modular divisor") encode. + +### The recompile dance (where it breaks) + +`codegen_nodes` (`mlir_common.py`) is a `while True` loop, `max_retry_compile = 5`. +During emission, `get_dma_info` (`mlir_codegen_backend.py`) inspects the index: + +- the **split path** (good): `apply_divisor(axis, divisor, "split")` peels an axis + into two affine dims to represent floor/mod, inserting a `0` into `dram_stride`; +- the **pad path** (fragile): when the tile is not divisible it mutates the tile + (`set_tile_size`, `tile_constraint.fixed = True`) and raises `RecompileSignal` + to restart emission. + +It breaks because: + +1. **One global tile must satisfy every operand's divisibility** on a shared axis. + Fused ops with conflicting constraints (common with reshape/modular indexing) + cannot be satisfied at once -- this is the loop<->tensor mismatch. +2. **Greedy + monotonic + no backtracking + 5-retry cap.** `tile_constraint.fixed` + persists across retries (the tile descriptor lives on `kernel_group`, survives + `reset`), so the search ratchets one way; conflicting fixes oscillate and hit + the cap -> `RuntimeError("Failed to compile kernel after multiple attempts")`. +3. **4D rank cap.** A reshape needing more than 4 affine dims after splitting + raises `NotImplementedError`. +4. **vlane / LMUL entanglement.** Pad-forcing moves `vlane_split_axis` / relaxes + `vlane_stride`, and `compute_vec` must be a power of two; these can be mutually + unsatisfiable with the divisibility constraints. + +Padding logic is currently spread across three places: the Python recompile/tile +-adjust dance (1), the Python `get_mask` vector-tail handling (2), and the MLIR +`TestLoopPadding` pass (3). Removing (3) alone does not fix the fragility; (1) is +the larger source. + +### We already have a de-facto custom DMA op + +`get_dma_code` emits `memref.dma_start` overloaded via string formatting with +extra operands (`dma_type` MVIN/MVOUT, tag, `vlane_split_axis`, `vlane_stride`) +and extra attributes (`dram_stride`, `tile_stride`, padding type). This is a +custom descriptor in all but name -- and it is what Spike / gem5 / TOGSim already +consume. + +## Proposed design + +Two op levels, with a pass bridging them. + +``` +[high] togsim.transfer iteration domain + iter->src / iter->dst affine maps + (maps MAY contain floordiv/mod; rank unbounded) + | decompose-transfer pass (cost-aware peel) + v +[low] scf.for { customized memref.dma_start } <- existing leaf, UNCHANGED +``` + +### Low-level descriptor (keep as-is) + +The existing customized `memref.dma_start` is the lowering target / leaf: affine, +4D, integer stride, simulator-understood. **Do not add maps or floor/mod to it** -- +that would re-create the representational limit and blur the boundary. Optionally +formalize it into a real op (`togsim.dma_descriptor`) with a verifier so the pass +rewrites real ops instead of strings; not required to start. + +### High-level transfer op (new) + +Strawman: + +```mlir +togsim.transfer + ins(%src : memref) // DRAM + outs(%dst : memref<...xf16, 1>) // scratchpad + iter_bounds = [%M, %N, %K] // iteration domain (dynamic via SSA operands) + attributes { + src_map = affine_map<(m,n,k)[s0] -> (m, (n floordiv s0), (n mod s0), k)>, // non-affine lives here + dst_map = affine_map<(m,n,k) -> (m, n, k)>, + vlane_split_axis = 1, vlane_stride = 4, + dma_kind = "MVIN", tag_policy = "async", + peel_plan = [0] // optional: which iter dims to peel (decided in Python; see below) + } +``` + +Design choices: + +- **Iteration domain + two maps, not a single src->dst map.** A direct src->dst + relation only works for bijections; broadcast (`cat([a, a])`) and non-bijective + access need the loop-mediated form. This is the `linalg.generic` model, and peel + becomes "tile the iteration domain." +- **floor/mod ride in the `AffineMap`.** MLIR `AffineMap` supports constant-divisor + `floordiv`/`mod`/`ceildiv` natively; the codegen already produces these as + strings in `convert_index`. Symbolic divisors are semi-affine -- representable, + handled by our pass. +- **`memref`, not raw pointers.** The memref carries base + shape + layout so the + pass can reason about strides; `src_ptr`/`dst_ptr` are inside it. Buffer shape is + the memref type; the transfer region is `iter_bounds` + maps. +- **Closest existing op is `linalg.generic` / `linalg.copy`** (same shape + maps + + body structure) but it lacks DMA/vlane/tag/scratchpad semantics and its tiling + lowers to subview+scf, not our descriptors -- so a custom op modeled on linalg's + design, reusing AffineMap utilities. + +### Decomposition pass (contract): aligned-only mechanical peel + +> **Scope decision (narrowed).** This pass is a **pure mechanical rank peel** of an +> already-affine access. It does **not** linearize floor/mod and does **not** do +> relayout. Those two responsibilities moved upstream (see "Division of labor" +> below): aligned floor/mod is removed by **axis splitting at the Inductor +> scheduling layer** (`axis-split-scheduling.md`), and misaligned access is +> resolved by **graph-level copy insertion**. So every `togsim.transfer` that +> reaches this pass is guaranteed per-axis affine; the only thing left is that its +> rank may exceed the 4D Gemmini descriptor. + +The DMA descriptor is an **affine map of rank <= 4 with integer strides** +(`base + sum_i stride_i * idx_i`). The pass sees affine input (rank `D`) and: + +1. **`D <= 4`** -> emit **one** customized `memref.dma_start`; the dims become the + descriptor's <=4D shape/strides. Identical to today's output (fast path). +2. **`D > 4`** -> peel `D - 4` dims into an outer `affine.for`; each iteration + computes a base with `affine.apply` (the peeled dims' linear contribution) and + issues the inner <=4D affine descriptor. SRAM offsets are computed symmetrically + in the same loop. + +That is the whole pass. There is **no linearization step** (upstream guarantees +affine) and **no relayout fallback** (upstream graph copy handles misalignment). + +**Fail loud, not silent.** If the pass encounters floor/mod that does not reduce to +per-axis affine (misaligned), or a genuinely non-affine / indirect / gather index, +that is a **contract violation** -- upstream did not normalize it. The pass +**asserts/errors** rather than silently inserting a relayout. A silent in-pass copy +would be a hidden performance cliff and would duplicate, at the wrong layer, a +global layout decision only the graph can make correctly. + +The decision point maps onto existing code: `get_dma_info` already raises at >4D. +That exact site becomes "emit `togsim.transfer`" (done, Phase 1), and this pass +consumes it. The recompile/tile-forcing dance is unnecessary because (a) aligned +floor/mod is gone before codegen and (b) the outer peel loop's `ceil` bound absorbs +non-divisible remainders. + +### Division of labor (the affine-only contract) + +| floor/mod source | handled by | cost | layer | +|---|---|---|---| +| aligned (single axis, divisor \| extent; group norm, broadcast) | axis split | free | Inductor scheduling | +| misaligned (uneven cat, non-factor reshape, multi-axis arg) | copy insertion | copy | FX graph | +| affine but rank > 4 (e.g. 5D permute) | mechanical peel | free | **this pass** | +| data-dependent / indirect / gather | indirect-indexing path | -- | out of scope | + +Only the third row is this pass. The first two produce the affine-only invariant +this pass relies on. + +### Relationship to memref-to-gemmini (ISA lowering) -- keep separate + +`memref.dma_start` is the boundary, not the endpoint. The layering is: + + togsim.transfer --[Python decompose]--> memref.dma_start --[Python lower_dma_to_gemmini]--> Gemmini ISA + +decompose-transfer stops at `memref.dma_start` and must **not** emit Gemmini +instructions directly; the ISA encoding is a separate pass +(`passes/lower_dma_to_gemmini.py`, which replaced the C++ test-memref-to-gemmini). +Rationale: + +- **Separation of concerns**: decompose does descriptor decomposition (affine + algebra: rank / peel); gemmini does instruction encoding (hardware). Different + axes; merging couples affine logic with ISA detail. They stay distinct passes. +- **`memref.dma_start` is a shared contract** with multiple consumers + (lower_dma_to_gemmini, dma-fine-grained, the TOG pass). Keeping it as the + interface lets all of them stay unchanged. +- **gemmini is now a Python out-of-line pass too** -- the conversion-framework + coupling (LLVMTypeConverter / getStridedElementPtr) was avoided by working at + the memref level (`memref.extract_aligned_pointer_as_index` + arith for + addresses, `llvm.inline_asm` for instructions; the existing standard lowering + finalizes to LLVM). So both decompose and gemmini live in Python; mlir-opt keeps + only the remaining custom passes. + +One constraint flows the other way: gemmini's ISA limits (max dims / size per MVIN) +set decompose's target inner-descriptor shape (the "<=4D" and max-extent bounds). +decompose must *respect* those limits when it picks what stays inner vs gets peeled +-- but respecting a constraint is not doing the lowering. + +### Cost-aware peeling (this is a cycle-accurate simulator) + +Descriptor count is a **modeled cost** (issue overhead + DRAM burst efficiency in +Ramulator). Rules: + +1. Peel the **outermost, lowest-trip-count** dims (descriptor count = product of + peeled extents). +2. Keep the inner descriptor **as large and contiguous as possible** (maximize + bytes per descriptor). + +(A pathological peel is not this pass's problem to fix: it means the operand's +layout is bad, which is a graph-level layout/copy decision, not an in-pass +relayout.) + +### Placement: hybrid (least burden) + +Keep the decision in Python (where shape/sympy info is available and iteration is +fast); keep the C++ pass purely mechanical. + +| Step | Where | +|---|---| +| peel-plan decision (which dims to peel, count estimate) | Python | +| encode plan as op attributes | Python -> MLIR | +| emit `scf.for { customized dma_start }` per the plan | C++ pass | + +The cost model can migrate into C++ later if desired. + +## Expected effects + +- **Removes recompile-dance hard-fails.** The `max_retry` `RuntimeError` path + disappears: access that does not linearize is peeled, not retried-then-killed. + This directly increases model coverage (the primary goal). +- **Removes the 4D rank cap.** Arbitrary-rank reshapes become expressible via the + base-pointer loop; the `NotImplementedError` for >4D goes away. +- **Inverts the tile<->DMA dependency.** Tile size is chosen for compute / vlane + efficiency only; the DMA conforms to whatever access results. No divisibility + forcing, no oscillation. Tile selection simplifies. +- **Shrinks the codegen.** The `FloorDiv` / `ModularIndexing` recompile branches in + `get_dma_info`, and the in-emission `RecompileSignal` paths, leave Python; the + codegen emits one declarative op instead of procedurally forcing tiles. +- **Collapses two of the three padding sites for the DMA case.** Once divisibility + is no longer required to represent access, the Python tile-adjust dance (1) is + unnecessary for DMA, and `get_mask` (2) shrinks. `TestLoopPadding` (3) is + addressed by Plan A. (Compute-side vectorization remainder is separate; see + Plan A.) +- **Behavior-preserving for the common case.** Access without floor/mod still emits + a single `dma_start` identical to today -> low-risk, incremental rollout. +- **Preserves the simulator contract.** The leaf is the existing customized + `dma_start`; Spike / gem5 / TOGSim see the same descriptor kind, just more of + them in a loop. +- **A clean tactical slice toward Plan B.** This factors out exactly the one piece + that is actually broken (DMA decomposition) into a lowering pass, without the + full linalg rewrite. +- **Cost-aware, so modeled performance is protected.** Peel small/outer, keep inner + contiguous. Pathological layouts are fixed upstream (graph copy), not by an + in-pass relayout. + +## Migration strategy + +1. Define `togsim.transfer` (op + verifier) above the existing descriptor. + Optionally formalize the descriptor as `togsim.dma_descriptor`. +2. Make the codegen emit `togsim.transfer` for loads/stores, carrying the access + maps and vlane attributes it already computes. +3. Implement `decompose-transfer` with the fast path first (<=4D affine -> one + `dma_start`), proving **bit-identical output** to today on a smoke test. +4. Add the **affine** peel path for >4D; validate end-to-end through all three + simulators (the loop-of-descriptors must satisfy the TOG / Spike / gem5 + contract). Make the pass **assert** on any non-affine residue (contract guard). +5. Land the upstream producers of the affine-only invariant: aligned axis split at + scheduling (`axis-split-scheduling.md`) and misaligned graph copy insertion. +6. Remove the `get_dma_info` recompile branches once the pass + upstream cover their + cases; use the failure ledger + assert-only `TestLoopPadding` to confirm nothing + regresses before deleting. + +## Relationship to Plan A and Plan B + +- **Plan A (graph-level padding)** reduces how often peeling/relayout is needed by + making dims granule-aligned, and retires `TestLoopPadding`. Complementary: this + op makes representation robust; Plan A reduces constraint frequency. +- **Plan B (linalg)** is the full structured-ops rewrite; `expand_shape` / + `collapse_shape` are the principled home for reshape, and the framework would + generate the same peel/relayout under the hood. This transfer op is the narrow, + now-achievable slice of that idea. + +## Risks / open questions + +- **C++ pass in the `PSAL-POSTECH/llvm-project` fork**: heavier iteration + (rebuild), logic split across two repos. Mitigated by the hybrid split (smarts in + Python, pass is mechanical). +- **TOG / Spike / gem5 contract on a loop of descriptors.** If TOG generation + assumes "one DMA = one node," the loop form needs handling. Validate at step 4. +- **Cost model accuracy** for the peel plan; start with a simple descriptor-count + threshold and refine against measured cycles. +- **Dynamic shapes**: `iter_bounds` as SSA operands; affine peel must handle + symbolic outer extents. (Symbolic-divisor floor/mod normalization is an upstream + concern, not this pass.) +- **Upstream completeness.** The pass's fail-loud contract is only safe if the + upstream producers (axis split + graph copy) actually normalize every misaligned + case. Until they do, the assert may fire on real models -- track which ops trip it + as the work-list for the upstream passes. +- **Async / tag management across the peel loop**: double-buffering / compute + overlap must survive decomposition (e.g. keep the inner large DMA async, sequence + the outer peel). + +## Appendix: alignment theory (when floor/mod is statically decomposable) + +This section records the math that decides, for a given DMA access, whether the +non-affine `FloorDiv`/`ModularIndexing` terms can be peeled into a *static* loop +of affine descriptors (free) or require a data movement (relayout / copy). + +### Setup + +A Gemmini-style descriptor addresses an element as + + addr(idx) = base + Σ_k stride_k · idx_k (integer strides, rank <= 4) + +i.e. each loop index `idx_k` contributes a **constant** stride. A DMA is +statically decomposable iff every index term it reads has constant stride over +the rectangular tile domain. Inductor index expressions, after fusion/view, carry +`FloorDiv(x, y)` and `ModularIndexing(x, y, z)` of the *flattened* loop variable +`x`. The question is when those reduce to constant-stride axes. + +### Mixed-radix decomposition + +Write the flattened index `x` (extent `E`) in mixed radix. For a `ModularIndexing` +with inner period `y` and modulus `z`, decompose uniquely as + + x = o·(y·z) + m·y + r, with 0 <= r < y, 0 <= m < z, o >= 0 + +Then `FloorDiv(x, y) = o·z + m`, and `ModularIndexing(x, y, z) = m`. Each of +`o, m, r` is a separate **implicit axis** with a constant per-axis stride — +*provided the axis boundaries do not move across the tile*. That holds iff the +period divides the extent it partitions: + +- `ModularIndexing(x, y, z)` is a valid rectangular axis **iff y·z | E**. +- `FloorDiv(x, y)` is a valid rectangular axis **iff y | E**. + +**Aligned** = the divisor (and modular period `y·z`) divides the extent, so the +wrap point lands on a fixed axis boundary -> constant stride -> peelable for free. +**Misaligned** = the wrap point falls at a loop-value-dependent position inside the +descriptor (e.g. uneven `cat`, ragged split) -> the stride is not constant -> +**not** statically decomposable; only a relayout (physical copy) fixes it. + +### One loop axis -> several implicit axes (complex fusion) + +When fusion merges many dims into one flattened loop variable, a *single* loop +axis can expand into **several** implicit axes through nested floor/mod, e.g. + + x in [0, D0·D1·D2): + a = FloorDiv(x, D1·D2) # outer + b = ModularIndexing(x, D2, D1) # middle + c = ModularIndexing(x, 1, D2) # inner + +That is three implicit descriptor axes coming from one loop axis. This is the +general case the un-flatten must handle: it is **not** limited to splitting one +axis into two. Key consequences: + +1. **The loop's own factorization is always aligned.** When the implicit axes + come from re-reading the loop's *own* contiguous factorization (the common + fusion case -- Inductor flattens contiguous dims then a consumer reads them + back via floor/mod), every period divides by construction (`D1·D2 | D0·D1·D2`, + etc.). So these un-flatten splits are **free** -- they just add descriptor + axes, never a copy. +2. **Rank blows past 4 fast.** k implicit axes per loop axis, across multiple + operands, means the descriptor rank exceeds the 4D Gemmini limit very quickly. + This is exactly why `togsim.transfer` + the peel pass matters *more* under + complex fusion, independent of any misalignment. The >4D branch in + `get_dma_info` already routes these to `togsim.transfer`. +3. **Misalignment is still only from non-factor views.** An implicit axis is + misaligned only when its period does not divide the extent -- i.e. the view + does not factor along the loop's factorization (uneven `cat`, ragged split, + group sizes that don't divide the channel count). Those, and only those, need + relayout. + +### Case-handling summary + +| Source of floor/mod | Aligned? | Handling | Cost | +|------------------------------------------------|----------|-----------------------------------|------| +| Broadcast / dim-merge (`[N,1]->[N,M]`, `i//M`) | always | un-merge (split loop axis back) | free | +| Reshape along the loop's own factorization | yes (`y·z\|E`) | un-flatten split, then peel for rank | free | +| >4D logical tile from complex fusion | yes | `togsim.transfer` -> peel into <=4D loop | free (extra DMA nodes) | +| Uneven `cat`, ragged split, non-dividing group | no | graph copy insertion (relayout, upstream) | copy = TPU `concatenate` | + +The TPU/XLA model is the reference: express only aligned views as +descriptor/bitcast (free reshape); never put a misaligned access in the +descriptor -- insert a copy (relayout) instead. Plan A (graph-level +force-contiguous / pad-to-granule, like XLA copy-insertion) is the upstream lever +that *reduces how often* the misaligned branch fires, keeping codegen affine-only. + +## Implementation status (Phase 1: codegen emission) + +Landed on branch `dma-transfer/codegen` (worktree), emission only -- the +decompose pass is deferred until explicitly signalled. A >4D access now emits a +`togsim.transfer` instead of hard-failing; without the pass it does not yet run +end-to-end (expected). + +- **`mlir_common.py` `init_tile_size`** generalized to any rank. Logical tile is + separated from the physical (<=4D) descriptor: only the innermost dims carry the + vectorized tile, all further-outer dims stay 1, and there is no rank cap. The + `nr_dim >= 3` formula reproduces the old 3D/4D values exactly (the old `[-4]=1` + is subsumed by "outer dims stay 1"); scalar/1D/2D keep their special cases. This + removes the old `raise NotImplementedError("dummy tile size fail!")` that + conflated logical and physical tile rank. +- **`mlir_codegen_backend.py`**: + - `__init__` adds `self._dma_needs_transfer = False`. + - `get_dma_info` >4D `else` branch (was + `raise NotImplementedError("Currently not implemented... ;)")`) now builds the + full N-D tile (`set_tile_size`, vlane split/stride) and sets + `self._dma_needs_transfer = True`. + - `emit_transfer(...)` emits the generic-form `"togsim.transfer"(...)` op + carrying `dma_kind`, `vlane_split_axis`, `vlane_stride`, `dram_stride`, + `tile_stride`, `padding`, with operands `(dram, dram_idx, sram, 0, tag)`. + `togsim` is an unregistered dialect, hence generic form. + - `load()` (MVIN) and `store()` (MVOUT) check the flag: if set, reset it and + call `emit_transfer`; otherwise the existing `get_dma_code` path is unchanged. + So aligned <=4D DMAs are **bit-identical** to before; only >4D accesses change. + +Validated: the 5D permute smoke test (`x.permute(4,3,2,1,0).contiguous() + 1.0`) +now emits MVIN/MVOUT `togsim.transfer` with 5D `dram_stride [1,6,30,120,360]` and a +`memref<1x1x2x4x2xf32,1>` tile, instead of crashing in `init_tile_size` or the +`get_dma_info` >4D branch. + +### Phase 2: aligned-only peel pass (landed: unit-collapse path) + +`passes/decompose_transfer.py` (registered in `passes/__init__.py`, runs before +`lower_vlane_idx`) lowers each `togsim.transfer` to a customized `memref.dma_start`: + +- **Unit-dim collapse (done, validated).** Drop extent-1 tile dims so the + descriptor reaches <=4D. The SRAM (spad) memref is collapsed to the effective + rank via `memref.collapse_shape` (the customized `dma_start` convention requires + SRAM rank == #indices == len(sram_stride)); DRAM stays flat rank-1 with its N-D + structure in `dram_stride`. The `vlane_split_axis` is **remapped** from the + original tile-dim index to the collapsed-dim index and rematerialized as a const + (carried as a value attr precisely so the pass can remap it). +- Supporting changes: `emit_transfer` now carries the SSA operands a `dma_start` + needs (`dma_type`, `vlane_stride`) + the `vlane_split_axis` value attr, so the + pass is mechanical. `lower_to_llvm.py` gains `expand-strided-metadata` to lower + `collapse_shape`. + +Validated end-to-end (Gem5 + Spike + TOGSim, `allclose=True`) on the 5D permute +`x.permute(4,3,2,1,0).contiguous() + 1.0`; no regression on 2D/3D/elementwise. + +- **Genuine >4 effective rank (isolation-only; INCOMPATIBLE with TOG -- see TODO).** + When >4 *non-unit* dims survive, the pass keeps the inner 4 as the <=4D descriptor + and peels the outer dims by **full unrolling**: one descriptor per outer-index + combo, the SRAM slice a rank-reduced `memref.subview` at the static slice offset, + the DRAM base `dram_idx + constant`. This passes `lower_text` / mlir-opt in + isolation, but **fails the full pipeline**: the C++ TOG generation pass cannot read + `memref.subview` + unrolled (constant-offset) DMAs and produces an empty + `loop_idx_list` (ValueError in `onnx_utility.py`). Surfaced once aligned axis-split + made the path reachable (pixel_shuffle -> 5D); axis-split now has a rank guard that + avoids triggering it. + +> **TODO (peel rework, tracked as GitHub issue #258).** Rewrite the >4D peel to emit +> a real `affine.for` over the peeled dims (so each DMA keeps an enclosing loop index +> the TOG pass can read) and index the spad directly instead of via `memref.subview`. +> Alternatively teach the C++ TOG pass to handle `subview` + unrolled DMAs. Until +> then the unroll path is isolation-only and the axis-split rank guard keeps it +> unreached. + +The input stays per-axis affine by upstream guarantee, so both paths are pure +mechanical peeling. A non-affine residue is a contract violation (aligned floor/mod +removal lives in `axis-split-scheduling.md`, misaligned relayout in graph copy +insertion -- see "Division of labor"); a genuinely non-affine / indirect index +would surface as a build failure here rather than being silently relaid out. From 1d0c888f272296c18f19c5e0400b1e95927106f1 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 17:33:52 +0900 Subject: [PATCH 03/20] [Frontend] axis-split: remove aligned floor/mod at the scheduling layer Split a loop axis so aligned FloorDiv/ModularIndexing collapse to per-axis affine indices. Mixed-radix split over a divisibility chain; integer-typed split symbols; r-prefix innermost reduce dims. Reindex the collapsed LoopBody instead of re-tracing; fold residual floor/mod via tensor range info. Shared boundary helpers, rank guard, and an uncovered floor/mod ledger. Enabled by default with the recompile fallback instrumented. --- PyTorchSimFrontend/mlir/axis_split.py | 301 +++++++++++++++++++++ PyTorchSimFrontend/mlir/mlir_common.py | 35 +-- PyTorchSimFrontend/mlir/mlir_scheduling.py | 44 +++ docs/axis-split-scheduling.md | 211 +++++++++++++++ 4 files changed, 574 insertions(+), 17 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/axis_split.py create mode 100644 docs/axis-split-scheduling.md diff --git a/PyTorchSimFrontend/mlir/axis_split.py b/PyTorchSimFrontend/mlir/axis_split.py new file mode 100644 index 00000000..1c33e021 --- /dev/null +++ b/PyTorchSimFrontend/mlir/axis_split.py @@ -0,0 +1,301 @@ +"""Aligned axis splitting at the Inductor scheduling layer. + +Goal: guarantee the MLIR codegen sees only per-axis affine index expressions +(no FloorDiv / ModularIndexing). When an index expr contains FloorDiv(v, k) or +ModularIndexing(v, k, m) where `v` is a single iteration variable of extent E +and the divisor (resp. k*m) divides E, the floor/mod is *aligned*: splitting the +loop axis v into (outer, inner) with v = outer*k + inner makes it collapse to a +plain affine term (outer), at zero data-movement cost. + +This is the cheap upstream tool of the affine-only contract. The misaligned case +(cat / non-factor reshape, divisor does not divide the extent) is NOT handled +here -- that needs graph-level copy insertion. + +The rebuild reuses Inductor's own LoopBody machinery, exactly like +MLIRScheduling.revert_group: feed a split var_ranges + iter_vars and re-trace the +node's store function so the index expressions are regenerated over the new +iteration domain. +""" +import sympy +from torch._inductor.ir import LoopBody +from torch._inductor.utils import sympy_index_symbol +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + + +def _as_int(x): + try: + return int(x) + except (TypeError, ValueError): + return None + + +def collect_boundaries(exprs, var_to_axis, var_ranges): + """{axis_index: set(boundary cut points)} for the given index expressions. + + A FloorDiv(v, k) contributes boundary k; ModularIndexing(v, k, m) contributes + k and k*m. Only aligned terms count (boundary divides the var extent). Shared + by find_split_plan (fused LoopBody) and graph_copy (operand loaders). + """ + import collections + bset = collections.defaultdict(set) + for expr in exprs: + for fd in expr.atoms(FloorDiv): + base, div = fd.args + k = _as_int(div) + if base in var_to_axis and k and k > 1: + E = _as_int(var_ranges.get(base)) + if E and E % k == 0: + bset[var_to_axis[base]].add(k) + for mi in expr.atoms(ModularIndexing): + base, div, mod = mi.args + k, m = _as_int(div), _as_int(mod) + if base in var_to_axis and k and m: + E = _as_int(var_ranges.get(base)) + if E and E % (k * m) == 0: + ax = var_to_axis[base] + if k > 1: + bset[ax].add(k) + if k * m < E: + bset[ax].add(k * m) + return bset + + +def _is_chain(boundaries, E): + """True iff [1, sorted(boundaries in (1,E)), E] is a divisibility chain.""" + chain = [1] + sorted(b for b in boundaries if 1 < b < E) + [E] + return all(chain[i + 1] % chain[i] == 0 for i in range(len(chain) - 1)) + + +def ledger(nodes, plan): + """Classify every FloorDiv/ModularIndexing in the kernel against `plan`. + + Returns a list of (op_name, reason, term_str) for the terms NOT covered by + axis-split, so we can measure how often the graph-copy cases (incompatible + radix / non-dividing / multi-axis / dynamic) actually reach codegen. Read-only. + Reasons: covered terms are omitted; uncovered ones are + multi_axis_arg - floor/mod argument is not a single iter var (case 7) + non_dividing - divisor (or k*m) does not divide the extent (case 6) + incompatible_radix - single var, divides, but boundaries did not form a + divisibility chain so the axis was left unsplit (case 5) + dynamic - symbolic divisor/extent + """ + rows = [] + + def classify(base, k, m, var_to_axis, var_ranges): + if not (isinstance(base, sympy.Symbol) and base in var_to_axis): + return None if False else "multi_axis_arg" + ax = var_to_axis[base] + E = _as_int(var_ranges.get(base)) + if k is None or E is None or (m is not None and _as_int(m) is None): + return "dynamic" + if ax in plan: + return "covered" + period = k if m is None else k * _as_int(m) + if period and E % period != 0: + return "non_dividing" + return "incompatible_radix" + + for n in nodes: + body = getattr(n, "_body", None) + if body is None: + continue + op = n.get_name() if hasattr(n, "get_name") else "?" + var_to_axis = {v: i for i, v in enumerate(body.iter_vars)} + for expr in body.indexing_exprs.values(): + for fd in expr.atoms(FloorDiv): + r = classify(fd.args[0], _as_int(fd.args[1]), None, var_to_axis, body.var_ranges) + if r and r != "covered": + rows.append((op, r, str(fd))) + for mi in expr.atoms(ModularIndexing): + r = classify(mi.args[0], _as_int(mi.args[1]), mi.args[2], var_to_axis, body.var_ranges) + if r and r != "covered": + rows.append((op, r, str(mi))) + return rows + + +def find_split_plan(nodes): + """Inspect a group of scheduler nodes and return {axis_index: boundaries}. + + `boundaries` is an ascending divisibility chain [1, b1, ..., E] of cut points + for that axis: splitting the axis at these boundaries (mixed radix, + `v = sum_i d_i * b_i`) makes every FloorDiv/ModularIndexing on it collapse to + an affine combination of the split sub-vars. The cut points are gathered from + the terms on the axis: + - FloorDiv(v, k) -> boundary k + - ModularIndexing(v, k, m) -> boundaries k and k*m (the digit lives in [k, k*m)) + Only aligned terms count (the boundary must divide the extent E). If the + collected boundaries for an axis do NOT form a divisibility chain (e.g. + floor-by-2 and mod-by-3 on extent 6), the radices are incompatible -> the axis + is left unsplit (its floor/mod stays for the misaligned/recompile path). + + axis_index is positional in the group's iteration space, so the same plan + applies to every fused node sharing that space. + """ + import collections + bset = collections.defaultdict(set) # axis -> set of boundary cut points + ext_of = {} # axis -> extent + for n in nodes: + body = getattr(n, "_body", None) + if body is None: + continue + var_to_axis = {v: i for i, v in enumerate(body.iter_vars)} + nb = collect_boundaries(body.indexing_exprs.values(), var_to_axis, body.var_ranges) + for ax, bs in nb.items(): + bset[ax] |= bs + ext_of[ax] = _as_int(body.var_ranges[body.iter_vars[ax]]) + + plan = {} + for ax, bs in bset.items(): + E = ext_of[ax] + # require a real, divisibility-chain split (incompatible radices -> skip). + if E and any(1 < b < E for b in bs) and _is_chain(bs, E): + plan[ax] = [1] + sorted(b for b in bs if 1 < b < E) + [E] + + # Validation aid: force-split the first even index axis even without floor/mod. + # A floor-free index split is an identity transformation, so allclose must hold; + # used to exercise the reduction pass-through path (no natural op produces a + # floor on a reduction kernel's index axis). Off unless TORCHSIM_AXIS_SPLIT_FORCE. + import os as _os + if _os.environ.get("TORCHSIM_AXIS_SPLIT_FORCE"): + for n in nodes: + body = getattr(n, "_body", None) + if body is None or not body.reduce_vars: + continue + for ax, v in enumerate(body.iter_vars): + E = _as_int(body.var_ranges.get(v)) + if ax not in plan and E and E % 2 == 0 and E > 2: + plan[ax] = [1, 2, E] + break + + # Rank guard: if the split would push the index rank past 4, skip it and fall + # back to baseline. The >4D logical tile is *meant* to be peeled into <=4D + # physical descriptors by the decompose-transfer pass, and the #258 TOG crash + # (arith.addi DRAM offset) is now fixed -- but the peel still has a numerical + # correctness bug (pixel_shuffle -> MISMATCH; the peel was only ever isolation- + # validated for MLIR structure, never run end-to-end). Keep the guard until the + # peel numerics are fixed; then this guard can be removed and the recompile-dance + # retired for pixel. + base_rank = next((len(b.iter_vars) for n in nodes + for b in (getattr(n, "_body", None),) if b is not None), 0) + extra = sum(len(ch) - 2 for ch in plan.values()) + if base_rank + extra > 4: + return {} + return plan + + +def build_split_body(node, plan, prefix="z"): + """Rebuild node._body / sizes for the given split plan. + + Returns (body, (index_size, reduce_size)). Reindexes the EXISTING (already + collapsed/reordered) node._body via LoopBody's copy path instead of re-tracing + from the raw store function: pass the body as `fn` so LoopBody.__init__ takes + _init_with_copy, which substitutes each original iter var with our expression + and runs simplify_with_ranges. For a split axis the substitution + v -> sum_i d_i * b_i (mixed radix over the boundary chain) makes every + FloorDiv/ModularIndexing on it collapse to an affine combination of the d_i, + and reindexing the collapsed body keeps already-merged dims merged (no rank + blow-up). indexing_from_args requires exactly one replacement expr per original + var (index dims then reduce dims), flattened to len(body.var_ranges). + """ + body = node._body + orig_index_vars = list(body.iter_vars) + orig_reduce_vars = list(body.reduce_vars) + + iter_vars = [] + index_args = [] # one expr per ORIGINAL index dim (substituted in) + var_ranges = {} + index_size = [] + ctr = 0 + + for ax, v in enumerate(orig_index_vars): + ext = body.var_ranges[v] + if ax in plan: + bounds = plan[ax] # ascending chain [1, b1, ..., E] + # one sub-var per segment: d_i has extent b_{i+1}/b_i, significance b_i. + subs = [] # (symbol, extent, significance) low->high + expr = sympy.Integer(0) + for i in range(len(bounds) - 1): + seg_ext = bounds[i + 1] // bounds[i] + nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1 + subs.append((nv, seg_ext, bounds[i])) + expr = expr + nv * bounds[i] + # iteration nest: most-significant (outermost) dim first. + for nv, seg_ext, _sig in reversed(subs): + iter_vars.append(nv) + var_ranges[nv] = sympy.Integer(seg_ext) + index_size.append(sympy.Integer(seg_ext)) + index_args.append(expr) + else: + nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1 + iter_vars.append(nv) + var_ranges[nv] = ext + index_size.append(ext) + index_args.append(nv) + + # Reduction dims pass through unchanged (a fresh symbol with the same range), + # using the "r" prefix and kept after the index dims so the reduction axis + # stays innermost (var_ranges is ordered iter-then-reduce; sizes splits on + # len(iter_vars)). We do not split reduction dims here. + reduce_vars = [] + reduce_size = [] + reduce_args = [] + for rctr, v in enumerate(orig_reduce_vars): + ext = body.var_ranges[v] + nv = sympy_index_symbol(f"r{rctr}") + reduce_vars.append(nv) + var_ranges[nv] = ext + reduce_size.append(ext) + reduce_args.append(nv) + + args = [index_args, reduce_args] if orig_reduce_vars else [index_args] + new_body = LoopBody(body, args, var_ranges, iter_vars, reduce_vars) + new_body.indexing_exprs = { + name: _fold_with_ranges(e, var_ranges) + for name, e in new_body.indexing_exprs.items() + } + return new_body, (index_size, reduce_size) + + +def _fold_with_ranges(expr, var_ranges): + """Fold residual FloorDiv/ModularIndexing that simplify_with_ranges missed. + + A mixed-radix split leaves terms like FloorDiv(z1 + 4*z2, 12); these are 0 by + construction (the lower digits sum below the boundary), but the Inductor + simplifier cannot prove a multi-term numerator < divisor. We prove it directly + from the split sub-var ranges via bound_sympy: + FloorDiv(num, d) -> 0 if 0 <= num < d + ModularIndexing(num, k, m) -> num // k if 0 <= num < k*m (mod is a no-op) + Iterated to a fixpoint (folding a mod can expose a foldable floor). + """ + from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + ranges = {} + for v, sz in var_ranges.items(): + e = _as_int(sz) + if e is not None and e >= 1: + ranges[v] = ValueRanges(0, e - 1) + if not ranges: + return expr + + def vr(num): + try: + return bound_sympy(num, ranges) + except Exception: + return None + + for _ in range(8): + changed = False + for fd in list(expr.atoms(FloorDiv)): + num, div = fd.args + d = _as_int(div) + b = vr(num) if d else None + if b is not None and b.lower >= 0 and b.upper < d: + expr = expr.subs(fd, sympy.Integer(0)); changed = True + for mi in list(expr.atoms(ModularIndexing)): + num, k, m = mi.args + ki, mi_ = _as_int(k), _as_int(m) + b = vr(num) if (ki and mi_) else None + if b is not None and b.lower >= 0 and b.upper < ki * mi_: + expr = expr.subs(mi, FloorDiv(num, k)); changed = True + if not changed: + break + return expr diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 734ca967..45bb144a 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -1,5 +1,6 @@ import dataclasses import math +import os import contextvars from contextlib import contextmanager from dataclasses import dataclass @@ -472,29 +473,25 @@ def apply_constraints(self, constraints, ranges): @staticmethod def init_tile_size(ranges, vlane_stride, vector_lane): + # Logical tile init for ANY rank. Only the innermost dims carry the + # vectorized tile; all further-outer dims stay 1. The physical Gemmini DMA + # descriptor is <=4D -- a higher-rank logical tile is mapped onto <=4D + # descriptors by togsim.transfer + the decompose pass (logical/physical + # tile split), so no rank cap here. nr_dim = len(ranges) + if nr_dim == 0: # scalar + return [1] tile_size = [1] * nr_dim - if len(tile_size) == 2: + if nr_dim == 1: + tile_size[0] = 1 if ranges[0] == 1 else 2 * vlane_stride * vector_lane + elif nr_dim == 2: tile_size[-1] = vlane_stride * vector_lane tile_size[-2] = 2 * vector_lane - elif len(tile_size) == 0: # Scalar - tile_size = [1] - ranges = [1] - elif len(tile_size) == 1 and ranges[0]==1: - tile_size[0] = 1 - elif len(tile_size) == 1: - tile_size[0] = 2 * vlane_stride * vector_lane - elif len(tile_size) == 3: + else: # 3D and up (general) tile_size[-1] = vector_lane tile_size[-2] = 4 * vector_lane tile_size[-3] = 2 - elif len(tile_size) == 4: - tile_size[-1] = vector_lane - tile_size[-2] = 4 * vector_lane - tile_size[-3] = 2 - tile_size[-4] = 1 - else: - raise NotImplementedError("dummy tile size fail!") + # tile_size[:-3] stay 1 (subsumes the old 4D [-4]=1 and any higher rank) return tile_size @staticmethod @@ -840,10 +837,14 @@ def codegen_nodes(self, nodes, kernel_name): node.run(vars, reduction_vars) except RecompileSignal as e: recompile_try += 1 + # Measure what still depends on the recompile-dance once axis-split + + # graph-copy are on by default (set TORCHSIM_RECOMPILE_LOG=1). + if os.environ.get("TORCHSIM_RECOMPILE_LOG"): + import sys as _sys + print(f"[RECOMPILE {recompile_try}/{max_retry_compile}] {e}", file=_sys.stderr) if recompile_try > max_retry_compile: raise RuntimeError("Failed to compile kernel after multiple attempts.") # Retry compile nodes - #print(f"Try recompile({recompile_try}/{max_retry_compile}). Reason: {e}") continue V.graph.removed_buffers |= self.removed_buffers # V.graph.inplaced_to_remove |= self.inplaced_to_remove diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 22d1011b..48eead47 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -249,6 +249,44 @@ def codegen_node(self, _node): nodes, key=lambda x: int(x.is_reduction()) ).group + def _dump_axis(tag): + import sys as _sys + print(f"\n[AXIS_SPLIT:{tag}] group={group} reduction_group={reduction_group}", file=_sys.stderr) + for _n in nodes: + _body = getattr(_n, "_body", None) + if _body is None: + continue + print(f"[AXIS_SPLIT:{tag}] node={_n.get_name()} var_ranges={getattr(_body, 'var_ranges', None)}", file=_sys.stderr) + for _k, _e in getattr(_body, "indexing_exprs", {}).items(): + print(f"[AXIS_SPLIT:{tag}] idx[{_k}] = {_e}", file=_sys.stderr) + + if os.environ.get("TORCHSIM_DEBUG_AXIS_SPLIT"): + _dump_axis("before") + + if os.environ.get("TORCHSIM_AXIS_LEDGER"): + from . import axis_split + import sys as _sys + _plan = axis_split.find_split_plan(nodes) + for _op, _reason, _term in axis_split.ledger(nodes, _plan): + print(f"[AXIS_LEDGER] op={_op} reason={_reason} term={_term}", file=_sys.stderr) + + # axis-split is ON by default; set TORCHSIM_AXIS_SPLIT=0 to disable. + if os.environ.get("TORCHSIM_AXIS_SPLIT", "1") != "0": + from . import axis_split + plan = axis_split.find_split_plan(nodes) + if plan: + for _n in nodes: + if getattr(_n, "_body", None) is None: + continue + _body, _ranges = axis_split.build_split_body(_n, plan) + _n._sizes, _n._body, _n.group = _ranges, _body, (_n.get_device(), self.group_fn(_ranges)) + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + if os.environ.get("TORCHSIM_DEBUG_AXIS_SPLIT"): + print(f"[AXIS_SPLIT] applied plan={plan}", file=__import__("sys").stderr) + _dump_axis("after") + # Note: We assume that there is at least one loop in the nodes # But, inductor simplifies the group, there could be no loop # In that case, we add dummy loop(size=1) to the group @@ -353,3 +391,9 @@ def get_order(n): if origins: _, _, last = max(origins) V.graph.wrapper_code.enter_context(last) + + +# Install the graph-copy (incompatible-radix relayout) lowering hook once at import. +# No-op unless TORCHSIM_GRAPH_COPY is set; see graph_copy.py. +from . import graph_copy as _graph_copy +_graph_copy.install() diff --git a/docs/axis-split-scheduling.md b/docs/axis-split-scheduling.md new file mode 100644 index 00000000..10171ab4 --- /dev/null +++ b/docs/axis-split-scheduling.md @@ -0,0 +1,211 @@ +# Aligned axis splitting at the Inductor scheduling layer + +Status: **prototype / proposed**. Companion to `dma-transfer-lowering.md`. This doc +covers the *upstream* half of the affine-only contract: removing aligned +`FloorDiv` / `ModularIndexing` from index expressions before they reach MLIR +codegen, by splitting loop axes at the Inductor scheduling layer. + +## Goal: the affine-only contract + +We want the MLIR codegen (`get_dma_info` in `mlir_codegen_backend.py`) to receive +only per-axis affine index expressions: + + off(i,j,k,...) = base + Sum_k stride_k * loop_var_k (stride_k constant int) + +with **zero** `FloorDiv` / `ModularIndexing`. If that invariant holds, codegen no +longer fights non-affine indices: the recompile dance (RecompileSignal, forced +tile sizes, max_retry_compile), the heuristic `TestLoopPadding` pass, and the +hard-fail-on-conflict path all become unnecessary. Codegen's only remaining job +is the *mechanical* rank<=4 peel for the Gemmini descriptor (orthogonal; see +`dma-transfer-lowering.md`), which operates on already-affine input. + +Two tools produce this invariant, matching the alignment theory: + +- **aligned floor/mod -> axis split** (this doc): loop transformation, free, no + data movement. +- **misaligned floor/mod -> graph copy insertion** (XLA-style): genuine data + movement; out of scope here. + +"Aligned" means the floor/mod argument is a *single* iteration variable `v` of +extent `E` and the divisor `k` (resp. `k*m` for ModularIndexing) divides `E`, so +splitting `v = outer*k + inner` lands the wrap point on a fixed axis boundary. + +## Where: the scheduling layer already rebuilds LoopBody + +`mlir_scheduling.py` already does loop-IR surgery at the scheduling layer: + +- `revert_group` (line ~219) rebuilds a `LoopBody` from `get_store_function()` + with a chosen `var_ranges` -- it undoes Inductor's `simplify_and_reorder`. +- `codegen_node` (line ~246) injects dummy size-1 loops when Inductor + over-simplified the group. + +Axis splitting is the same operation with a different `var_ranges`: split the +axes carrying aligned floor/mod, then rebuild. No new infrastructure -- reuse +`LoopBody`. This is "upstream" of MLIR codegen and native to Inductor's IR (sympy +ranges + index exprs), so we are not reverse-engineering MLIR text. + +## How: detect / rebuild / hook + +Implemented in `PyTorchSimFrontend/mlir/axis_split.py`, wired into +`codegen_node` behind `TORCHSIM_AXIS_SPLIT=1` (dump with +`TORCHSIM_DEBUG_AXIS_SPLIT=1`). + +1. **Detect -- `find_split_plan(nodes)`**: scan each node's + `_body.indexing_exprs` for `FloorDiv(v, k)` / `ModularIndexing(v, k, m)` where + `v` is a single iter var and the divisor divides `v`'s extent. Return + `{axis_index: divisor}`, keyed positionally so it applies to every fused node + sharing the iteration space. +2. **Rebuild -- `build_split_body(node, plan)`**: rebuild `node._body` / + `_sizes` with the split var_ranges; feed the store function the index + expression `outer*k + inner` at the split dim so the floor/mod collapses. +3. **Hook -- `codegen_node`**: apply the plan to every node + (`_sizes, _body, group = ...`), then recompute the group. + +## Empirical validation (group norm) + +`group_norm(x[2,6,4,4], num_groups=3)` normalize kernel, before vs after split: + + before var_ranges={p0:2, p1:6, p2:16} + idx0 = 96*p0 + 16*p1 + p2 # x input, affine + idx1 = 3*p0 + (p1//2) # mean/rstd <- FloorDiv(p1,2), 2|6 aligned + idx2 = p1 # weight/bias, affine + + after plan={1: 2}, var_ranges={s0:2, s1:3, s2:2, ...} + idx1 = 3*s0 + (s1//1) # FloorDiv collapsed to identity -> s1 + ... # mean now affine; s2/spatial broadcast (stride 0) + +The FloorDiv is eliminated. group `(2,6,16) -> (2,3,2,...)`. + +## Coverage (what this framework can and cannot do) + +| Case | Example | Status | +|---|---|---| +| aligned FloorDiv, single var | group norm `c//2` (2\|6) | DONE (prototype) | +| aligned ModularIndexing | `(v//k)%m`, k*m\|E | needs mixed-radix multi-split | +| multiple radices on one axis | `//2` + `%3`, E=6 | needs nested split (now: first divisor only) | +| reduction-axis floor/mod | `r//k` inside reduce | needs reduction-var splitting | +| divisor does not divide extent | C=8 groups of 3; uneven cat | IMPOSSIBLE by split -> graph copy | +| multi-axis argument | `(4p+q)//6` non-factor reshape | IMPOSSIBLE by split -> graph copy | +| dynamic / symbolic | `v//s`, symbolic extent | separate symbolic/guard path | + +The aligned class is the framework's domain (currently only single-split +FloorDiv); the misaligned class is structurally a graph-copy problem. + +## Resolved + +- **5D blow-up (fixed).** `build_split_body` now reindexes the already-collapsed + `node._body` via `LoopBody`'s copy path (pass the body as `fn` -> + `_init_with_copy`), instead of re-tracing the raw store function over + `inode.data.get_size()`. This keeps merged dims merged (spatial stays `16`), + so group_norm goes `(2,6,16) -> (2,3,2,16)` (4D, no cap hit), and + `_init_with_copy`'s `simplify_with_ranges` folds the split floor. +- **`floor//1` residue (fixed).** The fold only happened once the new symbols + carried integer/non-negative assumptions: build them with + `torch._inductor.utils.sympy_index_symbol` (not bare `sympy.Symbol`), which is + also why the index prefix must not be `s` (reserved for shape symbols). With + this, `idx1 = 3*p0 + (p1//2)` becomes `3*z0 + z1` -- the channel FloorDiv is + gone, not left as `z1//1`. +- **Symbol conventions.** Index dims use the `z` prefix; reduction dims use the + `r` prefix and are kept after the index dims so the reduction axis stays + innermost (`var_ranges` is ordered iter-then-reduce; `LoopBody.sizes` splits on + `len(iter_vars)`). LoopBody var names are remapped to `index` during MLIR + codegen, so the prefix is internal -- but it must not collide with the original + body's names (those are `p`/`q`, so `z`/`r` are safe). + +## Resolved (cont.) + +- **`floor//1` / residual floor on multi-level split (fixed).** `simplify_with_ranges` + cannot prove a *multi-term* numerator is below the divisor (e.g. + `FloorDiv(z1 + 4*z2, 12)` with `z1<4, z2<3`), so a 3-level mixed-radix split left + a residual floor that codegen rejected ("Not supporting this view operation"). + `_fold_with_ranges` now proves it directly from the split sub-var ranges via + `bound_sympy`: `FloorDiv(num,d)->0` when `0<=numnum//k` + when `0<=num 5D), which triggers the nascent + decompose-transfer peel + TOG path (see below). `find_split_plan` now has a rank + guard: if applying the plan would make the index rank exceed 4, the whole plan is + dropped and the kernel falls back to baseline. pixel_shuffle now passes (via + baseline); 3D group_norm still splits (rank 4, allowed). + +## Known issues / open + +- **decompose-transfer peel <-> TOG incompatibility**: the >4D peel emits + `memref.subview` + unrolled constant-offset `dma_start`, which the C++ TOG + generation pass cannot read (empty `loop_idx_list`). The rank guard above + side-steps it; the real fix is to rewrite the peel as an `affine.for` loop + (keeping a loop index TOG can read) instead of unrolling. **Tracked as a GitHub + issue + the `dma-transfer-lowering.md` TODO.** + +## Done + +- **Mixed-radix (ModularIndexing + multi-radix)**: `find_split_plan` returns a + per-axis divisibility-chain of boundaries; `build_split_body` splits into one + sub-var per segment (`v = sum_i d_i*b_i`). Validated allclose=True on group_norm + (FloorDiv, `[1,2,6]`) and `x.repeat(1,2)` (single-axis ModularIndexing, + `[1,8,16]`); pixel_shuffle (floor+mod on two axes) linearizes correctly. +- **Reduction pass-through**: reduction dims keep the `r` prefix and stay innermost + (after the index dims). Exercised via the `TORCHSIM_AXIS_SPLIT_FORCE` validation + gate (force-split a reduction kernel's index axis even without floor -- an + identity transform, so allclose must hold): layernorm `(512)->(256,2)` and + reduce `(68)->(34,2)` keep their reduction groups and pass. +- **Graph-copy for incompatible radices (case 5)** -- `graph_copy.py`, + `TORCHSIM_GRAPH_COPY`. When two operands of an elementwise consumer carry + incompatible-radix groupings on a shared axis (e.g. `a[c//2] + b[c%3]`, floor-by-2 + vs mod-by-3 on extent 6 -- not a divisibility chain), neither axis-split nor the + recompile-dance can express it. We wrap the registered lowering entries (the + make_pointwise results = every elementwise consumer, one place), trace each + operand's loader with `extract_read_writes` to get its read indices, run the same + `collect_boundaries` analysis, and if the union is not a chain, `realize()` the + cheaper operand. realize() (not clone -- Inductor inlines clone, confirmed) forces + a buffer: the consumer then reads it affine and the remaining single grouping is + handled by axis-split. Validated: `incompat` (`a.repeat_interleave(2)+b.repeat(2)`) + goes ERR -> allclose=True with `GRAPH_COPY+AXIS_SPLIT` (still ERR on default, + confirming graph-copy is the fix); no regression on the pattern battery, + test_add, resnet (compile overhead negligible). +- **Graph-copy for cross-axis floor/mod (case 7)** -- same hook. A transpose+reshape + feeding a consumer that keeps the output dims separate (broadcast / softmax / + layernorm / reduce-one-dim) produces a floor/mod whose argument spans *two* loop + vars, e.g. `(3*p0+p1)//4`; axis-split cannot split a multi-var argument. We detect + an operand whose read index has a floor/mod argument with >1 free symbol and + replace it with `ExternKernel.copy_input` (a realized identity Pointwise). This is + why copy_input and not `realize()`: `StorageBox.realize()` is a no-op on a + ReinterpretView (a reshape), so it does not materialize view operands; copy_input + forces the copy. The copy kernel iterates the operand's own contiguous shape, so + its index collapses to single-var for axis-split, and the consumer reads the copy + affine. Also covers single-operand consumers (a reduction reading a multi-var + view). Validated allclose=True: reshape+broadcast, softmax(reshape), + layernorm(reshape) (all ERR on default). NOTE the empirical correction: case 7 is + NOT rare -- it is the common attention/norm "reshape then reduce/broadcast" + shape; Inductor only avoids it when it can collapse the output to 1D (then the + floor is single-var). + +## Default-on + recompile-dance status + +axis-split and graph-copy are **ON by default** (disable with `TORCHSIM_AXIS_SPLIT=0` +/ `TORCHSIM_GRAPH_COPY=0`). With them on, the codegen recompile-dance (tile-forcing +for floor/mod divisibility) is demoted from primary mechanism to a rarely-hit +fallback. + +Measured under default-on (`TORCHSIM_RECOMPILE_LOG=1`), 33 tests, all pass: +- 16 core (elementwise/gemm/reduce/conv/view/fusion + mlp/resnet/transformer/vit): 0 recompiles. +- 7 broader families (cnn/pool/group_conv/sort/indirect_access/exponent/conv_fusion): 0 recompiles. +- 10 floor/mod patterns: 1 recompile total (an unrelated tile-divisibility in the + 3-level mixed-radix case). + +**Full retirement of the dance is deferred** (it is still a real dependency, not +just a safety net): removing the floor/mod recompile branches would break the +3-level mixed-radix case (1 recompile) and any case axis-split/graph-copy do not +yet cover (case 6, >4D rank-guard skips). attention/sdpa families were not run here +(too slow locally) and need CI validation before retirement. + +## Next steps + +1. Eliminate the last recompile dependency (the 3-level mixed-radix sub-kernel) so + the dance reaches 0/all -> then retire the floor/mod recompile branches (keep the + non-floor/mod ones: non-power-of-2 vec size, indirect). +2. Graph-copy coverage: case 6 (non-dividing divisor / uneven cat -> pad or gather), + and conflicts internal to templates (gemm/conv/sdpa). +3. High-rank interaction: cap split-induced rank or harden decompose-peel + TOG for + high-rank tiles (pixel_shuffle end-to-end, #258). +4. Dynamic shapes -> symbolic divisibility / guards. From 5ee55e44a6cfb26ce7681c6afa59c5e5807952af Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 17:33:52 +0900 Subject: [PATCH 04/20] [Frontend] graph-copy: relayout operands on incompatible / cross-axis floor/mod Insert a copy to relayout an operand whose floor/mod cannot be removed by axis-split: incompatible-radix shared-axis access and cross-axis multi-variable arguments. Enabled by default alongside axis-split. --- PyTorchSimFrontend/mlir/graph_copy.py | 163 ++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 PyTorchSimFrontend/mlir/graph_copy.py diff --git a/PyTorchSimFrontend/mlir/graph_copy.py b/PyTorchSimFrontend/mlir/graph_copy.py new file mode 100644 index 00000000..51c2e9b6 --- /dev/null +++ b/PyTorchSimFrontend/mlir/graph_copy.py @@ -0,0 +1,163 @@ +"""Graph-copy (relayout) for incompatible-radix operands. + +When an elementwise consumer reads two operands whose floor/mod groupings on a +shared axis are incompatible (the boundary cut points do not form a divisibility +chain, e.g. floor-by-2 and mod-by-3 on extent 6), axis-split cannot linearize the +fused index. We `realize()` the cheaper operand at the consumer's lowering, which +materializes it as a contiguous buffer; the consumer then reads it affine and only +the other (single, compatible) grouping remains for axis-split to handle. + +Detection reuses axis_split.collect_boundaries on each operand's loader index, so +it is the same precise radix analysis used at the scheduling layer -- not an FX +view-chain heuristic. The hook wraps the already-registered lowering entries (the +make_pointwise results), so it sees every elementwise consumer in one place. The +realize() (not a clone, which Inductor inlines) is what actually forces the buffer +boundary; see the PoC notes in docs. + +Gated by TORCHSIM_GRAPH_COPY (install() is a no-op otherwise). Behavior-neutral +unless a genuine incompatible-radix conflict is detected. +""" +import os +from torch._inductor import lowering as L +from torch._inductor import dependencies +from torch._inductor import ir +from torch._inductor.ir import TensorBox +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from . import axis_split + + +def _has_multivar_floormod(exprs): + """True if any FloorDiv/ModularIndexing argument spans >1 loop variable + (case 7: cross-axis floor/mod that axis-split cannot split).""" + for e in exprs: + for f in list(e.atoms(FloorDiv)) + list(e.atoms(ModularIndexing)): + if len(f.args[0].free_symbols) > 1: + return True + return False + + +def _numel(tb): + n = 1 + for s in tb.get_size(): + v = axis_split._as_int(s) + if v is None: + return float("inf") + n *= v + return n + + +def _relayout_args(args): + """Return a modified args list with one operand replaced by a forced copy when + it needs relayout, or None to leave args unchanged. The copy uses + ExternKernel.copy_input (a realized identity Pointwise) -- this materializes + *views* too, unlike StorageBox.realize() which is a no-op on a ReinterpretView. + The copy kernel iterates the operand's own (contiguous) shape, so its index + collapses to single-var and axis-split handles it; the consumer then reads the + copy affine.""" + pos = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + if not pos: + return None + tbs = [args[i] for i in pos] + # Output/iteration shape = the broadcast of all operands (the largest rank, + # max per dim). For a single-operand consumer (e.g. a reduction reading a + # multi-var-view input) this is just that operand's shape -- still enough to + # detect a multi-var floor and copy_input it (case 7); the 2-operand radix + # conflict (case 5) naturally needs >=2 operands. + ranges = max((t.get_size() for t in tbs), key=len) + extents = [axis_split._as_int(s) for s in ranges] + dbg = os.environ.get("TORCHSIM_GRAPH_COPY_DEBUG") + if dbg: + print(f"[GC] consumer ntbs={len(tbs)} ranges={extents} " + f"sizes={[[axis_split._as_int(s) for s in t.get_size()] for t in tbs]}") + if not extents or any(e is None for e in extents): + return None # scalar / dynamic -> skip + + # Only true elementwise consumers: each operand is broadcast-compatible with the + # output (same rank, every dim is 1 or == the output extent). This admits + # broadcasting operands (e.g. y[8,1] into [8,3]) while excluding mm/bmm/cat-style + # ops whose operands differ in a non-broadcast way. + for tb in tbs: + sz = [axis_split._as_int(s) for s in tb.get_size()] + if len(sz) != len(extents) or any( + d is not None and d != 1 and d != e for d, e in zip(sz, extents) + ): + return None + + # Trace each operand's loader to get its read indices (sympy) over the shared + # output iteration; make_loader returns a value, so extract_read_writes is what + # gives the index expressions. range_vars are positional per output axis, so the + # axis numbering is consistent across operands. + per_bnd = [] # [{axis: boundary set}] per operand + per_mv = [] # [bool] operand has multi-var floor/mod + for tb in tbs: + try: + rw = dependencies.extract_read_writes(tb.make_loader(), list(ranges)) + except Exception as e: + if dbg: + print(f"[GC] extract fail {type(e).__name__}: {repr(e)[:60]}") + per_bnd.append({}) + per_mv.append(False) + continue + v2a = {v: i for i, v in enumerate(rw.range_vars)} + exprs = [r.index for r in rw.reads if hasattr(r, "index")] + b = axis_split.collect_boundaries(exprs, v2a, rw.var_ranges) + mv = _has_multivar_floormod(exprs) + if dbg: + print(f"[GC] operand reads={[str(e) for e in exprs]} boundaries={dict(b)} multivar={mv}") + per_bnd.append(b) + per_mv.append(mv) + + victim = None + + # Case 5 -- incompatible radices on a shared axis between two operands. + for axis, E in enumerate(extents): + contrib = [(i, per_bnd[i][axis]) for i in range(len(tbs)) if per_bnd[i].get(axis)] + if len(contrib) < 2: + continue # single grouping -> axis-split handles + union = {b for _, s in contrib for b in s} + if axis_split._is_chain(union, E): + continue # compatible -> axis-split handles + victim = min(contrib, key=lambda c: _numel(tbs[c[0]]))[0] + break + + # Case 7 -- an operand whose floor/mod argument spans multiple consumer axes + # (e.g. (3*p0+p1)//4 from a transpose+reshape feeding a broadcast/softmax that + # keeps the dims separate). axis-split cannot split a multi-var argument. + if victim is None: + mv_ops = [i for i in range(len(tbs)) if per_mv[i]] + if mv_ops: + victim = min(mv_ops, key=lambda i: _numel(tbs[i])) + + if victim is None: + return None + new = list(args) + p = pos[victim] + new[p] = ir.ExternKernel.copy_input(args[p]) + if dbg: + print(f"[GC] relayout: copy_input operand #{victim} (arg {p})") + return new + + +def install(): + """Wrap registered lowering entries to insert relayout. Idempotent; ON by + default (set TORCHSIM_GRAPH_COPY=0 to disable). Call once at backend import + (after torch._inductor.lowering is populated -- make_pointwise runs at import + to build the entries, so we wrap the entries, not the factory).""" + if os.environ.get("TORCHSIM_GRAPH_COPY", "1") == "0": + return + if getattr(L, "_torchsim_relayout_installed", False): + return + for key, fn in list(L.lowerings.items()): + def wrap(orig): + def wrapped(*a, **k): + try: + na = _relayout_args(a) + except Exception: + na = None # detection must never break lowering + if na is not None: + a = na + return orig(*a, **k) + return wrapped + L.lowerings[key] = wrap(fn) + L._torchsim_relayout_installed = True From e39bfc08e8359b5f7fb09111ee40bb2558a0212e Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 17:33:52 +0900 Subject: [PATCH 05/20] [Frontend] build_tog: port test-tile-operation-graph to Python Port the analysis and IR-mutation halves of the C++ test-tile-operation-graph pass to Python, wire build_tog into the gem5 path, and drop the C++ pass. Node-id counter is thread-local for concurrent compilation. --- PyTorchSimFrontend/mlir/passes/build_tog.py | 1139 +++++++++++++++++++ 1 file changed, 1139 insertions(+) create mode 100644 PyTorchSimFrontend/mlir/passes/build_tog.py diff --git a/PyTorchSimFrontend/mlir/passes/build_tog.py b/PyTorchSimFrontend/mlir/passes/build_tog.py new file mode 100644 index 00000000..ae515010 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/build_tog.py @@ -0,0 +1,1139 @@ +"""Python port of the C++ `-test-tile-operation-graph` analysis pass. + +Reads a post-vcix MLIR file, walks `func.func @kernel`, and prints the Tile +Operation Graph (TOG) as a Python `graph = { id: {dict}, ... }` literal to +stdout. The output is byte-exact with the upstream C++ pass +(mlir/test/lib/Analysis/TestTileOperationGraph.cpp + the display() methods in +mlir/include/mlir/Analysis/TileOperationGraph.h). + +Usage: + python3 build_tog.py + +Requires the MLIR Python bindings on PYTHONPATH +(/riscv-llvm/python_packages/mlir_core). +""" + +import os +import sys + +# Allow running standalone without PYTHONPATH set. +_DEFAULT_BINDINGS = "/riscv-llvm/python_packages/mlir_core" +if os.path.isdir(_DEFAULT_BINDINGS) and _DEFAULT_BINDINGS not in sys.path: + sys.path.insert(0, _DEFAULT_BINDINGS) + +import mlir.ir as ir # noqa: E402 + + +# --------------------------------------------------------------------------- +# Node kinds / compute types (mirror TileOperationGraph.h enums). +# --------------------------------------------------------------------------- +BASE_NODE = 0 +COMPUTE_NODE = 1 +LOOP_NODE = 2 +DMA_NODE = 3 +DMA_WAIT_NODE = 4 + +VECTOR_COMPUTE = 0 +MATMUL_COMPUTE = 1 +MATMUL_PRELOAD = 2 + +# Printed compute_type string for each compute-node type (mirrors the C++ +# inline-asm marker emission in TestTileOperationGraph.cpp). +_COMPUTE_TYPE_NAME = { + VECTOR_COMPUTE: "VectorCompute", + MATMUL_COMPUTE: "MatmulCompute", + MATMUL_PRELOAD: "MatmulPreload", +} + + +# Unique-id counter (TOGNode::unique_id), incremented at construction. The C++ +# pass keeps this as a process global, but each kernel runs in its own mlir-opt +# PROCESS so the counters never collide. Here the pass runs in-process and +# extension_codecache compiles kernels CONCURRENTLY in a thread pool, so a single +# module global would race across threads (one kernel's reset/increments interleave +# with another's, leaving nodes -- including the root -- with wrong ids). Keep the +# counter thread-local so each compile thread has an isolated counter, mirroring +# the per-process isolation of the C++ pass. +import threading # noqa: E402 + +_ids = threading.local() + + +def _new_id(): + i = getattr(_ids, "counter", 0) + _ids.counter = i + 1 + return i + + +def _reset_ids(): + _ids.counter = 0 + + +# --------------------------------------------------------------------------- +# Node classes with display() mirroring the C++ header byte-for-byte. +# --------------------------------------------------------------------------- +class TOGNode: + kind = BASE_NODE + + def __init__(self, name, node_id=None): + self.node_id = _new_id() if node_id is None else node_id + self.node_name = name + self.parents = [] + self.children = [] + self.op = None + + def get_last_child(self): + return self.children[-1] if self.children else None + + def add_child(self, c): + self.children.append(c) + + def add_parent(self, p): + self.parents.append(p) + + @staticmethod + def _print_list(name, vec, out): + out.append('\t"%s": [' % name) + out.append(",".join(str(v) for v in vec)) + out.append("]") + + @staticmethod + def _print_str_list(name, vec, out): + out.append('\t"%s": [' % name) + out.append(",".join('"%s"' % v for v in vec)) + out.append("]") + + def _print_node_list(self, name, vec, out): + out.append('\t"%s": [' % name) + out.append(",".join(str(n.node_id) for n in vec)) + out.append("]") + + def display(self, out): + out.append("%d : {\n" % self.node_id) + out.append('\t"node_id": %d,\n' % self.node_id) + out.append('\t"node_name": "%s",\n' % self.node_name) + out.append('\t"node_type": %d,\n' % self.kind) + self._print_node_list("parents", self.parents, out) + out.append(",\n") + self._print_node_list("children", self.children, out) + if self.kind == BASE_NODE: + out.append("\n}\n") + + def bfs(self, out): + if self.node_name == "root": + out.append("graph = {\n") + self.display(out) + out.append(",") + for child in self.children: + child.bfs(out) + if self.node_name == "root": + out.append("}") + + +class TOGLoopNode(TOGNode): + kind = LOOP_NODE + + def __init__(self, name, idx, start, end, step, loop_type): + super().__init__(name) + self.loop_idx = idx + self.loop_start = start + self.loop_end = end + self.loop_step = step + self.loop_type = loop_type + + def display(self, out): + TOGNode.display(self, out) + out.append(",\n") + out.append('\t"loop_index": "%s",\n' % self.loop_idx) + out.append('\t"loop_start": %d,\n' % self.loop_start) + out.append('\t"loop_end": %d,\n' % self.loop_end) + out.append('\t"loop_step": %d,\n' % self.loop_step) + out.append('\t"loop_type": "%s"\n' % self.loop_type) + out.append("}\n") + + +class TOGDMANode(TOGNode): + kind = DMA_NODE + + def __init__(self, name, addr, tile_size, tile_stride, elem_size, is_write, + is_async, tag_idx_list, tag_stride_list, loop_idx_list, + loop_stride_list, indirect_mode): + super().__init__(name) + self.base_addr = addr + self.tile_size = tile_size + self.tile_stride = tile_stride + self.element_size = elem_size + self.is_write = is_write + self.is_async = is_async + self.tag_idx_list = tag_idx_list + self.tag_stride_list = tag_stride_list + self.loop_idx_list = loop_idx_list + self.loop_stride_list = loop_stride_list + self.indirect_mode = indirect_mode + + def display(self, out): + TOGNode.display(self, out) + out.append(",\n") + out.append('\t"is_write": %d,\n' % int(self.is_write)) + out.append('\t"is_async": %d,\n' % int(self.is_async)) + out.append('\t"base_address": "%s",\n' % self.base_addr) + out.append('\t"indirect_mode": %d,\n' % int(self.indirect_mode)) + self._print_list("tile_size", self.tile_size, out) + out.append(",\n") + self._print_list("tile_stride", self.tile_stride, out) + out.append(",\n") + self._print_str_list("tag_idx_list", self.tag_idx_list, out) + out.append(",\n") + self._print_list("tag_stride_list", self.tag_stride_list, out) + out.append(",\n") + self._print_str_list("loop_idx_list", self.loop_idx_list, out) + out.append(",\n") + self._print_list("loop_stride_list", self.loop_stride_list, out) + out.append(",\n") + out.append('\t"element_size": %d\n' % self.element_size) + out.append("}\n") + + +class TOGDMAWaitNode(TOGNode): + kind = DMA_WAIT_NODE + + def __init__(self, name, idx_list, tag_stride_list, tag_divider_list, addr): + super().__init__(name) + self.tag_idx_list = idx_list + self.tag_stride_list = tag_stride_list + self.tag_divider_list = tag_divider_list + self.base_addr = addr + + def display(self, out): + TOGNode.display(self, out) + out.append(",\n") + out.append('\t"base_address": "%s",\n' % self.base_addr) + self._print_str_list("tag_idx_list", self.tag_idx_list, out) + out.append(",\n") + self._print_list("tag_stride_list", self.tag_stride_list, out) + out.append(",\n") + self._print_list("tag_divider_list", self.tag_divider_list, out) + out.append("}\n") + + +class TOGComputeNode(TOGNode): + kind = COMPUTE_NODE + + def __init__(self, name, cycle, ctype): + super().__init__(name) + self.compute_cycle = cycle + self.compute_type = ctype + self.operations = [] + + def display(self, out): + TOGNode.display(self, out) + out.append(",\n") + out.append('\t"compute_cycle": %d,\n' % self.compute_cycle) + out.append('\t"compute_type": %d\n' % self.compute_type) + out.append("}\n") + + +# --------------------------------------------------------------------------- +# MLIR helpers. +# --------------------------------------------------------------------------- +def _op_name(value_or_op): + return value_or_op.operation.name + + +def _is_block_arg(value): + return ir.BlockArgument.isinstance(value) + + +def _defining_op_name(value): + """Return the op name that defines `value`, or None if it is a block arg.""" + if _is_block_arg(value): + return None + return value.owner.name + + +def _const_index_value(value): + """If `value` is an arith.constant of index type, return its int value.""" + if _is_block_arg(value): + return None + owner = value.owner + if owner.name != "arith.constant": + return None + attr = owner.attributes["value"] + try: + return ir.IntegerAttr(attr).value + except Exception: + return None + + +def _value_key(value): + """Identity key for a Value (induction vars map -> loop name).""" + if _is_block_arg(value): + ba = ir.BlockArgument(value) + return ("ba", ba.owner, ba.arg_number) + return ("res", value.owner, 0) + + +def _memref_space(memref_type): + mt = ir.MemRefType(memref_type) + sp = mt.memory_space + if sp is None: + return 0 + return ir.IntegerAttr(sp).value + + +def _int_array_attr(op, key): + if key not in op.attributes: + return None + arr = ir.ArrayAttr(op.attributes[key]) + return [ir.IntegerAttr(x).value for x in arr] + + +# ----- affine expr utilities (mirror the C++ free functions) --------------- +def _is_function_of_dim(expr, dim): + if ir.AffineDimExpr.isinstance(expr): + return ir.AffineDimExpr(expr).position == dim + if ir.AffineBinaryExpr.isinstance(expr): + b = ir.AffineBinaryExpr(expr) + return _is_function_of_dim(b.lhs, dim) or _is_function_of_dim(b.rhs, dim) + return False + + +def _get_coefficient_from_dim(expr, dim): + """Port of getCoefficientFromDim: coefficient of `dim` in expr, or -1.""" + if ir.AffineMulExpr.isinstance(expr): + b = ir.AffineMulExpr(expr) + lhs, rhs = b.lhs, b.rhs + if ir.AffineConstantExpr.isinstance(lhs) and ir.AffineDimExpr.isinstance(rhs): + if _is_function_of_dim(rhs, dim): + return ir.AffineConstantExpr(lhs).value + elif ir.AffineConstantExpr.isinstance(rhs) and ir.AffineDimExpr.isinstance(lhs): + if _is_function_of_dim(lhs, dim): + return ir.AffineConstantExpr(rhs).value + return -1 + if ir.AffineAddExpr.isinstance(expr): + b = ir.AffineAddExpr(expr) + r = _get_coefficient_from_dim(b.lhs, dim) + if r != -1: + return r + r = _get_coefficient_from_dim(b.rhs, dim) + if r != -1: + return r + return -1 + if ir.AffineDimExpr.isinstance(expr): + if _is_function_of_dim(expr, dim): + return 1 + return -1 + + +def _collect_coefficients(expr): + """Port of collectCoefficientsFromAffineExpr.""" + out = [] + + def rec(e): + if ir.AffineMulExpr.isinstance(e): + b = ir.AffineMulExpr(e) + lhs, rhs = b.lhs, b.rhs + if ir.AffineConstantExpr.isinstance(lhs): + if ir.AffineDimExpr.isinstance(rhs): + out.append(ir.AffineConstantExpr(lhs).value) + elif ir.AffineFloorDivExpr.isinstance(rhs): + out.append(ir.AffineConstantExpr(lhs).value) + elif ir.AffineConstantExpr.isinstance(rhs): + if ir.AffineDimExpr.isinstance(lhs): + out.append(ir.AffineConstantExpr(rhs).value) + elif ir.AffineFloorDivExpr.isinstance(lhs): + out.append(ir.AffineConstantExpr(rhs).value) + elif ir.AffineAddExpr.isinstance(e): + b = ir.AffineAddExpr(e) + rec(b.lhs) + rec(b.rhs) + elif ir.AffineFloorDivExpr.isinstance(e): + b = ir.AffineFloorDivExpr(e) + if ir.AffineConstantExpr.isinstance(b.rhs): + out.append(ir.AffineConstantExpr(b.rhs).value) + elif ir.AffineDimExpr.isinstance(e): + out.append(1) + + rec(expr) + return out + + +def _collect_dividers(expr): + """Port of collectDividersFromAffineExpr.""" + out = [] + + def rec(e): + if ir.AffineMulExpr.isinstance(e): + b = ir.AffineMulExpr(e) + lhs, rhs = b.lhs, b.rhs + if ir.AffineConstantExpr.isinstance(lhs): + if ir.AffineDimExpr.isinstance(rhs): + out.append(1) + elif ir.AffineFloorDivExpr.isinstance(rhs): + rec(rhs) + elif ir.AffineConstantExpr.isinstance(rhs): + if ir.AffineDimExpr.isinstance(lhs): + out.append(1) + elif ir.AffineFloorDivExpr.isinstance(lhs): + rec(lhs) + elif ir.AffineAddExpr.isinstance(e): + b = ir.AffineAddExpr(e) + rec(b.lhs) + rec(b.rhs) + elif ir.AffineFloorDivExpr.isinstance(e): + b = ir.AffineFloorDivExpr(e) + if ir.AffineConstantExpr.isinstance(b.rhs): + out.append(ir.AffineConstantExpr(b.rhs).value) + elif ir.AffineDimExpr.isinstance(e): + out.append(1) + + rec(expr) + return out + + +# --------------------------------------------------------------------------- +# The builder (mirrors TestTileOperationGraph members). +# --------------------------------------------------------------------------- +SKIP_OPS = { + "affine.yield", "affine.apply", "memref.get_global", "arith.constant", + "memref.alloc", "memref.reinterpret_cast", +} + + +class MatmulFsm: + Idle = 0 + Preload = 1 + MMPush = 2 + MMVpop = 3 + MMEpilogue = 4 + + +class TogBuilder: + def __init__(self): + self.nr_loop = 0 + self.loop_var_name = {} # value-identity-key -> loop name + self.compute_nodes = [] + self.loop_nodes = [] + self._reset_matmul_fsm() + + # ---- matmul FSM ---- + def _reset_matmul_fsm(self): + self.matmul_fsm = MatmulFsm.Idle + self.mm_iv_op0_count = 0 + self.mm_expected_vpop = 0 + self.mm_vpop_seen = 0 + self.mm_expected_transfer_writes = 0 + self.mm_transfer_writes_seen = 0 + self.current_preload_node = None + self.current_matmul_compute_node = None + + def _finish_matmul_block(self): + self.matmul_fsm = MatmulFsm.Idle + self.current_matmul_compute_node = None + self.mm_iv_op0_count = 0 + self.mm_expected_vpop = 0 + self.mm_vpop_seen = 0 + self.mm_expected_transfer_writes = 0 + self.mm_transfer_writes_seen = 0 + + def _enter_epilogue_or_finish(self): + if self.mm_expected_transfer_writes == 0: + self._finish_matmul_block() + else: + self.matmul_fsm = MatmulFsm.MMEpilogue + + def _append_mm_with_write_count(self, op): + cn = self.current_matmul_compute_node + if cn is None: + return + cn.operations.append(op) + if _op_name(op) != "vector.transfer_write": + return + if self.mm_expected_transfer_writes == 0: + return + self.mm_transfer_writes_seen += 1 + if self.mm_transfer_writes_seen >= self.mm_expected_transfer_writes: + self._finish_matmul_block() + + def _steal_leading_transfer_read(self, parent): + prepend = [] + if parent is None: + return prepend + last = parent.get_last_child() + if last is None or not isinstance(last, TOGComputeNode): + return prepend + if last.compute_type != VECTOR_COMPUTE: + return prepend + ops = last.operations + idx = None + for i, o in enumerate(ops): + if _op_name(o) == "vector.transfer_read": + idx = i + break + if idx is None: + return prepend + prepend.append(ops[idx]) + del ops[idx] + if not ops: + if parent.children and parent.children[-1] is last: + parent.children.pop() + else: + parent.children = [c for c in parent.children if c is not last] + if last in self.compute_nodes: + self.compute_nodes.remove(last) + # node deleted; its id stays consumed (matches C++ unique_id). + return prepend + + def _append_vector_compute(self, parent, op): + if parent is None: + return + last = parent.get_last_child() + if (isinstance(last, TOGComputeNode) + and last.compute_type == VECTOR_COMPUTE): + last.operations.append(op) + return + cn = TOGComputeNode("ComputeNode", 0, VECTOR_COMPUTE) + cn.op = op + cn.operations.append(op) + parent.add_child(cn) + cn.add_parent(parent) + self.compute_nodes.append(cn) + + # ---- vcix classification ---- + @staticmethod + def _vcix_iv_opcode(op): + if _op_name(op) != "vcix.iv": + return None + if "opcode" not in op.operation.attributes: + return None + return ir.IntegerAttr(op.operation.attributes["opcode"]).value + + @staticmethod + def _vcix_vi_opcode(op): + if _op_name(op) != "vcix.v.i": + return None + if "opcode" not in op.operation.attributes: + return None + return ir.IntegerAttr(op.operation.attributes["opcode"]).value + + # ---- affine.for bounds ---- + @staticmethod + def _affine_for_bounds(forop): + oper = forop.operation + step = ir.IntegerAttr(oper.attributes["step"]).value + lb_map = ir.AffineMapAttr(oper.attributes["lowerBoundMap"]).value + ub_map = ir.AffineMapAttr(oper.attributes["upperBoundMap"]).value + # operandSegmentSizes: [lb operands, ub operands, iter operands] + seg_attr = oper.attributes["operandSegmentSizes"] + seg = [seg_attr[i] for i in range(len(seg_attr))] + operands = list(oper.operands) + + def single_const(m): + return len(m.results) == 1 and ir.AffineConstantExpr.isinstance(m.results[0]) + + if single_const(lb_map): + start = ir.AffineConstantExpr(lb_map.results[0]).value + else: + start = _const_index_value(operands[0]) + n_lb = seg[0] if seg else 0 + if single_const(ub_map): + end = ir.AffineConstantExpr(ub_map.results[0]).value + else: + end = _const_index_value(operands[n_lb]) + return start, end, step + + # ---- DRAM index processing ---- + def _process_dram_indices(self, value, loop_index_list, indirect_box): + if not _is_block_arg(value) and value.owner.name == "affine.apply": + apply_op = value.owner + amap = ir.AffineMapAttr(apply_op.attributes["map"]).value + operands = list(apply_op.operands) + if "indirect_access" in apply_op.attributes: + indirect_box[0] = True + for op_v in operands: + if _is_block_arg(op_v): + expr = amap.results[0] + index_pos = operands.index(op_v) + coeff = _get_coefficient_from_dim(expr, index_pos) + key = self.loop_var_name[_value_key(op_v)] + loop_index_list.append((key, coeff)) + else: + self._process_dram_indices(op_v, loop_index_list, indirect_box) + elif _is_block_arg(value): + key = self.loop_var_name[_value_key(value)] + loop_index_list.append((key, 1)) + else: + c = _const_index_value(value) + if c is not None: + loop_index_list.append(("c" + str(c), c)) + + # ---- main recursion ---- + def print_operation(self, op, node): + name = _op_name(op) + if name in SKIP_OPS: + return + + if name == "affine.for": + oper = op.operation + attrs = oper.attributes + loop_type = "" + + def bool_true(k): + return k in attrs and ir.BoolAttr(attrs[k]).value + + if bool_true("outer_loop"): + loop_type = "outer_loop" + elif bool_true("accumulation_loop"): + loop_type = "accumulation_loop" + elif bool_true("inner_loop"): + loop_type = "inner_loop" + + if (bool_true("outer_loop") or bool_true("accumulation_loop") + or bool_true("inner_loop")): + start, end, step = self._affine_for_bounds(op) + loop_index = "loop_arg%03d" % self.nr_loop + self.nr_loop += 1 + iter_var = oper.regions[0].blocks[0].arguments[0] + self.loop_var_name[_value_key(iter_var)] = loop_index + loop_node = TOGLoopNode("loopNode", loop_index, start, end, step, + loop_type) + loop_node.op = op + self.loop_nodes.append(loop_node) + if node is not None: + node.add_child(loop_node) + loop_node.add_parent(node) + for region in oper.regions: + for block in region.blocks: + for inner in block.operations: + self.print_operation(inner, loop_node) + return + + if name == "memref.dma_start": + self._handle_dma_start(op, node) + return + + if name == "memref.dma_wait": + self._handle_dma_wait(op, node) + return + + if node is None: + return + + self._handle_compute(op, node) + + # ---- compute / matmul FSM dispatch ---- + def _handle_compute(self, op, node): + if _op_name(op) == "vcix.iv": + if (self.matmul_fsm == MatmulFsm.MMEpilogue + and self.current_matmul_compute_node): + self._finish_matmul_block() + opc = self._vcix_iv_opcode(op) + if opc is not None: + if opc == 1: + if self.matmul_fsm == MatmulFsm.Idle: + prepend = self._steal_leading_transfer_read(node) + cn = TOGComputeNode("ComputeNode", 0, MATMUL_PRELOAD) + for o in prepend: + cn.operations.append(o) + cn.operations.append(op) + cn.op = cn.operations[0] + node.add_child(cn) + cn.add_parent(node) + self.compute_nodes.append(cn) + self.current_preload_node = cn + self.matmul_fsm = MatmulFsm.Preload + elif (self.matmul_fsm == MatmulFsm.Preload + and self.current_preload_node): + self.current_preload_node.operations.append(op) + else: + raise RuntimeError("vcix.iv opcode=1 invalid state") + return + if opc == 0: + if (self.matmul_fsm == MatmulFsm.Preload + and self.current_preload_node): + cn = TOGComputeNode("ComputeNode", 0, MATMUL_COMPUTE) + cn.op = op + cn.operations.append(op) + node.add_child(cn) + cn.add_parent(node) + self.compute_nodes.append(cn) + self.current_preload_node = None + self.current_matmul_compute_node = cn + self.matmul_fsm = MatmulFsm.MMPush + self.mm_iv_op0_count = 1 + return + if (self.matmul_fsm == MatmulFsm.MMPush + and self.current_matmul_compute_node): + self.current_matmul_compute_node.operations.append(op) + self.mm_iv_op0_count += 1 + return + raise RuntimeError("vcix.iv opcode=0 invalid state") + if (self.matmul_fsm == MatmulFsm.MMEpilogue + and self.current_matmul_compute_node): + self._append_mm_with_write_count(op) + return + self._append_vector_compute(node, op) + return + + if _op_name(op) == "vcix.v.i": + viopc = self._vcix_vi_opcode(op) + if viopc is not None and viopc == 2: + if (self.matmul_fsm == MatmulFsm.MMPush + and self.current_matmul_compute_node): + self.mm_expected_vpop = self.mm_iv_op0_count + self.mm_expected_transfer_writes = self.mm_expected_vpop + self.mm_transfer_writes_seen = 0 + self.matmul_fsm = MatmulFsm.MMVpop + self.mm_vpop_seen = 1 + self.current_matmul_compute_node.operations.append(op) + if self.mm_vpop_seen >= self.mm_expected_vpop: + self._enter_epilogue_or_finish() + return + if (self.matmul_fsm == MatmulFsm.MMVpop + and self.current_matmul_compute_node): + self.current_matmul_compute_node.operations.append(op) + self.mm_vpop_seen += 1 + if self.mm_vpop_seen >= self.mm_expected_vpop: + self._enter_epilogue_or_finish() + return + if (self.matmul_fsm == MatmulFsm.MMEpilogue + and self.current_matmul_compute_node): + self._append_mm_with_write_count(op) + return + self._append_vector_compute(node, op) + return + + if self.matmul_fsm == MatmulFsm.Preload and self.current_preload_node: + self.current_preload_node.operations.append(op) + return + if (self.matmul_fsm == MatmulFsm.MMEpilogue + and self.current_matmul_compute_node): + self._append_mm_with_write_count(op) + return + if (self.matmul_fsm in (MatmulFsm.MMPush, MatmulFsm.MMVpop) + and self.current_matmul_compute_node): + self._append_mm_with_write_count(op) + return + self._append_vector_compute(node, op) + + # ---- dma_start ---- + def _dma_start_fields(self, op): + """Decode memref.dma_start operands by memref ranks. + + Layout: src[srcIdx], dst[dstIdx], numElements, tag[tagIdx], stride, + numElementsPerStride. + """ + operands = list(op.operands) + i = 0 + src = operands[i] + src_type = src.type + src_rank = len(ir.MemRefType(src_type).shape) + i += 1 + src_indices = operands[i:i + src_rank] + i += src_rank + dst = operands[i] + dst_type = dst.type + dst_rank = len(ir.MemRefType(dst_type).shape) + i += 1 + dst_indices = operands[i:i + dst_rank] + i += dst_rank + i += 1 # numElements + tag = operands[i] + tag_rank = len(ir.MemRefType(tag.type).shape) + i += 1 + tag_indices = operands[i:i + tag_rank] + return { + "src": src, "src_type": src_type, "src_indices": src_indices, + "dst": dst, "dst_type": dst_type, "dst_indices": dst_indices, + "tag": tag, "tag_indices": tag_indices, + } + + def _handle_dma_start(self, op, node): + oper = op.operation + f = self._dma_start_fields(op) + dma_async = bool(self._get_async(oper)) + + dst_space = _memref_space(f["dst_type"]) + src_space = _memref_space(f["src_type"]) + + subtile = _int_array_attr(oper, "subtile_size") + + if dst_space == 0 and src_space == 1: + is_write = True + tile_type = f["src_type"] + dram_memref = f["dst"] + dram_indices = f["dst_indices"] + elif dst_space == 1 and src_space == 0: + is_write = False + tile_type = f["dst_type"] + dram_memref = f["src"] + dram_indices = f["src_indices"] + else: + raise RuntimeError("Unexpected memory space") + + tile_mt = ir.MemRefType(tile_type) + tile_shape = subtile if subtile else list(tile_mt.shape) + tile_size = [int(x) for x in tile_shape] + + tile_stride = [] + ds = _int_array_attr(oper, "dram_stride") + if ds: + tile_stride = list(ds) + + loop_index_map = [] + indirect_box = [False] + self._process_dram_indices(dram_indices[0], loop_index_map, indirect_box) + + # std::map => dedup + lexicographic sort by key. + reordered = {} + for key, stride in loop_index_map: + if key not in reordered: + reordered[key] = stride + loop_idx_list = [] + loop_stride_list = [] + for key in sorted(reordered.keys()): + loop_idx_list.append(key) + loop_stride_list.append(reordered[key]) + + # base address + address = "arg" + if _is_block_arg(dram_memref): + address += str(ir.BlockArgument(dram_memref).arg_number) + + # element size + et = tile_mt.element_type + if ir.IntegerType.isinstance(et): + element_size = ir.IntegerType(et).width + elif ir.FloatType.isinstance(et): + element_size = ir.FloatType(et).width + else: + raise RuntimeError("Unsupported element type") + + # tag indices + tag_index_list = [] + tag_stride_list = [] + for tag_idx in f["tag_indices"]: + c = _const_index_value(tag_idx) + if c is not None: + tag_index_list.append(str(c)) + continue + if not _is_block_arg(tag_idx) and tag_idx.owner.name == "affine.apply": + apply_op = tag_idx.owner + for operand in apply_op.operands: + if _is_block_arg(operand): + tag_index_list.append( + self.loop_var_name[_value_key(operand)]) + elif (not _is_block_arg(operand) + and operand.owner.name == "affine.apply"): + nested = operand.owner + for nop in nested.operands: + if _is_block_arg(nop): + tag_index_list.append( + self.loop_var_name[_value_key(nop)]) + else: + nc = _const_index_value(nop) + if nc is not None: + tag_index_list.append(str(nc)) + else: + oc = _const_index_value(operand) + if oc is not None: + tag_index_list.append(str(oc)) + amap = ir.AffineMapAttr(apply_op.attributes["map"]).value + tag_stride_list = _collect_coefficients(amap.results[0]) + + if len(tag_index_list) == 0: + tag_index_list.append("0") + if len(tag_stride_list) == 0: + tag_stride_list.append(1) + + dma_node = TOGDMANode("DMANode", address, tile_size, tile_stride, + element_size, is_write, dma_async, tag_index_list, + tag_stride_list, loop_idx_list, loop_stride_list, + indirect_box[0]) + dma_node.op = op + node.add_child(dma_node) + dma_node.add_parent(node) + + # ---- dma_wait ---- + def _handle_dma_wait(self, op, node): + oper = op.operation + operands = list(oper.operands) + tag = operands[0] + tag_rank = len(ir.MemRefType(tag.type).shape) + tag_indices = operands[1:1 + tag_rank] + + tag_index_list = [] + tag_stride_list = [] + tag_divider_list = [] + for tag_idx in tag_indices: + if not _is_block_arg(tag_idx) and tag_idx.owner.name == "affine.apply": + apply_op = tag_idx.owner + for operand in apply_op.operands: + if _is_block_arg(operand): + tag_index_list.append( + self.loop_var_name[_value_key(operand)]) + else: + c = _const_index_value(operand) + if c is not None: + tag_index_list.append(str(c)) + amap = ir.AffineMapAttr(apply_op.attributes["map"]).value + tag_stride_list = _collect_coefficients(amap.results[0]) + tag_divider_list = _collect_dividers(amap.results[0]) + + # base address: scan users of tag memref for a dma_start. + address = "arg" + for use in tag.uses: + user = use.owner + if user.name == "memref.dma_start": + f = self._dma_start_fields(user) + dst_space = _memref_space(f["dst_type"]) + src_space = _memref_space(f["src_type"]) + dram_memref = None + if dst_space == 0 and src_space == 1: + dram_memref = f["dst"] + elif dst_space == 1 and src_space == 0: + dram_memref = f["src"] + if dram_memref is not None and _is_block_arg(dram_memref): + address += str(ir.BlockArgument(dram_memref).arg_number) + + if len(tag_stride_list) == 0: + tag_stride_list.append(1) + tag_divider_list.append(1) + + wait_node = TOGDMAWaitNode("DMAWaitNode", tag_index_list, tag_stride_list, + tag_divider_list, address) + wait_node.op = op + node.add_child(wait_node) + wait_node.add_parent(node) + + @staticmethod + def _get_async(oper): + if "async" not in oper.attributes: + return 0 + attr = oper.attributes["async"] + try: + return ir.IntegerAttr(attr).value + except Exception: + pass + try: + return 1 if ir.BoolAttr(attr).value else 0 + except Exception: + return 1 + + +# --------------------------------------------------------------------------- +# Empty affine.for erasure (fixed point), mirroring eraseEmptyAffineForLoops. +# --------------------------------------------------------------------------- +def _is_empty_affine_for(op): + if op.operation.name != "affine.for": + return False + body = op.operation.regions[0].blocks[0] + ops = list(body.operations) + # body holds exactly the terminator (affine.yield) and nothing else. + if len(ops) != 1: + return False + return ops[0].operation.name == "affine.yield" + + +def _erase_empty_affine_for(func_op): + changed = True + while changed: + changed = False + to_erase = [] + + def walk(block): + for op in block.operations: + for region in op.operation.regions: + for b in region.blocks: + walk(b) + if _is_empty_affine_for(op): + to_erase.append(op) + + walk(func_op.regions[0].blocks[0]) + for op in to_erase: + op.operation.erase() + changed = True + + +# --------------------------------------------------------------------------- +# IR mutation (mirrors the side effects of -test-tile-operation-graph). +# --------------------------------------------------------------------------- +def _make_inline_asm(ctx, ip, asm_string, compute_type=None): + """Build an llvm.inline_asm compute marker identical to the C++ pass.""" + i64 = ir.IntegerType.get_signless(64) + attrs = { + "has_side_effects": ir.UnitAttr.get(), + "asm_dialect": ir.IntegerAttr.get(i64, 0), + "asm_string": ir.StringAttr.get(asm_string), + "constraints": ir.StringAttr.get("~{a0},~{memory}"), + } + if compute_type is not None: + attrs["compute_type"] = ir.StringAttr.get(compute_type) + return ir.Operation.create( + "llvm.inline_asm", results=[], operands=[], attributes=attrs, + loc=ir.Location.unknown(ctx), ip=ip) + + +def _containing_block(op): + """Return the Block that directly contains `op`.""" + parent = op.operation.parent + for region in parent.regions: + for block in region.blocks: + for sib in block.operations: + if sib.operation == op.operation: + return block + return None + + +def _next_sibling(op): + """Return the op immediately following `op` in its block, or None.""" + block = _containing_block(op) + siblings = list(block.operations) + for i, sib in enumerate(siblings): + if sib.operation == op.operation: + return siblings[i + 1] if i + 1 < len(siblings) else None + return None + + +_ASM_START = ".insn r CUSTOM_3, 0, 0x40, x0, x0, x0" +_ASM_END = ".insn r CUSTOM_3, 0, 0x41, x0, x0, x0" + + +def _rewrite_loop_steps(builder): + """Sample mode: rewrite each loop node's affine.for step to its end. + + Done after the graph is built so the printed graph keeps the original step. + """ + index_t = ir.IndexType.get() + for node in builder.loop_nodes: + forop = node.op.operation + new_step = ir.IntegerAttr.get(index_t, node.loop_end) + forop.attributes["step"] = new_step + + +def _insert_compute_markers(builder): + """Insert inline-asm start/end markers around each compute node's ops.""" + for node in builder.compute_nodes: + ops = node.operations + if not ops: + continue + front = ops[0] + back = ops[-1] + ctx = front.operation.context + ctype_name = _COMPUTE_TYPE_NAME[node.compute_type] + # End marker immediately after `back`: insert before back's next sibling + # (or append to the block when back is the last op). The bindings have no + # "insert after" point, and move_after on a detached op crashes. + after = _next_sibling(back) + if after is not None: + end_ip = ir.InsertionPoint(after) + else: + end_ip = ir.InsertionPoint(_containing_block(back)) + _make_inline_asm(ctx, end_ip, _ASM_END) + # Start marker immediately before `front`. Done after the end marker so + # that the single-op case (front == back) still brackets the op. + _make_inline_asm(ctx, ir.InsertionPoint(front), _ASM_START, ctype_name) + + +# --------------------------------------------------------------------------- +# Driver. +# --------------------------------------------------------------------------- +def _find_kernel(module): + for op in module.body.operations: + if op.operation.name != "func.func": + continue + if ir.StringAttr(op.operation.attributes["sym_name"]).value == "kernel": + return op + return None + + +def _build(module, builder): + """Build the graph and return its display string, populating `builder`.""" + func_op = _find_kernel(module) + if func_op is None: + return "" + + _erase_empty_affine_for(func_op) + + block = func_op.regions[0].blocks[0] + out = [] + for op in block.operations: + if op.operation.name != "affine.for": + continue + root = TOGNode("root") + builder._reset_matmul_fsm() + builder.print_operation(op, root) + root.bfs(out) + return "".join(out) + + +def build_tog_string(module): + # The C++ unique_id is a process global; each mlir-opt invocation starts at + # 0. When this runs in-process per kernel, reset so node ids match byte-exact. + _reset_ids() + return _build(module, TogBuilder()) + + +def build_tog_and_mutate(module, sample_mode=True): + """Build the TOG and apply the C++ pass IR mutation in place. + + Returns the graph display string (byte-exact with build_tog_string). The + caller is responsible for serializing the mutated module (str(module)). + """ + _reset_ids() + builder = TogBuilder() + result = _build(module, builder) + if sample_mode: + _rewrite_loop_steps(builder) + _insert_compute_markers(builder) + return result + + +def run_tog(in_path, tog_out_path, custom_out_path, sample_mode=True, vectorlane=128): + """In-process replacement for the C++ -test-tile-operation-graph pass. + + Reads the post-vcix MLIR at `in_path`, builds the TOG (written as the + `graph = {...}` Python literal to `tog_out_path`) and applies the pass IR + mutation (sample-mode step rewrite + compute markers), writing the mutated + module to `custom_out_path`. `vectorlane` is accepted for parity with the + C++ option; it does not affect the output. Requires the MLIR bindings. + """ + ctx = ir.Context() + ctx.allow_unregistered_dialects = True + with ctx: + module = ir.Module.parse(open(in_path).read(), ctx) + graph = build_tog_and_mutate(module, sample_mode=bool(sample_mode)) + with open(custom_out_path, "w") as fh: + fh.write(str(module)) + with open(tog_out_path, "w") as fh: + fh.write(graph) + + +def main(argv): + import argparse + + parser = argparse.ArgumentParser(prog="build_tog.py") + parser.add_argument("input") + parser.add_argument("--sample-mode", type=int, default=1, choices=(0, 1)) + parser.add_argument("--vectorlane", type=int, default=128) + parser.add_argument("--mutate-out", default=None) + args = parser.parse_args(argv[1:]) + + with open(args.input) as fh: + src = fh.read() + ctx = ir.Context() + ctx.allow_unregistered_dialects = True + with ctx: + module = ir.Module.parse(src, ctx) + if args.mutate_out is not None: + result = build_tog_and_mutate(module, sample_mode=bool(args.sample_mode)) + with open(args.mutate_out, "w") as fh: + fh.write(str(module)) + else: + result = build_tog_string(module) + sys.stdout.write(result) + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) From 3c56c6be04111e4963a87d71919d8714b7c1643f Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 17:33:52 +0900 Subject: [PATCH 06/20] [Test] floor/mod axis-split + graph-copy coverage; deterministic deepseek seed Add tests/ops/view/test_floormod_axis_split.py covering axis-split and graph-copy patterns. Seed the global RNG in the deepseek base test so config-random weights are deterministic. --- .../models/DeepSeek/test_deepseek_v3_base.py | 5 + tests/ops/view/test_floormod_axis_split.py | 122 ++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 tests/ops/view/test_floormod_axis_split.py diff --git a/tests/models/DeepSeek/test_deepseek_v3_base.py b/tests/models/DeepSeek/test_deepseek_v3_base.py index 2941e776..79160127 100644 --- a/tests/models/DeepSeek/test_deepseek_v3_base.py +++ b/tests/models/DeepSeek/test_deepseek_v3_base.py @@ -230,6 +230,11 @@ def run_deepseek_v3_base( config.quantization_config = None config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) + # Seed the global RNG so config-random weight init is deterministic. Without + # this every run builds a different network, so the worst-element NPU-vs-CPU + # error randomly crosses the (loose) allclose threshold and the test is flaky. + torch.manual_seed(0) + if init_mode == "config-random": model = AutoModelForCausalLM.from_config( config=config, diff --git a/tests/ops/view/test_floormod_axis_split.py b/tests/ops/view/test_floormod_axis_split.py new file mode 100644 index 00000000..10ebd114 --- /dev/null +++ b/tests/ops/view/test_floormod_axis_split.py @@ -0,0 +1,122 @@ +"""Floor/mod index handling: axis-split (aligned) + graph-copy (incompatible). + +Covers the index-expression shapes that view/reshape/tile/group ops produce and +how the frontend handles them: + + - aligned floor/mod (single iter var, divisor divides extent): removed by + axis-split at the scheduling layer (TORCHSIM_AXIS_SPLIT). group_norm, repeat, + repeat_interleave, permute+reshape (mixed-radix). + - incompatible radices on a shared axis (case 5, e.g. a[c//2] + b[c%3]): the + conflicting operand is realized by graph-copy (TORCHSIM_GRAPH_COPY) so the + consumer reads it affine and the remainder is axis-split's. + - cross-axis / multi-variable floor/mod argument (case 7, e.g. (3*p0+p1)//4 from + a transpose+reshape feeding a broadcast/softmax/layernorm that keeps the dims + separate): graph-copy materializes the multi-var operand with copy_input (which + forces a copy of a view, unlike realize()); the copy kernel iterates the + operand's own shape so its index collapses to single-var for axis-split. + +The features are env-gated; this test turns them on for itself. axis-split is read +per kernel from the env; graph-copy installs its lowering hook at import, so we +re-run install() after setting the flag. + +Not in the CI allowlist (pytorchsim_test.yml) -- local feature/regression test. +""" +import os +import sys + +import torch +import torch.nn.functional as F + +sys.path.insert(0, os.path.join(os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim"), "tests")) +from _pytorchsim_utils import test_result + +os.environ.setdefault("TORCHSIM_AXIS_SPLIT", "1") +os.environ.setdefault("TORCHSIM_GRAPH_COPY", "1") +from PyTorchSimFrontend.mlir import graph_copy +graph_copy.install() + + +def _run(device, name, fn, *inputs): + torch.manual_seed(0) + opt = torch.compile(dynamic=False)(fn) + res = opt(*[t.to(device=device) for t in inputs]) + ref = fn(*[t.cpu() for t in inputs]) + test_result(name, res, ref, rtol=1e-3, atol=1e-3) + + +# --- aligned floor/mod: handled by axis-split --------------------------------- +def test_group_norm(device): + _run(device, "group_norm c//(C/G)", lambda x: F.group_norm(x, 3), torch.randn(2, 6, 4, 4)) + + +def test_repeat(device): + # tile -> ModularIndexing(c, 1, n) + _run(device, "repeat (mod)", lambda x: x.repeat(1, 2) + 1.0, torch.randn(4, 8)) + + +def test_repeat_interleave(device): + # -> FloorDiv(c, k) + _run(device, "repeat_interleave (floor)", + lambda x: torch.repeat_interleave(x, 2, dim=1) + 1.0, torch.randn(2, 4, 8)) + + +def test_permute_reshape(device): + # permute+reshape -> single-var mixed-radix floor/mod + _run(device, "permute+reshape (mixed-radix)", + lambda x: x.permute(0, 2, 1).reshape(2, 12) + 1.0, torch.randn(2, 3, 4)) + + +def test_three_level_mixed_radix(device): + # reshape+permute+reshape -> chain [1,4,12,24]; the 3-level split leaves a + # residual FloorDiv that simplify_with_ranges cannot fold -> _fold_with_ranges. + _run(device, "3-level mixed-radix", + lambda x: x.reshape(2, 3, 2, 4).permute(0, 2, 1, 3).reshape(2, 24) + 1.0, + torch.randn(2, 6, 4)) + + +def test_pixel_shuffle(device): + # splits two spatial axes -> would be 5D; the rank guard skips the split and + # falls back to baseline (the >4D decompose-peel/TOG path is #258). + _run(device, "pixel_shuffle (rank guard)", + lambda x: F.pixel_shuffle(x, 2) + 1.0, torch.randn(1, 8, 4, 4)) + + +# --- incompatible radices (case 5): handled by graph-copy --------------------- +def test_incompatible_radix(device): + # a[c//2] + b[c%3] on axis c=6 : floor-by-2 vs mod-by-3 (not a chain) + _run(device, "incompat a[c//2]+b[c%3]", + lambda a, b: torch.repeat_interleave(a, 2, dim=1) + b.repeat(1, 2), + torch.randn(2, 3), torch.randn(2, 3)) + + +# --- cross-axis multi-var floor/mod (case 7): handled by graph-copy copy_input - +def test_case7_reshape_broadcast(device): + # (3*p0+p1)//4 from transpose+reshape feeding an elementwise broadcast consumer + _run(device, "case7 reshape+broadcast", + lambda x, y: x.t().reshape(8, 3) + y, torch.randn(4, 6), torch.randn(8, 1)) + + +def test_case7_softmax_reshape(device): + # same multi-var floor feeding a reduction (softmax over the kept-separate dim) + _run(device, "case7 softmax(reshape)", + lambda x: F.softmax(x.t().reshape(8, 3), dim=1), torch.randn(4, 6)) + + +def test_case7_layernorm_reshape(device): + _run(device, "case7 layernorm(reshape)", + lambda x: F.layer_norm(x.t().reshape(8, 3), (3,)), torch.randn(4, 6)) + + +if __name__ == "__main__": + device = torch.device("npu:0") + with torch.no_grad(): + test_group_norm(device) + test_repeat(device) + test_repeat_interleave(device) + test_permute_reshape(device) + test_three_level_mixed_radix(device) + test_pixel_shuffle(device) + test_incompatible_radix(device) + test_case7_reshape_broadcast(device) + test_case7_softmax_reshape(device) + test_case7_layernorm_reshape(device) From a6b7ebb9f986b97ca2e03f45c4239c5f43143d15 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 20:24:55 +0900 Subject: [PATCH 07/20] [Frontend] decompose-transfer: affine.for peel with lane-banked physical SRAM offset Rewrite the >4D peel to mirror the C++ -dma-fine-grained subtile loop: wrap the outer dims in an affine.for nest (marked inner_loop so build_tog/TOG registers the induction var) and emit one <=4D memref.dma_start per iteration. The slice SRAM offset is the lane-banked physical offset -- split-outer dims rescaled by the lane coeff (stride/old_size*new_size, the MVIN block_stride / buildSramAffineMap rule) -- delivered as the last SRAM index operand. The previous unrolled subview carried the offset in the subview, which extract_aligned_pointer_as_index strips in the gemmini lowering, so every slice aliased the same spad location (pixel_shuffle MISMATCH). The DRAM offset folds with the original index into one affine.apply so processDramIndices can walk the loop index (#258). Thread vectorlane (systolic-array size) through run_python_passes into the pass for the rescale's nr_outerloop. Drop the axis-split rank guard now that >4D is peeled correctly, and register tests/ops/view/test_floormod_axis_split.py in the CI allowlist. Validated end-to-end (Gem5+Spike+TOGSim): pixel_shuffle (>4D peel) and the full floor/mod suite pass; elementwise/gemm/conv2d/reduce/softmax/MLP regress clean. --- .github/workflows/pytorchsim_test.yml | 19 +++ PyTorchSimFrontend/extension_codecache.py | 2 +- PyTorchSimFrontend/mlir/axis_split.py | 17 +- PyTorchSimFrontend/mlir/passes/__init__.py | 7 +- .../mlir/passes/decompose_transfer.py | 157 +++++++++++------- .../mlir/passes/lower_vlane_idx.py | 2 +- tests/ops/view/test_floormod_axis_split.py | 6 +- 7 files changed, 134 insertions(+), 76 deletions(-) diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index 98dbd791..54a2345b 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -172,6 +172,25 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/ops/view/test_cat.py + test_floormod_axis_split: + name: Run test_floormod_axis_split.py + runs-on: ubuntu-latest + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_floormod_axis_split.py + run: | + echo "Running test_floormod_axis_split.py" + docker run --rm \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/ops/view/test_floormod_axis_split.py + test_matmul: name: Run test_matmul.py runs-on: ubuntu-latest diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 7bd43079..0309d587 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -131,7 +131,7 @@ def load(cls, source_code, # .mlir in place, before mlir-opt. Currently lowers torchsim.vlane_idx # (replaces the old C++ -global-idx pass); add more in passes/__init__.py. from PyTorchSimFrontend.mlir.passes import run_python_passes, run_standard_lowering, run_tog - run_python_passes(input_path) + run_python_passes(input_path, vectorlane=vectorlane_size) new_input_path = os.path.splitext(input_path)[0] raw_tog_path = new_input_path + "_tog.py" tog_path = os.path.join(write_path, "tile_graph.onnx") diff --git a/PyTorchSimFrontend/mlir/axis_split.py b/PyTorchSimFrontend/mlir/axis_split.py index 1c33e021..a8253e02 100644 --- a/PyTorchSimFrontend/mlir/axis_split.py +++ b/PyTorchSimFrontend/mlir/axis_split.py @@ -167,19 +167,10 @@ def find_split_plan(nodes): plan[ax] = [1, 2, E] break - # Rank guard: if the split would push the index rank past 4, skip it and fall - # back to baseline. The >4D logical tile is *meant* to be peeled into <=4D - # physical descriptors by the decompose-transfer pass, and the #258 TOG crash - # (arith.addi DRAM offset) is now fixed -- but the peel still has a numerical - # correctness bug (pixel_shuffle -> MISMATCH; the peel was only ever isolation- - # validated for MLIR structure, never run end-to-end). Keep the guard until the - # peel numerics are fixed; then this guard can be removed and the recompile-dance - # retired for pixel. - base_rank = next((len(b.iter_vars) for n in nodes - for b in (getattr(n, "_body", None),) if b is not None), 0) - extra = sum(len(ch) - 2 for ch in plan.values()) - if base_rank + extra > 4: - return {} + # A split may push the per-axis index rank past 4. The resulting >4D logical tile + # is peeled into <=4D physical descriptors by the decompose-transfer pass (an + # affine.for nest carrying the lane-banked physical SRAM offset), so there is no + # rank cap here. return plan diff --git a/PyTorchSimFrontend/mlir/passes/__init__.py b/PyTorchSimFrontend/mlir/passes/__init__.py index a5b81e91..8a6843dc 100644 --- a/PyTorchSimFrontend/mlir/passes/__init__.py +++ b/PyTorchSimFrontend/mlir/passes/__init__.py @@ -44,9 +44,12 @@ def _ensure_mlir_bindings_on_path(): ] -def run_python_passes(mlir_path): +def run_python_passes(mlir_path, vectorlane=128): """Apply all registered Python MLIR passes to the .mlir at `mlir_path`, in place. + `vectorlane` (systolic-array size / number of vector lanes) is forwarded to passes + that need it (e.g. decompose_transfer's lane-banked >4D peel). + Returns True if the file was modified, False otherwise. """ with open(mlir_path) as f: @@ -63,7 +66,7 @@ def run_python_passes(mlir_path): with ctx, Location.unknown(): module = Module.parse(text) for p in active: - p.run(module) + p.run(module, vectorlane=vectorlane) out = str(module) with open(mlir_path, "w") as f: diff --git a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py index 87b8aadf..768b992a 100644 --- a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py +++ b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py @@ -1,14 +1,17 @@ """Python out-of-line MLIR pass: decompose togsim.transfer -> <=4D memref.dma_start. A togsim.transfer carries a per-axis affine DMA whose descriptor rank may exceed -the 4D Gemmini limit. This pass is a **pure mechanical rank peel** of that -already-affine access (see docs/dma-transfer-lowering.md, "aligned-only peel"): +the 4D Gemmini limit. This pass lowers it to <=4D customized memref.dma_start +(see docs/dma-transfer-lowering.md): - drop unit (extent-1) tile dims: they contribute no descriptor axis; - if the remaining (effective) rank <= 4 -> emit one customized memref.dma_start, reusing the transfer's operands (fast path); - - if effective rank > 4 -> peel the outer dims into a loop, adjusting the - base index by stride*iv per iteration, inner descriptor <=4D. + - if effective rank > 4 -> wrap the outer dims in an affine.for nest and emit + one <=4D memref.dma_start in the body, mirroring the C++ -dma-fine-grained + subtile loop. The slice DRAM/SRAM offsets are affine.apply over the loop vars; + the SRAM offset is the lane-banked physical offset (split-outer dims rescaled + by the lane coeff) delivered as the last SRAM index operand. It does NO floor/mod linearization (aligned split happens upstream at the scheduling layer) and NO relayout (misaligned access is copy-inserted at the @@ -43,6 +46,15 @@ def _int_array(attr): return [IntegerAttr(a).value for a in ArrayAttr(attr)] +def _const_int(value, default=None): + """Read an arith.constant index/integer operand's value, else `default`.""" + from mlir.ir import IntegerAttr + try: + return IntegerAttr(value.owner.attributes["value"]).value + except Exception: + return default + + def _squeeze_reassociation(shape): """Group source dims so each group's product is one effective (non-unit) dim; unit dims attach to a neighbor. Returns (groups, target_shape).""" @@ -62,13 +74,18 @@ def _squeeze_reassociation(shape): return groups, target -def run(module): - """Lower every togsim.transfer in `module`, in place. Context must be active.""" - import itertools +def run(module, vectorlane=128, **_): + """Lower every togsim.transfer in `module`, in place. Context must be active. + + vectorlane (= systolic-array size / number of vector lanes) feeds the lane-banked + physical SRAM offset in the >4D peel, matching -dma-fine-grained's + systolic-array-size option. + """ from mlir.ir import (InsertionPoint, Operation, MemRefType, ArrayAttr, IntegerAttr, IntegerType, IndexType, DenseI64ArrayAttr, DenseI32ArrayAttr, StridedLayoutAttr, AffineMap, AffineMapAttr, - AffineExpr) + AffineExpr, BoolAttr) + from mlir.dialects import affine i64 = IntegerType.get_signless(64) idx_ty = IndexType.get() @@ -85,6 +102,7 @@ def run(module): vlane_axis = IntegerAttr(op.attributes["vlane_split_axis"]).value dram_stride = _int_array(op.attributes["dram_stride"]) tile_stride = _int_array(op.attributes["tile_stride"]) + vlane_stride = _const_int(vst, 1) padding = op.attributes["padding"] sram_ty = MemRefType(sram.type) @@ -134,21 +152,20 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr): op.erase() continue - # Peel path: >4 effective dims. Keep the inner 4 as the <=4D descriptor and - # peel the outer (len-4) effective dims into a fully-unrolled set of slices - # (one descriptor per outer index combo; base advances by stride*idx). The - # SRAM slice is a rank-reduced memref.subview at the slice offset; the DRAM - # base advances by a *constant* per slice. + # Peel path: >4 effective dims. Wrap the outer (len-4) effective dims in an + # affine.for nest (one loop per outer dim, marked inner_loop so build_tog/TOG + # registers the induction var) and emit a single <=4D memref.dma_start in the + # innermost body -- mirroring the C++ -dma-fine-grained subtile loop. # - # The constant DRAM offset must be folded into an affine.apply over the - # original dram_idx (NOT arith.addi): the TOG pass reads loop_idx_list by - # walking the DRAM index via processDramIndices, which understands - # affine.apply / block-arg / constant but NOT arith.addi -- an addi yields an - # empty loop_idx_list and the kernel fails ONNX serialization (#258). The - # peeled dim itself is a fixed constant in each unrolled slice (this DMA does - # not iterate it), so it correctly contributes no loop var; the surviving - # loop vars come from the original dram_idx affine.apply, into which - # processDramIndices recurses. + # The slice SRAM offset is the PHYSICAL lane-banked offset: dims outer than the + # vlane axis are rescaled by the lane coeff (stride/old_size*new_size, the MVIN + # block_stride / buildSramAffineMap rule). It is delivered as the last SRAM index + # operand (row-major stride 1), NOT a subview offset -- the gemmini lowering reads + # the spad base via extract_aligned_pointer_as_index, which strips the subview + # offset, so the slice must be selected through the index. The DRAM offset is the + # flat contiguous offset, folded with the original dram_idx into one affine.apply + # (an arith.addi would be opaque to processDramIndices -- #258); the affine.for + # induction vars feed both maps so TOG reads the loop indices through them. peeled, inner = eff[:-4], eff[-4:] ndim = len(tile_shape) inner_shape = [tile_shape[d] for d in inner] @@ -157,40 +174,68 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr): tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[d]) for d in inner]) # the vlane axis must survive into the inner descriptor (it is the lane dim). new_vlane = inner.index(vlane_axis) if vlane_axis in inner else 0 - for combo in itertools.product(*[range(tile_shape[d]) for d in peeled]): - static_offsets = [0] * ndim - static_sizes = [1] * ndim - for k, d in enumerate(peeled): - static_offsets[d] = combo[k] - for d in inner: - static_sizes[d] = tile_shape[d] - sram_off = sum(combo[k] * tile_stride[peeled[k]] for k in range(len(peeled))) - dram_off = sum(combo[k] * dram_stride[peeled[k]] for k in range(len(peeled))) - res_ty = MemRefType.get( - inner_shape, elem, - layout=StridedLayoutAttr.get(sram_off, inner_strides), memory_space=space) - with InsertionPoint(op): - sub = Operation.create( - "memref.subview", results=[res_ty], operands=[sram], - attributes={"static_offsets": DenseI64ArrayAttr.get(static_offsets), - "static_sizes": DenseI64ArrayAttr.get(static_sizes), - "static_strides": DenseI64ArrayAttr.get([1] * ndim), - # operandSegmentSizes is an i32 property: [source, offsets, - # sizes, strides] dynamic-operand counts. All static here -> - # only the source operand. Must be i32, not i64 (i64 silently - # zeroes to [0,0,0,0] and fails verification). - "operandSegmentSizes": DenseI32ArrayAttr.get([1, 0, 0, 0])} - ).results[0] - if dram_off == 0: - dram_idx_val = dram_idx - else: - # affine.apply (d0) -> (d0 + dram_off) so TOG's processDramIndices - # recurses through it into the original dram_idx's loop vars. - amap = AffineMap.get(1, 0, [AffineExpr.get_dim(0) + dram_off]) - dram_idx_val = Operation.create( - "affine.apply", results=[idx_ty], operands=[dram_idx], - attributes={"map": AffineMapAttr.get(amap)}).results[0] - _emit(sub, [sram_idx] * 4, dram_idx_val, new_vlane, dr_attr, tl_attr) + + # Lane-banked physical stride for split-outer dims (vlane_stride defaults to 1). + split_extent = tile_shape[vlane_axis] + nr_outerloop = max( + (split_extent + vectorlane * vlane_stride - 1) // (vectorlane * vlane_stride), 1) + new_size = nr_outerloop * vlane_stride + target_stride = tile_stride[vlane_axis] + + def _phys(d): + s = tile_stride[d] + return s // split_extent * new_size if s > target_stride else s + + # subview to the inner <=4D block at the buffer start (offset 0); slice selection + # is done through the SRAM index, so the StridedLayout offset stays 0. + static_sizes = [1] * ndim + for d in inner: + static_sizes[d] = tile_shape[d] + res_ty = MemRefType.get( + inner_shape, elem, + layout=StridedLayoutAttr.get(0, inner_strides), memory_space=space) + + # affine.for nest over the peeled (outer) dims. + cur_ip = InsertionPoint(op) + ivs = [] + for d in peeled: + floop = affine.AffineForOp(0, tile_shape[d], 1, ip=cur_ip) + floop.operation.attributes["inner_loop"] = BoolAttr.get(True) + ivs.append(floop.induction_variable) + with InsertionPoint(floop.body): + affine.AffineYieldOp([]) + cur_ip = InsertionPoint.at_block_terminator(floop.body) + + npeel = len(peeled) + with cur_ip: + sub = Operation.create( + "memref.subview", results=[res_ty], operands=[sram], + attributes={"static_offsets": DenseI64ArrayAttr.get([0] * ndim), + "static_sizes": DenseI64ArrayAttr.get(static_sizes), + "static_strides": DenseI64ArrayAttr.get([1] * ndim), + # i32 [source, offsets, sizes, strides] dynamic-operand counts; + # all static -> source only. i64 silently zeroes and fails verify. + "operandSegmentSizes": DenseI32ArrayAttr.get([1, 0, 0, 0])} + ).results[0] + # physical SRAM offset = sum_k iv_k * phys_stride(peeled[k]) + sram_expr = AffineExpr.get_dim(0) * _phys(peeled[0]) + for k in range(1, npeel): + sram_expr = sram_expr + AffineExpr.get_dim(k) * _phys(peeled[k]) + sram_off_val = Operation.create( + "affine.apply", results=[idx_ty], operands=list(ivs), + attributes={"map": AffineMapAttr.get(AffineMap.get(npeel, 0, [sram_expr]))} + ).results[0] + # DRAM index = orig dram_idx + sum_k iv_k * dram_stride(peeled[k]) + dram_expr = AffineExpr.get_dim(0) + for k in range(npeel): + dram_expr = dram_expr + AffineExpr.get_dim(k + 1) * dram_stride[peeled[k]] + dram_idx_val = Operation.create( + "affine.apply", results=[idx_ty], operands=[dram_idx, *ivs], + attributes={"map": AffineMapAttr.get(AffineMap.get(npeel + 1, 0, [dram_expr]))} + ).results[0] + zero = _const(0) + _emit(sub, [zero, zero, zero, sram_off_val], dram_idx_val, new_vlane, + dr_attr, tl_attr) op.erase() diff --git a/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py b/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py index c9898f4b..76e30cb3 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py +++ b/PyTorchSimFrontend/mlir/passes/lower_vlane_idx.py @@ -33,7 +33,7 @@ def _iter_ops(block): yield from _iter_ops(b) -def run(module): +def run(module, **_): """Lower every torchsim.vlane_idx op in `module`, in place. Must be called with the module's Context active (the orchestrator provides it). diff --git a/tests/ops/view/test_floormod_axis_split.py b/tests/ops/view/test_floormod_axis_split.py index 10ebd114..365ee788 100644 --- a/tests/ops/view/test_floormod_axis_split.py +++ b/tests/ops/view/test_floormod_axis_split.py @@ -75,9 +75,9 @@ def test_three_level_mixed_radix(device): def test_pixel_shuffle(device): - # splits two spatial axes -> would be 5D; the rank guard skips the split and - # falls back to baseline (the >4D decompose-peel/TOG path is #258). - _run(device, "pixel_shuffle (rank guard)", + # splits two spatial axes -> 5D logical tile; the decompose-transfer pass peels + # the outer dims into an affine.for nest with the lane-banked physical SRAM offset. + _run(device, "pixel_shuffle (>4D peel)", lambda x: F.pixel_shuffle(x, 2) + 1.0, torch.randn(1, 8, 4, 4)) From 04e425690b2f4081dd32ba20e80b568ade3889c1 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 20:41:21 +0900 Subject: [PATCH 08/20] [Frontend] Unify all DMA codegen on togsim.transfer Route every MVIN/MVOUT -- both the MLIRKernel load/store backend path and the template path (gemm/conv/bmm/maxpool/cat) -- through emit_transfer, so a single decompose-transfer pass lowers all DMAs to memref.dma_start. This drops the get_dma_code emitter, the _dma_needs_transfer instance flag, and format_dma_op_attributes. togsim.transfer now also carries subtile_size and async, which decompose propagates onto the lowered dma_start (subtile filtered to the kept axes when unit dims collapse). For <=4D tiles decompose emits the descriptor directly on the original SRAM buffer (no collapse_shape) so the C++ -dma-fine-grained subtile split, which walks the SRAM operand, sees a direct buffer as before. Validated end-to-end (Spike + TOGSim) on elementwise, gemm (matmul/addmm), bmm, conv2d, group_conv, pool, cat, reduce, softmax, layernorm, batchnorm. Co-Authored-By: Claude Opus 4.8 --- .../mlir/mlir_codegen_backend.py | 83 ++++--------------- PyTorchSimFrontend/mlir/mlir_common.py | 21 ----- PyTorchSimFrontend/mlir/mlir_template.py | 30 ++----- .../mlir/passes/decompose_transfer.py | 37 +++++++-- 4 files changed, 54 insertions(+), 117 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 19ae3af5..a001c861 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -322,10 +322,6 @@ def __init__(self, kernel_group, reason=None): self.spad_buffer_dict = dict() self.base_vector_initialized = False self.loop_size = None - # Set by get_dma_info when a DMA access cannot fit one <=4D Gemmini - # descriptor; load()/store() then emit a togsim.transfer for the - # decompose pass to peel into a loop of <=4D dma_start. - self._dma_needs_transfer = False def reset(self, reason): save = self.exit_stack, self._nested_context_depth @@ -540,15 +536,8 @@ def load(self, name: str, index: sympy.Expr): sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) - # MVIN Encoding - if self._dma_needs_transfer: - self._dma_needs_transfer = False - code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, dram_stride, tile_stride, int(padding)) - else: - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, int(padding)) - code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, int(padding)) self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching if not comptute_depedency: @@ -617,14 +606,8 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) sram_index_var = self.spad_buffer_dict[str(value)][3] # Generate DMA instruction - if self._dma_needs_transfer: - self._dma_needs_transfer = False - code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, dram_stride, tile_stride, 0) - else: - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) self.dma_stores.writeline(common.DeferredLine(name, code)) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -751,9 +734,8 @@ def store_reduction(self, name, index, value): ops._store(value, sram_var, sram_index_var, tile_shape, buffer_name=name) # Generate DMA instruction - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) self.reductions_suffix.writeline(common.DeferredLine(name, code)) def indirect_indexing(self, index_var, size, check=True, wrap_neg=True): @@ -1257,13 +1239,9 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride else: - # >4D access: one Gemmini DMA descriptor (<=4D) cannot represent this. - # Build the full N-D tile and flag it for togsim.transfer; the decompose - # pass peels the excess dims into a loop of <=4D memref.dma_start. local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride - self._dma_needs_transfer = True if len(implicit_local_dims)!=0 and len(local_dims) != len(implicit_local_dims) and self.is_modular_indexing(index): for axis_constraints in self.kernel_group.tile_desc.implicit_dim_size.values(): @@ -1409,46 +1387,10 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe dram_stride = [0] + dram_stride[:-1] return local_tile_desc, index_var, dram_stride - def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute): - dma_key = (vlane_split_axis, vlane_stride, mlir_dtype) - if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: - dma_type, vlane_split_axis, vlane_stride = self.dma_read_cache[dma_key] - elif dma_type_name == "MVOUT" and dma_key in self.dma_write_cache: - dma_type, vlane_split_axis, vlane_stride = self.dma_write_cache[dma_key] - else: - vlane_split_axis = self.get_const_cse(vlane_split_axis) - vlane_stride = self.get_const_cse(vlane_stride) - if dma_type_name == "MVIN": - dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_read_counter}"]) - self.dma_read_counter += 1 - self.dma_read_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] - else: - dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"]) - self.dma_write_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] - tag = self.get_tag_cse() - zero_cse = self.get_const_cse(0) - - # Prepare opearnds and attributes - dram_operand = f"%{dram_var}[%{dram_index_var}]" - sram_operand = f"%{sram_var}[{sram_index_var}]" # Use string - tag_var = f"%{tag}[%{zero_cse}]" - dma_attribute = f"%{vlane_split_axis}, %{vlane_stride}" - sram_shape = tile_shape - tag_shape = "memref<1xi32>" - - if dma_type_name == "MVIN": - src_operand, dst_operand = dram_operand, sram_operand - src_shape, dst_shape = dram_shape, sram_shape - else: - src_operand, dst_operand = sram_operand, dram_operand - src_shape, dst_shape = sram_shape, dram_shape - - return f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape} {attribute}" - def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, - dram_shape, tile_shape, dram_stride, tile_stride, padding): + dram_shape, tile_shape, dram_stride, tile_stride, padding, + subtile_size=None, async_type=None): """Emit a generic togsim.transfer op for a DMA whose access exceeds the 4D Gemmini descriptor limit. Carries the full N-D access (dram/tile strides + shapes) plus the SSA operands a memref.dma_start needs @@ -1456,9 +1398,9 @@ def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtyp (passes/decompose_transfer.py) is purely mechanical: it peels the excess dims into a loop of <=4D memref.dma_start, reusing these operands. - The operand prep mirrors get_dma_code (dma_type enum via the read/write - cache+counter, vlane consts via CSE) so the transfer is self-contained; - togsim is an unregistered dialect -> generic form. + Operand prep uses the read/write cache+counter for the dma_type enum and + CSE'd vlane consts so the transfer is self-contained; togsim is an + unregistered dialect -> generic form. """ dma_key = (vlane_split_axis, vlane_stride, mlir_dtype) if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: @@ -1486,6 +1428,9 @@ def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtyp f'dram_stride = {dram_stride}, tile_stride = {tile_stride}, ' f'padding = {int(padding)} : i64' ) + if subtile_size: + av = int(async_type) if async_type is not None else 1 + attrs += f', subtile_size = {list(subtile_size)}, async = {av} : i64' # operands: dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_stride return ( f'"togsim.transfer"(%{dram_var}, %{dram_index_var}, %{sram_var}, %{zero_cse}, ' diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 45bb144a..a7921463 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -120,27 +120,6 @@ def get_dtype_nbytes(dtype): } } -def format_dma_op_attributes( - dram_stride: Sequence, - sram_stride: Sequence, - padding: int = 0, - *, - subtile_size: Optional[Sequence] = None, - async_type: Optional[int] = None, -) -> str: - """Attribute dict for memref.dma_start; stride lists as bracketed integer lists.""" - parts = [ - f"dram_stride = {dram_stride}", - f"sram_stride = {sram_stride}", - f"padding = {int(padding)}", - ] - if subtile_size: - parts.append(f"subtile_size = {subtile_size}") - av = int(async_type) if async_type is not None else 1 - parts.append(f"async = {av} : i64") - return "{" + ", ".join(parts) + "}" - - class ParallelLoopBuffer(IndentedBuffer): def indent(self, offset=1, attribute="", suffix=""): @contextlib.contextmanager diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index c8fc036f..529a49b5 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -952,18 +952,9 @@ def generate_dma_code(): zero_cse = self.get_const_cse(0, "index") sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - if subtile_size: - attribute = mlir_common.format_dma_op_attributes( - _dram_stride, - sram_strides, - int(padding), - subtile_size=subtile_size, - async_type=int(async_type) if async_type is not None else None, - ) - else: - attribute = mlir_common.format_dma_op_attributes(_dram_stride, sram_strides, int(padding)) - code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, _dram_stride, sram_strides, int(padding), + subtile_size=subtile_size if subtile_size else None, async_type=async_type) local_code.writeline(code) return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() @@ -1030,9 +1021,8 @@ def load_epilogue(self, name: str, index: sympy.Expr): # Allocate sram buffer dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) self.cse.generate(self.dma_loads, code, assignment = False) self.buffer_names[name] = sram_var else: @@ -1098,9 +1088,8 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): ops._store(value, sram_var, compute_index_var, tile_shape, buffer_name=buffer_name) # Generate DMA instruction - attribute = mlir_common.format_dma_op_attributes(dram_stride, tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, attribute) + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, dram_stride, tile_stride, 0) self.dma_stores.writeline(DeferredLine(name, code)) def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): @@ -1249,9 +1238,8 @@ def store_reduction_epilogue(self, name, index, value): # MVOUT Encoding # Generate DMA instruction - attribute = mlir_common.format_dma_op_attributes(dram_stride, final_tile_stride, 0) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, final_tile_shape, attribute) + code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, final_tile_shape, dram_stride, final_tile_stride, 0) self.reductions_suffix.writeline(DeferredLine(name, code)) def set_tile_size(self, template_fusion_info, prologue=False): diff --git a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py index 768b992a..29841f85 100644 --- a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py +++ b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py @@ -104,6 +104,11 @@ def run(module, vectorlane=128, **_): tile_stride = _int_array(op.attributes["tile_stride"]) vlane_stride = _const_int(vst, 1) padding = op.attributes["padding"] + try: + subtile = _int_array(op.attributes["subtile_size"]) + async_attr = op.attributes["async"] + except KeyError: + subtile, async_attr = None, None sram_ty = MemRefType(sram.type) elem, space = sram_ty.element_type, sram_ty.memory_space @@ -116,7 +121,7 @@ def _const(v): "arith.constant", results=[idx_ty], attributes={"value": IntegerAttr.get(idx_ty, v)}).results[0] - def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr): + def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_attr=None): vsa = _const(vsa_val) if kind == "MVIN": operands = [dram, dram_idx_val, sram_mem, *sram_indices, @@ -124,10 +129,26 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr): else: operands = [sram_mem, *sram_indices, dram, dram_idx_val, dma_type, tag, sram_idx, vsa, vst] + attrs = {"dram_stride": dr_attr, "sram_stride": tl_attr, "padding": padding} + if st_attr is not None: + attrs["subtile_size"] = st_attr + attrs["async"] = async_attr Operation.create( - "memref.dma_start", results=[], operands=operands, - attributes={"dram_stride": dr_attr, "sram_stride": tl_attr, - "padding": padding}) + "memref.dma_start", results=[], operands=operands, attributes=attrs) + + if len(tile_shape) <= 4: + # Already <=4D: emit the descriptor directly on the original SRAM, no + # collapse_shape. The C++ -dma-fine-grained subtile split walks the SRAM + # operand and chokes on a collapse_shape result, so keep it a direct buffer. + dr_attr = ArrayAttr.get([IntegerAttr.get(i64, s) for s in dram_stride]) + tl_attr = ArrayAttr.get([IntegerAttr.get(i64, s) for s in tile_stride]) + st_attr = (ArrayAttr.get([IntegerAttr.get(i64, s) for s in subtile]) + if subtile is not None else None) + with InsertionPoint(op): + _emit(sram, [sram_idx] * len(tile_shape), dram_idx, vlane_axis, + dr_attr, tl_attr, st_attr) + op.erase() + continue if len(eff) <= 4: # Fast path: drop unit dims so the descriptor reaches <=4D. The customized @@ -141,6 +162,8 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr): keep = [g[-1] for g in groups] # the non-unit dim in each group dr_attr = ArrayAttr.get([IntegerAttr.get(i64, dram_stride[i]) for i in keep]) tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[i]) for i in keep]) + st_attr = (ArrayAttr.get([IntegerAttr.get(i64, subtile[i]) for i in keep]) + if subtile is not None else None) # Remap vlane axis to the collapsed-dim index (the group containing it). new_vlane = next(gi for gi, g in enumerate(groups) if vlane_axis in g) with InsertionPoint(op): @@ -148,7 +171,7 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr): "memref.collapse_shape", results=[collapsed_ty], operands=[sram], attributes={"reassociation": reassoc}).results[0] _emit(sram_c, [sram_idx] * len(target), dram_idx, new_vlane, - dr_attr, tl_attr) + dr_attr, tl_attr, st_attr) op.erase() continue @@ -172,6 +195,8 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr): inner_strides = [tile_stride[d] for d in inner] dr_attr = ArrayAttr.get([IntegerAttr.get(i64, dram_stride[d]) for d in inner]) tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[d]) for d in inner]) + st_attr = (ArrayAttr.get([IntegerAttr.get(i64, subtile[d]) for d in inner]) + if subtile is not None else None) # the vlane axis must survive into the inner descriptor (it is the lane dim). new_vlane = inner.index(vlane_axis) if vlane_axis in inner else 0 @@ -235,7 +260,7 @@ def _phys(d): ).results[0] zero = _const(0) _emit(sub, [zero, zero, zero, sram_off_val], dram_idx_val, new_vlane, - dr_attr, tl_attr) + dr_attr, tl_attr, st_attr) op.erase() From 80633c7c135ab94876cc05ebae3b8ec923c5acf1 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 21:44:19 +0900 Subject: [PATCH 09/20] [Docs] axis-split: rank guard removed, >4D peel via affine.for Reflect a6b7ebb9: the find_split_plan rank guard is gone (>4D index now lowers through the decompose-transfer affine.for peel, pixel_shuffle end-to-end), and the decompose-transfer peel <-> TOG incompatibility is resolved. Move it from Known-issues to Done; drop the >4D rank-guard caveat and the high-rank next-step. Co-Authored-By: Claude Opus 4.8 --- docs/axis-split-scheduling.md | 38 ++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/docs/axis-split-scheduling.md b/docs/axis-split-scheduling.md index 10171ab4..48e7db2f 100644 --- a/docs/axis-split-scheduling.md +++ b/docs/axis-split-scheduling.md @@ -121,24 +121,30 @@ FloorDiv); the misaligned class is structurally a graph-copy problem. `_fold_with_ranges` now proves it directly from the split sub-var ranges via `bound_sympy`: `FloorDiv(num,d)->0` when `0<=numnum//k` when `0<=num 5D), which triggers the nascent - decompose-transfer peel + TOG path (see below). `find_split_plan` now has a rank - guard: if applying the plan would make the index rank exceed 4, the whole plan is - dropped and the kernel falls back to baseline. pixel_shuffle now passes (via - baseline); 3D group_norm still splits (rank 4, allowed). +- **High-rank blow-up (fixed via peel).** Splitting several axes can push the index + rank past 4 (pixel_shuffle -> 5D). This now lowers through the decompose-transfer + >4D peel (affine.for nest, one <=4D dma_start per iteration), so the earlier + `find_split_plan` rank guard was removed (a6b7ebb9): plans are no longer dropped to + baseline for exceeding rank 4. pixel_shuffle splits to >4D and passes end-to-end + (Gem5+Spike+TOGSim); 3D group_norm still splits (rank 4). ## Known issues / open -- **decompose-transfer peel <-> TOG incompatibility**: the >4D peel emits - `memref.subview` + unrolled constant-offset `dma_start`, which the C++ TOG - generation pass cannot read (empty `loop_idx_list`). The rank guard above - side-steps it; the real fix is to rewrite the peel as an `affine.for` loop - (keeping a loop index TOG can read) instead of unrolling. **Tracked as a GitHub - issue + the `dma-transfer-lowering.md` TODO.** +- None open here. The decompose-transfer peel <-> TOG incompatibility (>4D peel + unreadable by TOG) was resolved by rewriting the peel as an affine.for nest -- see + Done below. ## Done +- **>4D peel via affine.for (fixed)** -- a6b7ebb9. The earlier peel emitted + `memref.subview` + unrolled constant-offset `dma_start` that the TOG pass could not + read (empty `loop_idx_list`) and that aliased one spad slot + (extract_aligned_pointer_as_index strips the subview offset -> pixel_shuffle + MISMATCH). Rewritten to wrap the outer dims in an `affine.for` nest (marked + inner_loop so build_tog registers the induction var), with the lane-banked physical + SRAM offset carried as the last SRAM index operand and the DRAM offset folded into + one affine.apply (#258). The axis-split rank guard was removed; pixel_shuffle passes + end-to-end. - **Mixed-radix (ModularIndexing + multi-radix)**: `find_split_plan` returns a per-axis divisibility-chain of boundaries; `build_split_body` splits into one sub-var per segment (`v = sum_i d_i*b_i`). Validated allclose=True on group_norm @@ -196,8 +202,9 @@ Measured under default-on (`TORCHSIM_RECOMPILE_LOG=1`), 33 tests, all pass: **Full retirement of the dance is deferred** (it is still a real dependency, not just a safety net): removing the floor/mod recompile branches would break the 3-level mixed-radix case (1 recompile) and any case axis-split/graph-copy do not -yet cover (case 6, >4D rank-guard skips). attention/sdpa families were not run here -(too slow locally) and need CI validation before retirement. +yet cover (case 6: non-dividing divisor / uneven cat; reduction-axis floor/mod). +attention/sdpa families were not run here (too slow locally) and need CI validation +before retirement. ## Next steps @@ -206,6 +213,5 @@ yet cover (case 6, >4D rank-guard skips). attention/sdpa families were not run h non-floor/mod ones: non-power-of-2 vec size, indirect). 2. Graph-copy coverage: case 6 (non-dividing divisor / uneven cat -> pad or gather), and conflicts internal to templates (gemm/conv/sdpa). -3. High-rank interaction: cap split-induced rank or harden decompose-peel + TOG for - high-rank tiles (pixel_shuffle end-to-end, #258). +3. Reduction-axis floor/mod (`r//k` inside a reduce): needs reduction-var splitting. 4. Dynamic shapes -> symbolic divisibility / guards. From ab0beaabe39c4576e4071a3494090c45f6d0c78d Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 22:07:33 +0900 Subject: [PATCH 10/20] [Frontend] dma-fine-grained: port the C++ pass to Python (MLIR bindings) Port mlir/test/lib/Analysis/TestDmaFineGrained.cpp to a Python out-of-line pass (passes/dma_fine_grained.py): split the matmul MVIN DMAs (input/weight/bias) into subtile affine.for nests and fuse the input/weight nests, replacing the C++ -dma-fine-grained pass. The MLIR Python bindings expose no IRMapping, so the fused nest is built directly (each DMA emitted with the fused induction vars) instead of cloning bodies -- structurally equivalent, not byte-exact SSA text. Pipeline: the single mlir-opt invocation is split around the Python pass (loop-padding -> run_fine_grained in place -> pytorchsim-to-vcix) in both the functional and gem5 paths (extension_codecache); vectorlane (systolic-array size) is threaded in for the lane-banked SRAM offset rescale. Validated against mlir-opt -dma-fine-grained on rank 2/3/4 fixtures (matmul / bmm / conv: same vcix dma_start and line counts) and end-to-end (Gem5+Spike+TOGSim): gemm/bmm/conv2d plus the resnet/transformer/vit/mlp models pass. Docs: dma-transfer-lowering.md -- >4D peel is affine.for + lane-banked physical SRAM offset via the last index operand; dma_fine_grained / build_tog are now Python passes; the #258 appendix is marked resolved. --- PyTorchSimFrontend/extension_codecache.py | 56 ++- PyTorchSimFrontend/mlir/passes/__init__.py | 1 + .../mlir/passes/dma_fine_grained.py | 404 ++++++++++++++++++ docs/dma-transfer-lowering.md | 58 +-- 4 files changed, 479 insertions(+), 40 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/passes/dma_fine_grained.py diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 0309d587..68d02a34 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -38,14 +38,24 @@ def dump_metadata(args, arg_attributes, path): return def mlir_compile_command(filename, vectorlane_size, vlen=256): + # The C++ -dma-fine-grained pass is ported to Python (passes/dma_fine_grained.py): + # it runs in-process between loop-padding and pytorchsim-to-vcix, so the single + # mlir-opt invocation is split around it (loop-padding -> _padded.mlir, then the + # Python pass in place, then vcix). return [re.sub(r"[ \n]+", " ", f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding \ - -dma-fine-grained='systolic-array-size={vectorlane_size}' \ + {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ + {filename}.mlir -o {filename}_padded.mlir + """, + ).strip(), + re.sub(r"[ \n]+", " ", + f""" + {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ - {filename}.mlir -o {filename}_custom.mlir + {filename}_padded.mlir -o {filename}_custom.mlir """, ).strip(), re.sub(r"[ \n]+", " ", @@ -73,14 +83,22 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): ).strip()] def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_size, vlen=256): + # See mlir_compile_command: -dma-fine-grained is the Python pass, run in-process + # between loop-padding and vcix, so the opt invocation is split around it. return [re.sub(r"[ \n]+", " ", f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding='timing_mode=1' \ - -dma-fine-grained='systolic-array-size={vectorlane_size}' \ + {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ + {filename}.mlir -o {sample_filename}_padded.mlir + """, + ).strip(), + re.sub(r"[ \n]+", " ", + f""" + {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ - {filename}.mlir -o {sample_filename}_postvcix.mlir + {sample_filename}_padded.mlir -o {sample_filename}_postvcix.mlir """, ).strip(), re.sub(r"[ \n]+", " ", @@ -130,7 +148,7 @@ def load(cls, source_code, # Run the Python out-of-line MLIR passes (MLIR bindings) on the kernel # .mlir in place, before mlir-opt. Currently lowers torchsim.vlane_idx # (replaces the old C++ -global-idx pass); add more in passes/__init__.py. - from PyTorchSimFrontend.mlir.passes import run_python_passes, run_standard_lowering, run_tog + from PyTorchSimFrontend.mlir.passes import run_python_passes, run_standard_lowering, run_tog, run_fine_grained run_python_passes(input_path, vectorlane=vectorlane_size) new_input_path = os.path.splitext(input_path)[0] raw_tog_path = new_input_path + "_tog.py" @@ -152,13 +170,18 @@ def load(cls, source_code, # Use custom malloc to avoid size error new_link_option = link_option + " -Wl,--wrap=malloc -Wl,--wrap=free" cmds = mlir_compile_command(new_input_path, vectorlane_size, vlen=vlen) - opt_cmd = shlex.split(cmds[0]) - translate_cmd = shlex.split(cmds[1]) - llc_cmd = shlex.split(cmds[2]) - llc_asm_cmd = shlex.split(cmds[3]) + opt_pad_cmd = shlex.split(cmds[0]) + opt_vcix_cmd = shlex.split(cmds[1]) + translate_cmd = shlex.split(cmds[2]) + llc_cmd = shlex.split(cmds[3]) + llc_asm_cmd = shlex.split(cmds[4]) with lock: try: - subprocess.check_call(opt_cmd) + # loop-padding -> Python -dma-fine-grained (in place) -> vcix + subprocess.check_call(opt_pad_cmd) + run_fine_grained(new_input_path + "_padded.mlir", + new_input_path + "_padded.mlir", vectorlane_size) + subprocess.check_call(opt_vcix_cmd) # Standard MLIR -> LLVM-dialect lowering (registered upstream # passes) runs in-process via the bindings PassManager, picking # up after the custom mlir-opt passes (memref-to-gemmini). @@ -191,9 +214,10 @@ def load(cls, source_code, return key # Launch tile graph generator - gem5_sample_cmd = shlex.split(gem5_cmds[0]) - gem5_translate_cmd = shlex.split(gem5_cmds[1]) - gem5_llc_cmd = shlex.split(gem5_cmds[2]) + gem5_pad_cmd = shlex.split(gem5_cmds[0]) + gem5_vcix_cmd = shlex.split(gem5_cmds[1]) + gem5_translate_cmd = shlex.split(gem5_cmds[2]) + gem5_llc_cmd = shlex.split(gem5_cmds[3]) lock = FileLock(get_lock_path(write_path), timeout=LOCK_TIMEOUT) with lock: @@ -203,7 +227,11 @@ def load(cls, source_code, # to Python: run_tog reads that IR, writes the TOG (_tog.py) and the # mutated IR (_custom.mlir: sample-mode step rewrite + compute markers), # replacing the C++ -test-tile-operation-graph pass. - subprocess.check_call(gem5_sample_cmd) + # loop-padding(timing) -> Python -dma-fine-grained (in place) -> vcix + subprocess.check_call(gem5_pad_cmd) + run_fine_grained(sample_mlir_path + "_padded.mlir", + sample_mlir_path + "_padded.mlir", vectorlane_size) + subprocess.check_call(gem5_vcix_cmd) run_tog(sample_mlir_path + "_postvcix.mlir", raw_tog_path, sample_mlir_path + "_custom.mlir", sample_mode=extension_config.CONFIG_TLS_MODE, diff --git a/PyTorchSimFrontend/mlir/passes/__init__.py b/PyTorchSimFrontend/mlir/passes/__init__.py index 8a6843dc..310b0c84 100644 --- a/PyTorchSimFrontend/mlir/passes/__init__.py +++ b/PyTorchSimFrontend/mlir/passes/__init__.py @@ -34,6 +34,7 @@ def _ensure_mlir_bindings_on_path(): from . import decompose_transfer from .lower_to_llvm import run_standard_lowering # noqa: F401 (re-exported) from .build_tog import run_tog # noqa: F401 (re-exported; replaces C++ test-tile-operation-graph) +from .dma_fine_grained import run_fine_grained # noqa: F401 (replaces C++ -dma-fine-grained) # Ordered passes applied to each kernel .mlir before mlir-opt. # decompose_transfer first: it lowers togsim.transfer -> memref.dma_start, which diff --git a/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py b/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py new file mode 100644 index 00000000..ff49aea8 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py @@ -0,0 +1,404 @@ +"""Python port of the C++ `-dma-fine-grained` MLIR pass (TestDmaFineGrained.cpp). + +Splits the matmul MVIN DMAs (input / weight / optional bias) into subtile loops +(affine.for nests carrying per-subtile DRAM/SRAM offset affine.apply maps) and +fuses the input and weight loop nests, mirroring the C++ pass structurally. Runs +AFTER -test-loop-padding (it reads the padded tile shapes / loop bounds) and +BEFORE -test-pytorchsim-to-vcix, so extension_codecache splits the single mlir-opt +invocation around this pass. + +The C++ pass fuses the two subtile loop nests by cloning their bodies with an +IRMapping; the MLIR Python bindings expose no IRMapping, so this port builds the +fused nest directly and emits each DMA inside it using the fused induction vars +(equivalence target: same loop structure / counts, same offset maps, same +dma_start operands+attrs -- validated against mlir-opt -dma-fine-grained and the +end-to-end gemm/conv/model tests, not byte-exact SSA text). + +Operates on the customized memref.dma_start convention (see lower_dma_to_gemmini): +operands = src, *src_idx, dst, *dst_idx, num_elements(dma_type), tag, *tag_idx, +stride(=vlane_split_axis), num_elements_per_stride(=vlane_stride). MVIN dma_type in +{2,1,14}; tile shape = dst shape for MVIN. + +Pipeline entry point: run_fine_grained(in_path, out_path, vectorlane). +""" +import os +import sys + +_DEFAULT_BINDINGS = "/riscv-llvm/python_packages/mlir_core" +if os.path.isdir(_DEFAULT_BINDINGS) and _DEFAULT_BINDINGS not in sys.path: + sys.path.insert(0, _DEFAULT_BINDINGS) + +import mlir.ir as ir # noqa: E402 + +MVIN, MVIN2, MVIN3, MVOUT = 2, 1, 14, 3 + +# Per-rank subtile loop order and the fused-loop layout (mirror the C++ loopGroups). +# in_to_fused[d] / w_to_fused[d] give the fused-loop index that the input/weight +# DMA's dim d iterates; n_fused is the number of fused affine.for loops. +_FUSE = { + 2: dict(n_fused=3, in_to_fused=[0, 1], w_to_fused=[1, 2]), + 3: dict(n_fused=4, in_to_fused=[0, 1, 2], w_to_fused=[0, 2, 3]), + 4: dict(n_fused=7, in_to_fused=[0, 1, 4, 5], w_to_fused=[2, 3, 5, 6]), +} + + +# --------------------------------------------------------------------------- +# Small readers (mirror CustomDMAAttribute.h) +# --------------------------------------------------------------------------- +def _const_int(value, default=-1): + try: + return ir.IntegerAttr(value.owner.attributes["value"]).value + except Exception: + return default + + +def _int_array_attr(op, key): + if key not in op.attributes: + return [] + return [ir.IntegerAttr(a).value for a in ir.ArrayAttr(op.attributes[key])] + + +def _is_block_arg(v): + return isinstance(v, ir.BlockArgument) + + +class _Dma: + """Positional view of a customized memref.dma_start op.""" + + def __init__(self, op): + self.op = op + operands = list(op.operands) + src_rank = len(ir.MemRefType(operands[0].type).shape) + i = 0 + self.src = operands[i]; i += 1 + self.src_idx = operands[i:i + src_rank]; i += src_rank + self.dst = operands[i]; i += 1 + dst_rank = len(ir.MemRefType(self.dst.type).shape) + self.dst_idx = operands[i:i + dst_rank]; i += dst_rank + self.num_elements = operands[i]; i += 1 + self.tag = operands[i]; i += 1 + tag_rank = len(ir.MemRefType(self.tag.type).shape) + self.tag_idx = operands[i:i + tag_rank]; i += tag_rank + self.stride = operands[i]; i += 1 # = vlane_split_axis + self.num_elements_per_stride = operands[i] # = vlane_stride + self.src_rank, self.dst_rank, self.tag_rank = src_rank, dst_rank, tag_rank + + @property + def dma_type(self): + return _const_int(self.num_elements) + + @property + def is_mvin(self): + return self.dma_type in (MVIN, MVIN2, MVIN3) + + @property + def vlane_split_axis(self): + return _const_int(self.stride) + + @property + def vlane_stride(self): + return _const_int(self.num_elements_per_stride) & 0x7FFF + + def tile_shape(self): + mt = ir.MemRefType((self.dst if self.is_mvin else self.src).type) + return list(mt.shape) + + def subtile_size(self): + return _int_array_attr(self.op, "subtile_size") + + def sram_stride(self): + return _int_array_attr(self.op, "sram_stride") + + def dram_stride(self): + return _int_array_attr(self.op, "dram_stride") + + def is_async(self): + a = self.op.attributes + if "async" not in a: + return False + try: + return bool(ir.IntegerAttr(a["async"]).value) + except Exception: + return True + + +# --------------------------------------------------------------------------- +# Affine map builders (mirror buildDramAffineMap / buildSramAffineMap) +# --------------------------------------------------------------------------- +def _ceil_div(a, b): + return (a + b - 1) // b + + +def _build_dram_map(dma): + dram = dma.dram_stride() + sub = dma.subtile_size() + rank = len(dram) + expr = ir.AffineConstantExpr.get(0) + for i in range(rank): + expr = expr + ir.AffineDimExpr.get(i) * (dram[i] * sub[i]) + return ir.AffineMap.get(rank, 0, [expr]) + + +def _build_sram_map(dma, vectorlane): + tile_shape = dma.tile_shape() + tile_stride = dma.sram_stride() + sub = dma.subtile_size() + split = dma.vlane_split_axis + vstride = dma.vlane_stride + + target_stride = tile_stride[split] + old_size = tile_shape[split] + nr_outerloop = _ceil_div(old_size, vectorlane * vstride) + new_size = nr_outerloop * vstride + + expr = None + for i in range(len(tile_stride)): + subtilesize = sub[i] + stride = tile_stride[i] + if stride > target_stride: + stride = stride // old_size * new_size + d = ir.AffineDimExpr.get(i) + if i != split: + term = d * (subtilesize * stride) + else: + term = ir.AffineExpr.get_floor_div(d * subtilesize, vectorlane) * stride + expr = term if expr is None else expr + term + return ir.AffineMap.get(len(tile_stride), 0, [expr]) + + +def _build_tag_map(dma, loop_order): + """Mirror the tag stride map built inside buildSubtileLoop.""" + tile_sizes = dma.tile_shape() + sub = dma.subtile_size() + rank = len(tile_sizes) + strides = [1] * rank + for i in range(rank - 2, -1, -1): + cur, nxt = loop_order[i], loop_order[i + 1] + strides[cur] = strides[nxt] * _ceil_div(tile_sizes[nxt], sub[nxt]) + expr = ir.AffineConstantExpr.get(0) + for i in range(rank): + expr = expr + ir.AffineDimExpr.get(i) * strides[i] + return ir.AffineMap.get(rank, 0, [expr]) + + +def _loop_counts(dma, loop_order): + tile_sizes = dma.tile_shape() + sub = dma.subtile_size() + return [_ceil_div(tile_sizes[d], sub[d]) for d in range(len(tile_sizes))] + + +# --------------------------------------------------------------------------- +# DMA emission inside a body +# --------------------------------------------------------------------------- +def _sum_map(): + d0, d1 = ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1) + return ir.AffineMap.get(2, 0, [d0 + d1]) + + +def _apply(map_, operands, ip): + from mlir.dialects import affine + return affine.AffineApplyOp(map_, list(operands), ip=ip).result + + +def _dma_attrs(dma): + """Mirror getDmaAttrs: keep subtile/sram/dram strides, set async + fine_grained.""" + attrs = {} + op = dma.op + for k in ("subtile_size", "sram_stride", "dram_stride"): + if k in op.attributes: + attrs[k] = op.attributes[k] + attrs["async"] = ir.BoolAttr.get(dma.is_async()) + attrs["fine_grained"] = ir.BoolAttr.get(True) + return attrs + + +def _emit_dma(dma, ivs, vectorlane, ip): + """Emit one fine-grained memref.dma_start at `ip`, indexed by `ivs` (the fused + induction vars for this DMA's dims, in dim order).""" + idx_ty = ir.IndexType.get() + zero = _const_index(0, ip) + + dram_off = _apply(_build_dram_map(dma), ivs, ip) + src_idx0 = dma.src_idx[0] + dram_idx = _apply(_sum_map(), [dram_off, src_idx0], ip) + + sram_off = _apply(_build_sram_map(dma, vectorlane), ivs, ip) + tag_idx = _apply(_build_tag_map(dma, list(range(len(dma.tile_shape())))), ivs, ip) + + # SRAM indices: zeros except the last = sram offset (mirror sramIndices.back()). + sram_indices = [zero] * dma.dst_rank + sram_indices[-1] = sram_off + + operands = [dma.src, dram_idx, dma.dst, *sram_indices, + dma.num_elements, dma.tag, tag_idx, + dma.stride, dma.num_elements_per_stride] + ir.Operation.create("memref.dma_start", results=[], operands=operands, + attributes=_dma_attrs(dma), ip=ip) + + +def _const_index(v, ip): + from mlir.dialects import arith + return arith.ConstantOp(ir.IndexType.get(), + ir.IntegerAttr.get(ir.IndexType.get(), v), ip=ip).result + + +# --------------------------------------------------------------------------- +# Loop-nest construction +# --------------------------------------------------------------------------- +def _build_for_nest(bounds, ip): + """Create a nested affine.for over `bounds` (step 1, marked inner_loop). Returns + (induction_vars, innermost_body_ip_before_yield).""" + from mlir.dialects import affine + ivs = [] + cur_ip = ip + for b in bounds: + floop = affine.AffineForOp(0, b, 1, ip=cur_ip) + floop.operation.attributes["inner_loop"] = ir.BoolAttr.get(True) + ivs.append(floop.induction_variable) + with ir.InsertionPoint(floop.body): + affine.AffineYieldOp([]) + cur_ip = ir.InsertionPoint.at_block_terminator(floop.body) + return ivs, cur_ip + + +def _create_subtile_dma(dma, loop_order, ip, vectorlane): + """Standalone subtile loop for one DMA (used for bias). Mirrors createSubtileDMA.""" + counts = _loop_counts(dma, loop_order) + bounds = [counts[d] for d in loop_order] + ivs_in_order, body_ip = _build_for_nest(bounds, ip) + # map dim -> its induction var (loop_order[k] is the dim of the k-th loop) + iv_by_dim = [None] * len(counts) + for k, d in enumerate(loop_order): + iv_by_dim[d] = ivs_in_order[k] + _emit_dma(dma, iv_by_dim, vectorlane, body_ip) + + +# --------------------------------------------------------------------------- +# Operand reachability (mirror traverseOperands) +# --------------------------------------------------------------------------- +def _reaches(value, target): + if value == target: + return True + owner = value.owner + if isinstance(owner, ir.Block): # block argument: no defining op to walk + return False + for operand in owner.operands: + if _reaches(operand, target): + return True + return False + + +# --------------------------------------------------------------------------- +# Pass driver +# --------------------------------------------------------------------------- +def _iter_ops(block): + for op in list(block.operations): + yield op + for region in op.operation.regions: + for b in region.blocks: + yield from _iter_ops(b) + + +def _run_func(func, vectorlane): + from mlir.dialects import linalg + # First matmul only. + matmul = None + dmas = [] + for op in _iter_ops(func.regions[0].blocks[0]): + name = op.operation.name + if name == "linalg.matmul" and matmul is None: + matmul = op + elif name == "memref.dma_start": + dmas.append(op) + if matmul is None: + return + + m_in = matmul.operands[0] + m_w = matmul.operands[1] + m_res = list(matmul.operands)[-1] # output (init) operand + + mvin_input = mvin_weight = mvin_bias = None + for op in dmas: + d = _Dma(op) + if d.dma_type == MVOUT: + continue + if _reaches(m_in, d.dst): + mvin_input = d + elif _reaches(m_w, d.dst): + mvin_weight = d + elif _reaches(m_res, d.dst) and len(d.subtile_size()) > 1: + mvin_bias = d + + in_async = mvin_input is not None and mvin_input.is_async() + w_async = mvin_weight is not None and mvin_weight.is_async() + if not (in_async or w_async): + return + if mvin_input is None or mvin_weight is None: + return + + rank = len(mvin_input.tile_shape()) + if rank not in _FUSE: + return + fuse = _FUSE[rank] + loop_order = list(range(rank)) + + # Bias first (standalone), inserted before its own op. + if mvin_bias is not None: + brank = len(mvin_bias.tile_shape()) + border = {2: [0, 1], 4: [2, 3, 0, 1]}.get(brank) + if border is not None: + _create_subtile_dma(mvin_bias, border, + ir.InsertionPoint(mvin_bias.op), vectorlane) + mvin_bias.op.erase() + + # Fused input + weight nest. Fused loop bounds: take each fused loop's count + # from whichever DMA dim maps onto it. + in_counts = _loop_counts(mvin_input, loop_order) + w_counts = _loop_counts(mvin_weight, loop_order) + bounds = [None] * fuse["n_fused"] + for d, f in enumerate(fuse["in_to_fused"]): + bounds[f] = in_counts[d] + for d, f in enumerate(fuse["w_to_fused"]): + bounds[f] = w_counts[d] + + # Insert the fused nest at the weight DMA (the later of the two): both DMAs' + # original DRAM base indices (src_idx[0], computed in the enclosing loops) must + # dominate the nest. Codegen emits input before weight, matching the C++ pass + # which fuses after the weight subtile loop. + ip = ir.InsertionPoint(mvin_weight.op) + fused_ivs, body_ip = _build_for_nest(bounds, ip) + in_ivs = [fused_ivs[fuse["in_to_fused"][d]] for d in range(rank)] + w_ivs = [fused_ivs[fuse["w_to_fused"][d]] for d in range(rank)] + _emit_dma(mvin_input, in_ivs, vectorlane, body_ip) + _emit_dma(mvin_weight, w_ivs, vectorlane, body_ip) + mvin_input.op.erase() + mvin_weight.op.erase() + + +def run(module, vectorlane=128, **_): + """Apply fine-grained DMA subtiling to every func in `module`, in place.""" + from mlir.dialects import func as func_d # noqa: F401 (ensure dialect loaded) + for region in module.operation.regions: + for b in region.blocks: + for op in list(b.operations): + if op.operation.name == "func.func" and len(op.operation.regions[0].blocks): + _run_func(op, vectorlane) + + +def run_fine_grained(in_path, out_path, vectorlane=128): + """Parse `in_path`, run the pass, write `out_path`. Pipeline entry point.""" + with open(in_path) as f: + text = f.read() + ctx = ir.Context() + ctx.allow_unregistered_dialects = True + with ctx, ir.Location.unknown(): + module = ir.Module.parse(text) + run(module, vectorlane=vectorlane) + out = str(module) + with open(out_path, "w") as f: + f.write(out) + + +if __name__ == "__main__": + vl = int(sys.argv[3]) if len(sys.argv) > 3 else 128 + run_fine_grained(sys.argv[1], sys.argv[2], vl) diff --git a/docs/dma-transfer-lowering.md b/docs/dma-transfer-lowering.md index cbf875c0..1ab9be45 100644 --- a/docs/dma-transfer-lowering.md +++ b/docs/dma-transfer-lowering.md @@ -149,10 +149,15 @@ The DMA descriptor is an **affine map of rank <= 4 with integer strides** 1. **`D <= 4`** -> emit **one** customized `memref.dma_start`; the dims become the descriptor's <=4D shape/strides. Identical to today's output (fast path). -2. **`D > 4`** -> peel `D - 4` dims into an outer `affine.for`; each iteration - computes a base with `affine.apply` (the peeled dims' linear contribution) and - issues the inner <=4D affine descriptor. SRAM offsets are computed symmetrically - in the same loop. +2. **`D > 4`** -> peel `D - 4` dims into an outer `affine.for` (marked `inner_loop` + so the TOG pass reads the induction var); each iteration computes the DRAM base + with one `affine.apply` (the peeled dims' linear contribution folded with the + original index) and the **lane-banked physical** SRAM offset (dims outer than the + vlane axis rescaled by the lane coeff -- the MVIN `block_stride` / + `-dma-fine-grained` `buildSramAffineMap` rule, which needs the vector-lane count), + delivered as the **last SRAM index operand**. The offset must go through the index, + not a subview offset: the gemmini lowering reads the spad base via + `extract_aligned_pointer_as_index`, which strips a subview offset. That is the whole pass. There is **no linearization step** (upstream guarantees affine) and **no relayout fallback** (upstream graph copy handles misalignment). @@ -197,8 +202,9 @@ Rationale: algebra: rank / peel); gemmini does instruction encoding (hardware). Different axes; merging couples affine logic with ISA detail. They stay distinct passes. - **`memref.dma_start` is a shared contract** with multiple consumers - (lower_dma_to_gemmini, dma-fine-grained, the TOG pass). Keeping it as the - interface lets all of them stay unchanged. + (`lower_dma_to_gemmini`, `dma_fine_grained`, `build_tog` -- all now Python + out-of-line passes; the C++ `-dma-fine-grained` / `-test-tile-operation-graph` + are ported). Keeping it as the interface lets all of them stay unchanged. - **gemmini is now a Python out-of-line pass too** -- the conversion-framework coupling (LLVMTypeConverter / getStridedElementPtr) was avoided by working at the memref level (`memref.extract_aligned_pointer_as_index` + arith for @@ -453,26 +459,26 @@ now emits MVIN/MVOUT `togsim.transfer` with 5D `dram_stride [1,6,30,120,360]` an Validated end-to-end (Gem5 + Spike + TOGSim, `allclose=True`) on the 5D permute `x.permute(4,3,2,1,0).contiguous() + 1.0`; no regression on 2D/3D/elementwise. -- **Genuine >4 effective rank (isolation-only; INCOMPATIBLE with TOG -- see TODO).** - When >4 *non-unit* dims survive, the pass keeps the inner 4 as the <=4D descriptor - and peels the outer dims by **full unrolling**: one descriptor per outer-index - combo, the SRAM slice a rank-reduced `memref.subview` at the static slice offset, - the DRAM base `dram_idx + constant`. This passes `lower_text` / mlir-opt in - isolation, but **fails the full pipeline**: the C++ TOG generation pass cannot read - `memref.subview` + unrolled (constant-offset) DMAs and produces an empty - `loop_idx_list` (ValueError in `onnx_utility.py`). Surfaced once aligned axis-split - made the path reachable (pixel_shuffle -> 5D); axis-split now has a rank guard that - avoids triggering it. - -> **TODO (peel rework, tracked as GitHub issue #258).** Rewrite the >4D peel to emit -> a real `affine.for` over the peeled dims (so each DMA keeps an enclosing loop index -> the TOG pass can read) and index the spad directly instead of via `memref.subview`. -> Alternatively teach the C++ TOG pass to handle `subview` + unrolled DMAs. Until -> then the unroll path is isolation-only and the axis-split rank guard keeps it -> unreached. - -The input stays per-axis affine by upstream guarantee, so both paths are pure -mechanical peeling. A non-affine residue is a contract violation (aligned floor/mod +- **Genuine >4 effective rank (affine.for peel; #258 resolved).** When >4 *non-unit* + dims survive, the pass keeps the inner 4 as the <=4D descriptor and peels the outer + dims into an `affine.for` nest (marked `inner_loop`), emitting one inner descriptor + per iteration -- mirroring the `-dma-fine-grained` subtile loop. The DRAM base is + `affine.apply(dram_idx + sum_k iv_k * dram_stride_k)` (one apply, not `arith.addi`, + so the TOG pass walks the loop index through it). The SRAM slice offset is the + **lane-banked physical** offset (split-outer dims rescaled by the lane coeff) + delivered as the **last SRAM index operand**, *not* a `memref.subview` offset -- + `extract_aligned_pointer_as_index` in the gemmini lowering strips a subview offset, + which is why the earlier full-unroll + subview attempt produced wrong data and the + C++ TOG read an empty `loop_idx_list` (#258). + + The earlier full-unroll + subview form was isolation-only and INCOMPATIBLE with the + TOG; the `affine.for` rework (this is exactly the #258 TODO) fixed both the TOG + read and the numerics, so the axis-split rank guard was removed. Validated + end-to-end (Gem5 + Spike + TOGSim, `allclose=True`) on `pixel_shuffle(x, 2) + 1.0` + (5D tile) plus the gemm/bmm/conv/model suite. + +The input stays per-axis affine by upstream guarantee. A non-affine residue is a +contract violation (aligned floor/mod removal lives in `axis-split-scheduling.md`, misaligned relayout in graph copy insertion -- see "Division of labor"); a genuinely non-affine / indirect index would surface as a build failure here rather than being silently relaid out. From 41a365c17238719c29d954b4b4c37bc3831bb01e Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 23:15:17 +0900 Subject: [PATCH 11/20] [Frontend] pytorchsim-to-vcix: port the C++ pass to Python (MLIR bindings) Port mlir/test/lib/Conversion/PyTorchSimToVCIX/TestPyTorchSimToVCIXConversion.cpp to a Python out-of-line pass (passes/lower_to_vcix.py): lower linalg.matmul (gemm and conv2d) and the transcendental math ops (exp/erf/tanh/sin/cos) to VCIX dialect ops (RISC-V vector custom instructions), replacing the C++ -test-pytorchsim-to-vcix. The C++ pass is a dialect conversion (applyPartialConversion); the bindings expose no conversion framework, so each matchAndRewrite is reimplemented as imperative IR rewriting. The VCIX dialect is not in the Python bindings, so vcix ops are created as unregistered generic ops -- mlir-opt / mlir-translate (vcix registered) re-parse the {}-attr generic form fine, and run_standard_lowering already consumes vcix output via allow_unregistered_dialects, so this matches the existing pipeline. Pipeline: the vcix mlir-opt invocation is dropped; run_to_vcix runs in-process after the Python fine-grained pass and before the standard lowering (both functional and gem5 paths in extension_codecache). mlir-opt now runs only -test-loop-padding. Validated structurally against mlir-opt -test-pytorchsim-to-vcix (non-constant ops byte-identical including the dma_wait tag maps, on gemm and conv2d fixtures) and numerically end-to-end (Gem5+Spike+TOGSim allclose): gemm/bmm/conv2d (incl. large N/K), softmax, exp/erf/sin/cos, and the resnet18/vit/transformer/mlp models. --- PyTorchSimFrontend/extension_codecache.py | 52 +- PyTorchSimFrontend/mlir/passes/__init__.py | 1 + .../mlir/passes/lower_to_vcix.py | 619 ++++++++++++++++++ 3 files changed, 638 insertions(+), 34 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/passes/lower_to_vcix.py diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 68d02a34..a6a213ce 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -38,10 +38,10 @@ def dump_metadata(args, arg_attributes, path): return def mlir_compile_command(filename, vectorlane_size, vlen=256): - # The C++ -dma-fine-grained pass is ported to Python (passes/dma_fine_grained.py): - # it runs in-process between loop-padding and pytorchsim-to-vcix, so the single - # mlir-opt invocation is split around it (loop-padding -> _padded.mlir, then the - # Python pass in place, then vcix). + # The C++ -dma-fine-grained and -test-pytorchsim-to-vcix passes are ported to + # Python (passes/dma_fine_grained.py, lower_to_vcix.py), run in-process between + # loop-padding and the standard lowering. So mlir-opt now runs only loop-padding + # (-> _padded.mlir); the Python fine-grained + vcix passes produce _custom.mlir. return [re.sub(r"[ \n]+", " ", f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ @@ -49,14 +49,6 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ {filename}.mlir -o {filename}_padded.mlir """, - ).strip(), - re.sub(r"[ \n]+", " ", - f""" - {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ - -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ - {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ - {filename}_padded.mlir -o {filename}_custom.mlir - """, ).strip(), re.sub(r"[ \n]+", " ", f""" @@ -83,8 +75,8 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): ).strip()] def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_size, vlen=256): - # See mlir_compile_command: -dma-fine-grained is the Python pass, run in-process - # between loop-padding and vcix, so the opt invocation is split around it. + # See mlir_compile_command: -dma-fine-grained and -test-pytorchsim-to-vcix are + # Python passes run in-process; mlir-opt runs only loop-padding here. return [re.sub(r"[ \n]+", " ", f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ @@ -92,14 +84,6 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ {filename}.mlir -o {sample_filename}_padded.mlir """, - ).strip(), - re.sub(r"[ \n]+", " ", - f""" - {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ - -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ - {'--mlir-print-ir-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_MLIR_IR else ''} \ - {sample_filename}_padded.mlir -o {sample_filename}_postvcix.mlir - """, ).strip(), re.sub(r"[ \n]+", " ", f""" @@ -148,7 +132,7 @@ def load(cls, source_code, # Run the Python out-of-line MLIR passes (MLIR bindings) on the kernel # .mlir in place, before mlir-opt. Currently lowers torchsim.vlane_idx # (replaces the old C++ -global-idx pass); add more in passes/__init__.py. - from PyTorchSimFrontend.mlir.passes import run_python_passes, run_standard_lowering, run_tog, run_fine_grained + from PyTorchSimFrontend.mlir.passes import run_python_passes, run_standard_lowering, run_tog, run_fine_grained, run_to_vcix run_python_passes(input_path, vectorlane=vectorlane_size) new_input_path = os.path.splitext(input_path)[0] raw_tog_path = new_input_path + "_tog.py" @@ -171,17 +155,17 @@ def load(cls, source_code, new_link_option = link_option + " -Wl,--wrap=malloc -Wl,--wrap=free" cmds = mlir_compile_command(new_input_path, vectorlane_size, vlen=vlen) opt_pad_cmd = shlex.split(cmds[0]) - opt_vcix_cmd = shlex.split(cmds[1]) - translate_cmd = shlex.split(cmds[2]) - llc_cmd = shlex.split(cmds[3]) - llc_asm_cmd = shlex.split(cmds[4]) + translate_cmd = shlex.split(cmds[1]) + llc_cmd = shlex.split(cmds[2]) + llc_asm_cmd = shlex.split(cmds[3]) with lock: try: - # loop-padding -> Python -dma-fine-grained (in place) -> vcix + # loop-padding (mlir-opt) -> Python fine-grained -> Python vcix subprocess.check_call(opt_pad_cmd) run_fine_grained(new_input_path + "_padded.mlir", new_input_path + "_padded.mlir", vectorlane_size) - subprocess.check_call(opt_vcix_cmd) + run_to_vcix(new_input_path + "_padded.mlir", + new_input_path + "_custom.mlir", vectorlane_size, vlen) # Standard MLIR -> LLVM-dialect lowering (registered upstream # passes) runs in-process via the bindings PassManager, picking # up after the custom mlir-opt passes (memref-to-gemmini). @@ -215,9 +199,8 @@ def load(cls, source_code, # Launch tile graph generator gem5_pad_cmd = shlex.split(gem5_cmds[0]) - gem5_vcix_cmd = shlex.split(gem5_cmds[1]) - gem5_translate_cmd = shlex.split(gem5_cmds[2]) - gem5_llc_cmd = shlex.split(gem5_cmds[3]) + gem5_translate_cmd = shlex.split(gem5_cmds[1]) + gem5_llc_cmd = shlex.split(gem5_cmds[2]) lock = FileLock(get_lock_path(write_path), timeout=LOCK_TIMEOUT) with lock: @@ -227,11 +210,12 @@ def load(cls, source_code, # to Python: run_tog reads that IR, writes the TOG (_tog.py) and the # mutated IR (_custom.mlir: sample-mode step rewrite + compute markers), # replacing the C++ -test-tile-operation-graph pass. - # loop-padding(timing) -> Python -dma-fine-grained (in place) -> vcix + # loop-padding(timing, mlir-opt) -> Python fine-grained -> Python vcix subprocess.check_call(gem5_pad_cmd) run_fine_grained(sample_mlir_path + "_padded.mlir", sample_mlir_path + "_padded.mlir", vectorlane_size) - subprocess.check_call(gem5_vcix_cmd) + run_to_vcix(sample_mlir_path + "_padded.mlir", + sample_mlir_path + "_postvcix.mlir", vectorlane_size, vlen) run_tog(sample_mlir_path + "_postvcix.mlir", raw_tog_path, sample_mlir_path + "_custom.mlir", sample_mode=extension_config.CONFIG_TLS_MODE, diff --git a/PyTorchSimFrontend/mlir/passes/__init__.py b/PyTorchSimFrontend/mlir/passes/__init__.py index 310b0c84..543d0b40 100644 --- a/PyTorchSimFrontend/mlir/passes/__init__.py +++ b/PyTorchSimFrontend/mlir/passes/__init__.py @@ -35,6 +35,7 @@ def _ensure_mlir_bindings_on_path(): from .lower_to_llvm import run_standard_lowering # noqa: F401 (re-exported) from .build_tog import run_tog # noqa: F401 (re-exported; replaces C++ test-tile-operation-graph) from .dma_fine_grained import run_fine_grained # noqa: F401 (replaces C++ -dma-fine-grained) +from .lower_to_vcix import run_to_vcix # noqa: F401 (replaces C++ -test-pytorchsim-to-vcix) # Ordered passes applied to each kernel .mlir before mlir-opt. # decompose_transfer first: it lowers togsim.transfer -> memref.dma_start, which diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py new file mode 100644 index 00000000..4286ba19 --- /dev/null +++ b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py @@ -0,0 +1,619 @@ +"""Python port of the C++ `-test-pytorchsim-to-vcix` conversion pass +(TestPyTorchSimToVCIXConversion.cpp). + +Lowers `linalg.matmul` and the transcendental math ops (exp/erf/tanh/sin/cos) to +VCIX dialect ops (RISC-V vector custom instructions). The C++ pass is a +dialect-conversion (`applyPartialConversion`); the MLIR Python bindings expose no +conversion framework, so each matchAndRewrite is reimplemented as imperative IR +rewriting (walk + build replacement + replace uses + erase). + +The VCIX dialect is NOT registered in the Python bindings, so vcix ops are created +as unregistered generic ops. This round-trips: mlir-opt / mlir-translate (which do +have vcix registered) re-parse the `{}`-attr generic form fine, and the existing +`run_standard_lowering` already consumes the C++ vcix output via +`allow_unregistered_dialects` -- so emitting generic vcix ops here is consistent +with the current pipeline. + +Covers all 6 C++ patterns: linalg.matmul (gemm + conv2d) and exp/erf/tanh/sin/cos. +Wired into extension_codecache (run_to_vcix) after fine-grained, before the standard +lowering; mlir-opt then runs only -test-loop-padding. Validated structurally against +`mlir-opt -test-pytorchsim-to-vcix` (non-constant ops byte-identical incl. dma_wait tag +maps) and numerically end-to-end (gemm/bmm/conv2d/transcendental, Spike+gem5 allclose). +""" +import os +import sys + +_DEFAULT_BINDINGS = "/riscv-llvm/python_packages/mlir_core" +if os.path.isdir(_DEFAULT_BINDINGS) and _DEFAULT_BINDINGS not in sys.path: + sys.path.insert(0, _DEFAULT_BINDINGS) + +import mlir.ir as ir # noqa: E402 + +# math op name -> (opcode, imm) for the vcix.v.iv lowering (mirror Math*ToVCIX). +_MATH_VIV = { + "math.exp": (0b000011, 0), + "math.erf": (0b000000, 0), + "math.tanh": (0b000001, 0), + "math.sin": (0b000010, 0), + "math.cos": (0b000010, 1), +} + + +def _sew(elt_ty): + if ir.F16Type.isinstance(elt_ty) or ir.BF16Type.isinstance(elt_ty): + return 16 + if ir.F32Type.isinstance(elt_ty): + return 32 + if ir.F64Type.isinstance(elt_ty): + return 64 + if ir.IntegerType.isinstance(elt_ty): + return ir.IntegerType(elt_ty).width + if ir.IndexType.isinstance(elt_ty): + return 64 + return 0 + + +def _log2(x): + return x.bit_length() - 1 + + +def _legalize_vector_type(vt, vlen): + """Mirror legalizeVectorType: return (n, legal_vector_type) or (0, None).""" + elt_ty = vt.element_type + sew = _sew(elt_ty) + if sew == 0: + return 0, None + elt_count = vt.shape[0] + lmul = elt_count * sew // 64 + scalable = vt.scalable + if not scalable: + n = (_log2(lmul) - 2) if lmul > 32 else 1 + if n == 1: + return 1, vt + return n, ir.VectorType.get([vlen // (sew // 8)], elt_ty) + n = (_log2(lmul) - 2) if lmul > 8 else 1 + return n, ir.VectorType.get([elt_count >> (n - 1)], elt_ty, scalable=[True]) + + +def _i64(v): + return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), v) + + +def _i32(v): + return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), v) + + +def _viv(operand, result_ty, opcode, imm, rvl=None): + """Create an unregistered vcix.v.iv (vcix::BinaryImmOp) op at the current IP.""" + operands = [operand] if rvl is None else [operand, rvl] + return ir.Operation.create( + "vcix.v.iv", results=[result_ty], operands=operands, + attributes={"opcode": _i64(opcode), "imm": _i32(imm)}).results[0] + + +def _make_sf_vc_v_iv(vec, op_vt, n, legal_ty, opcode, imm): + """Mirror make_sf_vc_v_iv: chunk `vec` (type op_vt) into legal-width vcix.v.iv.""" + from mlir.dialects import arith, vector + total = op_vt.shape[0] + elt_count = legal_ty.shape[0] + scalable = legal_ty.scalable + rvl = None + if scalable: + rvl = arith.ConstantOp(ir.IntegerType.get_signless(64), _i64(9)).result + if n == 1: + return _viv(vec, legal_ty, opcode, imm, rvl) + elt_ty = legal_ty.element_type + zero = ir.DenseElementsAttr.get_splat(op_vt, ir.FloatAttr.get(elt_ty, 0.0)) + res = arith.ConstantOp(op_vt, zero).result + if scalable: + for i in range(n): + ext = vector.ScalableExtractOp(legal_ty, vec, i * elt_count).result + v = _viv(ext, legal_ty, opcode, imm, rvl) + res = vector.ScalableInsertOp(v, res, i * elt_count).result + else: + for i in range(total // elt_count): + ext = vector.ExtractStridedSliceOp( + ir.ArrayAttr.get([_i64(i * elt_count)]), + ir.ArrayAttr.get([_i64(elt_count)]), + ir.ArrayAttr.get([_i64(1)]), vec).result + v = _viv(ext, legal_ty, opcode, imm, rvl) + res = vector.InsertStridedSliceOp( + v, res, ir.ArrayAttr.get([_i64(i * elt_count)]), + ir.ArrayAttr.get([_i64(1)])).result + return res + + +def _iter_ops(block): + for op in list(block.operations): + yield op + for region in op.operation.regions: + for b in region.blocks: + yield from _iter_ops(b) + + +# --------------------------------------------------------------------------- +# matmul lowering helpers (mirror MatmulOpLowering) +# --------------------------------------------------------------------------- +def _elt_bits(elt_ty): + if ir.IntegerType.isinstance(elt_ty): + return ir.IntegerType(elt_ty).width + return ir.FloatType(elt_ty).width + + +def _bool_attr_true(op, key): + a = op.attributes + return key in a and ir.BoolAttr(a[key]).value + + +def _enclosing_loops(op): + """Walk ancestor ops; return (accumulation, outer, inner) affine.for lists, + outermost-first (mirror the C++ insert-at-begin).""" + acc, outer, inner = [], [], [] + parent = op.operation.parent + while parent is not None: + if parent.name == "affine.for": + if _bool_attr_true(parent, "accumulation_loop"): + acc.insert(0, parent) + if _bool_attr_true(parent, "outer_loop"): + outer.insert(0, parent) + if _bool_attr_true(parent, "inner_loop"): + inner.insert(0, parent) + parent = parent.parent + return acc, outer, inner + + +def _loop_iv(forop): + return forop.regions[0].blocks[0].arguments[0] + + +def _loop_ub(forop): + # single constant upper bound + m = ir.AffineMapAttr(forop.attributes["upperBoundMap"]).value + return ir.AffineConstantExpr(m.results[0]).value + + +def _block_terminator(forop): + blk = forop.regions[0].blocks[0] + ops = list(blk.operations) + return ops[-1] + + +def _affine_consts(expr): + """All AffineConstantExpr values reachable in `expr` (recursive).""" + out = [] + if ir.AffineConstantExpr.isinstance(expr): + out.append(ir.AffineConstantExpr(expr).value) + elif ir.AffineBinaryExpr.isinstance(expr): + be = ir.AffineBinaryExpr(expr) + out += _affine_consts(be.lhs) + out += _affine_consts(be.rhs) + return out + + +def _scan_conv_offsets(ow_loop, o_h, k_h, o_w, k_w): + """Mirror the heuristic offset scan: find affine.apply(o_h,k_h)/(o_w,k_w) in the + o_w loop and read the constant in its map (default 1).""" + offset_h = offset_w = 1 + for o in _iter_ops(ow_loop.regions[0].blocks[0]): + if o.operation.name != "affine.apply": + continue + ops = list(o.operation.operands) + if len(ops) < 2: + continue + m = ir.AffineMapAttr(o.operation.attributes["map"]).value + consts = _affine_consts(m.results[0]) + if ops[0] == o_h and ops[1] == k_h and consts: + offset_h = consts[-1] + if ops[0] == o_w and ops[1] == k_w and consts: + offset_w = consts[-1] + return offset_h, offset_w + + +def _mem_space(v): + mt = ir.MemRefType(v.type) + ms = mt.memory_space + return ir.IntegerAttr(ms).value if ms is not None else 0 + + +def _dram_is_write(src, dst): + """(dram_memref, is_write) by memory space, mirror getDramMemRef.""" + ss, ds = _mem_space(src), _mem_space(dst) + if ds == 0 and ss == 1: + return dst, True + if ds == 1 and ss == 0: + return src, False + return None, False + + +def _idx(v): + return ir.IntegerAttr.get(ir.IndexType.get(), v) + + +def _const_index(v): + from mlir.dialects import arith + return arith.ConstantOp(ir.IndexType.get(), _idx(v)).result + + +def _apply(map_, operands): + from mlir.dialects import affine + return affine.AffineApplyOp(map_, list(operands)).result + + +def _spad_maps(): + d0, d1, d2 = (ir.AffineDimExpr.get(i) for i in range(3)) + s0, s1 = (ir.AffineSymbolExpr.get(i) for i in range(2)) + spad = ir.AffineMap.get(3, 2, [d0 * s0 + d1 * s1 + d2]) + x = ir.AffineMap.get(1, 1, [ir.AffineExpr.get_floor_div(ir.AffineDimExpr.get(0), + ir.AffineSymbolExpr.get(0))]) + y = ir.AffineMap.get(1, 1, [ir.AffineExpr.get_mod(ir.AffineDimExpr.get(0), + ir.AffineSymbolExpr.get(0))]) + return spad, x, y + + +def _transfer_read(vec_ty, source, indices, padding): + from mlir.dialects import vector + src_rank = len(ir.MemRefType(source.type).shape) + vec_rank = len(ir.VectorType(vec_ty).shape) + perm = ir.AffineMap.get_minor_identity(src_rank, vec_rank) + return vector.TransferReadOp(vec_ty, source, list(indices), perm, padding, + [False] * vec_rank).result + + +def _transfer_write(value, dest, indices): + from mlir.dialects import vector + dst_rank = len(ir.MemRefType(dest.type).shape) + vec_rank = len(ir.VectorType(value.type).shape) + perm = ir.AffineMap.get_minor_identity(dst_rank, vec_rank) + vector.TransferWriteOp(None, value, dest, list(indices), perm, [False] * vec_rank) + + +def _dma_wait(tag, idx, num_elements): + from mlir.dialects import memref + memref.DmaWaitOp(tag, [idx], num_elements) + + +def _vcix(name, operands, result_tys, attrs): + return ir.Operation.create(name, results=result_tys, operands=list(operands), + attributes=attrs) + + +def _reaches(value, target): + if value == target: + return True + owner = value.owner + if isinstance(owner, ir.Block): + return False + for operand in owner.operands: + if _reaches(operand, target): + return True + return False + + +class _DmaView: + """Positional view of a customized memref.dma_start (see lower_dma_to_gemmini).""" + + def __init__(self, op): + self.op = op + operands = list(op.operands) + src_rank = len(ir.MemRefType(operands[0].type).shape) + i = 0 + self.src = operands[i]; i += 1 + i += src_rank + self.dst = operands[i]; i += 1 + dst_rank = len(ir.MemRefType(self.dst.type).shape) + i += dst_rank + i += 1 # num_elements + self.tag = operands[i]; i += 1 + tag_rank = len(ir.MemRefType(self.tag.type).shape) + self.tag_idx = operands[i:i + tag_rank] + + def subtile_size(self): + a = self.op.attributes + if "subtile_size" not in a: + return [] + return [ir.IntegerAttr(x).value for x in ir.ArrayAttr(a["subtile_size"])] + + def is_async(self): + a = self.op.attributes + if "async" not in a: + return False + try: + return bool(ir.IntegerAttr(a["async"]).value) + except Exception: + return True + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +def _lower_matmul(op, SS, vlen): + """Lower one linalg.matmul (gemm path) to the vcix push/compute/pop sequence. + Returns True if lowered, False if skipped (conv2d / unexpected nesting -> left + for the C++ pass / a later port). Mirrors MatmulOpLowering (gemm branch).""" + from mlir.dialects import arith + + A, B, C = op.operands[0], op.operands[1], op.operands[2] + mtA, mtB = ir.MemRefType(A.type), ir.MemRefType(B.type) + elt = mtA.element_type + M, K, N = mtA.shape[0], mtA.shape[1], mtB.shape[1] + elen = _elt_bits(elt) + nr_element = vlen // elen + i64 = ir.IntegerType.get_signless(64) + def a64(v): return ir.IntegerAttr.get(i64, v) + + acc, outer, inner = _enclosing_loops(op) + is_conv2d = len(inner) == 4 + if not acc or len(outer) < 2: + return False + tile_kw = tile_oh = tile_ow = None + if is_conv2d: # inner = [k_h, k_w, o_h, o_w] + tile_kw, tile_oh, tile_ow = inner[1], inner[2], inner[3] + + vectorType = ir.VectorType.get([nr_element], elt) + nr_m = max(min(M, nr_element), 2) + vectorMType = ir.VectorType.get([nr_m], elt) + spad_map, spadX, spadY = _spad_maps() + + idxMap = [0, 1, 2] + if "idx_map" in op.attributes: + idxMap = list(ir.DenseI32ArrayAttr(op.attributes["idx_map"])) + + # Scan the outermost loop for the A/B/Bias load DMAs (tags + subtile). + ATag = BTag = BiasTag = None + AAsync = BAsync = BiasAsync = 0 + BiasIdx = None + subtileM, subtileN, subtileK = M, N, K + for o in _iter_ops(outer[-1].regions[0].blocks[0]): + if o.operation.name != "memref.dma_start": + continue + d = _DmaView(o.operation) + dram, is_write = _dram_is_write(d.src, d.dst) + if dram is None or is_write: + continue + sram = d.dst # MVIN: dst is the spad + if not any(_reaches(opnd, sram) for opnd in op.operands): + continue + if not isinstance(dram.owner, ir.Block): # must be a block argument + continue + argn = ir.BlockArgument(dram).arg_number + sub = d.subtile_size() + if argn == idxMap[0]: + ATag, AAsync = d.tag, d.is_async() + if len(sub) >= 2: + subtileM, subtileK = sub[-2], sub[-1] + elif argn == idxMap[1]: + BTag, BAsync = d.tag, d.is_async() + if len(sub) >= 2: + subtileK, subtileN = sub[-2], sub[-1] + elif argn == idxMap[2]: + BiasTag, BiasAsync = d.tag, d.is_async() + BiasIdx = d.tag_idx + if ATag is None or BTag is None: + return False + + KStep = subtileK + push_length = min(subtileM, SS) + MStep = min(M, push_length) + NStep = min(subtileN, SS) + M_LOOP = min(M, push_length) + + # conv2d builds inside the existing k_w loop; gemm builds at the matmul site. + ip = (ir.InsertionPoint.at_block_terminator(tile_kw.regions[0].blocks[0]) + if is_conv2d else ir.InsertionPoint(op)) + with ip: + c0 = _const_index(0) + rvl = arith.ConstantOp(i64, a64(nr_element)).result + K_val, N_val, M_val = _const_index(K), _const_index(N), _const_index(M) + push_val = _const_index(push_length) + num1 = _const_index(1) + zero_pad = arith.ConstantOp(elt, ir.FloatAttr.get(elt, 0.0)).result + + # --- inner N / K loops --- + from mlir.dialects import affine + body_ip = ip + n_idx = c0 + k_idx = c0 + nk_inner = None # innermost n/k loop created (conv2d hosts o_h/o_w here) + if N > SS: + with body_ip: + nl = affine.AffineForOp(0, N // SS, 1) + nl.operation.attributes["inner_loop"] = ir.BoolAttr.get(True) + n_idx = nl.induction_variable + with ir.InsertionPoint(nl.body): + affine.AffineYieldOp([]) + body_ip = ir.InsertionPoint.at_block_terminator(nl.body) + nk_inner = nl + zero_vector = None + if K > SS: + with body_ip: + kl = affine.AffineForOp(0, K // SS, 1) + kl.operation.attributes["inner_loop"] = ir.BoolAttr.get(True) + k_idx = kl.induction_variable + with ir.InsertionPoint(kl.body): + affine.AffineYieldOp([]) + body_ip = ir.InsertionPoint.at_block_terminator(kl.body) + nk_inner = kl + else: + with body_ip: + zv = ir.DenseElementsAttr.get_splat(vectorType, ir.FloatAttr.get(elt, 0.0)) + zero_vector = arith.ConstantOp(vectorType, zv).result + + n_tag = c0 if N == subtileN else n_idx + k_tag = c0 if K == subtileK else k_idx + + with body_ip: + # --- B dma_wait --- + nacc = len(acc) + acc_ivs = [_loop_iv(l) for l in acc] + bexpr = ir.AffineDimExpr.get(0) * -1 + for i in range(1, nacc): + bexpr = bexpr + ir.AffineDimExpr.get(i) * -1 + b_extra = [] + bdo = nacc + if is_conv2d: + kW = _loop_ub(tile_kw) + bdo = nacc + 2 + bexpr = (bexpr + + ir.AffineDimExpr.get(bdo - 2) * ((N // subtileN) * (K // subtileK) * kW) + + ir.AffineDimExpr.get(bdo - 1) * ((N // subtileN) * (K // subtileK))) + b_extra = [_loop_iv(inner[0]), _loop_iv(inner[1])] # k_h, k_w + bexpr = (bexpr + + ir.AffineExpr.get_floor_div(ir.AffineDimExpr.get(bdo), _ceil_div(NStep, SS)) * (K // KStep) + + ir.AffineExpr.get_floor_div(ir.AffineDimExpr.get(bdo + 1), _ceil_div(KStep, SS)) * 1) + bmap = ir.AffineMap.get(bdo + 2, 0, [bexpr]) + btag_idx = _apply(bmap, acc_ivs + b_extra + [n_tag, k_tag]) + if BAsync: + _dma_wait(BTag, btag_idx, num1) + + # --- weight push loop (K x N) --- + for i in range(0, SS, nr_element): + if i < K: + sp = _apply(spad_map, [n_idx, k_idx, _const_index(i), K_val, _const_index(SS)]) + wx = _apply(spadX, [sp, N_val]) + wy = _apply(spadY, [sp, N_val]) + wv = _transfer_read(vectorType, B, [wx, wy], zero_pad) + else: + wv = zero_vector + _vcix("vcix.iv", [wv, rvl], [], + {"opcode": a64(1), "imm": a64(0), "rd": a64(0)}) + + # conv2d: move the o_h/o_w spatial loops after the weight push and continue the + # input-push/compute/pop inside the o_w loop (heuristic, mirrors the C++ branch + # for the no-extra-inner-loop case). + if is_conv2d: + # host the o_h/o_w spatial loops inside the innermost n/k loop (so n_idx/k_idx + # stay in scope) or directly in the k_w loop when no n/k loop was created. + host = nk_inner if nk_inner is not None else tile_kw + tile_oh.operation.move_before(_block_terminator(host)) + body_ip = ir.InsertionPoint.at_block_terminator(tile_ow.regions[0].blocks[0]) + + # --- M loop --- + m_idx = c0 + if M > push_length: + with body_ip: + ml = affine.AffineForOp(0, M // push_length, 1) + ml.operation.attributes["inner_loop"] = ir.BoolAttr.get(True) + m_idx = ml.induction_variable + with ir.InsertionPoint(ml.body): + affine.AffineYieldOp([]) + body_ip = ir.InsertionPoint.at_block_terminator(ml.body) + m_tag = c0 if M == subtileM else m_idx + + with body_ip: + # --- A dma_wait --- + aexpr = ir.AffineDimExpr.get(0) * -1 + for i in range(1, nacc): + aexpr = aexpr + ir.AffineDimExpr.get(i) * -1 + a_extra = [] + ado = nacc + if is_conv2d: + k_h, k_w, o_h, o_w = (_loop_iv(inner[j]) for j in range(4)) + kW, oW = _loop_ub(tile_kw), _loop_ub(tile_ow) + offset_h, offset_w = _scan_conv_offsets(tile_ow, o_h, k_h, o_w, k_w) + coeff_h = 1 + (oW - 1) * offset_w + (kW - 1) + ado = nacc + 2 + aexpr = (aexpr + + ir.AffineDimExpr.get(ado - 2) * ((K // subtileK) * (M // subtileM) * offset_h * coeff_h) + + ir.AffineDimExpr.get(ado - 1) * ((K // subtileK) * (M // subtileM) * offset_w)) + a_extra = [o_h, o_w] + aexpr = (aexpr + + ir.AffineDimExpr.get(ado) * (M // MStep) + + ir.AffineExpr.get_floor_div(ir.AffineDimExpr.get(ado + 1), _ceil_div(MStep, SS))) + amap = ir.AffineMap.get(ado + 2, 0, [aexpr]) + atag_idx = _apply(amap, acc_ivs + a_extra + [k_tag, m_tag]) + if AAsync: + _dma_wait(ATag, atag_idx, num1) + + # --- Bias dma_wait --- + if BiasTag is not None: + bias_is_const = BiasIdx and BiasIdx[0].owner.name == "arith.constant" + first_i = c0 if bias_is_const else n_tag + third_i = c0 if bias_is_const else m_tag + d0, d1 = ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1) + bias_expr = (ir.AffineExpr.get_floor_div(d0, _ceil_div(NStep, SS)) * (M // MStep) + + ir.AffineExpr.get_floor_div(d1, _ceil_div(MStep, SS))) + bias_map = ir.AffineMap.get(2, 0, [bias_expr]) + bias_tag_idx = _apply(bias_map, [first_i, third_i]) + if BiasAsync: + _dma_wait(BiasTag, bias_tag_idx, num1) + + # --- input push loop (M x K) --- + for i in range(0, M_LOOP, nr_element): + sp = _apply(spad_map, [k_idx, m_idx, _const_index(i), M_val, push_val]) + x = _apply(spadX, [sp, K_val]) + y = _apply(spadY, [sp, K_val]) + iv = _transfer_read(vectorMType, A, [x, y], zero_pad) + _vcix("vcix.iv", [iv, rvl], [], + {"opcode": a64(0), "imm": a64(0), "rd": a64(0)}) + + # --- compute --- + _vcix("vcix.i", [rvl], [], + {"opcode": a64(1), "imm": a64(4), "rd": a64(0), "rs2": a64(0), + "sew": a64(elen), "lmul": a64(0)}) + + # --- pop loop (M x N) --- + for i in range(0, M_LOOP, nr_element): + sp = _apply(spad_map, [n_idx, m_idx, _const_index(i), M_val, push_val]) + vpop = _vcix("vcix.v.i", [rvl], [vectorMType], + {"opcode": a64(2), "imm": a64(0), "rs2": a64(0)}).results[0] + x = _apply(spadX, [sp, N_val]) + y = _apply(spadY, [sp, N_val]) + prev = _transfer_read(vectorMType, C, [x, y], zero_pad) + if ir.IntegerType.isinstance(elt): + out = arith.AddIOp(prev, vpop).result + else: + out = arith.AddFOp(prev, vpop).result + _transfer_write(out, C, [x, y]) + op.erase() + return True + + +def run(module, vectorlane=128, vlen=128, **_): + """Lower linalg.matmul (gemm) + transcendental math ops to VCIX ops, in place.""" + # matmul first (uses surrounding loop structure before any rewrites) + mms = [] + for region in module.operation.regions: + for b in region.blocks: + for o in _iter_ops(b): + if o.operation.name == "linalg.matmul": + mms.append(o.operation) + for o in mms: + _lower_matmul(o, vectorlane, vlen) + targets = [] + for region in module.operation.regions: + for b in region.blocks: + for op in _iter_ops(b): + if op.operation.name in _MATH_VIV: + targets.append(op.operation) + for op in targets: + opcode, imm = _MATH_VIV[op.name] + vec = op.operands[0] + res_ty = op.results[0].type + vt = ir.VectorType(res_ty) + n, legal_ty = _legalize_vector_type(vt, vlen) + if legal_ty is None: + continue + with ir.InsertionPoint(op): + new = _make_sf_vc_v_iv(vec, vt, n, legal_ty, opcode, imm) + op.results[0].replace_all_uses_with(new) + op.erase() + + +def run_to_vcix(in_path, out_path, vectorlane=128, vlen=128): + with open(in_path) as f: + text = f.read() + ctx = ir.Context() + ctx.allow_unregistered_dialects = True + with ctx, ir.Location.unknown(): + module = ir.Module.parse(text) + run(module, vectorlane=vectorlane, vlen=vlen) + out = str(module) + with open(out_path, "w") as f: + f.write(out) + + +if __name__ == "__main__": + vl = int(sys.argv[3]) if len(sys.argv) > 3 else 128 + vlen_ = int(sys.argv[4]) if len(sys.argv) > 4 else 128 + run_to_vcix(sys.argv[1], sys.argv[2], vl, vlen_) From e5cc6a519e5088847bc1c937dba3af6f5e3075aa Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 23:17:00 +0900 Subject: [PATCH 12/20] [Docs] lower_to_llvm: only test-loop-padding remains in mlir-opt dma-fine-grained and pytorchsim-to-vcix are now Python passes (dma_fine_grained, lower_to_vcix); update the docstring listing -- only test-loop-padding still runs in mlir-opt. --- PyTorchSimFrontend/mlir/passes/lower_to_llvm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py b/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py index ce6e081a..ad287499 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py +++ b/PyTorchSimFrontend/mlir/passes/lower_to_llvm.py @@ -2,11 +2,12 @@ Runs the upstream *registered* lowering passes (convert-*-to-llvm, lower-affine, reconcile-unrealized-casts, ...) in-process on the post-custom-pass IR, replacing -the tail of the mlir-opt pipeline. The remaining custom passes (test-loop-padding, -dma-fine-grained, test-pytorchsim-to-vcix) still run in mlir-opt; the gem5 path's -test-tile-operation-graph is now the Python build_tog pass, and memref-to-gemmini -is the Python lower_dma_to_gemmini pass (run inside this lowering). As the custom -passes migrate to Python, mlir-opt shrinks toward an all-in-process flow. +the tail of the mlir-opt pipeline. Most custom passes are now Python out-of-line +passes: dma-fine-grained -> dma_fine_grained, test-pytorchsim-to-vcix -> +lower_to_vcix, test-tile-operation-graph -> build_tog (gem5 path), memref-to-gemmini +-> lower_dma_to_gemmini (run inside this lowering), global-idx -> lower_vlane_idx. +Only test-loop-padding still runs in mlir-opt; once it migrates, mlir-opt drops out +entirely and the flow is fully in-process. Validated to produce byte-identical LLVM IR to running the same passes inside mlir-opt. Note: only lower-vector-multi-reduction is func.func-scoped (the From 190dd045336b8f8c3d8146517bb0be294e73f1f9 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 23:22:39 +0900 Subject: [PATCH 13/20] [Frontend] Retire dead floor/mod recompile branches in codegen axis-split + graph-copy (on by default) linearize aligned floor/mod at the scheduling layer, so the index reaching get_dma_info is affine and the FloorDiv/ModularIndexing tile-divisibility branches there are never entered (measured: 0 entries across elementwise, gemm, bmm, conv, cat, floor/mod, reduce, attention). Remove those dead branches and their orphans: - the FloorDiv and ModularIndexing tile-forcing + RecompileSignal blocks - the implicit-ModularIndexing index rewrite and implicit_local_dims - the dead ModularIndexing branch in the dram_stride computation - is_modular_indexing, the write-only implicit_dim_size, unused import sys Kept: the non-floor/mod recompile paths (index-divisibility, indirect access, non-power-of-2 vec size), RecompileSignal, and the retry loop. The upstream implicit_dim_ops tile-forcing is left untouched (separate change). Validated end-to-end (Spike + TOGSim): elementwise, gemm, bmm, conv2d, group_conv, pool, cat, floor/mod suite, reduce, softmax, layernorm, batchnorm, gqa -- all pass, 0 recompiles. Co-Authored-By: Claude Opus 4.8 --- .../mlir/mlir_codegen_backend.py | 113 +----------------- PyTorchSimFrontend/mlir/mlir_common.py | 5 - 2 files changed, 1 insertion(+), 117 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index a001c861..6529d8d9 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1,6 +1,5 @@ import contextlib import sympy -import sys import time import re import os @@ -1166,7 +1165,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Note: index could contain symbols that represent dynamic axies # Extract dimension of index(e.g, index0, index1) local_dims = [int(str(i)[5:]) for i in index.free_symbols if "index" in str(i)] - implicit_local_dims = list(index.args) total_dims = [int(str(i)[5:]) for i in self.itervars] local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) local_dims.sort() # Assume that smaller index is placed in the outer loop @@ -1243,14 +1241,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride - if len(implicit_local_dims)!=0 and len(local_dims) != len(implicit_local_dims) and self.is_modular_indexing(index): - for axis_constraints in self.kernel_group.tile_desc.implicit_dim_size.values(): - if len(axis_constraints) <= 1: - continue - sorted_constraints = sorted(axis_constraints, key=lambda c: int(c.args[1])) - for constraint in sorted_constraints[1:]: - index = index.replace(constraint.original_expr, 0) - # Calculate dram stride in local tile-dim order. # This keeps dram/sram stride rank aligned with tile rank. local_dim_to_axis = {dim: axis for axis, dim in enumerate(local_dims)} @@ -1264,19 +1254,12 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe else: dram_dict = defaultdict(list) - implicit_dim_divisors = defaultdict(lambda: sys.maxsize) - # Assume that div will have high priority than mod for arg in index.as_ordered_terms(): coeff, dim = arg.as_coeff_mul() if len(dim) == 0: continue real_dim = list(dim[0].free_symbols)[0] - if dim[0].has(ModularIndexing): - if dim[0].args[1] < implicit_dim_divisors[str(real_dim)]: - implicit_dim_divisors[str(real_dim)] = dim[0].args[1] - dram_dict[str(real_dim)] = [coeff] - else: - dram_dict[str(real_dim)].append(coeff) + dram_dict[str(real_dim)].append(coeff) # Add missing dims if not added max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 @@ -1287,100 +1270,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe sorted_keys = sorted(dram_dict.keys()) dram_stride = sum((dram_dict[key] for key in sorted_keys), []) - # Support floordiv pattern - # FIXME. How to integrate implicit dims and floordiv? - # This was introduced to support GroupNorm - if index.has(FloorDiv) and not index.has(ModularIndexing): - dim_divisor = [1] * len(local_dims) - for sub in sympy.preorder_traversal(index): - if isinstance(sub, FloorDiv): - if not str(sub.args[0]).startswith("index"): - continue - dim_idx = int((str(sub.args[0])[5:])) - if dim_idx not in local_dim_to_axis: - continue - local_dim_idx = local_dim_to_axis[dim_idx] - if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: - # In this case, need to recompile - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - divisor = sub.args[1] * self.kernel_group.tile_desc.vmap.vlane_stride - new_size = ((original_size + divisor - 1) // divisor) * divisor - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - # Can't use dim_idx as vlane_split_axis - if dim_idx == self.kernel_group.tile_desc.vmap.vlane_split_axis: - self.kernel_group.tile_desc.vmap.vlane_split_axis = (dim_idx + 1) % len(original_tile) - - # Send recompile signal - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") - dim_divisor[local_dim_idx] = sub.args[1] - - # Update dram_stride, just insert 0 next to target dim - offset = 0 - for dim_idx, divisor in enumerate(dim_divisor): - if divisor == 1: - continue - dram_stride.insert(dim_idx+offset+1, 0) - local_tile_desc.apply_divisor(dim_idx+offset, divisor, "pad") - local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split") - offset = offset+1 - - # Support ModularIndexing pattern - # This pattern can be used to broadcast ex) torch.cat([a,a]) - # ModularIndexing(x, y, z) means (x // y) % z - # tile_size must be: multiple of y (floorDiv divisor) and divisor of z (modular divisor) - if index.has(ModularIndexing): - for sub in sympy.preorder_traversal(index): - if isinstance(sub, ModularIndexing): - if not str(sub.args[0]).startswith("index"): - continue - dim_idx = int((str(list(sub.args[0].free_symbols)[0])[5:])) - floor_divisor = sub.args[1] # y: floorDiv divisor - mod_divisor = sub.args[2] # z: modular divisor - current_tile_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] - - # Check if tile_size is multiple of floorDiv divisor - if int(current_tile_size % floor_divisor) != 0: - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - divisor = floor_divisor * self.kernel_group.tile_desc.vmap.vlane_stride - new_size = ((original_size + divisor - 1) // divisor) * divisor - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a multiple of floorDiv divisor {floor_divisor} in ModularIndexing") - - # Check if tile_size is a divisor of modular divisor - if int((mod_divisor * floor_divisor) % current_tile_size) != 0: - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - # Find the largest divisor of mod_divisor that is <= original_size - # and is a multiple of floor_divisor - new_size = original_size - while new_size > 0: - if mod_divisor % new_size == 0 and new_size % floor_divisor == 0: - break - new_size -= floor_divisor - - if new_size <= 0: - new_size = mod_divisor * floor_divisor - - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a divisor of modular divisor {mod_divisor} in ModularIndexing") - # FIXME. It will be nice to modify node instead of this exception handling... if len(self.itervars) == 1 and self.reduction_depth == 0: # In case of reduction loop only case, we will add dummy loop so shift it once diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index a7921463..38a77293 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -520,7 +520,6 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N vlane_stride=vlane_stride ) - self.implicit_dim_size = {} self.nr_rdim = 0 self.offset = sympy.Integer(0) # Dram offset @@ -686,9 +685,6 @@ def call_kernel(self, kernel_name): # generate the code to call this wrapper.generate_kernel_call(kernel_name, call_args, triton=False) - def is_modular_indexing(self, expr): - return "ModularIndexing" in str(expr) - def implicit_dim_ops(self, nodes): target_patterns = (ModularIndexing, FloorDiv, Mod) target_operands = [] @@ -764,7 +760,6 @@ def compute_tile_size(self, nodes, vars, reduction_vars): if implicit_ops: tile_constraints = self.extract_dividers(implicit_ops) self.kernel_group.tile_desc.apply_constraints(tile_constraints, self.ranges) - self.kernel_group.tile_desc.implicit_dim_size = tile_constraints # Check recodegen reason if self.recodegen is not None: From 82d47f90fd754557dcae835d45ebcdcd3db9cab6 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 23:40:11 +0900 Subject: [PATCH 14/20] [Frontend] Retire implicit_dim_ops tile-forcing (redundant under axis-split) implicit_dim_ops/extract_dividers/apply_constraints forced the initial tile size to match a view's floor/mod divider, up front in compute_tile_size. axis-split now linearizes those views at the scheduling layer, so the forcing is redundant: disabling it leaves every test allclose-correct and, on the affected kernels, slightly faster (the forced tile was over-constrained -- batchnorm 1189->1114, layernorm 4092->3947 cycles; non-floor/mod kernels unchanged). Remove the machinery and its now-unused imports (ModularIndexing, FloorDiv, Mod, MemoryDep, StarDep, WeakDep). Validated end-to-end (Spike + TOGSim): elementwise, gemm, bmm, conv2d, group_conv, pool, cat, floor/mod suite, reduce, softmax, layernorm, batchnorm, gqa -- all pass, 0 recompiles. Co-Authored-By: Claude Opus 4.8 --- PyTorchSimFrontend/mlir/mlir_common.py | 74 +------------------------- 1 file changed, 1 insertion(+), 73 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 38a77293..748c389c 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -15,9 +15,8 @@ from torch._inductor.codegen import cpp from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout -from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep from torch._inductor.codegen.wrapper import KernelDefinitionLine -from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity +from torch.utils._sympy.functions import Identity import sympy import contextlib @@ -436,20 +435,6 @@ def pad_vlane_tile(self): padded_size = used_vlane * vlane_stride self._tile_size[vlane_split_axis] = math.ceil(self._tile_size[vlane_split_axis] / padded_size) * padded_size - def apply_constraints(self, constraints, ranges): - for idx, (axis_constraints, axis_size) in enumerate(zip(constraints.values(), ranges)): - for const in axis_constraints: - if const.args[1] == 1: - continue - divider = int(const.args[1]) - - if not self.tile_constraint[idx].fixed: - self.tile_constraint[idx].fixed = True - self._tile_size[idx] = divider - elif self.tile_constraint[idx].fixed and self._tile_size[idx] > divider: - self._tile_size[idx] = divider - self.update_tile_stride() - @staticmethod def init_tile_size(ranges, vlane_stride, vector_lane): # Logical tile init for ANY rank. Only the innermost dims carry the @@ -685,56 +670,6 @@ def call_kernel(self, kernel_name): # generate the code to call this wrapper.generate_kernel_call(kernel_name, call_args, triton=False) - def implicit_dim_ops(self, nodes): - target_patterns = (ModularIndexing, FloorDiv, Mod) - target_operands = [] - for target_node in nodes: - for read_operand in target_node.read_writes.reads: - read_operand: MemoryDep - if isinstance(read_operand, StarDep) or isinstance(read_operand, WeakDep): - continue - read_index = read_operand.index - for arg_expr in read_index.args: - if arg_expr.atoms(*target_patterns): - target_operands.append(read_operand) - return target_operands - - def extract_dividers(self, implicit_ops): - # When a specific axis is processed, the key constraint to verify is the divider. - # The tile size must be forced to match the divider size. - dim_dividers = defaultdict(set) - for operand in implicit_ops: - subs_map = { - s: sympy.symbols(s.name.replace("c", "index", 1)) - for s in operand.index.free_symbols - } - rev_subs_map = { - sympy.symbols(s.name.replace("c", "index", 1)) : s - for s in operand.index.free_symbols - } - new_index = operand.index.subs(subs_map) - for arg in new_index.args: - if arg.is_number: - continue - if len(arg.free_symbols) > 1: - raise NotImplementedError("Not supporting this view operation...!") - if arg.is_Mul and arg.args[0].is_number: - arg = arg.args[1] - - if isinstance(arg, ModularIndexing): - modular_expr = ModularIndexing(arg.args[0], arg.args[1], arg.args[2]) - modular_expr.original_expr = arg - elif arg.is_symbol: - modular_expr = ModularIndexing(arg, 1, operand.ranges[rev_subs_map[arg]]) - modular_expr.original_expr = arg - elif "//" in str(arg): - modular_expr = ModularIndexing(arg.args[0], arg.args[1], operand.ranges[rev_subs_map[arg.args[0]]]//arg.args[1]) - modular_expr.original_expr = arg - else: - raise NotImplementedError("What is this case?") - dim_dividers[modular_expr.args[0]].add(modular_expr) - return dim_dividers - def compute_tile_size(self, nodes, vars, reduction_vars): vlane_split_axis = len(vars) - 1 vlane_stride = 2 # Set minimum vlane stride @@ -754,13 +689,6 @@ def compute_tile_size(self, nodes, vars, reduction_vars): self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0] - # Handle implict dims. Input operand could be high dimension tensor. - # Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173 - implicit_ops = self.implicit_dim_ops(nodes) - if implicit_ops: - tile_constraints = self.extract_dividers(implicit_ops) - self.kernel_group.tile_desc.apply_constraints(tile_constraints, self.ranges) - # Check recodegen reason if self.recodegen is not None: if self.recodegen == "spad_overflow": From 093a591ed15026e401ca50deac29b1ff400247e8 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 18 Jun 2026 15:35:30 +0900 Subject: [PATCH 15/20] [Docs] padding model + test-loop-padding plan; TPU layout/padding report Add the TPU layout-assignment & padding investigation (docs/tpu_layout_padding_report.md) and the loop-padding design doc. Settled model: padding is two layers -- (A) lane/sublane 8x128 alignment is materialized (footprint + DMA traffic), (B) the compute-block (MXU tile) boundary tail is masked (compute-utilization only, not traffic). test-loop-padding's post-codegen heuristic is to be replaced by informed emission at the scheduling/codegen layer (decide early, materialize late); the two costs must be modeled by separate functions (do not double-count the compute-block tail as traffic). --- docs/loop-padding-elimination.md | 344 ++++++++++++++++++++++++++++++ docs/tpu_layout_padding_report.md | 182 ++++++++++++++++ 2 files changed, 526 insertions(+) create mode 100644 docs/loop-padding-elimination.md create mode 100644 docs/tpu_layout_padding_report.md diff --git a/docs/loop-padding-elimination.md b/docs/loop-padding-elimination.md new file mode 100644 index 00000000..cfb9ce3b --- /dev/null +++ b/docs/loop-padding-elimination.md @@ -0,0 +1,344 @@ +# Padding model + retiring test-loop-padding (two-layer: alignment vs compute-tile) + +Status: **decided**. Earlier drafts argued for *eliminating* padding via variable-extent +DMA (rejected), then for *porting* the pass as-is (also wrong -- it over-materializes). +The settled answer (grounded in `docs/tpu_layout_padding_report.md`): padding is **two +layers** -- (A) lane/sublane alignment is materialized traffic, (B) compute-block tail is +masked compute-util -- and `test-loop-padding`'s post-codegen heuristic is replaced by +informed emission at the scheduling/codegen layer. See "RESOLVED MODEL" below for the +authoritative conclusion; the earlier sections are the analysis trail. + +## DECISION: the modeled NPU has no partial-extent DMA -> padding is fundamental + +We model an NPU whose DMA **always moves full tiles** (TPU-class dense movement); it +does **not** do partial / variable-extent transfers. Therefore: + +- Padding (full-tile DMA over padded buffers) is a **real architectural cost**, not a + simulator convenience. Moving the padding bytes is what the hardware actually does. +- "Handle the tail instead of padding" (variable-extent DMA, boundary clamp) would model + a *different machine* (one with partial-transfer DMA) and would under-count traffic / + cycles. **Rejected.** +- You cannot have "logical DRAM + full-tile boundary traffic" -- moving a full tile + requires the padded bytes to exist in DRAM. So eliminating the padded buffer is + incompatible with the full-tile-DMA model. The two are linked. + +Consequence: **keep padding, but do NOT port the current mechanism** -- reimplement it +at the layer that has the information. Why the current C++ pass is fundamentally wrong +(not just buggy): + +- It runs **after codegen** and **reverse-engineers** the padding need from the emitted + IR: it walks `affine.for` step sizes, `dram_stride`, and `affine.apply` maps to *guess* + which memrefs to grow, by how much, and how to rewrite addressing. The info it needs + (tile size, tensor shape, which dims are reduction vs parallel, the access map) was + **known at codegen time and thrown away**, then heuristically reconstructed. +- That reconstruction is inherently partial: hardcoded conv geometry (k_h/k_w/o_h/o_w), + "find the stride", coefficient-from-dim guessing, etc. New op patterns / multi-dim / + edge cases break it. **It cannot be shown to cover all cases** -- it is a heuristic + retrofit, not a derivation. + +Correct plan: **decide padding at the scheduling / codegen layer**, where the tile `T` +and extent `E` are *known*. When `T` does not divide `E`, the codegen knows +`E' = ceil(E/T)*T` directly and emits the padded buffer + full-tile loops/DMA **by +construction** -- no post-hoc IR analysis, no guessing. This eliminates the heuristic +pass (`test-loop-padding` -> gone, mlir-opt drops out) while **preserving padding** +(padded buffers + full-tile DMA traffic) and being robust by derivation rather than +inference. + +This is "scheduling-level padding" -- but to *produce* padding correctly, NOT to +eliminate it (the earlier tail-handling framing). Padding stays; only the mechanism +moves from a fragile post-codegen heuristic to direct emission from the tiling decision. + +Also resolved for free: the CI robustness bug (current pass uses `emitError` without +`signalPassFailure` -> error paths exit 0 and silently drop `@wrapper_kernel` -> +cryptic `undefined reference to wrapper_kernel` at link under autotune + unseeded +Poisson). Direct emission has no such silent-guess failure mode. + +## RESOLVED MODEL: padding is TWO layers (see tpu_layout_padding_report.md) + +The TPU layout/padding investigation (`docs/tpu_layout_padding_report.md`) settles the +debate: padding is **not one thing**. There are two layers with *different cost +semantics*, and conflating them was the source of the back-and-forth above. + +| layer | what | on real TPU | PyTorchSim cost | +|---|---|---|---| +| **(A) lane/sublane alignment** (8x128; T(2,128)/T(4,128) for small 2nd-minor; bf16 T(8,128)(2,1)) | tile must be address-aligned -> tensor stored padded in HBM | **materialized** (`nofold`); tensor physically bigger, unavoidable | **footprint + DMA traffic** (padded bytes are stored AND moved) | +| **(B) compute-block (MXU tile) boundary** (>8x128) | the contraction/output block tail when a dim isn't a multiple of the MXU block | **masking / peeling** (mostly NOT materialized); MXU computes zeros then masks output | **compute utilization only** (wasted MXU cycles) -- **NOT traffic** | + +This corrects both earlier extremes: +- "eliminate all padding" (tail-handling) -- WRONG: (A) is materialized real traffic. +- "materialize all padding" (current loop-padding) -- WRONG: it buffer-grows (B) too, + so it **double-counts (B) as traffic**; TPU masks (B). loop-padding over-materializes. + +### The two-cost-function rule (report 7.1 -- the key modeling constraint) +- **footprint / HBM traffic** function: count ONLY (A) lane/sublane alignment padding as + physical size (e.g. extent 100 -> stored/moved as 128). Reflect bf16 packing and the + small-2nd-minor T(2,128)/T(4,128) variants. +- **compute-utilization** function: (B) compute-block tail lowers MXU utilization via + masking; **do not add it as traffic** (would over-estimate bandwidth). Only the rare + alignment-forced `tensor.pad` materialization adds copy traffic. +- Pipeline ordering (report 1): the layout **decision** (which axis is lane, how much + alignment padding) is early/metadata; **materialization** is late. Matches "decide at + scheduling, materialize at codegen." + +### Corrected plan for test-loop-padding +Reimplement at the scheduling/codegen layer (informed by `tile_desc`, which already has +the vlane axis + tile sizes), splitting the two layers: +- **(A)** materialize lane/sublane alignment padding -> the padded staging buffer + + full-tile DMA (this is the `wrapper_kernel`-style staging; structurally necessary + because the real DRAM tensor is logical -- you cannot pad it in place). Counts as + footprint + traffic. +- **(B)** handle the compute-block tail by **masking** (`get_mask` already exists) -> + compute-util only, NOT a buffer grow, NOT extra traffic. +Then the post-codegen heuristic `test-loop-padding` is gone, padding is faithful per +layer, and the modeled hardware is unchanged. (Open: whether to force (A) to fixed +(8,128)/T(packing,128) granularity -- robust, matches TPU -- vs minimal `ceil`.) + +Validation (report 2.3 / 7.4): dump real XLA layouts with +`XLA_FLAGS="--xla_dump_to=... --xla_dump_hlo_as_text=true"` and read the `:T(...)` +annotations to ground the lane-axis + alignment-padding model against the compiler, +rather than guessing. + +--- +(Below: the earlier elimination analysis, retained as the record of WHY tail-handling +was rejected. The "fundamental vs not" framing still correctly explains the *mechanics*; +the conclusion -- that padding is eliminable -- is overturned by the DECISION above, +because on a full-tile-DMA NPU the padding traffic is a real cost that must be modeled.) + +## Original goal (SUPERSEDED -- kept for the analysis trail) + +`-test-loop-padding` is the only C++ MLIR pass still invoked in `mlir-opt` (after +build_tog, dma_fine_grained, lower_to_vcix, lower_dma_to_gemmini, lower_vlane_idx were +ported to Python). The original goal was to **eliminate it** by handling +tile-vs-extent misalignment at the scheduling layer. (Superseded: see DECISION -- we +port, not eliminate.) + +## What test-loop-padding does today (fact, from TestLoopPadding.cpp) + +Runs on the post-codegen MLIR of `@kernel`: + +1. For each `affine.for`, round the upper bound **up to a multiple of its step** + (= tile size): `paddedUpperBound = roundUpToMultiple(upperBound, stepSize)`. +2. For every DRAM `memref` indexed by a padded loop: **resize the memref** to the + padded extent (`modifyMemrefWithPadding`), **update the func signature** + (`updateFunctionSignatureWithMemRef`), and **rewrite `dram_stride` + the + `affine.apply` addressing maps** so the addressing matches the larger buffer. + Has a conv2d-specific path (nested `affine.apply` over k_h/k_w/o_h/o_w). +3. `timing_mode=1`: skip copying the padding region (cycles only, no real data). + +Net: loop trip counts and the **DRAM-side buffers** are grown to aligned sizes after +codegen, with addressing rewritten to match. + +## Where padding lives today (the split = the problem) + +| layer | mechanism | what it pads | +|---|---|---| +| Python tile selection (`mlir_common`: `apply_divisor("pad")`, `pad_vlane_tile`, `roundup_vectorlane`) + recompile-dance (`RecompileSignal`) | pads the **tile size** to vlane/divisor multiples; forces tiles via restart | the tile shape | +| Python `get_mask` (vector tail) | masks the unaligned tail **within a tile** for vector ops | partial-tile compute | +| MLIR `test-loop-padding` | rounds **loop trip count** + grows **DRAM buffer** + rewrites strides/maps | the iteration domain + DRAM side | + +Three mechanisms, three layers, one underlying concern (tile does not divide extent). + +## The mismatch model + +A dim has logical extent `E` and a chosen tile `T`. If `T | E`, everything is +aligned and loop-padding is a no-op. If `T does not divide E`, the last tile is +partial (`E mod T` elements). Two ways to make the hardware see full tiles: + +- **Pad**: treat the extent as `E' = ceil(E/T)*T`; the loop runs `E'/T` full tiles; + the tail tile covers padding (garbage / zero) that must not corrupt results. +- **Mask**: keep `E` and mask the partial tile so padding lanes/rows are inert. + +Today: tiles are padded in Python, the within-tile tail is masked (`get_mask`), and +the loop+DRAM are padded in MLIR. We want one coherent story. + +## Why padding exists -- what is fundamental vs not + +Three *distinct* things happen at a tile boundary; only one is layout-fundamental, and +it is not what loop-padding does. + +1. **Parallel-dim tail** (M, N, output spatial): the last tile just produces fewer + outputs. Process fewer -- no value-fill, no padding; only "don't read/write past the + real extent." +2. **Reduction-dim tail** (matmul K, reduce axis): the inactive elements must contribute + the **reduction identity** (0 for sum, -inf for max) or the result is corrupted. This + is real -- but it is satisfied either by buffer value-fill (loop-padding) OR by + masked/identity-fill at compute/push granularity. Evidence the latter already exists: + the matmul vcix lowering pushes `zero_vector` for the K-tail (`i >= K`), and + `get_mask` masks vector-reduction tails. So no *buffer* padding is required for this. +3. **DMA boundary**: a full-tile transfer would run past the real tensor -> **clamp the + transfer extent** (Q1=(c)). No buffer growth. + +loop-padding bundles #2 (value-fill) + #3 (buffer grow) into one post-codegen step to +avoid per-tile tail logic. **None of that is fundamental** -- (c) replaces it with tail +handling (extent clamp + mask/push-zero). So: "handle the tail well and you don't need +separate padding" is correct, for everything loop-padding does. + +### The one genuinely-fundamental padding: lane / sublane granularity + +TPU pads to **(8, 128)**. These are two *different* kinds of padding: +- **128** = the MXU systolic dimension. Full tiles are a *throughput* choice, and the + contraction tail is identity-fillable (see #2). Tail-handleable; not fundamental as + buffer padding. +- **8** = the VMEM **sublane** -- memory is physically tiled in (8, 128) banks. This is a + **physical layout granularity**: a dimension laid across lanes/sublanes must be a whole + number of lane-tiles. You cannot have a "partial lane" in a lane-banked memory, and + **masking does not help** -- the issue is the data's physical placement, not which + compute lanes are active. + +PyTorchSim mirrors this: the SRAM scratchpad is lane-banked (`vlane_idx * vu_sram_byte`; +see the MVIN layout). `pad_vlane_tile` / `roundup_vectorlane` pad the vlane-split tile dim +to a multiple of the lane count. **This padding is layout-fundamental and is RETAINED in +(c)** -- but it is *internal SRAM* padding (the spad is over-allocated per lane), cheap, +and never touches DRAM. + +Conclusion: padding splits into +- (a) **layout-fundamental** lane/sublane padding (the TPU "8") -> KEEP, internal SRAM, + cheap; and +- (b) **compute/bound** padding (reduction identity + DMA bounds; the TPU "128" + DRAM + buffer grow) -> NOT fundamental, replaced by tail handling. + +`test-loop-padding` is entirely in category (b) -> eliminable. So TPU pads to 128/8 not +because tail handling is *impossible*, but because (i) for 128 it is a dense-throughput +design choice (TPU avoids runtime masking), and (ii) for 8 it is a real lane-banked +*memory layout* constraint -- which we already satisfy with cheap internal SRAM padding, +independent of loop-padding. + +## Proposed design (sketch -- to refine together) + +At the scheduling / codegen-prep layer, when a dim's extent is not a multiple of its +chosen tile: + +1. Compute `padded_extent = roundup(extent, tile)` at tile-selection time (the layer + already knows the tile). +2. Emit the loop nest with **padded bounds** and size the spad/DRAM tile descriptors + to `padded_extent` -- i.e., produce what loop-padding produces, but at emit time, + so the addressing maps / `dram_stride` are correct by construction. +3. Reuse `get_mask` for boundary correctness (no real data movement / compute on the + padding region). +4. `test-loop-padding` then has nothing to do -> delete it; drop `-test-loop-padding` + from `extension_codecache`; `mlir-opt` is gone. + +## Design decisions + +**Q1 (crux): DRAM buffer resize vs the real tensor. RESOLVED -> (c) hybrid.** +loop-padding grows the DRAM function-arg `memref`; the real `npu` tensor is +logical-sized. Decision: **pad the SRAM (spad) tile fully** (cheap, internal -- the +spad is already over-allocated per-lane), **keep DRAM logical**, and **clamp the +boundary tile's DMA** to the real tail so no OOB DRAM access happens. No device +buffer growth, no func-signature/stride rewrite -> genuinely scheduling-level. +(Rejected: (a) over-allocating the device DRAM buffer -- leaks into PyTorchSimDevice +allocation and needs a logical/padded shape bridge; (b) was the same as (c) minus the +explicit SRAM-full-padding framing.) + +Consequence: the loop still iterates `ceil(E/T)` tiles (padded trip count) and the +spad tile is full `T`, but for the **last tile** the DMA moves only `E - (ceil(E/T)-1)*T` +real rows; the spad tail rows are garbage, kept inert in compute by `get_mask`. + +### PRECONDITION: index/data-dependent tail handling + +(c) means the **last (tail) tile must behave differently** from the rest: move only +`E - (ceil(E/T)-1)*T` real rows from logical DRAM, leave the spad tail inert. That is +an **index-dependent operation** -- behavior varies with the loop induction variable. + +- **Compute side already supports this.** `get_mask` builds a per-iteration predicate + `step_vec < (upper_bound - compute_idx)` (depends on the loop index), masking the + tail lanes. Index-dependent masking already exists for vector compute. +- **DMA side does NOT.** The DMA transfer length today = the spad tile shape, a + **compile-time constant**. loop-padding exists precisely to avoid a variable-length + DMA: it grows DRAM so even the boundary reads a full `T`. Remove loop-padding and the + boundary DMA must transfer a **variable (index-dependent) length** -- a capability the + customized `memref.dma_start` / Spike MVIN does not have today. + +**So the gating precondition is: the DMA must support an index-dependent transfer +extent** (move `min(T, E - i*T)` rows for tile `i`). Establishing that is the real +foundation of this work; without it, scheduling-level (c) padding cannot be expressed. + +### Q1a: how to satisfy the precondition + + - **Variable-extent DMA (data-dependent).** Extend the customized `memref.dma_start` + + lowering + Spike MVIN to accept a runtime transfer length, and emit + `affine.min(T, E - i*T)` per tile. General (scales to multi-dim, any extent), + uniform loop body. Cost: a real hardware-model + descriptor extension. This is the + capability the precondition names. + - **Static tail-peel.** `floor(E/T)` full-tile iterations + a separate compile-time + partial DMA. No variable-length DMA needed, but **combinatorial in the number of + unaligned dims** (2^k corner DMAs) -- the same blow-up that made the decompose + unroll-peel a dead end (#258). Does not scale; rejected as the general path. + Leaning: the variable-extent DMA is the principled answer -- it is the missing + capability, and it generalizes. + +### Finding: variable extent is a codegen change, not a Spike change + +The transfer dim sizes are packed into the CONFIG instruction's **rs1 register** +(`lower_dma_to_gemmini.py:144`): today `cfg_rs1 = i64_const((shape4[0]&0xFFFF)<<48 | +... )` -- a compile-time constant. But `asm(CONFIG, cfg_rs1, cfg_rs2)` passes rs1 as a +**register operand**, and Spike reads the dim sizes from it into `P.VU.dma_dim_size`. +So the hardware model already takes the extent at runtime; today's codegen merely feeds +a constant. + +Implication: the variable-extent precondition is satisfiable **at the codegen / +lower_dma_to_gemmini layer, with no Spike change** -- emit `cfg_rs1` as a *computed* +value (pack a runtime dim size `min(T, E - i*T)` into the 16-bit field via arith) +instead of `i64_const`. The boundary tile then moves only the real tail; the spad stays +fully padded; DRAM stays logical. This is exactly (c). + +Evidence (strong): the CONFIG instruction is `CUSTOM_1` funct7=0 with rs1/rs2 as +**register operands** (`asm(func7, rs1, rs2)` -> `.insn r CUSTOM_1, 0x3, func7, x0, +$0, $1`); the MVIN reads the dims from `P.VU.dma_dim_size`, which is runtime VU state +set by that CONFIG insn. A config insn reading rs1 register bits into dma_dim_size is +the only sensible implementation -- so runtime-variable extent is already supported by +the model; today's codegen just feeds a constant rs1. (One file unread: the exact +torchsim config insn in riscv-isa-sim -- verify it stores all four 16-bit dim fields +from rs1.) + +Remaining work to thread it: (1) carry a dynamic per-axis transfer extent on the +customized `memref.dma_start` (a length operand, like the existing dynamic indices); +(2) `lower_dma_to_gemmini` builds `cfg_rs1` from those (constant when static, arith when +dynamic); (3) the scheduling layer computes `affine.min(T, E - i*T)` for unaligned dims +and passes it. + +**Q2: where is the transfer shape threaded? RESOLVED -> pass shape as an operand, +computed separately.** Compute the real (boundary-clamped) per-axis transfer extent in +a separate step (the scheduling layer, via `affine.min(T, E - i*T)` / arith) and pass +it to the customized `memref.dma_start` as an explicit **shape operand** (alongside the +existing dynamic index operands). `lower_dma_to_gemmini` then packs `cfg_rs1` from that +operand (constant-folds when static). Keeps the extent an explicit, separately-computed +value rather than implicit in the descriptor's static memref shape. + +**Q3: timing-mode skip-copy equivalent.** Padding iterations must cost cycles but move +no real data: loop = padded, DMA size = real-tail (the Q2 shape operand). Confirm +`get_mask` covers the compute side and the shape operand covers the DMA side. + +**Q4: conv2d -- NOT a special case under (c).** loop-padding has a conv-specific branch +(it walks the nested `affine.apply` over k_h/k_w/o_h/o_w and rewrites the maps + +`dram_stride`). That complexity exists because loop-padding **grows the DRAM buffer and +rewrites addressing** -- and conv's input address is a nested affine +(`input_row = o_h*stride + k_h - pad`), so growing the buffer forces rewriting those +nested maps. + +Under (c) we do neither: we keep DRAM logical and only **clamp the boundary tile's DMA +extent**, leaving the address computation untouched. A DMA is a rectangular DRAM<->SRAM +block transfer (base address + per-axis extents); "don't read past the DRAM end" is a +per-axis `min(tile, remaining)` regardless of gemm vs conv. The conv composite indexing +affects *where* the block starts (address) and *how* compute consumes it (handled by +`get_mask`), not *how many* elements move (the per-axis extent clamp). So conv reduces to +the same flat per-axis clamp as gemm; loop-padding's conv branch has no counterpart here. +(Residual check, not a design fork: conv tiles overlap (halo from stride/kernel) so input +tiles are not a clean DRAM partition -- confirm the per-axis "remaining" is still just +`E - base_on_that_axis`, which it is, since each tile's address is independent.) + +**Q5: incremental path.** Start by deleting loop-padding for kernels where no dim +needs padding (it is already a no-op there) and confirm zero diff; then handle the +real padding cases, possibly behind a flag, validating each. + +## Validation + +e2e suite (gemm/bmm/conv2d with deliberately non-multiple extents, + models), +`allclose` through Gem5+Spike+TOGSim. Where feasible, structurally compare the emitted +MLIR against the current `-test-loop-padding` output on the same kernels. + +## Relation to prior work + +Same move as floor/mod -> axis-split + graph-copy: handle the misalignment upstream so +the MLIR layer sees only the clean (aligned) case and the pass disappears. This is the +padding instance of that principle (Plan A in `dma-transfer-lowering.md`). diff --git a/docs/tpu_layout_padding_report.md b/docs/tpu_layout_padding_report.md new file mode 100644 index 00000000..68cbacbb --- /dev/null +++ b/docs/tpu_layout_padding_report.md @@ -0,0 +1,182 @@ +# TPU Layout Assignment & Padding 메커니즘 조사 보고서 + +> **목적**: PyTorchSim에서 TPU 워크로드의 메모리 footprint / compute utilization을 정확히 모델링하기 위해, XLA/Mosaic 컴파일 파이프라인에서 (1) 어느 축이 lane/sublane으로 선택되는지, (2) 패딩이 언제·어떻게 일어나는지, (3) 그 패딩이 물리적으로 물질화되는지 마스킹으로 처리되는지를 정리한 핸드오프 문서. +> +> **수신자**: PyTorchSim 모델링/구현 담당 agent +> **작성 기준일**: 2026-06-18 +> **신뢰도 표기**: [확정]=공개 문서/소스로 검증됨, [추론]=문서 기반 합리적 추론, [미확인]=공개 자료로 닿지 못함 + +--- + +## 0. 한 줄 요약 + +TPU에서 **lane(128) 축 선택과 패딩은 XLA의 layout assignment pass에서 동시에 결정**되며, 패딩은 두 층위로 나뉜다: **(A) 8×128 레이아웃 정렬 패딩은 주소 정렬상 강제 물질화**(실제 텐서가 HBM에서 커짐), **(B) 그보다 큰 연산 블록 크기의 경계(tail)는 masking/peeling 등으로 처리**(대체로 비물질화). 모델링 시 (A)는 footprint+traffic, (B)는 compute utilization만 반영해야 한다. + +--- + +## 1. 컴파일 파이프라인 순서 [확정] + +``` +프론트엔드(JAX/PyTorch) + → JAXPR/FX + → HLO (DotGeneral 등, layout 미확정) + → [layout assignment] ← lane/sublane 축 + 패딩 결정 + → [fusion] ← op들을 커널로 묶음 + → LLO (TPU-specific IR) + → VLIW bundles +``` + +- 출처: JAX→VLIW 컴파일러 추적 (patricktoulme.substack.com), OpenXLA 공식 문서. +- 핵심 함의: **layout 결정이 fusion보다 먼저**다. 따라서 "어느 축이 lane인가 / 얼마나 패딩되는가"는 fusion을 몰라도 결정 가능하지만, **실제 pad/relayout op의 삽입(물질화)은 fusion이 보이는 단계(LLO)에서** 해야 minimal하게 된다. +- PyTorchSim 관점: 이 분리를 그대로 따를 것. FX/상위 단계에서는 layout **결정**(메타데이터)만, 실제 padding/relayout **물질화**는 하위 단계에서. + +--- + +## 2. lane/sublane 축 선택 메커니즘 [확정] + +### 2.1 결정 시점과 표현 +layout assignment pass에서 HLO 텐서에 `{minor_to_major : T(tile)}` 어노테이션이 부착되는 순간 확정. + +관측된 예시 (layout assignment 전후): +``` +// Before +%dot.10 = f32[16,64]{1,0} dot(...) +%reduce.16 = f32[16]{0} reduce(...) +// After +%dot.10 = f32[16,64]{1,0:T(8,128)} dot(...) +%reduce.16 = f32[16]{0:T(128)} reduce(...) +``` +- `{1,0}` = minor_to_major 순서 (row-major: 마지막 차원이 메모리 연속). +- `:T(8,128)` = TPU 타일링, VPU의 **8 sublane × 128 lane**에 대응. + +### 2.2 규칙 +- **minor_to_major 리스트의 첫 원소(가장 minor한 차원) = lane(128) 방향**, 그 다음 = sublane(8) 방향. +- 타일링은 **항상 most-minor 두 축에만** 적용. 나머지 major 차원은 타일링 없이 그대로 (rank 무관). +- 누가 결정하나: **matmul/dot이 anchor로 layout 강제** → elementwise는 통과 → reduce는 cross-lane 비용 때문에 lane에서 빠지려는 압력 → graph 위로 전파(propagation). + +### 2.3 검증 방법 (PyTorchSim ground truth) +```bash +XLA_FLAGS="--xla_dump_to=/path --xla_dump_hlo_as_text=true" python model.py +``` +덤프된 HLO의 `:T(...)` 어노테이션으로 실제 lane 축/패딩을 추측 아닌 컴파일러 출력으로 확인 가능. +(주의: `--xla_enable_hlo_passes_only=layout-assignment` 단독은 후속 buffer assignment에서 에러 가능 → 전체 덤프 권장. 출처: openxla/xla issue #12850) + +--- + +## 3. 타일 크기 규칙 [확정] + +| 조건 | 타일 | 비고 | +|---|---|---| +| f32, 일반 | `T(8,128)` | 32-bit 8×128 벡터 레지스터에 대응 | +| bf16 | `T(8,128)(2,1)` | 2단계 타일링 = BF16 packing. 짝/홀수 행 16-bit 둘을 묶어 32-bit 하나로 | +| 2nd-minor 차원 = 1 or 2 | `T(2,128)` | "Compact 2nd Minor Layout" — 메모리 절약 | +| 2nd-minor 차원 = 3 or 4 | `T(4,128)` | 동일 목적 | + +- **중요 정정**: sublane 패딩이 항상 8은 아니다. 작은 2nd-minor 차원이면 2 또는 4로 줄어든다. + - 함의: **LLM 디코딩 token=1의 sublane 패딩은 8배가 아니라 2배** (`T(2,128)`). +- bf16 packing 이유: TPU는 32-bit 네이티브. most-minor보다 2nd-minor 가로지르는 데이터 이동이 효율적이라 같은 column에서 16-bit 둘을 모음. +- 출처: OpenXLA tiled_layout 문서, gdymind 블로그. + +### 3.1 MXU 크기 (세대별, 연산 단위 — 레이아웃 타일과 별개) [확정] +- v6e, TPU7x(Ironwood): **256×256** +- v6e 이전: **128×128** +- peak FLOPs 위해선 matmul 차원이 해당 세대 MXU 크기보다 커야 함. +- ⚠️ 이건 *연산* 단위. *메모리 레이아웃* 타일은 세대 무관 8×128 유지. 둘을 섞지 말 것. +- 출처: Google Cloud TPU performance guide. + +--- + +## 4. 패딩 처리: 두 층위 [핵심 — 확정] + +### 4.1 (A) 8×128 레이아웃 정렬 패딩 = 강제 물질화 +- **실제 텐서가 HBM에서 패딩된 크기로 저장됨.** 회피 불가. +- 이유: 타일이 HBM에 연속으로 깔리려면 각 타일이 꽉 찬 8×128이어야 함 → 주소 정렬(address alignment) 필수. +- MLIR 코드생성 일반론에서 이는 `nofold` 패딩에 해당: "address alignment가 필수인 경우 강제로 패딩". value padding이 불필요해 보여도 정렬 때문에 fold되지 않고 물질화됨. +- Google 공식 확인: 128×8 청크를 못 채우면 XLA가 텐서를 패딩하고, 이는 "on-chip 메모리 저장량을 늘리고 OOM 유발 가능" = 물리적 공간 점유. +- 패딩량 = `⌈d/tile⌉ × tile − d` (lane/sublane 각각). +- 출처: OpenXLA, Google Cloud performance guide, MLIR codegen 논문(arxiv 2202.03293). + +### 4.2 (B) 연산 블록 크기(>8×128) 경계 = tail 처리 +컴파일러가 비용 보고 세 전략 중 선택 (MLIR codegen 논문 §3.2): + +1. **Loop peeling / versioning** (비물질화): main loop는 정적 상수 부분, 경계는 cleanup loop(타일=1로 축소). 텐서 안 키움. +2. **실제 패딩** (물질화): 동적 타일을 정적 크기로 패딩, 값은 소비 연산의 **neutral(항등원)** — matmul이면 0. `tensor.pad` op으로 물질화, 크기 = 정적 타일 − 동적 타일. 추가 복사 비용 발생. +3. **명시적 masking** (비물질화): MXU가 0 포함 전체 계산 후 출력에서 마스킹. "MXU엔 '하지 마라' 신호가 없어 패딩 0을 실데이터와 곱하고 출력에서 마스킹 — 정확하나 느림(버려지는 work에 MXU 비용 지불)". + +- 기본은 비물질화(masking/peeling)가 흔함. pad 물질화는 정렬 필수 케이스. +- 출처: MLIR codegen 논문, Pallas matmul 튜토리얼(neuropurrfectai). + +### 4.3 두 층위의 관계 [추론 — 합리적] +- 연산 블록 tail이 8×128 정렬과도 안 맞으면: 이미 물질화된 레이아웃 패딩(A)을 그대로 읽고, 블록 크기와 정렬 사이의 간격(B의 순수분)만 masking/peeling. +- 즉 "레이아웃 패딩은 물리적으로 이미 존재, 그걸 넘어서는 블록 경계분만 tail 전략"으로 두 층이 포개짐. + +--- + +## 5. TPU 특유 제약: tiled 축 경계가 비싼 이유 [확정] + +- Ragged Paged Attention 논문: 논리/물리 레이아웃 불일치 + narrow dtype packing 때문에 **tiled 차원(특히 lane)에서 임의 메모리 슬라이스가 근본적으로 어렵다** (VREG blending 없이 ragged 입력을 메모리에 직접 쓰는 경우 특히). +- 실전 해법: **ragged/동적 차원을 non-tiled(major) 축에 배치**, packing을 2nd-minor에 삽입해 XLA가 최소 타일 `T(packing,128)`을 쓰도록 강제 → 임의 동적 슬라이스 가능. +- Pallas `pl.BoundedSlice` / `pl.ds`로 동적 크기 청크 처리 가능 (패딩 대신 정확 크기). +- 함의: 컴파일러/커널이 작은·동적 축을 lane에서 빼려 애쓰는 이유가 정량적으로 설명됨. +- 출처: Ragged Paged Attention (arxiv 2604.15464), JAX Pallas pipelining 문서. + +--- + +## 6. LLM 디코딩 특이사항 [확정 + 추론] + +- 디코딩 GEMV(`[1,hidden]×[hidden,out]`)에서 token=1 차원: + - 보통 hidden이 lane(128)에, token=1은 sublane으로 → **sublane 2배 패딩**(`T(2,128)`), lane까지 가지 않음. [추론, §3 규칙 기반] +- **activation 패딩 traffic은 디코딩에서 무시 가능**: activation(`[1,hidden]`, 수 KB)은 weight/KV cache(수십~수백 GB)보다 3~5 자릿수 작음. 8배든 2배든 전체 traffic에서 미미. [확정 — 디코딩 memory-bound 특성] +- 디코딩 traffic 모델은 **weight 전체 재읽기 + KV cache 읽기에 집중**할 것. activation 패딩 오차의 영향은 작음. +- 패딩의 실질 페널티는 traffic이 아니라 MXU utilization 저하인데, memory-bound regime에선 wall-clock 비결정적. +- → self-spec decoding 연구 동기(한 번 읽은 weight당 토큰 더 뽑기)와 직결. + +--- + +## 7. PyTorchSim 모델링 권고 [실행 항목] + +### 7.1 두 비용 함수를 분리하라 (가장 중요) +- **footprint / HBM traffic 함수**: (A) 8×128 레이아웃 정렬 패딩만 물리 크기로 계산. 예: 길이 100 → 128로 저장·전송. bf16 packing, small-tile(2/4×128) 변형 반영. +- **compute utilization 함수**: (B) 연산 블록 경계 패딩 처리. + - 기본 masking: MXU 사이클 낭비로 utilization ↓, **traffic 중복 계상 금지**. + - pad 물질화 케이스(정렬 필수)만 추가 복사 traffic. + - peeling 케이스는 작은 cleanup 커널로 별도. +- **함정**: 연산 경계 패딩을 traffic으로 또 더하면 대역폭 과대평가. 물리 패딩(A)만 traffic+compute, 연산 경계(B)는 compute만. + +### 7.2 layout 결정 로직 +- minor_to_major 축 선택 + 타일 크기 선택을 §2~§3 규칙으로 모델링. +- matmul anchor → 전파 → reduce는 lane에서 빼기 선호. +- 세대별 MXU 크기(128 vs 256)를 파라미터화. + +### 7.3 비대칭 반영 +- tiled 축(lane) 경계 처리 비용 > non-tiled 축. 이 비대칭을 넣으면 실제 커널이 동적 축을 major로 미는 동작 재현됨(§5). + +### 7.4 검증 +- §2.3의 `XLA_FLAGS` 덤프로 실제 `:T(...)` 어노테이션 떠서 모델 출력과 대조. +- 가능하면 LLO 덤프까지 떠서 경계 타일이 `tensor.pad`(물질화) vs masking/peeling 중 무엇으로 처리되는지 확인. + +--- + +## 8. 미확인 / 후속 조사 필요 [미확인] + +1. **Mosaic의 tail 전략 선택 휴리스틱**: peeling vs pad vs masking을 언제 고르는지의 내부 규칙. Mosaic이 상당 부분 비공개(Google 내부 컴파일러)라 공개 자료로 닿지 못함. → LLO 덤프 실측이 유일한 확실한 길. +2. **연산 블록 경계 0의 정확한 주입 위치**: VREG / VMEM / MXU 입력 중 어디서 0이 주입되고 VMEM 점유에 잡히는지. → Mosaic 소스 또는 LLO 덤프 필요. +3. **layout_assignment.cc의 minor_to_major 선택 휴리스틱 코드 레벨**: matmul이 정확히 어떤 layout을 강제하고 어떻게 전파하는지. OpenXLA 공개 소스(layout_assignment.cc, instruction_fusion.cc)에서 추적 가능 — 아직 코드 레벨로 파지 않음. +4. **DMA의 strided/partial transfer가 레이아웃 패딩을 정확히 어떻게 처리하는지** (세대별): v4부터 512B granularity striding 지원은 확인됨. 패딩 영역 전송 회피 가능 여부의 정밀 동작은 세대·XLA 버전 의존. + +--- + +## 9. 출처 목록 + +- OpenXLA — Tiled layout: https://openxla.org/xla/tiled_layout +- OpenXLA — Shapes and layout: https://openxla.org/xla/shapes +- Google Cloud — TPU performance guide: https://docs.cloud.google.com/tpu/docs/performance-guide +- Google Cloud — Intro to Cloud TPU: https://docs.cloud.google.com/tpu/docs/intro-to-tpu +- Google Cloud — TPU v4: https://docs.cloud.google.com/tpu/docs/v4 +- From JAX to VLIW (컴파일러 추적, layout assignment 전후 HLO): https://patricktoulme.substack.com/p/from-jax-to-vliw-tracing-a-computation +- Pallas matmul 튜토리얼 (MXU 마스킹, tail 패딩): https://neuropurrfectai.substack.com/p/part-2-your-first-pallas-kernel-tiled +- Composable/Modular Code Generation in MLIR (경계 타일 3전략, nofold): https://arxiv.org/pdf/2202.03293 +- Ragged Paged Attention (tiled 축 슬라이스 제약): https://arxiv.org/html/2604.15464 +- JAX — TPU pipelining (동적 슬라이스): https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html +- openxla/xla issue #12850 (pass 덤프 방법): https://github.com/openxla/xla/issues/12850 +- gdymind 블로그 (small-tile, packing 정리): https://gdymind.com/2026/02/26/XLA02-shapes-layout-tiling/ From af79689c50f0d87286bb9e84bc3e77ad70490335 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 18 Jun 2026 14:25:19 +0900 Subject: [PATCH 16/20] [Frontend] Fix floor/mod guard + decompose unit-dim/vlane-axis edge cases Three fixes from the max-effort review of this branch: - get_dma_info: after retiring the floor/mod recompile branches, a residual floor/mod (store-side ModularIndexing, reduction-axis floor/mod, incompatible radix) that axis-split/graph-copy did not linearize was silently bucketed by its base symbol in the dram_stride loop, emitting a wrong DRAM descriptor. Raise NotImplementedError instead of mis-striding silently. No test triggers it (0 floor/mod reach get_dma_info in the suite) -- it is a safety net. - decompose_transfer collapse fast path: keep=[g[-1]] picked the last dim of each reassociation group, which is a unit dim when trailing unit dims attach after the non-unit one (e.g. [..,4,1,1]); strides/subtile were read from the wrong axis. Pick the non-unit dim in each group. - decompose_transfer >4D peel: new_vlane fell back to 0 whenever the vlane split axis was not among the inner 4 dims, conflating peeled-into-the-outer-loop (genuinely unrepresentable -> raise) with a unit lane axis (default 0 is fine). Validated: elementwise, gemm, conv2d, cat, floor/mod suite (incl. pixel_shuffle >4D peel), softmax, layernorm, batchnorm -- all pass, no spurious raise. Co-Authored-By: Claude Opus 4.8 --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 11 +++++++++++ .../mlir/passes/decompose_transfer.py | 15 +++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 6529d8d9..725e0dc6 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1172,6 +1172,17 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe index = index.subs({s: 0 for s in indirect_syms}, simultaneous=True) indirect_dims = [f"{i}" for i in indirect_syms] + # axis-split + graph-copy linearize aligned floor/mod upstream. Anything that + # reaches here still carrying floor/mod (store-side ModularIndexing, + # reduction-axis floor/mod, incompatible-radix views) would be silently + # mis-strided in the dram_stride computation below, so fail loudly instead. + if index.has(FloorDiv) or index.has(ModularIndexing): + raise NotImplementedError( + f"Unlinearized floor/mod in DMA index: {index}. axis-split/graph-copy " + f"did not eliminate it; this view is unsupported " + f"(see docs/axis-split-scheduling.md)." + ) + # Reduction can have two type of tile size if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): local_dims = total_dims # Brodatcast tile shape diff --git a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py index 29841f85..c0e82b66 100644 --- a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py +++ b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py @@ -159,7 +159,9 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_at reassoc = ArrayAttr.get( [ArrayAttr.get([IntegerAttr.get(i64, d) for d in g]) for g in groups]) collapsed_ty = MemRefType.get(target, elem, memory_space=space) - keep = [g[-1] for g in groups] # the non-unit dim in each group + # the non-unit dim in each group (g[-1] is wrong when trailing unit dims + # attach after it, e.g. [..,4,1,1] -> the kept dim must be the extent-4 one). + keep = [next((d for d in g if tile_shape[d] > 1), g[-1]) for g in groups] dr_attr = ArrayAttr.get([IntegerAttr.get(i64, dram_stride[i]) for i in keep]) tl_attr = ArrayAttr.get([IntegerAttr.get(i64, tile_stride[i]) for i in keep]) st_attr = (ArrayAttr.get([IntegerAttr.get(i64, subtile[i]) for i in keep]) @@ -198,7 +200,16 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_at st_attr = (ArrayAttr.get([IntegerAttr.get(i64, subtile[d]) for d in inner]) if subtile is not None else None) # the vlane axis must survive into the inner descriptor (it is the lane dim). - new_vlane = inner.index(vlane_axis) if vlane_axis in inner else 0 + if vlane_axis in inner: + new_vlane = inner.index(vlane_axis) + elif vlane_axis in peeled: + # lane dim peeled into the outer loop nest: it cannot be expressed in the + # <=4D descriptor, and _phys's lane-banking assumes it is a real axis. + raise NotImplementedError( + f"vlane split axis {vlane_axis} peeled into the outer loop nest; " + f">4D DMA peel cannot place the lane dim in the <=4D descriptor") + else: + new_vlane = 0 # vlane axis is a unit dim; lane on the first inner dim # Lane-banked physical stride for split-outer dims (vlane_stride defaults to 1). split_extent = tile_shape[vlane_axis] From a6ef5f8dba772534c994a230071007e469905529 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 18 Jun 2026 14:53:34 +0900 Subject: [PATCH 17/20] [Frontend] graph-copy per-dim ranges + vcix C++-parity guards (review) More fixes from the max-effort review, after verifying each against the C++ reference and reachability: - graph_copy _relayout_args: ranges picked the consumer iteration shape by rank alone (max key=len), so for two equal-rank operands with different per-dim extents the broadcast-from operand's smaller shape could win and the real incompatible-radix conflict on the broadcast-to dim was missed (order-dependent: a commutative reorder flipped correct relayout into a silent miss). Use per-dim max extent over the max-rank operands. - lower_to_vcix _sew/_legalize_vector_type: mirror the C++ legalizeVectorType -- F16/BF16 return sew 0 (transcendentals stay unlowered for -convert-math-to-llvm, as in the validated path) instead of being lowered to VCIX, and add the missing rank != 1 guard. - lower_to_vcix matmul: port the C++ guards as loud failures -- M/N/K must be a multiple of the systolic size when > SS (else the N//SS / K//SS loops drop the tail tile), and A vs B must agree on the K subtile (last-writer-wins would pick one silently). Latent today (heuristic/autotune only emit SS-multiple tiles). - Doc-only: graph-copy is default-on (TORCHSIM_GRAPH_COPY=0 to disable); fixed the two stale 'no-op unless set' comments. Validated: elementwise, gemm, bmm, conv2d, group_conv, cat, floor/mod suite, softmax, layernorm -- all pass. Co-Authored-By: Claude Opus 4.8 --- PyTorchSimFrontend/mlir/graph_copy.py | 13 +++++++++--- PyTorchSimFrontend/mlir/mlir_scheduling.py | 2 +- .../mlir/passes/lower_to_vcix.py | 20 +++++++++++++++++-- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/PyTorchSimFrontend/mlir/graph_copy.py b/PyTorchSimFrontend/mlir/graph_copy.py index 51c2e9b6..c58fab59 100644 --- a/PyTorchSimFrontend/mlir/graph_copy.py +++ b/PyTorchSimFrontend/mlir/graph_copy.py @@ -14,8 +14,8 @@ realize() (not a clone, which Inductor inlines) is what actually forces the buffer boundary; see the PoC notes in docs. -Gated by TORCHSIM_GRAPH_COPY (install() is a no-op otherwise). Behavior-neutral -unless a genuine incompatible-radix conflict is detected. +Default-on; set TORCHSIM_GRAPH_COPY=0 to disable (install() is then a no-op). +Behavior-neutral unless a genuine incompatible-radix conflict is detected. """ import os from torch._inductor import lowering as L @@ -64,7 +64,14 @@ def _relayout_args(args): # multi-var-view input) this is just that operand's shape -- still enough to # detect a multi-var floor and copy_input it (case 7); the 2-operand radix # conflict (case 5) naturally needs >=2 operands. - ranges = max((t.get_size() for t in tbs), key=len) + # Per-dim max extent over the max-rank operands (order-independent). Picking a + # single operand by rank alone (max key=len) would, for two equal-rank operands + # with different per-dim extents, take a broadcast-from operand's smaller shape + # and then miss the genuine conflict on the broadcast-to dim. + maxrank = max(len(t.get_size()) for t in tbs) + full = [t.get_size() for t in tbs if len(t.get_size()) == maxrank] + ranges = [max((s[d] for s in full), key=lambda v: (axis_split._as_int(v) or -1)) + for d in range(maxrank)] extents = [axis_split._as_int(s) for s in ranges] dbg = os.environ.get("TORCHSIM_GRAPH_COPY_DEBUG") if dbg: diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 48eead47..c78d4c53 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -394,6 +394,6 @@ def get_order(n): # Install the graph-copy (incompatible-radix relayout) lowering hook once at import. -# No-op unless TORCHSIM_GRAPH_COPY is set; see graph_copy.py. +# Default-on; set TORCHSIM_GRAPH_COPY=0 to disable. See graph_copy.py. from . import graph_copy as _graph_copy _graph_copy.install() diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py index 4286ba19..1aa31e96 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py +++ b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py @@ -40,8 +40,9 @@ def _sew(elt_ty): - if ir.F16Type.isinstance(elt_ty) or ir.BF16Type.isinstance(elt_ty): - return 16 + # Mirror C++ legalizeVectorType: only F32/F64/integer/index get a sew. F16/BF16 + # return 0 so transcendental math ops stay unlowered (-convert-math-to-llvm), + # matching the validated path -- do NOT emit VCIX for them here. if ir.F32Type.isinstance(elt_ty): return 32 if ir.F64Type.isinstance(elt_ty): @@ -59,6 +60,8 @@ def _log2(x): def _legalize_vector_type(vt, vlen): """Mirror legalizeVectorType: return (n, legal_vector_type) or (0, None).""" + if len(vt.shape) != 1: # C++ guards getRank() != 1 + return 0, None elt_ty = vt.element_type sew = _sew(elt_ty) if sew == 0: @@ -337,6 +340,12 @@ def _lower_matmul(op, SS, vlen): mtA, mtB = ir.MemRefType(A.type), ir.MemRefType(B.type) elt = mtA.element_type M, K, N = mtA.shape[0], mtA.shape[1], mtB.shape[1] + # Mirror the C++ guard: a dimension > SS must be an exact multiple, else the + # N//SS / K//SS loop trip counts below silently drop the tail tile. + for _dim, _name in ((M, "M"), (N, "N"), (K, "K")): + if _dim > SS and _dim % SS != 0: + raise NotImplementedError( + f"matmul {_name}={_dim} must be a multiple of systolic size {SS} when > {SS}") elen = _elt_bits(elt) nr_element = vlen // elen i64 = ir.IntegerType.get_signless(64) @@ -364,6 +373,7 @@ def a64(v): return ir.IntegerAttr.get(i64, v) AAsync = BAsync = BiasAsync = 0 BiasIdx = None subtileM, subtileN, subtileK = M, N, K + a_subk = b_subk = None for o in _iter_ops(outer[-1].regions[0].blocks[0]): if o.operation.name != "memref.dma_start": continue @@ -382,15 +392,21 @@ def a64(v): return ir.IntegerAttr.get(i64, v) ATag, AAsync = d.tag, d.is_async() if len(sub) >= 2: subtileM, subtileK = sub[-2], sub[-1] + a_subk = sub[-1] elif argn == idxMap[1]: BTag, BAsync = d.tag, d.is_async() if len(sub) >= 2: subtileK, subtileN = sub[-2], sub[-1] + b_subk = sub[-2] elif argn == idxMap[2]: BiasTag, BiasAsync = d.tag, d.is_async() BiasIdx = d.tag_idx if ATag is None or BTag is None: return False + # A and B must agree on the K subtile (last-writer-wins would otherwise pick one silently). + if a_subk is not None and b_subk is not None and a_subk != b_subk: + raise NotImplementedError( + f"Mismatched subtile K between A ({a_subk}) and B ({b_subk}) matmul operands") KStep = subtileK push_length = min(subtileM, SS) From 417e4f20494ee74704bbff1e388dffb557306a83 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 18 Jun 2026 16:10:40 +0900 Subject: [PATCH 18/20] [Frontend] vcix: lower fused matmul whose operand is vector_store'd, fix exp chunk Two fixes to the C++->Python vcix port (lower_to_vcix.py) that SDPA exercises but the gemm/bmm/conv tests do not: - _lower_matmul bailed with 'if ATag is None or BTag is None: return False', gating on an MVIN dma_start tag for both operands. In SDPA's fused scores.V matmul, operand B is the softmax output produced in place by affine.vector_store, not DMAed, so BTag stayed None and the matmul was left un-lowered -> wrong attention output. Mirror the C++ MatmulOpLowering: an operand is initialized by either a dma_start OR a preceding affine.vector_store into its root memref; bail only when an operand is truly uninitialized. BTag/BAsync stay None/0 and are only read under 'if BAsync:', so the B dma_wait is correctly skipped (as in C++). - _make_sf_vc_v_iv n>1 transcendental chunking called vector.ExtractStridedSliceOp(offsets, sizes, strides, vec) -- wrong arg order, missing the result type and vector operand, raising TypeError under these MLIR bindings. Pass (result=legal_ty, vector=vec, offsets, sizes, strides). Only reached by large transcendentals (n>1), e.g. SDPA softmax exp, so CI's small-tile (n==1) tests never hit it. Validated end-to-end (Spike+TOGSim allclose): SDPA 56 cases pass (was crash/wrong); matmul/bmm/conv2d regress clean. Bisected: C++ vcix passes SDPA, Python vcix did not; exp chunking and fine-grained ruled out separately. Co-Authored-By: Claude Opus 4.8 --- .../mlir/passes/lower_to_vcix.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py index 1aa31e96..404a8ab4 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py +++ b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py @@ -116,9 +116,10 @@ def _make_sf_vc_v_iv(vec, op_vt, n, legal_ty, opcode, imm): else: for i in range(total // elt_count): ext = vector.ExtractStridedSliceOp( + legal_ty, vec, ir.ArrayAttr.get([_i64(i * elt_count)]), ir.ArrayAttr.get([_i64(elt_count)]), - ir.ArrayAttr.get([_i64(1)]), vec).result + ir.ArrayAttr.get([_i64(1)])).result v = _viv(ext, legal_ty, opcode, imm, rvl) res = vector.InsertStridedSliceOp( v, res, ir.ArrayAttr.get([_i64(i * elt_count)]), @@ -374,7 +375,28 @@ def a64(v): return ir.IntegerAttr.get(i64, v) BiasIdx = None subtileM, subtileN, subtileK = M, N, K a_subk = b_subk = None + # Mirror the C++ isAInitialized / isBInitialized flags: an operand is + # "initialized" either by an MVIN dma_start (tag found below) or by a + # preceding affine.vector_store into its root memref (the fused case, e.g. + # SDPA scores.V where B is the softmax output produced in-place, not DMAed). + isAInit = isBInit = False + + def _root(v): + owner = v.owner + if not isinstance(owner, ir.Block): + nm = owner.name + if nm in ("memref.reinterpret_cast", "memref.cast"): + return owner.operands[0] + return v + rootA, rootB = _root(A), _root(B) for o in _iter_ops(outer[-1].regions[0].blocks[0]): + if o.operation.name == "affine.vector_store": + dest = _root(o.operation.operands[1]) + if dest == rootA: + isAInit = True + elif dest == rootB: + isBInit = True + continue if o.operation.name != "memref.dma_start": continue d = _DmaView(o.operation) @@ -390,18 +412,20 @@ def a64(v): return ir.IntegerAttr.get(i64, v) sub = d.subtile_size() if argn == idxMap[0]: ATag, AAsync = d.tag, d.is_async() + isAInit = True if len(sub) >= 2: subtileM, subtileK = sub[-2], sub[-1] a_subk = sub[-1] elif argn == idxMap[1]: BTag, BAsync = d.tag, d.is_async() + isBInit = True if len(sub) >= 2: subtileK, subtileN = sub[-2], sub[-1] b_subk = sub[-2] elif argn == idxMap[2]: BiasTag, BiasAsync = d.tag, d.is_async() BiasIdx = d.tag_idx - if ATag is None or BTag is None: + if not isAInit or not isBInit: return False # A and B must agree on the K subtile (last-writer-wins would otherwise pick one silently). if a_subk is not None and b_subk is not None and a_subk != b_subk: From 716654b0f086b2c6904c227a9a111344e4084e0a Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 18 Jun 2026 19:25:54 +0900 Subject: [PATCH 19/20] [Frontend] floor/mod: make axis-split + graph-copy always-on, drop debug env vars axis-split and graph-copy are the floor/mod handling path and were default-on but still gated behind TORCHSIM_AXIS_SPLIT / TORCHSIM_GRAPH_COPY. Remove the gates so they run unconditionally, and delete the env vars that were only introduced for validation/debug during development: TORCHSIM_AXIS_SPLIT, TORCHSIM_GRAPH_COPY - default-on toggles TORCHSIM_AXIS_SPLIT_FORCE - force-split validation aid TORCHSIM_AXIS_LEDGER + axis_split.ledger() - coverage measurement TORCHSIM_DEBUG_AXIS_SPLIT + _dump_axis() - debug dump TORCHSIM_GRAPH_COPY_DEBUG - graph-copy debug prints TORCHSIM_RECOMPILE_LOG - vestigial recompile log Also drop the now-dead ledger() function, the _dump_axis() helper, and the unused os import in graph_copy.py. The floor/mod regression test no longer sets the removed env vars. Behavior is unchanged (the toggles were already on). Co-Authored-By: Claude Opus 4.8 --- PyTorchSimFrontend/mlir/axis_split.py | 63 ---------------------- PyTorchSimFrontend/mlir/graph_copy.py | 23 ++------ PyTorchSimFrontend/mlir/mlir_common.py | 7 +-- PyTorchSimFrontend/mlir/mlir_scheduling.py | 47 ++++------------ tests/ops/view/test_floormod_axis_split.py | 14 ++--- 5 files changed, 21 insertions(+), 133 deletions(-) diff --git a/PyTorchSimFrontend/mlir/axis_split.py b/PyTorchSimFrontend/mlir/axis_split.py index a8253e02..71ec4809 100644 --- a/PyTorchSimFrontend/mlir/axis_split.py +++ b/PyTorchSimFrontend/mlir/axis_split.py @@ -66,53 +66,6 @@ def _is_chain(boundaries, E): return all(chain[i + 1] % chain[i] == 0 for i in range(len(chain) - 1)) -def ledger(nodes, plan): - """Classify every FloorDiv/ModularIndexing in the kernel against `plan`. - - Returns a list of (op_name, reason, term_str) for the terms NOT covered by - axis-split, so we can measure how often the graph-copy cases (incompatible - radix / non-dividing / multi-axis / dynamic) actually reach codegen. Read-only. - Reasons: covered terms are omitted; uncovered ones are - multi_axis_arg - floor/mod argument is not a single iter var (case 7) - non_dividing - divisor (or k*m) does not divide the extent (case 6) - incompatible_radix - single var, divides, but boundaries did not form a - divisibility chain so the axis was left unsplit (case 5) - dynamic - symbolic divisor/extent - """ - rows = [] - - def classify(base, k, m, var_to_axis, var_ranges): - if not (isinstance(base, sympy.Symbol) and base in var_to_axis): - return None if False else "multi_axis_arg" - ax = var_to_axis[base] - E = _as_int(var_ranges.get(base)) - if k is None or E is None or (m is not None and _as_int(m) is None): - return "dynamic" - if ax in plan: - return "covered" - period = k if m is None else k * _as_int(m) - if period and E % period != 0: - return "non_dividing" - return "incompatible_radix" - - for n in nodes: - body = getattr(n, "_body", None) - if body is None: - continue - op = n.get_name() if hasattr(n, "get_name") else "?" - var_to_axis = {v: i for i, v in enumerate(body.iter_vars)} - for expr in body.indexing_exprs.values(): - for fd in expr.atoms(FloorDiv): - r = classify(fd.args[0], _as_int(fd.args[1]), None, var_to_axis, body.var_ranges) - if r and r != "covered": - rows.append((op, r, str(fd))) - for mi in expr.atoms(ModularIndexing): - r = classify(mi.args[0], _as_int(mi.args[1]), mi.args[2], var_to_axis, body.var_ranges) - if r and r != "covered": - rows.append((op, r, str(mi))) - return rows - - def find_split_plan(nodes): """Inspect a group of scheduler nodes and return {axis_index: boundaries}. @@ -151,22 +104,6 @@ def find_split_plan(nodes): if E and any(1 < b < E for b in bs) and _is_chain(bs, E): plan[ax] = [1] + sorted(b for b in bs if 1 < b < E) + [E] - # Validation aid: force-split the first even index axis even without floor/mod. - # A floor-free index split is an identity transformation, so allclose must hold; - # used to exercise the reduction pass-through path (no natural op produces a - # floor on a reduction kernel's index axis). Off unless TORCHSIM_AXIS_SPLIT_FORCE. - import os as _os - if _os.environ.get("TORCHSIM_AXIS_SPLIT_FORCE"): - for n in nodes: - body = getattr(n, "_body", None) - if body is None or not body.reduce_vars: - continue - for ax, v in enumerate(body.iter_vars): - E = _as_int(body.var_ranges.get(v)) - if ax not in plan and E and E % 2 == 0 and E > 2: - plan[ax] = [1, 2, E] - break - # A split may push the per-axis index rank past 4. The resulting >4D logical tile # is peeled into <=4D physical descriptors by the decompose-transfer pass (an # affine.for nest carrying the lane-banked physical SRAM offset), so there is no diff --git a/PyTorchSimFrontend/mlir/graph_copy.py b/PyTorchSimFrontend/mlir/graph_copy.py index c58fab59..0c49b86f 100644 --- a/PyTorchSimFrontend/mlir/graph_copy.py +++ b/PyTorchSimFrontend/mlir/graph_copy.py @@ -14,10 +14,8 @@ realize() (not a clone, which Inductor inlines) is what actually forces the buffer boundary; see the PoC notes in docs. -Default-on; set TORCHSIM_GRAPH_COPY=0 to disable (install() is then a no-op). Behavior-neutral unless a genuine incompatible-radix conflict is detected. """ -import os from torch._inductor import lowering as L from torch._inductor import dependencies from torch._inductor import ir @@ -73,10 +71,6 @@ def _relayout_args(args): ranges = [max((s[d] for s in full), key=lambda v: (axis_split._as_int(v) or -1)) for d in range(maxrank)] extents = [axis_split._as_int(s) for s in ranges] - dbg = os.environ.get("TORCHSIM_GRAPH_COPY_DEBUG") - if dbg: - print(f"[GC] consumer ntbs={len(tbs)} ranges={extents} " - f"sizes={[[axis_split._as_int(s) for s in t.get_size()] for t in tbs]}") if not extents or any(e is None for e in extents): return None # scalar / dynamic -> skip @@ -100,9 +94,7 @@ def _relayout_args(args): for tb in tbs: try: rw = dependencies.extract_read_writes(tb.make_loader(), list(ranges)) - except Exception as e: - if dbg: - print(f"[GC] extract fail {type(e).__name__}: {repr(e)[:60]}") + except Exception: per_bnd.append({}) per_mv.append(False) continue @@ -110,8 +102,6 @@ def _relayout_args(args): exprs = [r.index for r in rw.reads if hasattr(r, "index")] b = axis_split.collect_boundaries(exprs, v2a, rw.var_ranges) mv = _has_multivar_floormod(exprs) - if dbg: - print(f"[GC] operand reads={[str(e) for e in exprs]} boundaries={dict(b)} multivar={mv}") per_bnd.append(b) per_mv.append(mv) @@ -141,18 +131,13 @@ def _relayout_args(args): new = list(args) p = pos[victim] new[p] = ir.ExternKernel.copy_input(args[p]) - if dbg: - print(f"[GC] relayout: copy_input operand #{victim} (arg {p})") return new def install(): - """Wrap registered lowering entries to insert relayout. Idempotent; ON by - default (set TORCHSIM_GRAPH_COPY=0 to disable). Call once at backend import - (after torch._inductor.lowering is populated -- make_pointwise runs at import - to build the entries, so we wrap the entries, not the factory).""" - if os.environ.get("TORCHSIM_GRAPH_COPY", "1") == "0": - return + """Wrap registered lowering entries to insert relayout. Idempotent. Call once + at backend import (after torch._inductor.lowering is populated -- make_pointwise + runs at import to build the entries, so we wrap the entries, not the factory).""" if getattr(L, "_torchsim_relayout_installed", False): return for key, fn in list(L.lowerings.items()): diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 748c389c..a70d1c7d 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -737,13 +737,8 @@ def codegen_nodes(self, nodes, kernel_name): with self as kernel: for node in nodes: node.run(vars, reduction_vars) - except RecompileSignal as e: + except RecompileSignal: recompile_try += 1 - # Measure what still depends on the recompile-dance once axis-split + - # graph-copy are on by default (set TORCHSIM_RECOMPILE_LOG=1). - if os.environ.get("TORCHSIM_RECOMPILE_LOG"): - import sys as _sys - print(f"[RECOMPILE {recompile_try}/{max_retry_compile}] {e}", file=_sys.stderr) if recompile_try > max_retry_compile: raise RuntimeError("Failed to compile kernel after multiple attempts.") # Retry compile nodes diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index c78d4c53..41ec61af 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -249,43 +249,18 @@ def codegen_node(self, _node): nodes, key=lambda x: int(x.is_reduction()) ).group - def _dump_axis(tag): - import sys as _sys - print(f"\n[AXIS_SPLIT:{tag}] group={group} reduction_group={reduction_group}", file=_sys.stderr) + # axis-split: linearize compatible floor/mod radices at the scheduling layer. + from . import axis_split + plan = axis_split.find_split_plan(nodes) + if plan: for _n in nodes: - _body = getattr(_n, "_body", None) - if _body is None: + if getattr(_n, "_body", None) is None: continue - print(f"[AXIS_SPLIT:{tag}] node={_n.get_name()} var_ranges={getattr(_body, 'var_ranges', None)}", file=_sys.stderr) - for _k, _e in getattr(_body, "indexing_exprs", {}).items(): - print(f"[AXIS_SPLIT:{tag}] idx[{_k}] = {_e}", file=_sys.stderr) - - if os.environ.get("TORCHSIM_DEBUG_AXIS_SPLIT"): - _dump_axis("before") - - if os.environ.get("TORCHSIM_AXIS_LEDGER"): - from . import axis_split - import sys as _sys - _plan = axis_split.find_split_plan(nodes) - for _op, _reason, _term in axis_split.ledger(nodes, _plan): - print(f"[AXIS_LEDGER] op={_op} reason={_reason} term={_term}", file=_sys.stderr) - - # axis-split is ON by default; set TORCHSIM_AXIS_SPLIT=0 to disable. - if os.environ.get("TORCHSIM_AXIS_SPLIT", "1") != "0": - from . import axis_split - plan = axis_split.find_split_plan(nodes) - if plan: - for _n in nodes: - if getattr(_n, "_body", None) is None: - continue - _body, _ranges = axis_split.build_split_body(_n, plan) - _n._sizes, _n._body, _n.group = _ranges, _body, (_n.get_device(), self.group_fn(_ranges)) - _, (group, reduction_group) = max( - nodes, key=lambda x: int(x.is_reduction()) - ).group - if os.environ.get("TORCHSIM_DEBUG_AXIS_SPLIT"): - print(f"[AXIS_SPLIT] applied plan={plan}", file=__import__("sys").stderr) - _dump_axis("after") + _body, _ranges = axis_split.build_split_body(_n, plan) + _n._sizes, _n._body, _n.group = _ranges, _body, (_n.get_device(), self.group_fn(_ranges)) + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group # Note: We assume that there is at least one loop in the nodes # But, inductor simplifies the group, there could be no loop @@ -394,6 +369,6 @@ def get_order(n): # Install the graph-copy (incompatible-radix relayout) lowering hook once at import. -# Default-on; set TORCHSIM_GRAPH_COPY=0 to disable. See graph_copy.py. +# See graph_copy.py. from . import graph_copy as _graph_copy _graph_copy.install() diff --git a/tests/ops/view/test_floormod_axis_split.py b/tests/ops/view/test_floormod_axis_split.py index 365ee788..19e32e69 100644 --- a/tests/ops/view/test_floormod_axis_split.py +++ b/tests/ops/view/test_floormod_axis_split.py @@ -4,20 +4,18 @@ how the frontend handles them: - aligned floor/mod (single iter var, divisor divides extent): removed by - axis-split at the scheduling layer (TORCHSIM_AXIS_SPLIT). group_norm, repeat, - repeat_interleave, permute+reshape (mixed-radix). + axis-split at the scheduling layer. group_norm, repeat, repeat_interleave, + permute+reshape (mixed-radix). - incompatible radices on a shared axis (case 5, e.g. a[c//2] + b[c%3]): the - conflicting operand is realized by graph-copy (TORCHSIM_GRAPH_COPY) so the - consumer reads it affine and the remainder is axis-split's. + conflicting operand is realized by graph-copy so the consumer reads it affine + and the remainder is axis-split's. - cross-axis / multi-variable floor/mod argument (case 7, e.g. (3*p0+p1)//4 from a transpose+reshape feeding a broadcast/softmax/layernorm that keeps the dims separate): graph-copy materializes the multi-var operand with copy_input (which forces a copy of a view, unlike realize()); the copy kernel iterates the operand's own shape so its index collapses to single-var for axis-split. -The features are env-gated; this test turns them on for itself. axis-split is read -per kernel from the env; graph-copy installs its lowering hook at import, so we -re-run install() after setting the flag. +Both features are always on; graph-copy installs its lowering hook at import. Not in the CI allowlist (pytorchsim_test.yml) -- local feature/regression test. """ @@ -30,8 +28,6 @@ sys.path.insert(0, os.path.join(os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim"), "tests")) from _pytorchsim_utils import test_result -os.environ.setdefault("TORCHSIM_AXIS_SPLIT", "1") -os.environ.setdefault("TORCHSIM_GRAPH_COPY", "1") from PyTorchSimFrontend.mlir import graph_copy graph_copy.install() From 0161c64dadb21b5bf4d68180f58d85382c78a468 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 18 Jun 2026 20:53:43 +0900 Subject: [PATCH 20/20] [Frontend] Unify in-process MLIR passes under one phase-list driver dma_fine_grained and lower_to_vcix already exposed run(module, **opts) like the registered decompose_transfer/lower_vlane_idx, but were called file-based and directly from extension_codecache, so the .mlir was parsed+printed twice (once per pass) between loop-padding and the standard lowering, and the pipeline was hardcoded+duplicated across the functional and gem5 paths. Give both passes MARKERS and group the four rewrite passes into PRE_OPT_PASSES / POST_OPT_PASSES around the one remaining mlir-opt pass (-test-loop-padding). A single driver run_module_passes(in, out, passes, **opts) parses once, runs each marker-matched pass on the shared Module in order, prints once (copies through when no marker matches). run_python_passes is now PRE_OPT via that driver; the functional/gem5 fine-grained+vcix calls each become one run_module_passes. run_fine_grained / run_to_vcix stay re-exported for standalone/CLI use. Validated (Spike+TOGSim): elementwise, gemm, conv2d, softmax, floor/mod suite, SDPA -- all pass. Co-Authored-By: Claude Opus 4.8 --- PyTorchSimFrontend/extension_codecache.py | 23 ++++----- PyTorchSimFrontend/mlir/passes/__init__.py | 48 +++++++++++-------- .../mlir/passes/dma_fine_grained.py | 2 + .../mlir/passes/lower_to_vcix.py | 2 + 4 files changed, 45 insertions(+), 30 deletions(-) diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index a6a213ce..492133a3 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -132,7 +132,10 @@ def load(cls, source_code, # Run the Python out-of-line MLIR passes (MLIR bindings) on the kernel # .mlir in place, before mlir-opt. Currently lowers torchsim.vlane_idx # (replaces the old C++ -global-idx pass); add more in passes/__init__.py. - from PyTorchSimFrontend.mlir.passes import run_python_passes, run_standard_lowering, run_tog, run_fine_grained, run_to_vcix + from PyTorchSimFrontend.mlir.passes import ( + run_python_passes, run_module_passes, POST_OPT_PASSES, + run_standard_lowering, run_tog, + ) run_python_passes(input_path, vectorlane=vectorlane_size) new_input_path = os.path.splitext(input_path)[0] raw_tog_path = new_input_path + "_tog.py" @@ -160,12 +163,11 @@ def load(cls, source_code, llc_asm_cmd = shlex.split(cmds[3]) with lock: try: - # loop-padding (mlir-opt) -> Python fine-grained -> Python vcix + # loop-padding (mlir-opt) -> Python fine-grained + vcix (one parse/print) subprocess.check_call(opt_pad_cmd) - run_fine_grained(new_input_path + "_padded.mlir", - new_input_path + "_padded.mlir", vectorlane_size) - run_to_vcix(new_input_path + "_padded.mlir", - new_input_path + "_custom.mlir", vectorlane_size, vlen) + run_module_passes(new_input_path + "_padded.mlir", + new_input_path + "_custom.mlir", + POST_OPT_PASSES, vectorlane=vectorlane_size, vlen=vlen) # Standard MLIR -> LLVM-dialect lowering (registered upstream # passes) runs in-process via the bindings PassManager, picking # up after the custom mlir-opt passes (memref-to-gemmini). @@ -210,12 +212,11 @@ def load(cls, source_code, # to Python: run_tog reads that IR, writes the TOG (_tog.py) and the # mutated IR (_custom.mlir: sample-mode step rewrite + compute markers), # replacing the C++ -test-tile-operation-graph pass. - # loop-padding(timing, mlir-opt) -> Python fine-grained -> Python vcix + # loop-padding(timing, mlir-opt) -> Python fine-grained + vcix (one parse/print) subprocess.check_call(gem5_pad_cmd) - run_fine_grained(sample_mlir_path + "_padded.mlir", - sample_mlir_path + "_padded.mlir", vectorlane_size) - run_to_vcix(sample_mlir_path + "_padded.mlir", - sample_mlir_path + "_postvcix.mlir", vectorlane_size, vlen) + run_module_passes(sample_mlir_path + "_padded.mlir", + sample_mlir_path + "_postvcix.mlir", + POST_OPT_PASSES, vectorlane=vectorlane_size, vlen=vlen) run_tog(sample_mlir_path + "_postvcix.mlir", raw_tog_path, sample_mlir_path + "_custom.mlir", sample_mode=extension_config.CONFIG_TLS_MODE, diff --git a/PyTorchSimFrontend/mlir/passes/__init__.py b/PyTorchSimFrontend/mlir/passes/__init__.py index 543d0b40..82cadc2f 100644 --- a/PyTorchSimFrontend/mlir/passes/__init__.py +++ b/PyTorchSimFrontend/mlir/passes/__init__.py @@ -32,34 +32,39 @@ def _ensure_mlir_bindings_on_path(): from . import lower_vlane_idx from . import decompose_transfer +from . import dma_fine_grained +from . import lower_to_vcix from .lower_to_llvm import run_standard_lowering # noqa: F401 (re-exported) from .build_tog import run_tog # noqa: F401 (re-exported; replaces C++ test-tile-operation-graph) -from .dma_fine_grained import run_fine_grained # noqa: F401 (replaces C++ -dma-fine-grained) -from .lower_to_vcix import run_to_vcix # noqa: F401 (replaces C++ -test-pytorchsim-to-vcix) +from .dma_fine_grained import run_fine_grained # noqa: F401 (re-exported; standalone/CLI) +from .lower_to_vcix import run_to_vcix # noqa: F401 (re-exported; standalone/CLI) -# Ordered passes applied to each kernel .mlir before mlir-opt. -# decompose_transfer first: it lowers togsim.transfer -> memref.dma_start, which -# downstream passes (and the gemmini lowering) expect. -PASSES = [ +# Module rewrite passes around the one remaining mlir-opt pass (-test-loop-padding). +# Each exposes MARKERS + run(module, **opts); run_module_passes parses once per phase. +# decompose_transfer first: togsim.transfer -> memref.dma_start (downstream expects it). +PRE_OPT_PASSES = [ decompose_transfer, lower_vlane_idx, ] +# fine-grained first: splits the matmul DMAs that the vcix lowering then reads. +POST_OPT_PASSES = [ + dma_fine_grained, + lower_to_vcix, +] -def run_python_passes(mlir_path, vectorlane=128): - """Apply all registered Python MLIR passes to the .mlir at `mlir_path`, in place. - - `vectorlane` (systolic-array size / number of vector lanes) is forwarded to passes - that need it (e.g. decompose_transfer's lane-banked >4D peel). - - Returns True if the file was modified, False otherwise. - """ - with open(mlir_path) as f: +def run_module_passes(in_path, out_path, passes, **opts): + """Parse `in_path` once, run each marker-matched pass on the shared Module in + order, print once to `out_path` (in place if equal). `opts` forwarded to each + run(module, **opts). Returns True if any pass ran.""" + with open(in_path) as f: text = f.read() - # Fast path: nothing to do if no pass's target op appears in the text. - active = [p for p in PASSES if any(mk in text for mk in p.MARKERS)] + active = [p for p in passes if any(mk in text for mk in p.MARKERS)] if not active: + if out_path != in_path: + import shutil + shutil.copyfile(in_path, out_path) return False from mlir.ir import Context, Module, Location @@ -68,9 +73,14 @@ def run_python_passes(mlir_path, vectorlane=128): with ctx, Location.unknown(): module = Module.parse(text) for p in active: - p.run(module, vectorlane=vectorlane) + p.run(module, **opts) out = str(module) - with open(mlir_path, "w") as f: + with open(out_path, "w") as f: f.write(out) return True + + +def run_python_passes(mlir_path, vectorlane=128): + """Run the pre-mlir-opt Module passes (PRE_OPT_PASSES) on `mlir_path`, in place.""" + return run_module_passes(mlir_path, mlir_path, PRE_OPT_PASSES, vectorlane=vectorlane) diff --git a/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py b/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py index ff49aea8..3f583ef2 100644 --- a/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py +++ b/PyTorchSimFrontend/mlir/passes/dma_fine_grained.py @@ -30,6 +30,8 @@ import mlir.ir as ir # noqa: E402 +MARKERS = ("subtile_size",) # only subtile DMAs are split + MVIN, MVIN2, MVIN3, MVOUT = 2, 1, 14, 3 # Per-rank subtile loop order and the fused-loop layout (mirror the C++ loopGroups). diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py index 404a8ab4..ac93ebc8 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py +++ b/PyTorchSimFrontend/mlir/passes/lower_to_vcix.py @@ -29,6 +29,8 @@ import mlir.ir as ir # noqa: E402 +MARKERS = ("linalg.matmul", "math.exp", "math.erf", "math.tanh", "math.sin", "math.cos") + # math op name -> (opcode, imm) for the vcix.v.iv lowering (mirror Math*ToVCIX). _MATH_VIV = { "math.exp": (0b000011, 0),