Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 1 addition & 112 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextlib
import sympy
import sys
import time
import re
import os
Expand Down Expand Up @@ -1166,7 +1165,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
# Note: index could contain symbols that represent dynamic axies
# Extract dimension of index(e.g, index0, index1)
local_dims = [int(str(i)[5:]) for i in index.free_symbols if "index" in str(i)]
implicit_local_dims = list(index.args)
total_dims = [int(str(i)[5:]) for i in self.itervars]
local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane)
local_dims.sort() # Assume that smaller index is placed in the outer loop
Expand Down Expand Up @@ -1243,14 +1241,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
local_tile_desc.vmap.vlane_split_axis = local_vlane_split_axis
local_tile_desc.vmap.vlane_stride = kg_tile_desc.vmap.vlane_stride

if len(implicit_local_dims)!=0 and len(local_dims) != len(implicit_local_dims) and self.is_modular_indexing(index):
for axis_constraints in self.kernel_group.tile_desc.implicit_dim_size.values():
if len(axis_constraints) <= 1:
continue
sorted_constraints = sorted(axis_constraints, key=lambda c: int(c.args[1]))
for constraint in sorted_constraints[1:]:
index = index.replace(constraint.original_expr, 0)

# Calculate dram stride in local tile-dim order.
# This keeps dram/sram stride rank aligned with tile rank.
local_dim_to_axis = {dim: axis for axis, dim in enumerate(local_dims)}
Expand All @@ -1264,19 +1254,12 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
else:

dram_dict = defaultdict(list)
implicit_dim_divisors = defaultdict(lambda: sys.maxsize)
# Assume that div will have high priority than mod
for arg in index.as_ordered_terms():
coeff, dim = arg.as_coeff_mul()
if len(dim) == 0:
continue
real_dim = list(dim[0].free_symbols)[0]
if dim[0].has(ModularIndexing):
if dim[0].args[1] < implicit_dim_divisors[str(real_dim)]:
implicit_dim_divisors[str(real_dim)] = dim[0].args[1]
dram_dict[str(real_dim)] = [coeff]
else:
dram_dict[str(real_dim)].append(coeff)
dram_dict[str(real_dim)].append(coeff)

# Add missing dims if not added
max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1
Expand All @@ -1287,100 +1270,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
sorted_keys = sorted(dram_dict.keys())
dram_stride = sum((dram_dict[key] for key in sorted_keys), [])

# Support floordiv pattern
# FIXME. How to integrate implicit dims and floordiv?
# This was introduced to support GroupNorm
if index.has(FloorDiv) and not index.has(ModularIndexing):
dim_divisor = [1] * len(local_dims)
for sub in sympy.preorder_traversal(index):
if isinstance(sub, FloorDiv):
if not str(sub.args[0]).startswith("index"):
continue
dim_idx = int((str(sub.args[0])[5:]))
if dim_idx not in local_dim_to_axis:
continue
local_dim_idx = local_dim_to_axis[dim_idx]
if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0:
# In this case, need to recompile
original_tile = self.kernel_group.tile_desc.get_tile_size()
original_size = original_tile[dim_idx]
divisor = sub.args[1] * self.kernel_group.tile_desc.vmap.vlane_stride
new_size = ((original_size + divisor - 1) // divisor) * divisor
new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size())
new_tile_sizes[dim_idx] = new_size
self.kernel_group.tile_desc.set_tile_size(new_tile_sizes)
self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True

# Can't use dim_idx as vlane_split_axis
if dim_idx == self.kernel_group.tile_desc.vmap.vlane_split_axis:
self.kernel_group.tile_desc.vmap.vlane_split_axis = (dim_idx + 1) % len(original_tile)

# Send recompile signal
self.reset("recompile")
raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}")
dim_divisor[local_dim_idx] = sub.args[1]

# Update dram_stride, just insert 0 next to target dim
offset = 0
for dim_idx, divisor in enumerate(dim_divisor):
if divisor == 1:
continue
dram_stride.insert(dim_idx+offset+1, 0)
local_tile_desc.apply_divisor(dim_idx+offset, divisor, "pad")
local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split")
offset = offset+1

# Support ModularIndexing pattern
# This pattern can be used to broadcast ex) torch.cat([a,a])
# ModularIndexing(x, y, z) means (x // y) % z
# tile_size must be: multiple of y (floorDiv divisor) and divisor of z (modular divisor)
if index.has(ModularIndexing):
for sub in sympy.preorder_traversal(index):
if isinstance(sub, ModularIndexing):
if not str(sub.args[0]).startswith("index"):
continue
dim_idx = int((str(list(sub.args[0].free_symbols)[0])[5:]))
floor_divisor = sub.args[1] # y: floorDiv divisor
mod_divisor = sub.args[2] # z: modular divisor
current_tile_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx]

# Check if tile_size is multiple of floorDiv divisor
if int(current_tile_size % floor_divisor) != 0:
original_tile = self.kernel_group.tile_desc.get_tile_size()
original_size = original_tile[dim_idx]
divisor = floor_divisor * self.kernel_group.tile_desc.vmap.vlane_stride
new_size = ((original_size + divisor - 1) // divisor) * divisor
new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size())
new_tile_sizes[dim_idx] = new_size
self.kernel_group.tile_desc.set_tile_size(new_tile_sizes)
self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True

self.reset("recompile")
raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a multiple of floorDiv divisor {floor_divisor} in ModularIndexing")

# Check if tile_size is a divisor of modular divisor
if int((mod_divisor * floor_divisor) % current_tile_size) != 0:
original_tile = self.kernel_group.tile_desc.get_tile_size()
original_size = original_tile[dim_idx]
# Find the largest divisor of mod_divisor that is <= original_size
# and is a multiple of floor_divisor
new_size = original_size
while new_size > 0:
if mod_divisor % new_size == 0 and new_size % floor_divisor == 0:
break
new_size -= floor_divisor

if new_size <= 0:
new_size = mod_divisor * floor_divisor

new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size())
new_tile_sizes[dim_idx] = new_size
self.kernel_group.tile_desc.set_tile_size(new_tile_sizes)
self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True

self.reset("recompile")
raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a divisor of modular divisor {mod_divisor} in ModularIndexing")

# FIXME. It will be nice to modify node instead of this exception handling...
if len(self.itervars) == 1 and self.reduction_depth == 0:
# In case of reduction loop only case, we will add dummy loop so shift it once
Expand Down
79 changes: 1 addition & 78 deletions PyTorchSimFrontend/mlir/mlir_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from torch._inductor.codegen import cpp
from torch._inductor.virtualized import V
from torch._inductor.ir import MultiOutputLayout
from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep
from torch._inductor.codegen.wrapper import KernelDefinitionLine
from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity
from torch.utils._sympy.functions import Identity
import sympy
import contextlib

Expand Down Expand Up @@ -436,20 +435,6 @@ def pad_vlane_tile(self):
padded_size = used_vlane * vlane_stride
self._tile_size[vlane_split_axis] = math.ceil(self._tile_size[vlane_split_axis] / padded_size) * padded_size

def apply_constraints(self, constraints, ranges):
for idx, (axis_constraints, axis_size) in enumerate(zip(constraints.values(), ranges)):
for const in axis_constraints:
if const.args[1] == 1:
continue
divider = int(const.args[1])

if not self.tile_constraint[idx].fixed:
self.tile_constraint[idx].fixed = True
self._tile_size[idx] = divider
elif self.tile_constraint[idx].fixed and self._tile_size[idx] > divider:
self._tile_size[idx] = divider
self.update_tile_stride()

@staticmethod
def init_tile_size(ranges, vlane_stride, vector_lane):
# Logical tile init for ANY rank. Only the innermost dims carry the
Expand Down Expand Up @@ -520,7 +505,6 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N
vlane_stride=vlane_stride
)

self.implicit_dim_size = {}
self.nr_rdim = 0
self.offset = sympy.Integer(0) # Dram offset

Expand Down Expand Up @@ -686,59 +670,6 @@ def call_kernel(self, kernel_name):
# generate the code to call this
wrapper.generate_kernel_call(kernel_name, call_args, triton=False)

def is_modular_indexing(self, expr):
return "ModularIndexing" in str(expr)

def implicit_dim_ops(self, nodes):
target_patterns = (ModularIndexing, FloorDiv, Mod)
target_operands = []
for target_node in nodes:
for read_operand in target_node.read_writes.reads:
read_operand: MemoryDep
if isinstance(read_operand, StarDep) or isinstance(read_operand, WeakDep):
continue
read_index = read_operand.index
for arg_expr in read_index.args:
if arg_expr.atoms(*target_patterns):
target_operands.append(read_operand)
return target_operands

def extract_dividers(self, implicit_ops):
# When a specific axis is processed, the key constraint to verify is the divider.
# The tile size must be forced to match the divider size.
dim_dividers = defaultdict(set)
for operand in implicit_ops:
subs_map = {
s: sympy.symbols(s.name.replace("c", "index", 1))
for s in operand.index.free_symbols
}
rev_subs_map = {
sympy.symbols(s.name.replace("c", "index", 1)) : s
for s in operand.index.free_symbols
}
new_index = operand.index.subs(subs_map)
for arg in new_index.args:
if arg.is_number:
continue
if len(arg.free_symbols) > 1:
raise NotImplementedError("Not supporting this view operation...!")
if arg.is_Mul and arg.args[0].is_number:
arg = arg.args[1]

if isinstance(arg, ModularIndexing):
modular_expr = ModularIndexing(arg.args[0], arg.args[1], arg.args[2])
modular_expr.original_expr = arg
elif arg.is_symbol:
modular_expr = ModularIndexing(arg, 1, operand.ranges[rev_subs_map[arg]])
modular_expr.original_expr = arg
elif "//" in str(arg):
modular_expr = ModularIndexing(arg.args[0], arg.args[1], operand.ranges[rev_subs_map[arg.args[0]]]//arg.args[1])
modular_expr.original_expr = arg
else:
raise NotImplementedError("What is this case?")
dim_dividers[modular_expr.args[0]].add(modular_expr)
return dim_dividers

def compute_tile_size(self, nodes, vars, reduction_vars):
vlane_split_axis = len(vars) - 1
vlane_stride = 2 # Set minimum vlane stride
Expand All @@ -758,14 +689,6 @@ def compute_tile_size(self, nodes, vars, reduction_vars):
self.kernel_group.tile_desc.vmap.vlane_split_axis = 0
self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0]

# Handle implict dims. Input operand could be high dimension tensor.
# Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173
implicit_ops = self.implicit_dim_ops(nodes)
if implicit_ops:
tile_constraints = self.extract_dividers(implicit_ops)
self.kernel_group.tile_desc.apply_constraints(tile_constraints, self.ranges)
self.kernel_group.tile_desc.implicit_dim_size = tile_constraints

# Check recodegen reason
if self.recodegen is not None:
if self.recodegen == "spad_overflow":
Expand Down