Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 1 addition & 73 deletions PyTorchSimFrontend/mlir/mlir_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from torch._inductor.codegen import cpp
from torch._inductor.virtualized import V
from torch._inductor.ir import MultiOutputLayout
from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep
from torch._inductor.codegen.wrapper import KernelDefinitionLine
from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity
from torch.utils._sympy.functions import Identity
import sympy
import contextlib

Expand Down Expand Up @@ -436,20 +435,6 @@ def pad_vlane_tile(self):
padded_size = used_vlane * vlane_stride
self._tile_size[vlane_split_axis] = math.ceil(self._tile_size[vlane_split_axis] / padded_size) * padded_size

def apply_constraints(self, constraints, ranges):
for idx, (axis_constraints, axis_size) in enumerate(zip(constraints.values(), ranges)):
for const in axis_constraints:
if const.args[1] == 1:
continue
divider = int(const.args[1])

if not self.tile_constraint[idx].fixed:
self.tile_constraint[idx].fixed = True
self._tile_size[idx] = divider
elif self.tile_constraint[idx].fixed and self._tile_size[idx] > divider:
self._tile_size[idx] = divider
self.update_tile_stride()

@staticmethod
def init_tile_size(ranges, vlane_stride, vector_lane):
# Logical tile init for ANY rank. Only the innermost dims carry the
Expand Down Expand Up @@ -685,56 +670,6 @@ def call_kernel(self, kernel_name):
# generate the code to call this
wrapper.generate_kernel_call(kernel_name, call_args, triton=False)

def implicit_dim_ops(self, nodes):
target_patterns = (ModularIndexing, FloorDiv, Mod)
target_operands = []
for target_node in nodes:
for read_operand in target_node.read_writes.reads:
read_operand: MemoryDep
if isinstance(read_operand, StarDep) or isinstance(read_operand, WeakDep):
continue
read_index = read_operand.index
for arg_expr in read_index.args:
if arg_expr.atoms(*target_patterns):
target_operands.append(read_operand)
return target_operands

def extract_dividers(self, implicit_ops):
# When a specific axis is processed, the key constraint to verify is the divider.
# The tile size must be forced to match the divider size.
dim_dividers = defaultdict(set)
for operand in implicit_ops:
subs_map = {
s: sympy.symbols(s.name.replace("c", "index", 1))
for s in operand.index.free_symbols
}
rev_subs_map = {
sympy.symbols(s.name.replace("c", "index", 1)) : s
for s in operand.index.free_symbols
}
new_index = operand.index.subs(subs_map)
for arg in new_index.args:
if arg.is_number:
continue
if len(arg.free_symbols) > 1:
raise NotImplementedError("Not supporting this view operation...!")
if arg.is_Mul and arg.args[0].is_number:
arg = arg.args[1]

if isinstance(arg, ModularIndexing):
modular_expr = ModularIndexing(arg.args[0], arg.args[1], arg.args[2])
modular_expr.original_expr = arg
elif arg.is_symbol:
modular_expr = ModularIndexing(arg, 1, operand.ranges[rev_subs_map[arg]])
modular_expr.original_expr = arg
elif "//" in str(arg):
modular_expr = ModularIndexing(arg.args[0], arg.args[1], operand.ranges[rev_subs_map[arg.args[0]]]//arg.args[1])
modular_expr.original_expr = arg
else:
raise NotImplementedError("What is this case?")
dim_dividers[modular_expr.args[0]].add(modular_expr)
return dim_dividers

def compute_tile_size(self, nodes, vars, reduction_vars):
vlane_split_axis = len(vars) - 1
vlane_stride = 2 # Set minimum vlane stride
Expand All @@ -754,13 +689,6 @@ def compute_tile_size(self, nodes, vars, reduction_vars):
self.kernel_group.tile_desc.vmap.vlane_split_axis = 0
self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0]

# Handle implict dims. Input operand could be high dimension tensor.
# Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173
implicit_ops = self.implicit_dim_ops(nodes)
if implicit_ops:
tile_constraints = self.extract_dividers(implicit_ops)
self.kernel_group.tile_desc.apply_constraints(tile_constraints, self.ranges)

# Check recodegen reason
if self.recodegen is not None:
if self.recodegen == "spad_overflow":
Expand Down
Loading