From 8239e4f5e778101fd42e2c6d41b3444a9a2b582c Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 17 Jun 2026 23:40:11 +0900 Subject: [PATCH] [Frontend] Retire implicit_dim_ops tile-forcing (redundant under axis-split) implicit_dim_ops/extract_dividers/apply_constraints forced the initial tile size to match a view's floor/mod divider, up front in compute_tile_size. axis-split now linearizes those views at the scheduling layer, so the forcing is redundant: disabling it leaves every test allclose-correct and, on the affected kernels, slightly faster (the forced tile was over-constrained -- batchnorm 1189->1114, layernorm 4092->3947 cycles; non-floor/mod kernels unchanged). Remove the machinery and its now-unused imports (ModularIndexing, FloorDiv, Mod, MemoryDep, StarDep, WeakDep). Validated end-to-end (Spike + TOGSim): elementwise, gemm, bmm, conv2d, group_conv, pool, cat, floor/mod suite, reduce, softmax, layernorm, batchnorm, gqa -- all pass, 0 recompiles. Co-Authored-By: Claude Opus 4.8 --- PyTorchSimFrontend/mlir/mlir_common.py | 74 +------------------------- 1 file changed, 1 insertion(+), 73 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 38a77293..748c389c 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -15,9 +15,8 @@ from torch._inductor.codegen import cpp from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout -from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep from torch._inductor.codegen.wrapper import KernelDefinitionLine -from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity +from torch.utils._sympy.functions import Identity import sympy import contextlib @@ -436,20 +435,6 @@ def pad_vlane_tile(self): padded_size = used_vlane * vlane_stride self._tile_size[vlane_split_axis] = math.ceil(self._tile_size[vlane_split_axis] / padded_size) * padded_size - def apply_constraints(self, constraints, ranges): - for idx, (axis_constraints, axis_size) in enumerate(zip(constraints.values(), ranges)): - for const in axis_constraints: - if const.args[1] == 1: - continue - divider = int(const.args[1]) - - if not self.tile_constraint[idx].fixed: - self.tile_constraint[idx].fixed = True - self._tile_size[idx] = divider - elif self.tile_constraint[idx].fixed and self._tile_size[idx] > divider: - self._tile_size[idx] = divider - self.update_tile_stride() - @staticmethod def init_tile_size(ranges, vlane_stride, vector_lane): # Logical tile init for ANY rank. Only the innermost dims carry the @@ -685,56 +670,6 @@ def call_kernel(self, kernel_name): # generate the code to call this wrapper.generate_kernel_call(kernel_name, call_args, triton=False) - def implicit_dim_ops(self, nodes): - target_patterns = (ModularIndexing, FloorDiv, Mod) - target_operands = [] - for target_node in nodes: - for read_operand in target_node.read_writes.reads: - read_operand: MemoryDep - if isinstance(read_operand, StarDep) or isinstance(read_operand, WeakDep): - continue - read_index = read_operand.index - for arg_expr in read_index.args: - if arg_expr.atoms(*target_patterns): - target_operands.append(read_operand) - return target_operands - - def extract_dividers(self, implicit_ops): - # When a specific axis is processed, the key constraint to verify is the divider. - # The tile size must be forced to match the divider size. - dim_dividers = defaultdict(set) - for operand in implicit_ops: - subs_map = { - s: sympy.symbols(s.name.replace("c", "index", 1)) - for s in operand.index.free_symbols - } - rev_subs_map = { - sympy.symbols(s.name.replace("c", "index", 1)) : s - for s in operand.index.free_symbols - } - new_index = operand.index.subs(subs_map) - for arg in new_index.args: - if arg.is_number: - continue - if len(arg.free_symbols) > 1: - raise NotImplementedError("Not supporting this view operation...!") - if arg.is_Mul and arg.args[0].is_number: - arg = arg.args[1] - - if isinstance(arg, ModularIndexing): - modular_expr = ModularIndexing(arg.args[0], arg.args[1], arg.args[2]) - modular_expr.original_expr = arg - elif arg.is_symbol: - modular_expr = ModularIndexing(arg, 1, operand.ranges[rev_subs_map[arg]]) - modular_expr.original_expr = arg - elif "//" in str(arg): - modular_expr = ModularIndexing(arg.args[0], arg.args[1], operand.ranges[rev_subs_map[arg.args[0]]]//arg.args[1]) - modular_expr.original_expr = arg - else: - raise NotImplementedError("What is this case?") - dim_dividers[modular_expr.args[0]].add(modular_expr) - return dim_dividers - def compute_tile_size(self, nodes, vars, reduction_vars): vlane_split_axis = len(vars) - 1 vlane_stride = 2 # Set minimum vlane stride @@ -754,13 +689,6 @@ def compute_tile_size(self, nodes, vars, reduction_vars): self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0] - # Handle implict dims. Input operand could be high dimension tensor. - # Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173 - implicit_ops = self.implicit_dim_ops(nodes) - if implicit_ops: - tile_constraints = self.extract_dividers(implicit_ops) - self.kernel_group.tile_desc.apply_constraints(tile_constraints, self.ranges) - # Check recodegen reason if self.recodegen is not None: if self.recodegen == "spad_overflow":