diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 8f695395..dfc502b3 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1,7 +1,6 @@ import contextlib import sympy import time -import re import os from functools import reduce from operator import mul @@ -338,34 +337,6 @@ def get_padding_type(self): # return 1 return 0 - def convert_index(self, expr): - if len(expr.free_symbols) != 1: - raise NotImplementedError("Not supporting this view operation...!") - - if expr.is_symbol: - return expr - - expr_str = str(expr) - if isinstance(expr, ModularIndexing): - dim = list(expr.args[0].free_symbols)[0] - replace_str = f"({expr.args[0]} floordiv {expr.args[1]}) mod {expr.args[2]}" - expr_str = re.sub(r"ModularIndexing\([^)]*\)", replace_str, expr_str) - elif "//" in expr_str: - expr_str = expr_str.replace("//", " floordiv ") - else: - raise NotImplementedError("What is this case?") - - first_arg = expr.args[0] - if len(first_arg.free_symbols) != 1: - raise NotImplementedError("What is this case?") - - # Create affine.apply operation - indices = [list(first_arg.free_symbols)[0]] - with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): - map_var = ops.affine_map(indices, expr_str) - index = ops.affine_apply(map_var, indices) - return index - def _convert_sympy_to_mlir_expr(self, expr, sorted_args): """ Convert sympy expression to MLIR affine map expression by replacing index variables. @@ -379,41 +350,24 @@ def _convert_sympy_to_mlir_expr(self, expr, sorted_args): target_arg = arg else: continue - new_arg = sympy.Symbol(str(self.convert_index(target_arg))) - expr = expr.replace(target_arg, new_arg) - indices.append(str(new_arg)) - - # Convert ModularIndexing and FloorDiv to sympy expressions - # ModularIndexing(x, y, z) means (x // y) % z -> Mod(FloorDiv(x, y), z) - # FloorDiv(x, y) means x // y -> will be converted to floordiv in string representation - # Use preorder_traversal to find all instances - replacements = {} - for sub in sympy.preorder_traversal(expr): - if isinstance(sub, ModularIndexing): - # Convert ModularIndexing to Mod(FloorDiv(...), ...) - if sub.args[1] != 1: - floor_div = FloorDiv(sub.args[0], sub.args[1]) - else: - floor_div = sub.args[0] - mod_expr = sympy.Mod(floor_div, sub.args[2]) - replacements[sub] = mod_expr - elif isinstance(sub, FloorDiv): - # Keep FloorDiv as is, will be handled in custom string conversion - # We need to mark it for special handling - pass - - # Apply replacements - for old_expr, new_expr in replacements.items(): - expr = expr.subs(old_expr, new_expr) + indices.append(str(target_arg)) + + # axis-split + graph-copy linearize aligned floor/mod upstream (the + # "affine-only contract", see docs/axis-split-scheduling.md), so the + # index reaching codegen must already be pure affine. A residual + # ModularIndexing/FloorDiv means the view was not linearized; fail + # loudly instead of silently mis-lowering it. + if expr.has(ModularIndexing, FloorDiv): + raise NotImplementedError( + f"Unlinearized floor/mod in affine index: {expr}. axis-split/graph-copy " + f"did not eliminate it; this view is unsupported " + f"(see docs/axis-split-scheduling.md)." + ) # Custom string conversion for MLIR affine expressions def mlir_str(expr): """Convert sympy expression to MLIR affine expression string""" - if isinstance(expr, FloorDiv): - return f"({mlir_str(expr.args[0])} floordiv {mlir_str(expr.args[1])})" - elif isinstance(expr, sympy.Mod): - return f"({mlir_str(expr.args[0])} mod {mlir_str(expr.args[1])})" - elif isinstance(expr, sympy.Add): + if isinstance(expr, sympy.Add): terms = [mlir_str(term) for term in expr.args] return " + ".join(terms) elif isinstance(expr, sympy.Mul): @@ -469,20 +423,29 @@ def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSE # Identity case return expr_list[0] + # axis-split + graph-copy linearize aligned floor/mod upstream (the + # "affine-only contract", see docs/axis-split-scheduling.md). A residual + # ModularIndexing/FloorDiv here would be stringified into a bare dim + # symbol below and silently mis-lowered, so fail loudly instead. + if any(a.has(ModularIndexing, FloorDiv) for a in expr_list): + raise NotImplementedError( + f"Unlinearized floor/mod in affine index: {expr_list}. axis-split/graph-copy " + f"did not eliminate it; this view is unsupported " + f"(see docs/axis-split-scheduling.md)." + ) + indices = [] new_expr_list = [0] * len(expr_list) for idx, arg in enumerate(expr_list): if arg.is_Mul and arg.args[0].is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg.args[1]))) + itervar = arg.args[1] + # Round-trip through a plain Symbol to drop sympy assumptions. + new_arg = sympy.Symbol(str(itervar)) new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx]) indices.append(str(new_arg)) elif not arg.is_number: - try: - new_arg = sympy.Symbol(str(self.convert_index(arg))) - #not implemented case - except NotImplementedError: - print(f"Not implemented case: {arg}") - raise NotImplementedError(f"Not implemented case: {arg}") + # Round-trip through a plain Symbol to drop sympy assumptions. + new_arg = sympy.Symbol(str(arg)) new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) indices.append(str(new_arg)) else: