Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 73 additions & 14 deletions PyTorchSimFrontend/extension_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,26 @@ def dump_metadata(args, arg_attributes, path):

with open(meta_path, "a") as file:
for (arg_name, arg_attribute), arg in zip(arg_attributes, args):
file.write(f'{arg_name}=({arg_attribute[0]}, {arg.dtype}, {arg.shape})\n')
if isinstance(arg, torch.Tensor):
file.write(f'{arg_name}=({arg_attribute[0]}, {arg.dtype}, {arg.shape})\n')
else:
# Dynamic shape: a scalar size argument (e.g. s52) -- not a tensor.
file.write(f'{arg_name}=({arg_attribute[0]}, {type(arg).__name__}, {arg})\n')
return

def _concretize_attrs_for_sampling(arg_attributes, tile):
"""Size the cycle-sampling host buffers to one tile. Under dynamic shape the
arg_attributes carry stringified symbolic extents (e.g. 's52'); the one-tile
sampling kernel only touches [0, tile) of each tensor, so replace any symbolic
numel/size with `tile` (a static int). Non-symbolic entries (e.g. the size
arg, numel 1) are left as is."""
cz = lambda v: tile if isinstance(v, str) else v
out = []
for name, (atype, dtype, numel, sizes, stride) in arg_attributes:
out.append([name, [atype, dtype, cz(numel), [cz(s) for s in sizes], stride]])
return out


def mlir_compile_command(filename, vectorlane_size, vlen=256):
# The C++ -dma-fine-grained and -test-pytorchsim-to-vcix passes are ported to
# Python (passes/dma_fine_grained.py, lower_to_vcix.py), run in-process between
Expand Down Expand Up @@ -172,7 +189,15 @@ def load(cls, source_code,
link_option = f"-Wl,--section-start=.spad=0x{spad_info['spad_vaddr']:x}"
else:
link_option = ""
# Generate LLVM kernel calller and binary for validation
# Generate LLVM kernel calller and binary for validation. The validation
# binary is shape-agnostic: under dynamic shape it reads the runtime extent
# from the size-arg buffer and sizes its host buffers from it
# (mlir_caller_codegen), so one binary serves any size -- like the producer.
# Dynamic shape: a kernel has a size-symbol arg (MLIR_ARGS_VAR) iff some dim
# is a runtime extent. Use that flag (authoritative) rather than sniffing the
# IR text.
from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs
is_dynamic_shape = any(MLIRKernelArgs.is_mlir_arg_var(attr[0]) for _, attr in arg_attributes)
if extension_config.pytorchsim_functional_mode:
# Use custom malloc to avoid size error
new_link_option = link_option + " -Wl,--wrap=malloc -Wl,--wrap=free"
Expand Down Expand Up @@ -230,7 +255,29 @@ def load(cls, source_code,
run_module_passes(sample_mlir_path + "_padded.mlir",
sample_mlir_path + "_postvcix.mlir",
POST_OPT_PASSES, vectorlane=vectorlane_size, vlen=vlen)
run_tog(sample_mlir_path + "_postvcix.mlir", raw_tog_path,
# Dynamic shape: gem5 measures per-tile compute cost, which is
# shape-invariant. Sample it on a one-tile copy (each symbolic loop
# bound pinned to its step) so the legacy cycle machinery runs on a
# concrete kernel, while the symbolic _postvcix.mlir is kept for the
# producer .so / cycle_table below.
# pin_loops_to_one_tile is general (static + dynamic); today it is
# wired only for dynamic, where the legacy full TOG cannot be built
# (symbolic trip count) and is skipped anyway. Driving the trace
# path's cycle sampling through it for STATIC too is the intended
# direction, but needs the sampling decoupled from run_tog first
# (run_tog also builds the legacy full TOG, which needs full loops).
tog_input = sample_mlir_path + "_postvcix.mlir"
sample_tile = None
if is_dynamic_shape:
import mlir.ir as _ir
from PyTorchSimFrontend.mlir.passes.cycle_table import pin_loops_to_one_tile
_ctx = _ir.Context(); _ctx.allow_unregistered_dialects = True
with _ctx:
_pm = _ir.Module.parse(open(tog_input).read(), _ctx)
sample_tile = pin_loops_to_one_tile(_pm)
tog_input = sample_mlir_path + "_pinned.mlir"
open(tog_input, "w").write(str(_pm))
run_tog(tog_input, raw_tog_path,
sample_mlir_path + "_custom.mlir",
sample_mode=extension_config.CONFIG_TLS_MODE,
vectorlane=vectorlane_size)
Expand All @@ -246,8 +293,13 @@ def load(cls, source_code,
if not extension_config.pytorchsim_timing_mode:
return key

# Generate MLIR kernel calller and binary for cycle calculation
cycle_llvm_caller = MLIRKernelCallerCodeGen(False, arg_attributes, cycle_sim=True)
# Generate MLIR kernel calller and binary for cycle calculation.
# Dynamic shape: size the host buffers to one tile (the sampling kernel
# was pinned to a single tile above); arg_attributes carry symbolic
# extents that cannot size a buffer.
sample_attrs = (_concretize_attrs_for_sampling(arg_attributes, sample_tile)
if is_dynamic_shape else arg_attributes)
cycle_llvm_caller = MLIRKernelCallerCodeGen(False, sample_attrs, cycle_sim=True)
cycle_llvm_caller.generate_wrapper_file(write_path, cycle_wrapper_name)
cycle_llvm_caller.compile_wih_kernel(write_path, key + "_sample", cycle_wrapper_name, cycle_binary_name, link_option)

Expand All @@ -273,15 +325,20 @@ def load(cls, source_code,
if kwargs['loop_size'] is not None and kwargs['loop_size'][-1] < vectorlane_size:
w_offset = kwargs['loop_size'][-1]
w_offset = 0 # max(w_offset - x_offset, 0)
tile_graph_generator = tog_generator(origins)
tile_graph_generator.load_file(raw_tog_path)
tile_graph_generator.generate_tile_graph(
tog_path,
cycle_list=cycle_list,
x_offset=x_offset, # FIXME.
w_offset=w_offset, # FIXME.
vector_lane=vectorlane_size
)
# DEPRECATED legacy ONNX-TOG output (tile_graph.onnx); unused when the
# trace pipeline is the default sim path. It enumerates tiles statically,
# so it cannot be built for a dynamic (runtime-extent) kernel -- skip it.
# x_offset/w_offset above are still needed by the trace cycle_table.
if not is_dynamic_shape:
tile_graph_generator = tog_generator(origins)
tile_graph_generator.load_file(raw_tog_path)
tile_graph_generator.generate_tile_graph(
tog_path,
cycle_list=cycle_list,
x_offset=x_offset, # FIXME.
w_offset=w_offset, # FIXME.
vector_lane=vectorlane_size
)

# Trace pipeline (DEFAULT): emit the compiled trace producer .so + the
# cycle-table TSV from the post-vcix IR and gem5 cycle_list/offsets. This
Expand Down Expand Up @@ -341,6 +398,8 @@ def run_kernel_simulation(*args, autotune_subprocess_timeout_sec=None, **kwargs)
# Dump arguments and meta data
dump_metadata(args, arg_attributes, result_path)
runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path)
# The runtime extents reach the simulator via the attribute YAML
# (write_kernel_attribute_file -> shape_args), not from here.
if extension_config.pytorchsim_functional_mode and not autotune:
funcsim = FunctionalSimulator(result_path, key)
funcsim.run_spike(args, arg_attributes,
Expand Down
147 changes: 119 additions & 28 deletions PyTorchSimFrontend/mlir/axis_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,130 @@ def _as_int(x):
return None


# --- symbolic-aware boundary arithmetic ------------------------------------
# These reduce EXACTLY to the integer case when their operands are concrete, so
# static axis splitting is unchanged; they additionally accept symbolic size
# expressions (e.g. a flattened reshape extent E = M*N with divisor N), where a
# boundary that is a genuine product of dims divides the extent by construction.
# A dynamic dim symbol is created integer/positive, so sympy proves the
# divisibility (Mod(M*N, N) -> 0) and the quotient (cancel(M*N/N) -> M).

def _divides(d, E):
"""True iff d divides E. For concrete ints this is `E % d == 0`."""
di, Ei = _as_int(d), _as_int(E)
if di is not None and Ei is not None:
return di != 0 and Ei % di == 0
try:
return bool(sympy.simplify(sympy.Mod(E, d)) == 0)
except Exception:
return False


def _eq(a, b):
"""Provable equality of two size exprs (structural for ints)."""
ai, bi = _as_int(a), _as_int(b)
if ai is not None and bi is not None:
return ai == bi
try:
return bool(sympy.simplify(a - b) == 0)
except Exception:
return a == b


def _gt1(x):
"""True iff x is a non-trivial boundary (> 1). A symbolic dim is assumed > 1."""
xi = _as_int(x)
if xi is not None:
return xi > 1
return not _eq(x, sympy.Integer(1))


def _proper(b, E):
"""True iff b is a proper interior divisor of E: 1 < b < E and b | E."""
bi, Ei = _as_int(b), _as_int(E)
if bi is not None and Ei is not None:
return 1 < bi < Ei and Ei % bi == 0
return _gt1(b) and not _eq(b, E) and _divides(b, E)


def _quotient(a, b):
"""a / b as an exact int (concrete) or simplified sympy expr (symbolic)."""
ai, bi = _as_int(a), _as_int(b)
if ai is not None and bi is not None:
return ai // bi
return sympy.cancel(a / b)


def _as_size(x):
"""Wrap a concrete int as sympy.Integer; pass a sympy expr through unchanged
(preserving its integer/positive assumptions)."""
xi = _as_int(x)
return sympy.Integer(xi) if xi is not None else x


def _ordered_chain(boundaries, E):
"""Order the proper divisors of E into a divisibility chain [1, ..., E], else None.

Generalises the old `_is_chain` + numeric `sorted`: orders by the divisibility
partial order (b_i precedes b_j iff b_i | b_j) rather than by numeric value, so
symbolic boundaries (suffix-products of dims, e.g. N | M*N) chain correctly. For
concrete ints this yields exactly the old ascending divisibility chain. Returns
None when the boundaries do not form a TOTAL divisibility chain (the
incompatible-radix / misaligned case), so the axis is left unsplit.
"""
bs = []
for b in boundaries:
if _proper(b, E) and not any(_eq(b, x) for x in bs):
bs.append(b)
ordered = []
remaining = list(bs)
while remaining:
# the divisibility-minimum is the unique element that divides all others.
mins = [b for b in remaining
if all(_divides(b, o) for o in remaining if not _eq(b, o))]
if len(mins) != 1:
return None # no unique minimum -> incomparable -> not a chain
ordered.append(mins[0])
remaining = [o for o in remaining if not _eq(o, mins[0])]
chain = [sympy.Integer(1)] + ordered + [_as_size(E)]
for i in range(len(chain) - 1):
if not _divides(chain[i], chain[i + 1]):
return None
return chain


def collect_boundaries(exprs, var_to_axis, var_ranges):
"""{axis_index: set(boundary cut points)} for the given index expressions.

A FloorDiv(v, k) contributes boundary k; ModularIndexing(v, k, m) contributes
k and k*m. Only aligned terms count (boundary divides the var extent). Shared
by find_split_plan (fused LoopBody) and graph_copy (operand loaders).
by find_split_plan (fused LoopBody) and graph_copy (operand loaders). Boundaries
and extents may be symbolic (dynamic reshape); divisibility is checked via
`_divides`, so a symbolic divisor that is a genuine factor of the extent counts.
"""
import collections
bset = collections.defaultdict(set)
for expr in exprs:
for fd in expr.atoms(FloorDiv):
base, div = fd.args
k = _as_int(div)
if base in var_to_axis and k and k > 1:
E = _as_int(var_ranges.get(base))
if E and E % k == 0:
bset[var_to_axis[base]].add(k)
if base in var_to_axis and _gt1(div):
E = var_ranges.get(base)
if E is not None and _divides(div, E):
bset[var_to_axis[base]].add(div)
for mi in expr.atoms(ModularIndexing):
base, div, mod = mi.args
k, m = _as_int(div), _as_int(mod)
if base in var_to_axis and k and m:
E = _as_int(var_ranges.get(base))
if E and E % (k * m) == 0:
if base in var_to_axis:
E = var_ranges.get(base)
km = div * mod
if E is not None and _divides(km, E):
ax = var_to_axis[base]
if k > 1:
bset[ax].add(k)
if k * m < E:
bset[ax].add(k * m)
if _gt1(div):
bset[ax].add(div)
if _proper(km, E):
bset[ax].add(km)
return bset


def _is_chain(boundaries, E):
"""True iff [1, sorted(boundaries in (1,E)), E] is a divisibility chain."""
chain = [1] + sorted(b for b in boundaries if 1 < b < E) + [E]
return all(chain[i + 1] % chain[i] == 0 for i in range(len(chain) - 1))


def find_split_plan(nodes):
"""Inspect a group of scheduler nodes and return {axis_index: boundaries}.

Expand All @@ -80,13 +167,14 @@ def find_split_plan(nodes):
collected boundaries for an axis do NOT form a divisibility chain (e.g.
floor-by-2 and mod-by-3 on extent 6), the radices are incompatible -> the axis
is left unsplit (its floor/mod stays for the misaligned/recompile path).
Boundaries/extents may be symbolic (see _ordered_chain).

axis_index is positional in the group's iteration space, so the same plan
applies to every fused node sharing that space.
"""
import collections
bset = collections.defaultdict(set) # axis -> set of boundary cut points
ext_of = {} # axis -> extent
ext_of = {} # axis -> extent (int or symbolic)
for n in nodes:
body = getattr(n, "_body", None)
if body is None:
Expand All @@ -95,14 +183,17 @@ def find_split_plan(nodes):
nb = collect_boundaries(body.indexing_exprs.values(), var_to_axis, body.var_ranges)
for ax, bs in nb.items():
bset[ax] |= bs
ext_of[ax] = _as_int(body.var_ranges[body.iter_vars[ax]])
ext_of[ax] = body.var_ranges[body.iter_vars[ax]]

plan = {}
for ax, bs in bset.items():
E = ext_of[ax]
E = ext_of.get(ax)
if E is None:
continue
# require a real, divisibility-chain split (incompatible radices -> skip).
if E and any(1 < b < E for b in bs) and _is_chain(bs, E):
plan[ax] = [1] + sorted(b for b in bs if 1 < b < E) + [E]
chain = _ordered_chain(bs, E)
if chain is not None and len(chain) > 2:
plan[ax] = chain

# A split may push the per-axis index rank past 4. The resulting >4D logical tile
# is peeled into <=4D physical descriptors by the decompose-transfer pass (an
Expand Down Expand Up @@ -143,15 +234,15 @@ def build_split_body(node, plan, prefix="z"):
subs = [] # (symbol, extent, significance) low->high
expr = sympy.Integer(0)
for i in range(len(bounds) - 1):
seg_ext = bounds[i + 1] // bounds[i]
seg_ext = _quotient(bounds[i + 1], bounds[i])
nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1
subs.append((nv, seg_ext, bounds[i]))
expr = expr + nv * bounds[i]
# iteration nest: most-significant (outermost) dim first.
for nv, seg_ext, _sig in reversed(subs):
iter_vars.append(nv)
var_ranges[nv] = sympy.Integer(seg_ext)
index_size.append(sympy.Integer(seg_ext))
var_ranges[nv] = _as_size(seg_ext)
index_size.append(_as_size(seg_ext))
index_args.append(expr)
else:
nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1
Expand Down
Loading