From 3bd23c660c13339e4cb05592c9b6c060ae7d9551 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 23:22:39 +0900 Subject: [PATCH 1/2] [Frontend] Retire dead floor/mod recompile branches in codegen axis-split + graph-copy (on by default) linearize aligned floor/mod at the scheduling layer, so the index reaching get_dma_info is affine and the FloorDiv/ModularIndexing tile-divisibility branches there are never entered (measured: 0 entries across elementwise, gemm, bmm, conv, cat, floor/mod, reduce, attention). Remove those dead branches and their orphans: - the FloorDiv and ModularIndexing tile-forcing + RecompileSignal blocks - the implicit-ModularIndexing index rewrite and implicit_local_dims - the dead ModularIndexing branch in the dram_stride computation - is_modular_indexing, the write-only implicit_dim_size, unused import sys Kept: the non-floor/mod recompile paths (index-divisibility, indirect access, non-power-of-2 vec size), RecompileSignal, and the retry loop. The upstream implicit_dim_ops tile-forcing is left untouched (separate change). Validated end-to-end (Spike + TOGSim): elementwise, gemm, bmm, conv2d, group_conv, pool, cat, floor/mod suite, reduce, softmax, layernorm, batchnorm, gqa -- all pass, 0 recompiles. Co-Authored-By: Claude Opus 4.8 --- .../mlir/mlir_codegen_backend.py | 113 +----------------- PyTorchSimFrontend/mlir/mlir_common.py | 5 - 2 files changed, 1 insertion(+), 117 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index a001c861..6529d8d9 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1,6 +1,5 @@ import contextlib import sympy -import sys import time import re import os @@ -1166,7 +1165,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Note: index could contain symbols that represent dynamic axies # Extract dimension of index(e.g, index0, index1) local_dims = [int(str(i)[5:]) for i in index.free_symbols if "index" in str(i)] - implicit_local_dims = list(index.args) total_dims = [int(str(i)[5:]) for i in self.itervars] local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) local_dims.sort() # Assume that smaller index is placed in the outer loop @@ -1243,14 +1241,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride - if len(implicit_local_dims)!=0 and len(local_dims) != len(implicit_local_dims) and self.is_modular_indexing(index): - for axis_constraints in self.kernel_group.tile_desc.implicit_dim_size.values(): - if len(axis_constraints) <= 1: - continue - sorted_constraints = sorted(axis_constraints, key=lambda c: int(c.args[1])) - for constraint in sorted_constraints[1:]: - index = index.replace(constraint.original_expr, 0) - # Calculate dram stride in local tile-dim order. # This keeps dram/sram stride rank aligned with tile rank. local_dim_to_axis = {dim: axis for axis, dim in enumerate(local_dims)} @@ -1264,19 +1254,12 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe else: dram_dict = defaultdict(list) - implicit_dim_divisors = defaultdict(lambda: sys.maxsize) - # Assume that div will have high priority than mod for arg in index.as_ordered_terms(): coeff, dim = arg.as_coeff_mul() if len(dim) == 0: continue real_dim = list(dim[0].free_symbols)[0] - if dim[0].has(ModularIndexing): - if dim[0].args[1] < implicit_dim_divisors[str(real_dim)]: - implicit_dim_divisors[str(real_dim)] = dim[0].args[1] - dram_dict[str(real_dim)] = [coeff] - else: - dram_dict[str(real_dim)].append(coeff) + dram_dict[str(real_dim)].append(coeff) # Add missing dims if not added max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 @@ -1287,100 +1270,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe sorted_keys = sorted(dram_dict.keys()) dram_stride = sum((dram_dict[key] for key in sorted_keys), []) - # Support floordiv pattern - # FIXME. How to integrate implicit dims and floordiv? - # This was introduced to support GroupNorm - if index.has(FloorDiv) and not index.has(ModularIndexing): - dim_divisor = [1] * len(local_dims) - for sub in sympy.preorder_traversal(index): - if isinstance(sub, FloorDiv): - if not str(sub.args[0]).startswith("index"): - continue - dim_idx = int((str(sub.args[0])[5:])) - if dim_idx not in local_dim_to_axis: - continue - local_dim_idx = local_dim_to_axis[dim_idx] - if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: - # In this case, need to recompile - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - divisor = sub.args[1] * self.kernel_group.tile_desc.vmap.vlane_stride - new_size = ((original_size + divisor - 1) // divisor) * divisor - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - # Can't use dim_idx as vlane_split_axis - if dim_idx == self.kernel_group.tile_desc.vmap.vlane_split_axis: - self.kernel_group.tile_desc.vmap.vlane_split_axis = (dim_idx + 1) % len(original_tile) - - # Send recompile signal - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") - dim_divisor[local_dim_idx] = sub.args[1] - - # Update dram_stride, just insert 0 next to target dim - offset = 0 - for dim_idx, divisor in enumerate(dim_divisor): - if divisor == 1: - continue - dram_stride.insert(dim_idx+offset+1, 0) - local_tile_desc.apply_divisor(dim_idx+offset, divisor, "pad") - local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split") - offset = offset+1 - - # Support ModularIndexing pattern - # This pattern can be used to broadcast ex) torch.cat([a,a]) - # ModularIndexing(x, y, z) means (x // y) % z - # tile_size must be: multiple of y (floorDiv divisor) and divisor of z (modular divisor) - if index.has(ModularIndexing): - for sub in sympy.preorder_traversal(index): - if isinstance(sub, ModularIndexing): - if not str(sub.args[0]).startswith("index"): - continue - dim_idx = int((str(list(sub.args[0].free_symbols)[0])[5:])) - floor_divisor = sub.args[1] # y: floorDiv divisor - mod_divisor = sub.args[2] # z: modular divisor - current_tile_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] - - # Check if tile_size is multiple of floorDiv divisor - if int(current_tile_size % floor_divisor) != 0: - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - divisor = floor_divisor * self.kernel_group.tile_desc.vmap.vlane_stride - new_size = ((original_size + divisor - 1) // divisor) * divisor - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a multiple of floorDiv divisor {floor_divisor} in ModularIndexing") - - # Check if tile_size is a divisor of modular divisor - if int((mod_divisor * floor_divisor) % current_tile_size) != 0: - original_tile = self.kernel_group.tile_desc.get_tile_size() - original_size = original_tile[dim_idx] - # Find the largest divisor of mod_divisor that is <= original_size - # and is a multiple of floor_divisor - new_size = original_size - while new_size > 0: - if mod_divisor % new_size == 0 and new_size % floor_divisor == 0: - break - new_size -= floor_divisor - - if new_size <= 0: - new_size = mod_divisor * floor_divisor - - new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) - new_tile_sizes[dim_idx] = new_size - self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) - self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True - - self.reset("recompile") - raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a divisor of modular divisor {mod_divisor} in ModularIndexing") - # FIXME. It will be nice to modify node instead of this exception handling... if len(self.itervars) == 1 and self.reduction_depth == 0: # In case of reduction loop only case, we will add dummy loop so shift it once diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index a7921463..38a77293 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -520,7 +520,6 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N vlane_stride=vlane_stride ) - self.implicit_dim_size = {} self.nr_rdim = 0 self.offset = sympy.Integer(0) # Dram offset @@ -686,9 +685,6 @@ def call_kernel(self, kernel_name): # generate the code to call this wrapper.generate_kernel_call(kernel_name, call_args, triton=False) - def is_modular_indexing(self, expr): - return "ModularIndexing" in str(expr) - def implicit_dim_ops(self, nodes): target_patterns = (ModularIndexing, FloorDiv, Mod) target_operands = [] @@ -764,7 +760,6 @@ def compute_tile_size(self, nodes, vars, reduction_vars): if implicit_ops: tile_constraints = self.extract_dividers(implicit_ops) self.kernel_group.tile_desc.apply_constraints(tile_constraints, self.ranges) - self.kernel_group.tile_desc.implicit_dim_size = tile_constraints # Check recodegen reason if self.recodegen is not None: From 2287c29496e17bb0f1a93130cd06189bd17a35d1 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 23:40:11 +0900 Subject: [PATCH 2/2] [Frontend] Retire implicit_dim_ops tile-forcing (redundant under axis-split) implicit_dim_ops/extract_dividers/apply_constraints forced the initial tile size to match a view's floor/mod divider, up front in compute_tile_size. axis-split now linearizes those views at the scheduling layer, so the forcing is redundant: disabling it leaves every test allclose-correct and, on the affected kernels, slightly faster (the forced tile was over-constrained -- batchnorm 1189->1114, layernorm 4092->3947 cycles; non-floor/mod kernels unchanged). Remove the machinery and its now-unused imports (ModularIndexing, FloorDiv, Mod, MemoryDep, StarDep, WeakDep). Validated end-to-end (Spike + TOGSim): elementwise, gemm, bmm, conv2d, group_conv, pool, cat, floor/mod suite, reduce, softmax, layernorm, batchnorm, gqa -- all pass, 0 recompiles. Co-Authored-By: Claude Opus 4.8 --- PyTorchSimFrontend/mlir/mlir_common.py | 74 +------------------------- 1 file changed, 1 insertion(+), 73 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 38a77293..748c389c 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -15,9 +15,8 @@ from torch._inductor.codegen import cpp from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout -from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep from torch._inductor.codegen.wrapper import KernelDefinitionLine -from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity +from torch.utils._sympy.functions import Identity import sympy import contextlib @@ -436,20 +435,6 @@ def pad_vlane_tile(self): padded_size = used_vlane * vlane_stride self._tile_size[vlane_split_axis] = math.ceil(self._tile_size[vlane_split_axis] / padded_size) * padded_size - def apply_constraints(self, constraints, ranges): - for idx, (axis_constraints, axis_size) in enumerate(zip(constraints.values(), ranges)): - for const in axis_constraints: - if const.args[1] == 1: - continue - divider = int(const.args[1]) - - if not self.tile_constraint[idx].fixed: - self.tile_constraint[idx].fixed = True - self._tile_size[idx] = divider - elif self.tile_constraint[idx].fixed and self._tile_size[idx] > divider: - self._tile_size[idx] = divider - self.update_tile_stride() - @staticmethod def init_tile_size(ranges, vlane_stride, vector_lane): # Logical tile init for ANY rank. Only the innermost dims carry the @@ -685,56 +670,6 @@ def call_kernel(self, kernel_name): # generate the code to call this wrapper.generate_kernel_call(kernel_name, call_args, triton=False) - def implicit_dim_ops(self, nodes): - target_patterns = (ModularIndexing, FloorDiv, Mod) - target_operands = [] - for target_node in nodes: - for read_operand in target_node.read_writes.reads: - read_operand: MemoryDep - if isinstance(read_operand, StarDep) or isinstance(read_operand, WeakDep): - continue - read_index = read_operand.index - for arg_expr in read_index.args: - if arg_expr.atoms(*target_patterns): - target_operands.append(read_operand) - return target_operands - - def extract_dividers(self, implicit_ops): - # When a specific axis is processed, the key constraint to verify is the divider. - # The tile size must be forced to match the divider size. - dim_dividers = defaultdict(set) - for operand in implicit_ops: - subs_map = { - s: sympy.symbols(s.name.replace("c", "index", 1)) - for s in operand.index.free_symbols - } - rev_subs_map = { - sympy.symbols(s.name.replace("c", "index", 1)) : s - for s in operand.index.free_symbols - } - new_index = operand.index.subs(subs_map) - for arg in new_index.args: - if arg.is_number: - continue - if len(arg.free_symbols) > 1: - raise NotImplementedError("Not supporting this view operation...!") - if arg.is_Mul and arg.args[0].is_number: - arg = arg.args[1] - - if isinstance(arg, ModularIndexing): - modular_expr = ModularIndexing(arg.args[0], arg.args[1], arg.args[2]) - modular_expr.original_expr = arg - elif arg.is_symbol: - modular_expr = ModularIndexing(arg, 1, operand.ranges[rev_subs_map[arg]]) - modular_expr.original_expr = arg - elif "//" in str(arg): - modular_expr = ModularIndexing(arg.args[0], arg.args[1], operand.ranges[rev_subs_map[arg.args[0]]]//arg.args[1]) - modular_expr.original_expr = arg - else: - raise NotImplementedError("What is this case?") - dim_dividers[modular_expr.args[0]].add(modular_expr) - return dim_dividers - def compute_tile_size(self, nodes, vars, reduction_vars): vlane_split_axis = len(vars) - 1 vlane_stride = 2 # Set minimum vlane stride @@ -754,13 +689,6 @@ def compute_tile_size(self, nodes, vars, reduction_vars): self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0] - # Handle implict dims. Input operand could be high dimension tensor. - # Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173 - implicit_ops = self.implicit_dim_ops(nodes) - if implicit_ops: - tile_constraints = self.extract_dividers(implicit_ops) - self.kernel_group.tile_desc.apply_constraints(tile_constraints, self.ranges) - # Check recodegen reason if self.recodegen is not None: if self.recodegen == "spad_overflow":