Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 30 additions & 67 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
import sympy
import time
import re
import os
from functools import reduce
from operator import mul
Expand Down Expand Up @@ -338,34 +337,6 @@ def get_padding_type(self):
# return 1
return 0

def convert_index(self, expr):
if len(expr.free_symbols) != 1:
raise NotImplementedError("Not supporting this view operation...!")

if expr.is_symbol:
return expr

expr_str = str(expr)
if isinstance(expr, ModularIndexing):
dim = list(expr.args[0].free_symbols)[0]
replace_str = f"({expr.args[0]} floordiv {expr.args[1]}) mod {expr.args[2]}"
expr_str = re.sub(r"ModularIndexing\([^)]*\)", replace_str, expr_str)
elif "//" in expr_str:
expr_str = expr_str.replace("//", " floordiv ")
else:
raise NotImplementedError("What is this case?")

first_arg = expr.args[0]
if len(first_arg.free_symbols) != 1:
raise NotImplementedError("What is this case?")

# Create affine.apply operation
indices = [list(first_arg.free_symbols)[0]]
with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse):
map_var = ops.affine_map(indices, expr_str)
index = ops.affine_apply(map_var, indices)
return index

def _convert_sympy_to_mlir_expr(self, expr, sorted_args):
"""
Convert sympy expression to MLIR affine map expression by replacing index variables.
Expand All @@ -379,41 +350,24 @@ def _convert_sympy_to_mlir_expr(self, expr, sorted_args):
target_arg = arg
else:
continue
new_arg = sympy.Symbol(str(self.convert_index(target_arg)))
expr = expr.replace(target_arg, new_arg)
indices.append(str(new_arg))

# Convert ModularIndexing and FloorDiv to sympy expressions
# ModularIndexing(x, y, z) means (x // y) % z -> Mod(FloorDiv(x, y), z)
# FloorDiv(x, y) means x // y -> will be converted to floordiv in string representation
# Use preorder_traversal to find all instances
replacements = {}
for sub in sympy.preorder_traversal(expr):
if isinstance(sub, ModularIndexing):
# Convert ModularIndexing to Mod(FloorDiv(...), ...)
if sub.args[1] != 1:
floor_div = FloorDiv(sub.args[0], sub.args[1])
else:
floor_div = sub.args[0]
mod_expr = sympy.Mod(floor_div, sub.args[2])
replacements[sub] = mod_expr
elif isinstance(sub, FloorDiv):
# Keep FloorDiv as is, will be handled in custom string conversion
# We need to mark it for special handling
pass

# Apply replacements
for old_expr, new_expr in replacements.items():
expr = expr.subs(old_expr, new_expr)
indices.append(str(target_arg))

# axis-split + graph-copy linearize aligned floor/mod upstream (the
# "affine-only contract", see docs/axis-split-scheduling.md), so the
# index reaching codegen must already be pure affine. A residual
# ModularIndexing/FloorDiv means the view was not linearized; fail
# loudly instead of silently mis-lowering it.
if expr.has(ModularIndexing, FloorDiv):
raise NotImplementedError(
f"Unlinearized floor/mod in affine index: {expr}. axis-split/graph-copy "
f"did not eliminate it; this view is unsupported "
f"(see docs/axis-split-scheduling.md)."
)

# Custom string conversion for MLIR affine expressions
def mlir_str(expr):
"""Convert sympy expression to MLIR affine expression string"""
if isinstance(expr, FloorDiv):
return f"({mlir_str(expr.args[0])} floordiv {mlir_str(expr.args[1])})"
elif isinstance(expr, sympy.Mod):
return f"({mlir_str(expr.args[0])} mod {mlir_str(expr.args[1])})"
elif isinstance(expr, sympy.Add):
if isinstance(expr, sympy.Add):
terms = [mlir_str(term) for term in expr.args]
return " + ".join(terms)
elif isinstance(expr, sympy.Mul):
Expand Down Expand Up @@ -469,20 +423,29 @@ def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSE
# Identity case
return expr_list[0]

# axis-split + graph-copy linearize aligned floor/mod upstream (the
# "affine-only contract", see docs/axis-split-scheduling.md). A residual
# ModularIndexing/FloorDiv here would be stringified into a bare dim
# symbol below and silently mis-lowered, so fail loudly instead.
if any(a.has(ModularIndexing, FloorDiv) for a in expr_list):
raise NotImplementedError(
f"Unlinearized floor/mod in affine index: {expr_list}. axis-split/graph-copy "
f"did not eliminate it; this view is unsupported "
f"(see docs/axis-split-scheduling.md)."
)

indices = []
new_expr_list = [0] * len(expr_list)
for idx, arg in enumerate(expr_list):
if arg.is_Mul and arg.args[0].is_number:
new_arg = sympy.Symbol(str(self.convert_index(arg.args[1])))
itervar = arg.args[1]
# Round-trip through a plain Symbol to drop sympy assumptions.
new_arg = sympy.Symbol(str(itervar))
new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx])
indices.append(str(new_arg))
elif not arg.is_number:
try:
new_arg = sympy.Symbol(str(self.convert_index(arg)))
#not implemented case
except NotImplementedError:
print(f"Not implemented case: {arg}")
raise NotImplementedError(f"Not implemented case: {arg}")
# Round-trip through a plain Symbol to drop sympy assumptions.
new_arg = sympy.Symbol(str(arg))
new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx])
indices.append(str(new_arg))
else:
Expand Down