diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 785a3d95..c3395ec2 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -41,9 +41,26 @@ def dump_metadata(args, arg_attributes, path): with open(meta_path, "a") as file: for (arg_name, arg_attribute), arg in zip(arg_attributes, args): - file.write(f'{arg_name}=({arg_attribute[0]}, {arg.dtype}, {arg.shape})\n') + if isinstance(arg, torch.Tensor): + file.write(f'{arg_name}=({arg_attribute[0]}, {arg.dtype}, {arg.shape})\n') + else: + # Dynamic shape: a scalar size argument (e.g. s52) -- not a tensor. + file.write(f'{arg_name}=({arg_attribute[0]}, {type(arg).__name__}, {arg})\n') return +def _concretize_attrs_for_sampling(arg_attributes, tile): + """Size the cycle-sampling host buffers to one tile. Under dynamic shape the + arg_attributes carry stringified symbolic extents (e.g. 's52'); the one-tile + sampling kernel only touches [0, tile) of each tensor, so replace any symbolic + numel/size with `tile` (a static int). Non-symbolic entries (e.g. the size + arg, numel 1) are left as is.""" + cz = lambda v: tile if isinstance(v, str) else v + out = [] + for name, (atype, dtype, numel, sizes, stride) in arg_attributes: + out.append([name, [atype, dtype, cz(numel), [cz(s) for s in sizes], stride]]) + return out + + def mlir_compile_command(filename, vectorlane_size, vlen=256): # The C++ -dma-fine-grained and -test-pytorchsim-to-vcix passes are ported to # Python (passes/dma_fine_grained.py, lower_to_vcix.py), run in-process between @@ -172,7 +189,15 @@ def load(cls, source_code, link_option = f"-Wl,--section-start=.spad=0x{spad_info['spad_vaddr']:x}" else: link_option = "" - # Generate LLVM kernel calller and binary for validation + # Generate LLVM kernel calller and binary for validation. The validation + # binary is shape-agnostic: under dynamic shape it reads the runtime extent + # from the size-arg buffer and sizes its host buffers from it + # (mlir_caller_codegen), so one binary serves any size -- like the producer. + # Dynamic shape: a kernel has a size-symbol arg (MLIR_ARGS_VAR) iff some dim + # is a runtime extent. Use that flag (authoritative) rather than sniffing the + # IR text. + from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs + is_dynamic_shape = any(MLIRKernelArgs.is_mlir_arg_var(attr[0]) for _, attr in arg_attributes) if extension_config.pytorchsim_functional_mode: # Use custom malloc to avoid size error new_link_option = link_option + " -Wl,--wrap=malloc -Wl,--wrap=free" @@ -230,7 +255,29 @@ def load(cls, source_code, run_module_passes(sample_mlir_path + "_padded.mlir", sample_mlir_path + "_postvcix.mlir", POST_OPT_PASSES, vectorlane=vectorlane_size, vlen=vlen) - run_tog(sample_mlir_path + "_postvcix.mlir", raw_tog_path, + # Dynamic shape: gem5 measures per-tile compute cost, which is + # shape-invariant. Sample it on a one-tile copy (each symbolic loop + # bound pinned to its step) so the legacy cycle machinery runs on a + # concrete kernel, while the symbolic _postvcix.mlir is kept for the + # producer .so / cycle_table below. + # pin_loops_to_one_tile is general (static + dynamic); today it is + # wired only for dynamic, where the legacy full TOG cannot be built + # (symbolic trip count) and is skipped anyway. Driving the trace + # path's cycle sampling through it for STATIC too is the intended + # direction, but needs the sampling decoupled from run_tog first + # (run_tog also builds the legacy full TOG, which needs full loops). + tog_input = sample_mlir_path + "_postvcix.mlir" + sample_tile = None + if is_dynamic_shape: + import mlir.ir as _ir + from PyTorchSimFrontend.mlir.passes.cycle_table import pin_loops_to_one_tile + _ctx = _ir.Context(); _ctx.allow_unregistered_dialects = True + with _ctx: + _pm = _ir.Module.parse(open(tog_input).read(), _ctx) + sample_tile = pin_loops_to_one_tile(_pm) + tog_input = sample_mlir_path + "_pinned.mlir" + open(tog_input, "w").write(str(_pm)) + run_tog(tog_input, raw_tog_path, sample_mlir_path + "_custom.mlir", sample_mode=extension_config.CONFIG_TLS_MODE, vectorlane=vectorlane_size) @@ -246,8 +293,13 @@ def load(cls, source_code, if not extension_config.pytorchsim_timing_mode: return key - # Generate MLIR kernel calller and binary for cycle calculation - cycle_llvm_caller = MLIRKernelCallerCodeGen(False, arg_attributes, cycle_sim=True) + # Generate MLIR kernel calller and binary for cycle calculation. + # Dynamic shape: size the host buffers to one tile (the sampling kernel + # was pinned to a single tile above); arg_attributes carry symbolic + # extents that cannot size a buffer. + sample_attrs = (_concretize_attrs_for_sampling(arg_attributes, sample_tile) + if is_dynamic_shape else arg_attributes) + cycle_llvm_caller = MLIRKernelCallerCodeGen(False, sample_attrs, cycle_sim=True) cycle_llvm_caller.generate_wrapper_file(write_path, cycle_wrapper_name) cycle_llvm_caller.compile_wih_kernel(write_path, key + "_sample", cycle_wrapper_name, cycle_binary_name, link_option) @@ -273,15 +325,20 @@ def load(cls, source_code, if kwargs['loop_size'] is not None and kwargs['loop_size'][-1] < vectorlane_size: w_offset = kwargs['loop_size'][-1] w_offset = 0 # max(w_offset - x_offset, 0) - tile_graph_generator = tog_generator(origins) - tile_graph_generator.load_file(raw_tog_path) - tile_graph_generator.generate_tile_graph( - tog_path, - cycle_list=cycle_list, - x_offset=x_offset, # FIXME. - w_offset=w_offset, # FIXME. - vector_lane=vectorlane_size - ) + # DEPRECATED legacy ONNX-TOG output (tile_graph.onnx); unused when the + # trace pipeline is the default sim path. It enumerates tiles statically, + # so it cannot be built for a dynamic (runtime-extent) kernel -- skip it. + # x_offset/w_offset above are still needed by the trace cycle_table. + if not is_dynamic_shape: + tile_graph_generator = tog_generator(origins) + tile_graph_generator.load_file(raw_tog_path) + tile_graph_generator.generate_tile_graph( + tog_path, + cycle_list=cycle_list, + x_offset=x_offset, # FIXME. + w_offset=w_offset, # FIXME. + vector_lane=vectorlane_size + ) # Trace pipeline (DEFAULT): emit the compiled trace producer .so + the # cycle-table TSV from the post-vcix IR and gem5 cycle_list/offsets. This @@ -341,6 +398,8 @@ def run_kernel_simulation(*args, autotune_subprocess_timeout_sec=None, **kwargs) # Dump arguments and meta data dump_metadata(args, arg_attributes, result_path) runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path) + # The runtime extents reach the simulator via the attribute YAML + # (write_kernel_attribute_file -> shape_args), not from here. if extension_config.pytorchsim_functional_mode and not autotune: funcsim = FunctionalSimulator(result_path, key) funcsim.run_spike(args, arg_attributes, diff --git a/PyTorchSimFrontend/mlir/axis_split.py b/PyTorchSimFrontend/mlir/axis_split.py index 71ec4809..15404bd0 100644 --- a/PyTorchSimFrontend/mlir/axis_split.py +++ b/PyTorchSimFrontend/mlir/axis_split.py @@ -29,43 +29,130 @@ def _as_int(x): return None +# --- symbolic-aware boundary arithmetic ------------------------------------ +# These reduce EXACTLY to the integer case when their operands are concrete, so +# static axis splitting is unchanged; they additionally accept symbolic size +# expressions (e.g. a flattened reshape extent E = M*N with divisor N), where a +# boundary that is a genuine product of dims divides the extent by construction. +# A dynamic dim symbol is created integer/positive, so sympy proves the +# divisibility (Mod(M*N, N) -> 0) and the quotient (cancel(M*N/N) -> M). + +def _divides(d, E): + """True iff d divides E. For concrete ints this is `E % d == 0`.""" + di, Ei = _as_int(d), _as_int(E) + if di is not None and Ei is not None: + return di != 0 and Ei % di == 0 + try: + return bool(sympy.simplify(sympy.Mod(E, d)) == 0) + except Exception: + return False + + +def _eq(a, b): + """Provable equality of two size exprs (structural for ints).""" + ai, bi = _as_int(a), _as_int(b) + if ai is not None and bi is not None: + return ai == bi + try: + return bool(sympy.simplify(a - b) == 0) + except Exception: + return a == b + + +def _gt1(x): + """True iff x is a non-trivial boundary (> 1). A symbolic dim is assumed > 1.""" + xi = _as_int(x) + if xi is not None: + return xi > 1 + return not _eq(x, sympy.Integer(1)) + + +def _proper(b, E): + """True iff b is a proper interior divisor of E: 1 < b < E and b | E.""" + bi, Ei = _as_int(b), _as_int(E) + if bi is not None and Ei is not None: + return 1 < bi < Ei and Ei % bi == 0 + return _gt1(b) and not _eq(b, E) and _divides(b, E) + + +def _quotient(a, b): + """a / b as an exact int (concrete) or simplified sympy expr (symbolic).""" + ai, bi = _as_int(a), _as_int(b) + if ai is not None and bi is not None: + return ai // bi + return sympy.cancel(a / b) + + +def _as_size(x): + """Wrap a concrete int as sympy.Integer; pass a sympy expr through unchanged + (preserving its integer/positive assumptions).""" + xi = _as_int(x) + return sympy.Integer(xi) if xi is not None else x + + +def _ordered_chain(boundaries, E): + """Order the proper divisors of E into a divisibility chain [1, ..., E], else None. + + Generalises the old `_is_chain` + numeric `sorted`: orders by the divisibility + partial order (b_i precedes b_j iff b_i | b_j) rather than by numeric value, so + symbolic boundaries (suffix-products of dims, e.g. N | M*N) chain correctly. For + concrete ints this yields exactly the old ascending divisibility chain. Returns + None when the boundaries do not form a TOTAL divisibility chain (the + incompatible-radix / misaligned case), so the axis is left unsplit. + """ + bs = [] + for b in boundaries: + if _proper(b, E) and not any(_eq(b, x) for x in bs): + bs.append(b) + ordered = [] + remaining = list(bs) + while remaining: + # the divisibility-minimum is the unique element that divides all others. + mins = [b for b in remaining + if all(_divides(b, o) for o in remaining if not _eq(b, o))] + if len(mins) != 1: + return None # no unique minimum -> incomparable -> not a chain + ordered.append(mins[0]) + remaining = [o for o in remaining if not _eq(o, mins[0])] + chain = [sympy.Integer(1)] + ordered + [_as_size(E)] + for i in range(len(chain) - 1): + if not _divides(chain[i], chain[i + 1]): + return None + return chain + + def collect_boundaries(exprs, var_to_axis, var_ranges): """{axis_index: set(boundary cut points)} for the given index expressions. A FloorDiv(v, k) contributes boundary k; ModularIndexing(v, k, m) contributes k and k*m. Only aligned terms count (boundary divides the var extent). Shared - by find_split_plan (fused LoopBody) and graph_copy (operand loaders). + by find_split_plan (fused LoopBody) and graph_copy (operand loaders). Boundaries + and extents may be symbolic (dynamic reshape); divisibility is checked via + `_divides`, so a symbolic divisor that is a genuine factor of the extent counts. """ import collections bset = collections.defaultdict(set) for expr in exprs: for fd in expr.atoms(FloorDiv): base, div = fd.args - k = _as_int(div) - if base in var_to_axis and k and k > 1: - E = _as_int(var_ranges.get(base)) - if E and E % k == 0: - bset[var_to_axis[base]].add(k) + if base in var_to_axis and _gt1(div): + E = var_ranges.get(base) + if E is not None and _divides(div, E): + bset[var_to_axis[base]].add(div) for mi in expr.atoms(ModularIndexing): base, div, mod = mi.args - k, m = _as_int(div), _as_int(mod) - if base in var_to_axis and k and m: - E = _as_int(var_ranges.get(base)) - if E and E % (k * m) == 0: + if base in var_to_axis: + E = var_ranges.get(base) + km = div * mod + if E is not None and _divides(km, E): ax = var_to_axis[base] - if k > 1: - bset[ax].add(k) - if k * m < E: - bset[ax].add(k * m) + if _gt1(div): + bset[ax].add(div) + if _proper(km, E): + bset[ax].add(km) return bset -def _is_chain(boundaries, E): - """True iff [1, sorted(boundaries in (1,E)), E] is a divisibility chain.""" - chain = [1] + sorted(b for b in boundaries if 1 < b < E) + [E] - return all(chain[i + 1] % chain[i] == 0 for i in range(len(chain) - 1)) - - def find_split_plan(nodes): """Inspect a group of scheduler nodes and return {axis_index: boundaries}. @@ -80,13 +167,14 @@ def find_split_plan(nodes): collected boundaries for an axis do NOT form a divisibility chain (e.g. floor-by-2 and mod-by-3 on extent 6), the radices are incompatible -> the axis is left unsplit (its floor/mod stays for the misaligned/recompile path). + Boundaries/extents may be symbolic (see _ordered_chain). axis_index is positional in the group's iteration space, so the same plan applies to every fused node sharing that space. """ import collections bset = collections.defaultdict(set) # axis -> set of boundary cut points - ext_of = {} # axis -> extent + ext_of = {} # axis -> extent (int or symbolic) for n in nodes: body = getattr(n, "_body", None) if body is None: @@ -95,14 +183,17 @@ def find_split_plan(nodes): nb = collect_boundaries(body.indexing_exprs.values(), var_to_axis, body.var_ranges) for ax, bs in nb.items(): bset[ax] |= bs - ext_of[ax] = _as_int(body.var_ranges[body.iter_vars[ax]]) + ext_of[ax] = body.var_ranges[body.iter_vars[ax]] plan = {} for ax, bs in bset.items(): - E = ext_of[ax] + E = ext_of.get(ax) + if E is None: + continue # require a real, divisibility-chain split (incompatible radices -> skip). - if E and any(1 < b < E for b in bs) and _is_chain(bs, E): - plan[ax] = [1] + sorted(b for b in bs if 1 < b < E) + [E] + chain = _ordered_chain(bs, E) + if chain is not None and len(chain) > 2: + plan[ax] = chain # A split may push the per-axis index rank past 4. The resulting >4D logical tile # is peeled into <=4D physical descriptors by the decompose-transfer pass (an @@ -143,15 +234,15 @@ def build_split_body(node, plan, prefix="z"): subs = [] # (symbol, extent, significance) low->high expr = sympy.Integer(0) for i in range(len(bounds) - 1): - seg_ext = bounds[i + 1] // bounds[i] + seg_ext = _quotient(bounds[i + 1], bounds[i]) nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1 subs.append((nv, seg_ext, bounds[i])) expr = expr + nv * bounds[i] # iteration nest: most-significant (outermost) dim first. for nv, seg_ext, _sig in reversed(subs): iter_vars.append(nv) - var_ranges[nv] = sympy.Integer(seg_ext) - index_size.append(sympy.Integer(seg_ext)) + var_ranges[nv] = _as_size(seg_ext) + index_size.append(_as_size(seg_ext)) index_args.append(expr) else: nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1 diff --git a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py index 7c842272..bdb71be5 100644 --- a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py +++ b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py @@ -34,6 +34,32 @@ def get_argv_idx(self): self.arg_use_count += 1 return self.arg_use_count-1 + def _is_var(self, flag): + return MLIRKernelArgs.is_mlir_arg_var(flag) + + @staticmethod + def _is_symbol(numel): + """A numel that is a size SYMBOL (e.g. 's52'), not a concrete value. Concrete + sizes may also be strings here (the meta stringifies sympy.Integer, e.g. + '128'); those are numeric, a symbol is not.""" + return isinstance(numel, str) and not numel.isdigit() + + def _numel_c_expr(self, numel): + """C expression for an arg's element count. Dynamic shape: a size SYMBOL is + the runtime extent, read into `N_` from its size buffer (see + generate_args_define); a concrete numel (int or numeric string) is a literal.""" + return f"N_{numel}" if self._is_symbol(numel) else str(numel) + + def _assign_argv_indices(self): + """Assign each loaded/dumped arg an argv slot in arg_attributes order, the + same order Simulator.dump_args writes the .raw paths. Size (VAR) args get a + slot too (they are kernel inputs).""" + for arg_name, arg_attribute in self.arg_attributes: + flag = arg_attribute[0] + if (self.is_in_arg(flag) or self.is_out_arg(flag) or self._is_var(flag)) \ + and arg_name not in self.load_args: + self.load_args[arg_name] = self.get_argv_idx() + def write_header(self): self.writeline('#include ') self.writeline('#include ') @@ -56,12 +82,12 @@ def is_inout_arg(self, value): def load_arg(self): for arg_name, arg_attribute in self.arg_attributes: - if self.is_in_arg(arg_attribute[0]): - argv_idx = self.get_argv_idx() if arg_name not in self.load_args else self.load_args[arg_name] - self.load_args[arg_name] = argv_idx + # VAR (size) args are loaded in generate_args_define (before the tensor + # buffers they size); skip them here. + if self.is_in_arg(arg_attribute[0]) and not self._is_var(arg_attribute[0]): + argv_idx = self.load_args[arg_name] ctype = DTYPE_TO_C[arg_attribute[1]] - elem_count = arg_attribute[2] - size_expr = f'({elem_count}ULL * sizeof({ctype}))' + size_expr = f'((uint64_t)({self._numel_c_expr(arg_attribute[2])}) * sizeof({ctype}))' self.writeline(f'if(load_arg(c_{arg_name}, {size_expr}, argv[{argv_idx}]) == -1){self.open_bracket}') with self.code.indent(): @@ -71,10 +97,9 @@ def load_arg(self): def dump_arg(self): for arg_name, arg_attribute in self.arg_attributes: if self.is_out_arg(arg_attribute[0]): - argv_idx = self.get_argv_idx() if not self.is_inout_arg(arg_attribute[0]) else self.load_args[arg_name] + argv_idx = self.load_args[arg_name] ctype = DTYPE_TO_C[arg_attribute[1]] - elem_count = arg_attribute[2] - size_expr = f'({elem_count}ULL * sizeof({ctype}))' + size_expr = f'((uint64_t)({self._numel_c_expr(arg_attribute[2])}) * sizeof({ctype}))' self.writeline(f'if(dump_arg(c_{arg_name}, {size_expr}, argv[{argv_idx}]) == -1){self.open_bracket}') with self.code.indent(): self.writeline(f'return -1{self.ending}') @@ -93,30 +118,53 @@ def generate_args_define(self): name_set = set() if self.validation: self.writeline(f"int* padding = malloc(0x100000ULL * sizeof(int)){self.ending}") - for arg_name, (_, arg_type, arg_size, arg_sizes, arg_stride) in self.arg_attributes: - if not arg_name in name_set: - if torch.is_floating_point(torch.tensor([], dtype=arg_type)): - bits = torch.finfo(arg_type).bits - elif arg_type == torch.bool: - bits = 8 - else: - bits = torch.iinfo(arg_type).bits - buffer_size = int(math.ceil(arg_size * bits // 8 / 64) * 64) * 2 # Round up to 64 bytes + Add some padding for safety - self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({buffer_size}ULL){self.ending}') - name_set.add(arg_name) + # Dynamic shape: handle size (VAR) args first -- malloc, load from argv, and + # read the runtime extent into N_, BEFORE the tensor buffers, which are + # sized from it. + for arg_name, (flag, arg_type, arg_size, _, _) in self.arg_attributes: + if not self._is_var(flag) or arg_name in name_set: + continue + ctype = DTYPE_TO_C[arg_type] + self.writeline(f'{ctype}* c_{arg_name} = malloc(64ULL){self.ending}') + if self.validation: + self.writeline(f'if(load_arg(c_{arg_name}, sizeof(int64_t), argv[{self.load_args[arg_name]}]) == -1){self.open_bracket}') + with self.code.indent(): + self.writeline(f'return -1{self.ending}') + self.writeline(self.closed_bracket) + self.writeline(f'int64_t N_{arg_name} = ((int64_t*)c_{arg_name})[0]{self.ending}') + name_set.add(arg_name) + for arg_name, (flag, arg_type, arg_size, arg_sizes, arg_stride) in self.arg_attributes: + if self._is_var(flag) or arg_name in name_set: + continue + if torch.is_floating_point(torch.tensor([], dtype=arg_type)): + bits = torch.finfo(arg_type).bits + elif arg_type == torch.bool: + bits = 8 + else: + bits = torch.iinfo(arg_type).bits + ctype = DTYPE_TO_C[arg_type] + if self._is_symbol(arg_size): + # runtime extent: round bytes up to 64 and double, computed in C. + nbytes = f"(N_{arg_size} * {bits} / 8)" + buffer_size = f"((({nbytes} + 63) / 64) * 64) * 2" + else: + buffer_size = f"{int(math.ceil(int(arg_size) * bits // 8 / 64) * 64) * 2}ULL" # round up to 64 bytes + safety pad + self.writeline(f'{ctype}* c_{arg_name} = malloc({buffer_size}){self.ending}') + name_set.add(arg_name) self.writeline(self.newline) def generate_main(self): self.writeline(f'{self.newline}int main(int argc, char *argv[]) {self.open_bracket}{self.newline}') with self.code.indent(): if self.validation: + self._assign_argv_indices() # argv slots in arg order (incl. size args) self.generate_args_define() self.load_arg() self.writeline(self.newline) else: self.generate_args_define() - func_arguments = [f"c_{arg_name}, c_{arg_name}, 0, {arg_shape}, 1" for arg_name, (_, arg_type, arg_shape, _, _) in self.arg_attributes] + func_arguments = [f"c_{arg_name}, c_{arg_name}, 0, {self._numel_c_expr(arg_shape)}, 1" for arg_name, (_, arg_type, arg_shape, _, _) in self.arg_attributes] self.writeline(f"wrapper_{self.kernel_name}({', '.join(func_arguments)}){self.ending}{self.newline}") if self.validation: diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 8f695395..48513249 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -903,6 +903,19 @@ def codegen_loops(self): code.splice(self.const_buffer) code.splice(self.alloc_buffer) code.splice(self.spad_buffer) + # Dynamic shape: materialize each symbolic loop bound as an index SSA at the + # function top level (a valid affine symbol). The extent arrives as a + # memref<1xi64> arg named after the symbol (mlir_argdefs sizevars); load and + # cast it once before the loop nest. LoopLevel._bound_str emits %_bound. + dyn_syms = [] + for lp in loops.loops + reductions.loops: + if isinstance(lp.size, sympy.Symbol) and lp.size.name not in dyn_syms: + dyn_syms.append(lp.size.name) + if dyn_syms: + code.writeline("%dyn_zero = arith.constant 0 : index") + for nm in dyn_syms: + code.writeline(f"%{nm}_val = memref.load %{nm}[%dyn_zero] : memref<1xi64>") + code.writeline(f"%{nm}_bound = arith.index_cast %{nm}_val : i64 to index") # Outerloop with contextlib.ExitStack() as stack: for loop in loops.loops: @@ -980,8 +993,12 @@ def make_choices(self, nodes, kernel_name): for axis in list(candidate_axes): prev_tile_sz = self.kernel_group.tile_desc.get_tile_size() - # If tile size is maximized for this axis, remove from candidate axes - if prev_tile_sz[axis] >= prev_ranges[axis] * 2 or prev_tile_sz[axis] >= 2 ** 13: + # If tile size is maximized for this axis, remove from candidate axes. + # Dynamic shape: a symbolic dim has no compile-time bound to grow the + # tile toward, so drop the axis (keep the fixed tile) rather than + # comparing tile >= sympy*2 (cannot determine truth value). + if mlir_common.is_symbolic_dim(prev_ranges[axis]) or \ + prev_tile_sz[axis] >= prev_ranges[axis] * 2 or prev_tile_sz[axis] >= 2 ** 13: candidate_axes.remove(axis) self.reset(None) continue diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index a70d1c7d..61f0058e 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -96,6 +96,14 @@ def get_dtype_nbytes(dtype): raise NotImplementedError(f"Unsupported dtype for precision calculation: {dtype}") return MLIR_TO_BIT[mlir_dtype] // 8 +def is_symbolic_dim(x): + """True if `x` is a runtime (symbolic) dimension -- a sympy expression that is + not a compile-time constant. Dynamic shape (torch.compile(dynamic=True)) makes a + loop range / dim such a symbol (e.g. ks0); the tiling and bound-emission paths + must skip their concrete-int arithmetic for it. Single predicate for every such + guard (a concrete sympy.Integer is NOT symbolic).""" + return isinstance(x, sympy.Expr) and not x.is_number + DTYPE_LOWP_FP = [ torch.bfloat16, torch.float16, @@ -174,10 +182,20 @@ def is_mlir_arg_out(value): def is_mlir_arg_inout(value): return MLIRKernelArgs.MLIR_ARGS_INOUT & value + @staticmethod + def is_mlir_arg_var(value): + # A size-symbol arg (a dynamic extent passed as a scalar), not a tensor. + return bool(MLIRKernelArgs.MLIR_ARGS_VAR & value) + @staticmethod def get_mlir_shape(info): tensor_type = DTYPE_TO_MLIR[info[0]] - return f"memref<{info[1]}x{tensor_type}>" + numel = info[1] + # Dynamic shape: a symbolic numel becomes a dynamic memref dim ("?"); the + # actual extent arrives at runtime via the size-symbol arg (mlir_argdefs + # sizevars) and is materialized as the loop bound (codegen_loops). + dim = "?" if isinstance(numel, sympy.Expr) else numel + return f"memref<{dim}x{tensor_type}>" def mlir_argdefs(self, extra_node=dict()): buffer_types = {} @@ -224,7 +242,15 @@ def set_info(outer, inner, arg_type): continue set_info(outer, inner, self.MLIR_ARGS_OUT) for outer, inner in self.sizevars.items(): - set_info(outer, inner, self.MLIR_ARGS_VAR) + # Dynamic shape: a size symbol (e.g. s52) is not a buffer/graph_input/ + # constant, so buffer_types has no entry for it. Key it by its NAME (str) + # like a buffer -- the symbol's name is also the host-side SymInt variable + # the wrapper passes at the call site -- and describe it as a scalar int + # (-> memref<1x i64>), mirroring the sympy graph_input case above. + name = str(outer) + if name not in buffer_types: + buffer_types[name] = [get_sympy_Expr_dtype(outer), 1, [1], [1]] + set_info(name, inner, self.MLIR_ARGS_VAR) return arg_defs, call_args, arg_attributes, buffer_types class VectorLaneMapping(): @@ -328,6 +354,16 @@ def is_dim_dividable(self, dim_sizes: list[int]) -> bool: if len(dim_sizes) != len(self._tile_size): raise ValueError("dim_sizes must match the tile size dimensions") + # Dynamic shape: divisibility cannot be proven at compile time, and the + # recompile-to-divisible path (adjust_tile_to_divisible -> RecompileSignal) + # has no symbolic equivalent -- it would loop forever shrinking the tile to 1. + # index_expr / indirect indexing under dynamic shape is Step 2 (B3); fail + # clearly here instead of a sympy "cannot determine truth value" crash. + if any(is_symbolic_dim(d) for d in dim_sizes): + raise NotImplementedError( + "index_expr/indirect indexing under dynamic shape is not supported " + "yet (symbolic dim reached is_dim_dividable)") + dim_sizes_cpy = list(dim_sizes) axis, stride = self.vmap.vlane_split_axis, self.vmap.vlane_stride remain = dim_sizes_cpy[axis] % stride @@ -395,6 +431,13 @@ def trim_large_tail(self, ranges: list[int]): constraint = self.tile_constraint[i] if constraint.fixed: continue + # Dynamic shape: the tail-padding heuristic exists only to shave the tile + # to a KNOWN dim and minimize wasted tail. With a symbolic dim the tail + # extent is unknown, so keep the fixed init tile and let the tail become a + # runtime remainder tile (masked). Skipping also avoids %/comparison on a + # sympy symbol (cannot determine truth value). + if is_symbolic_dim(dim_range): + continue elif constraint.must_divide_dim: BETA = 0 @@ -460,6 +503,10 @@ def init_tile_size(ranges, vlane_stride, vector_lane): @staticmethod def get_padding_ratio(tile_range: int, dim_range: int) -> float: + # Dynamic shape: a symbolic dim has no compile-time tail, so report zero + # padding waste ("nothing to trim") rather than doing %/<= on a sympy symbol. + if is_symbolic_dim(dim_range) or is_symbolic_dim(tile_range): + return 0.0 if tile_range <= 0 or dim_range <= 0: raise ValueError("tile_range and dim_range must be positive integers") tail = dim_range % tile_range @@ -1019,14 +1066,26 @@ class LoopLevel: reduction_vars: Dict[str, str] = dataclasses.field(default_factory=dict) affine_yield: Dict[str, str] = dataclasses.field(default_factory=dict) + def _bound_str(self): + # Dynamic shape: a symbolic upper bound is emitted as an index SSA value + # (%_bound, materialized at the function top level by codegen_loops), + # which is a valid affine symbol; a concrete bound stays an integer literal. + if is_symbolic_dim(self.size): + if not isinstance(self.size, sympy.Symbol): + raise NotImplementedError( + f"dynamic loop bound must be a single size symbol, got {self.size}") + return f"%{self.size.name}_bound" + return f"{self.size}" + def lines(self): + bound = self._bound_str() if len(self.reduction_vars): acc = ', '.join([f"%{acc.name}" for acc in self.reduction_vars.keys()]) args = ', '.join([f"%{iter.name} = %{init.name}" for (_, iter, init, _) in self.reduction_vars.values()]) dtype = ', '.join([f"{dtype}" for (_, _, _, dtype) in self.reduction_vars.values()]) - line = f"{acc} = affine.for %{self.var} = {self.start} to {self.size} step {self.step} iter_args({args}) -> ({dtype})" + line = f"{acc} = affine.for %{self.var} = {self.start} to {bound} step {self.step} iter_args({args}) -> ({dtype})" else: - line = f"affine.for %{self.var} = {self.start} to {self.size} step {self.step}" + line = f"affine.for %{self.var} = {self.start} to {bound} step {self.step}" return [line] diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 8520596c..cb73c23e 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -321,6 +321,20 @@ def define_function(self, kernel): wrapper.header.writeline(code) self.outer_function.add(function_name) + @staticmethod + def _literalize_meta(obj): + """Render meta (arg_attributes) as a valid Python literal for the generated + wrapper. Dynamic shapes put sympy symbols (e.g. s52) in the shape/stride + fields; emitted bare they are undefined at module scope -> NameError on + import. Stringify them ('s52'); the real extent arrives as a runtime kernel + arg (see the wrapper's call() body), so the compile-time descriptor only + needs to be import-safe and shape-agnostic.""" + if isinstance(obj, sympy.Expr): + return str(obj) + if isinstance(obj, (list, tuple)): + return type(obj)(MLIRScheduling._literalize_meta(x) for x in obj) + return obj + def define_kernel(self, src_code, meta_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: @@ -333,7 +347,7 @@ def define_kernel(self, src_code, meta_code, kernel_name, vector_lane, spad_info codecache_def.writeline(f"loop_size={loop_size},") codecache_def.writeline(f"spad_info={spad_info},") codecache_def.writeline(f"origins={origins},") - codecache_def.writeline(f"arg_attributes={meta_code},") + codecache_def.writeline(f"arg_attributes={self._literalize_meta(meta_code)},") headers = extension_codecache.get_header(src_code) if headers is not None: codecache_def.writeline(f"global_var_header='''{headers[0]}''',") diff --git a/PyTorchSimFrontend/mlir/passes/build_skeleton.py b/PyTorchSimFrontend/mlir/passes/build_skeleton.py index 4c3d89cb..cb011137 100644 --- a/PyTorchSimFrontend/mlir/passes/build_skeleton.py +++ b/PyTorchSimFrontend/mlir/passes/build_skeleton.py @@ -499,7 +499,10 @@ def build_skeleton(module): """ _reset_ids() builder = TogBuilder() - _build(module, builder) # populates loop/compute nodes + op back-pointers + # serialize=False: we only need the builder side effects (loop/compute/DMA + # nodes), not the TOG string -- and display() needs a constant loop_end, which + # is None for a dynamic loop. The loop bound stays on the affine.for in the IR. + _build(module, builder, serialize=False) block = _kernel_block(module) if block is None: diff --git a/PyTorchSimFrontend/mlir/passes/build_tog.py b/PyTorchSimFrontend/mlir/passes/build_tog.py index ae515010..11fe6843 100644 --- a/PyTorchSimFrontend/mlir/passes/build_tog.py +++ b/PyTorchSimFrontend/mlir/passes/build_tog.py @@ -1047,8 +1047,15 @@ def _find_kernel(module): return None -def _build(module, builder): - """Build the graph and return its display string, populating `builder`.""" +def _build(module, builder, serialize=True): + """Build the graph, populating `builder`; return its display string. + + `serialize=False` skips the bfs/display string pass and returns "". The + skeleton path (build_skeleton) only needs the builder side effects (loop / + compute / DMA nodes), not the serialized TOG, and display() formats a constant + `loop_end` -- which is None for a dynamic (runtime-extent) loop. The loop bound + itself is carried by the affine.for op in the IR (lowered to a runtime-bounded + loop downstream), so the skeleton does not need it serialized here.""" func_op = _find_kernel(module) if func_op is None: return "" @@ -1063,7 +1070,8 @@ def _build(module, builder): root = TOGNode("root") builder._reset_matmul_fsm() builder.print_operation(op, root) - root.bfs(out) + if serialize: + root.bfs(out) return "".join(out) diff --git a/PyTorchSimFrontend/mlir/passes/cycle_table.py b/PyTorchSimFrontend/mlir/passes/cycle_table.py index 40dd3459..2cd99daf 100644 --- a/PyTorchSimFrontend/mlir/passes/cycle_table.py +++ b/PyTorchSimFrontend/mlir/passes/cycle_table.py @@ -49,6 +49,55 @@ def overlapping_cycle(cycle, compute_type, x_offset, w_offset): return max(int(cycle) - int(offset), 0) +def pin_loops_to_one_tile(module): + """Pin every affine.for that would run more than once to a SINGLE tile, by + forcing its upper bound to the loop's step (one iteration). The cpp-TOG cycle + sampling needs only per-tile compute cost, which is shape-invariant -- one tile + is enough -- so this is the general sampling reduction for BOTH static and + dynamic kernels (it replaces the legacy sample-mode step rewrite for the trace + path): + + * static bound C > step S -> set bound = S (was ceil(C/S) iterations). + * symbolic bound (%..._bound, dynamic dim) -> set bound = S (runtime extent + unknown; one tile suffices and avoids needing the extent at all). + * bound already <= step (e.g. the innermost compute loop) -> left as is. + + Run this on a COPY used only for gem5 sampling; the original module is kept for + the producer .so / cycle_table (both stay shape-agnostic). Mutates `module` in + place. Returns the largest pinned step (tile element count) for sizing the + sampling host buffers. + """ + tile = 1 + idx_t = ir.IndexType.get() + for op in list(walk_ops(module.body)): + o = op.operation + if o.name != "affine.for": + continue + step = ir.IntegerAttr(o.attributes["step"]).value + ub_map = ir.AffineMapAttr(o.attributes["upperBoundMap"]).value + const_ub = (len(ub_map.results) == 1 + and ir.AffineConstantExpr.isinstance(ub_map.results[0])) + if const_ub: + ub = ir.AffineConstantExpr(ub_map.results[0]).value + if ub <= step: + continue # already a single iteration + # constant, multi-iteration: rewrite the bound map to the step + o.attributes["upperBoundMap"] = ir.AffineMapAttr.get( + ir.AffineMap.get_constant(step)) + else: + # symbolic bound: replace its SSA upper-bound operand with a constant=step + seg = o.attributes["operandSegmentSizes"] + n_lb = seg[0] # [lb operands, ub operands, iter operands] + ub_val = o.operands[n_lb] + cst = ir.Operation.create( + "arith.constant", results=[idx_t], + attributes={"value": ir.IntegerAttr.get(idx_t, step)}, + ip=ir.InsertionPoint(op), loc=ir.Location.unknown()) + ub_val.replace_all_uses_with(cst.results[0]) + tile = max(tile, step) + return tile + + def _compute_types(skeleton_module): """tile_id-ordered list of compute_type ints, from the skeleton's togsim.compute ops.""" diff --git a/PyTorchSimFrontend/mlir/passes/lower_to_emitc.py b/PyTorchSimFrontend/mlir/passes/lower_to_emitc.py index 3d1f7cde..a6c3b1a8 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_to_emitc.py +++ b/PyTorchSimFrontend/mlir/passes/lower_to_emitc.py @@ -117,23 +117,57 @@ def _strip_aux(module): def _rewrite_signature(kernel, ctx): """Replace @kernel's memref tensor args with the ABI args (EmitCtx*, int64_t* shape_args, int32_t n) and rename it to togsim_kernel. - Returns the ctx Value.""" + Returns the ctx Value. + + Dynamic shape: any original arg still USED after build_skeleton's DCE is a size + symbol (memref<1xi64>) whose load feeds a loop bound -- tensor args are + referenced by name in the togsim.dma attrs, not by SSA value, so they DCE to + unused. Re-source each such `memref.load %argSize[..]` from `shape_args[k]` + (k = the size arg's order; the runtime fills shape_args in the same order), so + the producer's loop bound reads the runtime extent and the arg can be dropped. + """ block = kernel.regions[0].blocks[0] - for arg in block.arguments: - if len(list(arg.uses)) > 0: - raise ValueError( - "kernel arg still used after build_skeleton; cannot drop it " - "(expected the DCE to have removed all tensor-data ops)") - # erase existing (memref) args high-to-low, then append the ABI args. - for i in reversed(range(len(block.arguments))): - block.erase_argument(i) + orig_args = list(block.arguments) + loc = ir.Location.unknown(ctx) ptr = ir.Type.parse(CTX_TYPE, ctx) i64ptr = ir.Type.parse("!emitc.ptr", ctx) i32 = ir.IntegerType.get_signless(32) - loc = ir.Location.unknown(ctx) + # Append the ABI args first so shape_args exists to re-source size reads from. block.add_argument(ptr, loc) block.add_argument(i64ptr, loc) block.add_argument(i32, loc) + shape_args = block.arguments[len(orig_args) + 1] + + idx_t = ir.IndexType.get() + i64_t = ir.IntegerType.get_signless(64) + k = 0 + for a in orig_args: + if not list(a.uses): + continue + for use in list(a.uses): + ld = use.owner + if ld.name != "memref.load": + raise ValueError( + "kernel arg still used after build_skeleton by %s; only a size " + "load (memref.load) is expected under dynamic shape" % ld.name) + ip = ir.InsertionPoint(ld) + kc = ir.Operation.create( + "arith.constant", results=[idx_t], + attributes={"value": ir.IntegerAttr.get(idx_t, k)}, ip=ip, loc=loc) + sub = ir.Operation.create( + "emitc.subscript", results=[i64_t], + operands=[shape_args, kc.results[0]], ip=ip, loc=loc) + ld.results[0].replace_all_uses_with(sub.results[0]) + ld.erase() + k += 1 + + # every original arg is unused now -> drop them, leaving only the ABI args. + for a in orig_args: + if len(list(a.uses)) > 0: + raise ValueError( + "kernel arg still used after the shape rewrite; cannot drop it") + for i in reversed(range(len(orig_args))): + block.erase_argument(i) kernel.operation.attributes["function_type"] = ir.TypeAttr.get( ir.FunctionType.get([ptr, i64ptr, i32], [])) kernel.operation.attributes["sym_name"] = ir.StringAttr.get(ENTRY) diff --git a/Simulator/simulator.py b/Simulator/simulator.py index a4517285..75bc0205 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -91,15 +91,17 @@ def write_arg(self, arg, path, name): os.makedirs(dump_path, exist_ok=True) index = self.get_biggest_filename(dump_path) + data_path = os.path.join(dump_path, f'{index}.raw') if (isinstance(arg, torch.Tensor)): - data_path = os.path.join(dump_path, f'{index}.raw') tensor = arg.cpu().detach() buffer_size = tensor.untyped_storage().size() buffer = (ctypes.c_char * buffer_size).from_address(tensor.data_ptr()) t_arr = np.frombuffer(buffer, dtype=TORCH_TO_NUMPY[tensor.dtype], count=buffer_size // tensor.element_size()) t_arr.tofile(data_path) else: - assert(0) + # Dynamic shape: a scalar size argument (a runtime extent, e.g. s52). + # The kernel reads it from a memref<1xi64> buffer, so write one int64. + np.array([int(arg)], dtype=np.int64).tofile(data_path) return index def dump_args(self, args, arg_attributes, load_path, dump_path): @@ -108,7 +110,9 @@ def dump_args(self, args, arg_attributes, load_path, dump_path): for (arg_name, arg_attribute), arg in zip(arg_attributes, args): size = arg_attribute[2] if arg_attribute[1] != torch.bool else (arg_attribute[2] + 7) // 8 array_size.append(size) - if MLIRKernelArgs.is_mlir_arg_in(arg_attribute[0]): + # A size symbol arg (MLIR_ARGS_VAR, e.g. a dynamic extent s52) is a kernel + # INPUT: the kernel loads it for its loop bound, so dump it like an input. + if MLIRKernelArgs.is_mlir_arg_in(arg_attribute[0]) or MLIRKernelArgs.is_mlir_arg_var(arg_attribute[0]): index = self.write_arg(arg, load_path, arg_name) file_path.append(os.path.join(load_path, arg_name, f'{index}.raw')) elif MLIRKernelArgs.is_mlir_arg_out(arg_attribute[0]): @@ -467,9 +471,18 @@ def write_kernel_attribute_file(attribute_dir, inputs, alloc_pool=None): index = str(len(os.listdir(attribute_dir))) attribute_file = os.path.join(attribute_dir, index) + # Tensors carry an address; a scalar (e.g. a dynamic-shape size arg s52) + # carries a runtime extent -- collect those into shape_args, in arg order, + # which is the order the trace producer reads shape_args[k]. + shape_args = [] for idx, tensor in enumerate(inputs): - address_info[f"arg{idx}"] = tensor.data_ptr() + if isinstance(tensor, torch.Tensor): + address_info[f"arg{idx}"] = tensor.data_ptr() + else: + shape_args.append(int(tensor)) yaml_content["address_info"] = address_info + if shape_args: + yaml_content["shape_args"] = shape_args for buf_name, range in alloc_pool.items(): sram_buffer[buf_name] = range @@ -575,6 +588,11 @@ def run_standalone( logger.warning("TORCHSIM_LEGACY_TOG=1 selects the DEPRECATED legacy ONNX TOG path") if use_trace: cmd = f"{base_cmd} --trace_so {trace_so} --cycle_table {cycle_tsv}" + # Carry the per-kernel attribute YAML (address_info + a dynamic + # kernel's shape_args) into the trace path, the same file the legacy + # path passes via the models_list command. + if attribute_path: + cmd += f" --attribute {attribute_path}" else: # DEPRECATED: legacy ONNX TOG path cmd = f"{base_cmd} --models_list {trace_file_path}" if extension_config.CONFIG_TOGSIM_DEBUG_LEVEL: diff --git a/TOGSim/src/main.cc b/TOGSim/src/main.cc index 274d63da..d0bf9a9f 100644 --- a/TOGSim/src/main.cc +++ b/TOGSim/src/main.cc @@ -25,6 +25,7 @@ namespace po = boost::program_options; std::unique_ptr build_trace_tilegraph(Simulator* simulator, const std::string& trace_so_path, const std::string& cycle_table_path, + const std::string& attribute_path, int partition_id) { const auto& cfg = simulator->get_hardware_config_yaml(); int num_cores = cfg["num_cores"] ? cfg["num_cores"].as() : 1; @@ -43,7 +44,21 @@ std::unique_ptr build_trace_tilegraph(Simulator* simulator, while (ct >> c >> o) { cyc.push_back(c); ovl.push_back(o); } } if (cyc.empty()) { cyc.assign(256, 128); ovl.assign(256, 0); } - auto run = togsim::run_producer(trace_so_path.c_str(), nullptr, 0, + // Dynamic shape: the producer reads its loop bounds from shape_args[k]. Read + // them from the per-kernel attribute YAML (the same file that carries + // address_info for the legacy path), under the `shape_args` sequence. + std::vector shape_args; + if (!attribute_path.empty()) { + YAML::Node attr = YAML::LoadFile(attribute_path); + if (attr["shape_args"]) { + for (const auto& v : attr["shape_args"]) shape_args.push_back(v.as()); + spdlog::info("[TOGSim-trace] shape_args: {} values from {}", + shape_args.size(), attribute_path); + } + } + auto run = togsim::run_producer(trace_so_path.c_str(), + shape_args.empty() ? nullptr : shape_args.data(), + (int)shape_args.size(), bases.data(), (int)bases.size(), cyc.data(), ovl.data(), (int)cyc.size(), partition_cores.data(), (int32_t)partition_cores.size()); @@ -62,7 +77,7 @@ void launchKernel(Simulator* simulator, unsigned int kernel_id, std::string onnx std::string trace_so = dir + "/trace.so"; std::string cycle_tsv = dir + "/trace_cycles.tsv"; if ((!legacy || std::string(legacy) != "1") && fs::exists(trace_so)) { - tile_graph = build_trace_tilegraph(simulator, trace_so, cycle_tsv, partition_id); + tile_graph = build_trace_tilegraph(simulator, trace_so, cycle_tsv, attribute_path, partition_id); if (tile_graph) tog_path = trace_so; else spdlog::warn("[TOGSim] trace.so run failed for {}; falling back to ONNX", trace_so); } @@ -164,6 +179,10 @@ int main(int argc, char** argv) { cmd_parser.add_command_line_option( "cycle_table", "Path to a 'cycleoverlapping' per-tile_id sidecar (TSV) " "for --trace_so; falls back to a flat stub if omitted"); + cmd_parser.add_command_line_option( + "attribute", "Path to the per-kernel attribute YAML (address_info, " + "shape_args) for --trace_so; carries a dynamic kernel's runtime " + "shape the same way the legacy path carries address_info"); try { cmd_parser.parse(argc, argv); } catch (const CommandLineParser::ParsingError& e) { @@ -216,7 +235,12 @@ int main(int argc, char** argv) { // round-robin over partition 0's cores only; see build_trace_tilegraph). std::string cycle_table_path; cmd_parser.set_if_defined("cycle_table", &cycle_table_path); - auto tg = build_trace_tilegraph(simulator, trace_so_path, cycle_table_path, 0); + // Dynamic shape: the producer reads its loop bounds from shape_args[k], which + // build_trace_tilegraph loads from the per-kernel attribute YAML (the same + // file that carries address_info for the legacy path). + std::string attribute_path; + cmd_parser.set_if_defined("attribute", &attribute_path); + auto tg = build_trace_tilegraph(simulator, trace_so_path, cycle_table_path, attribute_path, 0); if (!tg) { spdlog::error("[TOGSim] trace producer run failed"); exit(1); } tg->set_arrival_time(simulator->get_core_cycle()); tg->set_kernel_id(0); diff --git a/docs/dynamic-shape-plan.md b/docs/dynamic-shape-plan.md new file mode 100644 index 00000000..7582f6c7 --- /dev/null +++ b/docs/dynamic-shape-plan.md @@ -0,0 +1,231 @@ +# C++ trace 경로의 dynamic shape — 구현 계획 (전체) + +목표: trace 경로(C++ TOG)를 확장해 `torch.compile(dynamic=True)`가 1D contiguous +elementwise뿐 아니라 **전체 op**에서 동작하게 한다. 현재 상태부터 일반 dynamic shape +지원까지의 빈틈없는 로드맵이다. + +--- + +## 0. 현재 상태 (완료, PR #269) + +1D 단일 심볼 **elementwise add**가 **하나의 컴파일된 `trace.so`** 로 임의 크기에서 +**timing + functional 출력** 둘 다 e2e 동작한다: + +- symbolic dim용 tile-sizing 가드 (`is_symbolic_dim`) +- symbolic MLIR loop bound (`affine.for ... to %_bound`, `memref`) +- import-safe wrapper meta +- 한-타일 cycle 샘플링 (`pin_loops_to_one_tile`) +- producer `.so`가 loop bound를 `shape_args[k]`에서 읽음 (`emitc.subscript`) +- 런타임 shape를 attribute YAML로 전달 (`--attribute` → `run_producer`) +- shape-agnostic Spike 검증 바이너리 (size 버퍼에서 런타임 extent를 읽음) + +검증: 1024/2048/1536 + 1D tail 1000 이 한 바이너리에서 정확한 값. + +**add가 되고 나머지가 안 되는 이유:** add는 *contiguous* → DRAM 접근 stride가 상수 +`[1]`, 타일 dims도 상수 `[512]`; 심볼은 오직 loop **trip count**(`to %s_bound`)에만 등장. +*주소 산술*이나 *stride*에는 심볼이 안 들어간다. 남은 모든 op는 이 가정을 깬다. + +--- + +## 1. 단 하나의 핵심 통찰 (나머지를 다루기 쉽게 만드는 것) + +**런타임은 이미 일반적이다. 작업은 전부 codegen에 있다.** + +trace DMA ABI/런타임은 이미 런타임·다차원·strided descriptor를 실어 나른다: + +``` +togsim_dma(ctx, dir, arg_id, offset, ndim, dims[], strides[], elem_bits, ...) // 런타임 int64* + → TraceRec {addr, dims, strides, elem_bits} // 런타임에 기록 + → make_dma → Instruction(dram_addr, tile_size=dims, tile_stride=strides, ...) + → DMA / DRAM(Ramulator) 모델: dims/strides로 strided 주소 스트림 + 비용 + → SRAM throttle footprint = prod(dims) * elem_bytes +``` + +즉 `offset`, `dims`, `strides`가 *이미* 런타임 int64 값으로 스택 전체를 흐른다. +**TOGSim/런타임/DRAM 모델 재작성은 불필요하다.** add가 제한적인 유일한 이유는 codegen이 +dims/strides를 컴파일타임 상수 attr로 *굳히고* 사소한 주소 형태만 다루기 때문이다. + +따라서 아래 모든 것은 codegen 능력 두 가지로 귀결된다: + +- **(C1) 일반 symbolic index-식 lowering** — 임의의 affine/sympy index를 런타임 C로 lowering. + leaf를 `itervar → 루프 변수`, `size 심볼 → shape_args[k]`, `정수 → literal` 로 해소하고, + 모든 연산자(`+`, `*`(symbol×symbol 포함), `//`, `%`)를 emit. +- **(C2) 런타임 `togsim.dma` descriptor** — `dims`/`strides`를 상수 attr뿐 아니라 **런타임 + operand**로도 실을 수 있게 하고, `lower_to_emitc`가 (C1) 값으로 배열을 채움. + +그러면 동적 **offset**(전체 index 식), 동적 **stride**(한 계수), **tail-trim dim**(`min` 식)이 +전부 (C1)+(C2)의 특수형이 된다. + +--- + +## 2. 빌드 순서 (단계별) + +각 단계: 목표 / touch point / 변경 / 검증 / 위험. 각 단계가 이전 단계 위에 쌓이도록 정렬. + +### Phase 1 — 일반 symbolic index-식 lowering [토대, P0] + +- **목표:** `_index_expr`가 itervar + size 심볼 + 정수로 된 임의의 index 식을 lowering. +- **touch:** `mlir_codegen_backend.py:798-837`(`_index_expr`), `:883`(`index_expr` 리네임), + `mlir_common.py`(leaf 분류기 + 공유 `shape_args[k]` 인덱스 맵). +- **변경:** `const_coeff * itervar` 패턴매칭을 **재귀 sympy walk**로 교체: + - leaf `itervar(indexN)` → 루프 induction 값 (기존 `dim_list`/`itervar_cses`); + - leaf `size 심볼(ks/s)` → `emitc.subscript(shape_args, k)` (`_rewrite_signature` 메커니즘 + 재사용; 심볼→`shape_args` 인덱스 맵 공유 필요, §3); + - leaf `Integer` → literal; + - 연산자: `Add`/`Mul`(symbol×symbol 포함)/`FloorDiv`/`Mod` → 런타임 vector/scalar op. + - 현재의 `int(str(arg)[1:])`(모든 심볼을 `indexN`으로 가정)와 + `renamed = {s: "d"+str(s)[5:]}` 제거/일반화 — `ks0`에서 크래시함. +- **검증:** strided 2D 접근(transpose 또는 matmul-타일 주소 `i*K + j`)이 유효한 MLIR로 + lowering되고 producer가 올바른 주소를 기록; 1D add 불변. +- **위험:** 중-상 — 중심 재작성, blast radius 큼(모든 load/store index). static은 상수계수 + fast path를 byte 단위로 보존. + +#### Phase 1 보강 — FloorDiv/Mod 는 axis-split의 일 (convert_index 아님) +floor/mod 경로를 끝까지 확인한 결과 — 정정: **codegen의 floor/mod 처리(`convert_index`)는 +확장 대상이 아니라 은퇴 대상이다.** + +- 설계 의도(`docs/axis-split-scheduling.md`)는 **"affine-only contract"**: codegen이 + FloorDiv/ModularIndexing이 **전혀 없는** affine 인덱스만 받게, `axis_split.py`가 상류에서 + floor/mod를 다차원 strided 접근으로 제거한다. `convert_index`(`:342`)/ + `_convert_sympy_to_mlir_expr`(`:370`)가 `(x floordiv y) mod z` 를 affine map으로 emit하던 + 것은 그 이전의 **legacy codegen-내부 처리**이고, axis-split(현재 prototype) 전환이 + 끝나면 사라진다. → **convert_index의 floor/mod 분기는 동적용으로 일반화하지 말 것.** +- **그런데 동적에선 axis-split도 지금은 못 한다.** `collect_boundaries`(`axis_split.py:44-54`) + 가 divisor `k`와 extent `E` 를 둘 다 **concrete int 로 요구**한다(`_as_int(div)`, + `_as_int(var_ranges[base])`, `E % k == 0`). symbolic divisor/extent면 `_as_int`→None → + split 안 됨 → floor/mod 가 살아남아 codegen reject(`:1200` "Unlinearized floor/mod") 또는 + convert_index raise(free symbol 2개 / invalid affine `floordiv s0`)로 간다. +- **그래서 동적 floor/mod 작업 = `axis_split.py`를 symbolic-aware로**: divisor가 원본 shape의 + *진짜 dim* 이면 symbolic extent를 **construction상 나눠떨어진다** (예: `[M,N]` flatten, + divisor N → `E = M*N`, `E % N == 0`). 이 "symbolic 정렬"을 인식해 **symbolic split** 을 + 내면, 그 결과가 Phase 2의 동적 strided 접근으로 흐른다. (정렬 안 되는 view는 graph-copy + 영역 — 범위 밖.) **convert_index/affine-divisor 경로는 손대지 않는다.** +- 참고: affine `floordiv`/`mod` 는 어차피 divisor가 상수여야 유효(MLIR 규칙)하므로, 동적 + divisor를 affine으로 표현하는 길은 처음부터 없다 — 그래서 답은 "affine화"가 아니라 + "axis-split이 strided로 미리 없애기"다. + +### Phase 2 — 런타임 `togsim.dma` dims/strides [P0, Phase 1 필요] + +- **목표:** dim/stride가 심볼에 의존하는 DMA가 그 값을 런타임으로 실음. +- **touch:** + - `passes/togsim_ops.py` — `togsim.dma` op: dims/strides에 런타임 operand 허용 + (attr entry가 sentinel(예 `-1`)이면 "런타임: operand m 참조"). + - `passes/build_skeleton.py:98-99`(`_emit_dma`가 dims/strides를 `i64_array`로), + `:204`(`n_symbols != 0` bail) — dim/stride가 심볼이면 런타임 operand로 emit((C1) index + 값) + attr엔 sentinel; 심볼 bail 완화. + - `passes/lower_to_emitc.py:418-419`(`_arr(ctx, dims/strides)`) — 런타임-aware 배열 fill: + sentinel entry → `dims[i] = ;`, 아니면 literal. +- **변경:** 위와 같음. `offset` operand 경로는 이미 런타임(add가 증명); Phase 1이 그 + *계수*를 런타임화. +- **검증:** matmul 타일의 row-stride = 동적 K가 `strides=[K_runtime, 1]`로 기록; + Ramulator가 strided 접근 비용; SRAM footprint 정확. +- **위험:** 중 — op 스키마 변경 + 두 패스; all-constant 경로는 동일하게 유지. + +### Phase 3 — tail-trim DMA (padding/masking 교체) [TODO A, P0/P1] + +- **목표:** 경계(부분) 타일이 유효 remainder만 전송; 패딩/마스킹된 full 타일 없음. + 동적 >1D/tail 정확성 + 정적 홀수 크기 실패 해결. +- **touch:** `passes/decompose_transfer.py`, `passes/dma_fine_grained.py`, + `togsim.transfer`/dma emission; 이 경로에서 loop-padding 패스는 빠짐(레거시 메커니즘 — + `docs/loop-padding-elimination.md`). +- **변경:** 타일 dim을 따라 마지막 타일은 DMA `dims`를 `min(tile, extent - offset)`로 emit + — (C1) 런타임 식, (C2)로 런타임 dim operand로 전달. masked-compute tail은 COMPUTE용으로 + 남길 수 있음; DMA는 유효 바이트만 옮김. +- **타이밍:** 자동으로 정확 — producer가 trimmed `dims`를 기록하므로 trace 비용이 trimmed + 전송 반영(레거시 "full-tile DMA 비용" 우려는 옛 모델 얘기). +- **검증:** tail 크기(예 1000, 2D 47×10)가 정확한 값 + 마지막-타일 DMA 크기가 remainder인 + trace. +- **위험:** 중 — masked-compute 경로와 상호작용; compute는 마스킹, DMA는 trim 확인. + +### Phase 4 — dynamic shape op 템플릿 [TODO B, P1] + +- **목표:** matmul / conv / bmm / sdpa 가 `dynamic=True`로 컴파일. +- **touch:** `mlir_gemm_template.py`, `mlir_conv_template.py`, `mlir_bmm_template.py`, + `mlir_sdpa_template.py`, `mlir_template.py`(`gemmini_gemm_mapping`, + `gemm_combination_mapping`: `math.ceil(M/...)`, `sympy.divisors`, divisor 루프). +- **변경:** elementwise처럼 symbolic-aware MLIR emit — symbolic loop bound(`%_bound`), + `memref`, stride가 동적 dim인 strided 접근(Phase 1/2 feed). 상수-int tiling 수학 + (`math.ceil` 등)을 `is_symbolic_dim`으로 가드. +- **검증:** 동적 matmul(M 동적, 이후 K/N 동적)이 정확한 값 + 한 `.so`에서 스케일링 trace; + 런타임 seq_len을 가진 decode-style 커널. +- **위험:** 상 — op별, 각 템플릿이 고유한 concrete-shape 가정. + +### Phase 5 — 다중 심볼 정확성 + 계약 [Phase 4 내 P0] + +size-arg ↔ `shape_args[k]` 순서는 e2e 단일 계약이어야 함: + +- **A-1** `lower_to_emitc._rewrite_signature`가 `k`를 *uses* 기준 배정(미사용 size 심볼 + 건너뜀); 런타임은 `shape_args`를 *arg 순서*(모든 비텐서)로 채움. 둘을 **같은 기준**(arg- + attributes 순서; 미사용 심볼도 슬롯 유지 또는 양쪽에서 드롭)으로. 단일 계약: + size-arg 위치 == `shape_args[k]`. +- **A-2** 복합 numel `'128*s52'` → `_is_symbol(isdigit)`이 잘못된 C `N_128*s52` 생성. + Phase 1이면 numel을 문자열 휴리스틱 아닌 식으로 lowering. 그 전까지는 loud + `NotImplementedError`(조용히 잘못 emit 금지). +- **A-3** `_concretize_attrs_for_sampling`의 `cz = isinstance(str)`가 stringify된 정적 + `'128'`을 `tile`로 변환; `_is_symbol`/`is_symbolic_dim`과 같은 술어로 통일(숫자 문자열= + concrete). +- **검증:** 서로 다른 동적 dim 2개(예 M, N)인 커널이 정확한 값 + trace가 각 extent를 맞게. + +### Phase 6 — loud-fail 가드 (중간 안전망) [P1, 일찍] + +Phase 1-5가 안착하기 전, 아직 미지원인 동적 케이스를 전부 **큰 소리로 실패**(명확한 +`NotImplementedError`)하게, 절대 조용히 틀리지 않게: 복합 numel(A-2), 다중 심볼 어긋남 +(A-1), 공유-bound pin(A-4), bool-동적 dump(A-6). 일반 경로 구축 중 "단일 심볼 우연" 부류를 +방어. + +### Phase 7 — 인프라 / 검증 / 정리 + +- **loop-padding 배포:** `TestLoopPadding.cpp`의 symbolic-skip이 LLVM 포크에만 있음. 배포 + 결정(재빌드 + 툴체인 반영) — 아래 둘을 게이트. +- **CI (C-1):** `tests/ops/elementwise/test_dynamic_add.py`(+ 새 동적 테스트)를 + `pytorchsim_test.yml`에 등록 — 단 **loop-padding fix가 CI 툴체인에 들어간 후**, 안 그러면 + CI 실패(loop-padding이 symbolic bound를 2^32로 클로버). +- **결과 파싱:** `TOGSimulator.get_result_from_file`가 trace-경로 로그를 파싱 안 함 + ("Unable to parse the output file"); cycle은 로그에 정확 — 파싱 배선. +- **static cost 샘플링:** static 경로도 `pin_loops_to_one_tile` 경유(`run_tog`에서 분리, + 레거시 full TOG도 만들기 때문). +- **A-4** `pin_loops_to_one_tile`: 루프별 상수 + 그 루프의 bound operand만 교체(전역 + `replace_all_uses_with` 금지); ub operand >1 처리. +- **A-5** `write_kernel_attribute_file`: 텐서 전용 카운터로 `arg{idx}` 부여 → 앞선 scalar가 + `address_info` 인덱스에 구멍 안 내게(`main.cc`가 stub 대신 실제 base 읽을 때 중요). +- **A-6** `dump_args` bool 분기: symbolic numel일 때 `+7//8` 산술 스킵. + +--- + +## 3. Cross-cutting 계약 (단일로 유지) + +- **size 심볼 ↔ `shape_args[k]`:** 단일 순서, `_rewrite_signature`(producer), + `write_kernel_attribute_file`(Spike + attribute YAML), `main.cc`가 공유. `k` = size arg의 + arg-attributes 순서상 위치; 런타임도 같은 순서로 채움. +- **`is_symbolic_dim(x)`** (`mlir_common`): "런타임 dim"의 단일 술어(sympy.Expr, `is_number` + 아님). 모든 tiling/bound/dma 가드가 사용. 숫자 문자열은 concrete. +- **런타임 DMA ABI** (`togsim_runtime.h`): `offset`, `dims[]`, `strides[]`가 계약; codegen이 + 채우고 모델이 소비. 병행 채널 추가 금지. + +--- + +## 4. 테스트 매트릭스 (단계별 추가; Phase 7에서 CI 등록) + +| 테스트 | 검증 대상 | +|---|---| +| 1D add, 다중 크기 (완료) | trip count, functional, 한 .so | +| 2D add, 한 dim 동적 | strided 접근(Phase 1), tail-trim(Phase 3) | +| 2D add, 두 dim 동적 | 다중 심볼 계약(Phase 5) | +| matmul, M 동적 | 템플릿(Phase 4), 동적 stride(Phase 2) | +| matmul, M+K+N 동적 | 다중 심볼 + strided | +| decode (런타임 seq_len) | 동기가 된 실제 케이스 | +| tail / 비배수 크기 | remainder 정확성(Phase 3) | + +--- + +## 5. 위험 / 열린 질문 + +- Phase 1 blast radius: 모든 memref index가 `_index_expr`를 거침. static 상수 경로를 동일 + 유지해야(정적 matmul/conv 회귀 테스트). +- 심볼에 의한 FloorDiv/Mod(view/reshape/broadcast) — Inductor가 주는 index 식에 실제로 + 등장하는지, producer에서 런타임 `//`/`%` 비용이 수용 가능한지 확인. +- 동적 stride 하의 cost-table 유효성: per-tile COMPUTE 비용은 shape-invariant(타일 크기 + 고정)라 테이블 유효; DMA 비용은 trace 주소(Phase 2)에서. compute 비용이 stride에 + 의존하지 않음을 확인. +- loop-padding: Phase 3가 동적 의존을 제거; 패스를 완전히 은퇴할지 + (`docs/loop-padding-elimination.md`) static용으로 남길지 결정. diff --git a/tests/ops/elementwise/test_dynamic_add.py b/tests/ops/elementwise/test_dynamic_add.py new file mode 100644 index 00000000..6e6783c3 --- /dev/null +++ b/tests/ops/elementwise/test_dynamic_add.py @@ -0,0 +1,41 @@ +"""Dynamic-shape elementwise add on the C++ trace path. + +A single torch.compile(dynamic=True) kernel compiles to one trace producer .so +and is simulated at several input sizes -- the producer reads its loop bound from +shape_args at runtime, so the same .so serves any size. This exercises the +dynamic-shape pipeline end to end (symbolic tiling -> symbolic MLIR loop bound -> +shape_args producer -> per-tile cost table -> runtime shape via the attribute +file, plus a shape-agnostic Spike validation binary for the output values). + +Sizes are multiples of the tile so no tail padding is needed (padding-shape +correctness is a separate follow-up). +""" +import os +import sys + +import torch +import torch._dynamo + +sys.path.insert(0, os.path.join(os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim"), "tests")) +from _pytorchsim_utils import test_result + + +def test_dynamic_add(device, sizes=(1024, 2048)): + def add(a, b): + return a + b + + # Compile once with a symbolic shape; run at every size from the same .so. + opt_fn = torch.compile(dynamic=True)(add) + for n in sizes: + x = torch.randn(n).to(device=device) + y = torch.randn(n).to(device=device) + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(y, 0) + res = opt_fn(x, y) + out = add(x.cpu(), y.cpu()) + test_result(f"DynamicAdd(N={n})", res, out) + + +if __name__ == "__main__": + device = torch.device("npu:0") + test_dynamic_add(device, (1024, 2048)) diff --git a/tests/test_axis_split_symbolic.py b/tests/test_axis_split_symbolic.py new file mode 100644 index 00000000..7f8623b2 --- /dev/null +++ b/tests/test_axis_split_symbolic.py @@ -0,0 +1,89 @@ +"""Unit test for symbolic-aware aligned axis splitting (axis_split.py). + +Pure sympy/Inductor test (no simulator): verifies the boundary-detection and +divisibility-chain layer is a strict SUPERSET -- concrete-int reshapes behave +exactly as before, and symbolic reshapes (flattened extent E = product of dims, +divisor a genuine factor) are detected and chained correctly. The incompatible +(misaligned) and non-divisor cases must bail (no split), for both int and symbol. + +Not in CI's simulator allowlist; run directly: python tests/test_axis_split_symbolic.py +""" +import sympy +from torch._inductor.utils import sympy_index_symbol +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +import PyTorchSimFrontend.mlir.axis_split as ax + +v = sympy_index_symbol("v") + + +def I(x): + return sympy.Integer(x) + + +def _chain_vals(chain): + if chain is None: + return None + if all(c.is_number for c in chain): + return [int(c) for c in chain] + return [str(c) for c in chain] + + +def _boundaries(exprs, E): + return ax.collect_boundaries(exprs, {v: 0}, {v: E}).get(0, set()) + + +_failures = [] + + +def check(name, got, exp): + if got != exp: + _failures.append(f"{name}: got {got}, expected {exp}") + print("FAIL", name, "->", got, f"(expected {exp})") + else: + print("PASS", name, "->", got) + + +def main(): + # ---- static (must match legacy behaviour) ---- + b = _boundaries([FloorDiv(v, I(3)), ModularIndexing(v, I(1), I(3))], I(12)) + check("static reshape [4,3] boundaries", {int(x) for x in b}, {3}) + check("static reshape [4,3] chain", _chain_vals(ax._ordered_chain(b, I(12))), [1, 3, 12]) + + check("static incompatible {2,3} E=6", _chain_vals(ax._ordered_chain({I(2), I(3)}, I(6))), None) + + b = _boundaries( + [FloorDiv(v, I(12)), ModularIndexing(v, I(4), I(3)), ModularIndexing(v, I(1), I(4))], + I(24), + ) + check("static 3-level boundaries", {int(x) for x in b}, {4, 12}) + check("static 3-level chain", _chain_vals(ax._ordered_chain(b, I(24))), [1, 4, 12, 24]) + + # ---- symbolic (new) ---- + M = sympy.Symbol("M", integer=True, positive=True) + N = sympy.Symbol("N", integer=True, positive=True) + A = sympy.Symbol("A", integer=True, positive=True) + B = sympy.Symbol("B", integer=True, positive=True) + C = sympy.Symbol("C", integer=True, positive=True) + P = sympy.Symbol("P", integer=True, positive=True) + + b = _boundaries([FloorDiv(v, N), ModularIndexing(v, I(1), N)], M * N) + check("sym reshape [M,N] boundaries", {str(x) for x in b}, {"N"}) + check("sym reshape [M,N] chain", _chain_vals(ax._ordered_chain(b, M * N)), ["1", "N", "M*N"]) + check("sym seg_ext M*N/N", str(ax._quotient(M * N, N)), "M") + + b = _boundaries([FloorDiv(v, B * C), ModularIndexing(v, C, B), ModularIndexing(v, I(1), C)], A * B * C) + check("sym 3-level boundaries", {str(x) for x in b}, {"C", "B*C"}) + check("sym 3-level chain", _chain_vals(ax._ordered_chain(b, A * B * C)), ["1", "C", "B*C", "A*B*C"]) + + # incomparable symbolic divisors -> bail (misaligned) + check("sym incomparable {N,P} E=N*P", _chain_vals(ax._ordered_chain({N, P}, N * P)), None) + # non-divisor symbolic -> no boundary collected + check("sym non-divisor E=M*N+1", dict(ax.collect_boundaries([FloorDiv(v, N)], {v: 0}, {v: M * N + 1})), {}) + + if _failures: + raise SystemExit("Axis-split symbolic unit test FAILED:\n " + "\n ".join(_failures)) + print("\nAxis-split symbolic unit test: ALL PASS") + + +if __name__ == "__main__": + main()