diff --git a/gimmik/__init__.py b/gimmik/__init__.py index b32ebdc..cd21134 100644 --- a/gimmik/__init__.py +++ b/gimmik/__init__.py @@ -8,6 +8,7 @@ from gimmik.hip import HIPMatMul from gimmik.metal import MetalMatMul from gimmik.opencl import OpenCLMatMul +from gimmik.ptx import PTXMatMul def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', @@ -22,7 +23,8 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', 'cuda': CUDAMatMul, 'ispc': ISPCMatMul, 'hip': HIPMatMul, - 'opencl': OpenCLMatMul + 'opencl': OpenCLMatMul, + 'ptx': PTXMatMul } mm = platmap[platform](alpha*mat, beta, None, n, ldb, ldc) diff --git a/gimmik/base.py b/gimmik/base.py index f547afc..f806b05 100644 --- a/gimmik/base.py +++ b/gimmik/base.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import itertools as it +import json import pkgutil import re @@ -90,6 +91,9 @@ def __init__(self, A, beta=0.0, aligne=None, n=None, ldb=None, ldc=None): self.bix = np.nonzero(np.any(A != 0, axis=0))[0] self.bix = {kx: k for k, kx in enumerate(self.bix)} + # Create config cache + self._config_cache = {} + def kernels(self, dtype, kname='gimmik_mm', **kwargs): basemeta = self.basemeta @@ -103,14 +107,7 @@ def kernels(self, dtype, kname='gimmik_mm', **kwargs): raise ValueError('Invalid floating point data type') # Common template arguments - baseargs = { - 'dtype': dtype, 'kname': kname, - 'A': self.A, 'beta': self.beta, 'width': 1, - 'm': self.m, 'n': self.n, 'k': self.k, - 'ldb': self.ldb, 'ldc': self.ldc, - 'afix': self.afix, 'alix': self.alix, 'bix': self.bix, - 'dot': _dot, 'partition': _partition, 'chunk': _chunk - } + baseargs = self._base_template_args(dtype, kname) # Incrementally generate and render the kernels gen = self._kernel_generators(dtype, dsize, **kwargs) @@ -136,15 +133,76 @@ def kernels(self, dtype, kname='gimmik_mm', **kwargs): except StopIteration: pass + def _base_template_args(self, dtype, kname): + return { + 'dtype': dtype, 'kname': kname, + 'A': self.A, 'beta': self.beta, 'width': 1, + 'm': self.m, 'n': self.n, 'k': self.k, + 'ldb': self.ldb, 'ldc': self.ldc, + 'afix': self.afix, 'alix': self.alix, 'bix': self.bix, + 'dot': _dot, 'partition': _partition, 'chunk': _chunk + } + def _process_meta(self, meta): pass + def _get_config(self, key): + if key not in self._config_cache: + cfgdir = f'kernels/{self.platform}/config' + path = f'{cfgdir}/{key}.json' + default_path = f'{cfgdir}/default.json' + try: + cfgdata = pkgutil.get_data('gimmik', path) + except FileNotFoundError: + cfgdata = pkgutil.get_data('gimmik', default_path) + self._config_cache[key] = json.loads(cfgdata.decode('utf-8')) + return self._config_cache[key] + + def _eval_condition(self, condition, stats): + if 'all' in condition: + return all(self._eval_condition(c, stats) for c in condition['all']) + if 'any' in condition: + return any(self._eval_condition(c, stats) for c in condition['any']) + if 'not' in condition: + return not self._eval_condition(condition['not'], stats) + + value = stats[condition['field']] + op = next(k for k in condition if k != 'field') + expected = condition[op] + + match op: + case 'eq': + return value == expected + case 'ne': + return value != expected + case 'lt': + return value is not None and value < expected + case 'lte': + return value is not None and value <= expected + case 'gt': + return value is not None and value > expected + case 'gte': + return value is not None and value >= expected + case 'in': + return value in expected + case 'is_null': + return value is None + case 'is_not': + return value is not None + case 'divisible_by': + return value is not None and value % expected == 0 + case 'is_null_or_divisible_by': + return (value is None or value % expected == 0) + case _: + raise ValueError(f'op `{op}` not supported') + def _render_kernel(self, dtype, tplname, tplargs): tpl = _PlatformTemplateLookup(self.platform).get_template(tplname) src = tpl.render(**tplargs) # At single precision suffix all floating point constants by 'f' - if dtype == 'float': + # (PTX doesn't use an 'f' suffix for FP literals) + if dtype == 'float' and self.platform != 'ptx': src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?', r'\g<0>f', src) diff --git a/gimmik/cuda.py b/gimmik/cuda.py index b18c509..9e1da43 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +import numpy as np + from gimmik.base import MatMul @@ -8,7 +10,15 @@ class CUDAMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + @staticmethod + def is_suitable(arr): + nnz = np.count_nonzero(arr) + nuq = len(np.unique(np.abs(arr))) + density = nnz / arr.size + return (nuq <= 28) or (density <= 0.15) + + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + **kwargs): # B loading, C streaming kernel yield ('cstream', {}, {}) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako new file mode 100644 index 0000000..dbd8433 --- /dev/null +++ b/gimmik/kernels/ptx/base.mako @@ -0,0 +1,4 @@ +.version ${ptx[0]}.${ptx[1]} +.target sm_${cc[0]}${cc[1]}${'a' if cc[0] >= 9 else ''} +.address_size 64 +${next.body()} diff --git a/gimmik/kernels/ptx/bstream-msplit-v2.mako b/gimmik/kernels/ptx/bstream-msplit-v2.mako new file mode 100644 index 0000000..bc8dd38 --- /dev/null +++ b/gimmik/kernels/ptx/bstream-msplit-v2.mako @@ -0,0 +1,193 @@ +<%inherit file='base'/> + +<% +mx = partition(A, into=msplit, by='rows') +bchunks = chunk(bix, bsz) +m_per_group = max(len(mcx) for mcx in mx) +bsub_bytes = 2 * bsz * blockx * 2 * dwidth_i +def bsub_off(buf, idx): + return (buf * bsz + idx) * blockx * 2 * dwidth_i +%> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, bsub_thread; +% if use_cpasync: + .reg .u32 bsub_sm_thread; +% endif + .reg .${pftype} bv_a, bv_b, csub_a<${m_per_group}>, csub_b<${m_per_group}>; + .reg .pred p1, p_skip; + .shared .align 16 .b8 _bsub[${bsub_bytes}]; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${2*dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${2*dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${2*dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; + } +% if use_cpasync: + { + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; + } +% endif + +% for cid, mcx in enumerate(mx): +## cid = ${cid}, rows ${mcx} + setp.ne.u32 p_skip, tid_y, ${cid}; + @p_skip bra $L_END_CID_${cid}; + +% if beta_zero or not preload_c: +## Zero accumulators +% for j, row_j in enumerate(mcx): +% if afix[row_j] != -1: + mov.${pftype} csub_a${j}, ${fzero}; + mov.${pftype} csub_b${j}, ${fzero}; +% endif +% endfor +% else: +## Pre-load C and scale by beta so per-row completion is a plain store +% for j, row_j in enumerate(mcx): +% if afix[row_j] != -1: + ld.weak.global.cg.v2.${pftype} {csub_a${j}, csub_b${j}}, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} csub_a${j}, csub_a${j}, ${float(beta)}; + mul.${pftype} csub_b${j}, csub_b${j}, ${float(beta)}; +% endif +% endfor +% endif + +## Pre-fill double buffer +% if use_cpasync: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${2*dwidth_i}; +% endfor + cp.async.commit_group; + cp.async.wait_all; + bar.sync 0; +% else: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: + { + .reg .${pftype} _bva, _bvb; + ld.weak.global.cg.v2.${pftype} {_bva, _bvb}, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.v2.${pftype} [bsub_thread + ${bsub_off(0, idx)}], {_bva, _bvb}; + } +% endfor + bar.sync 0; +% endif + +## Main loop over B-chunks (double-buffered) +% for bb in range(len(bchunks)): +<% + buf_cur = bb % 2 + buf_next = (bb + 1) % 2 +%> +% if not loop.last: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[bb + 1]) if i % msplit == cid]: +% if use_cpasync: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${2*dwidth_i}; +% else: + { + .reg .${pftype} _bva, _bvb; + ld.weak.global.cg.v2.${pftype} {_bva, _bvb}, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.v2.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], {_bva, _bvb}; + } +% endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif + +% for idx, kx in enumerate(bchunks[bb]): +% if any(A[row_j, kx] for row_j in mcx): + ld.shared.v2.${pftype} {bv_a, bv_b}, [bsub_thread + ${bsub_off(buf_cur, idx)}]; +% endif +% for j, row_j in enumerate(mcx): +% if A[row_j, kx] != 0: + fma.rn.${pftype} csub_a${j}, bv_a, ${A[row_j, kx]}, csub_a${j}; + fma.rn.${pftype} csub_b${j}, bv_b, ${A[row_j, kx]}, csub_b${j}; +% endif +% endfor +% for j, row_j in enumerate(mcx): +% if kx == alix[row_j]: +% if beta_zero: + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {csub_a${j}, csub_b${j}}; +% elif preload_c: + st.weak.global.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {csub_a${j}, csub_b${j}}; +% else: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ca, _ca, ${float(beta)}, csub_a${j}; + fma.rn.${pftype} _cb, _cb, ${float(beta)}, csub_b${j}; + st.weak.global.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +% endfor +% endfor +% if use_cpasync: +% if not loop.last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor + +## Handle zero rows in this cid's group +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta_zero: + { + .reg .${pftype} _z; + mov.${pftype} _z, ${fzero}; + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {_z, _z}; + } +% elif beta != 1: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _ca, _ca, ${float(beta)}; + mul.${pftype} _cb, _cb, ${float(beta)}; + st.weak.global.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +% endfor +% endif + +$L_END_CID_${cid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako new file mode 100644 index 0000000..989735b --- /dev/null +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -0,0 +1,281 @@ +<%inherit file='base'/> + +<% +mx = partition(A, into=msplit, by='rows') +bchunks = chunk(bix, bsz) +m_per_group = max(len(mcx) for mcx in mx) +bsub_bytes = 2 * bsz * blockx * dwidth_i +def bsub_off(buf, idx): + return (buf * bsz + idx) * blockx * dwidth_i +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, bsub_thread; +% if use_cpasync: + .reg .u32 bsub_sm_thread; +% endif + .reg .${pftype} bv, csub<${m_per_group}>; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _bsub[${bsub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; + } +% if use_cpasync: + { + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; + } +% endif + +% for cid, mcx in enumerate(mx): +## cid = ${cid}, rows ${mcx} + setp.ne.u32 p_skip, tid_y, ${cid}; + @p_skip bra $L_END_CID_${cid}; + +% if beta_zero or not preload_c: +## Zero accumulators +% for j, row_j in enumerate(mcx): +% if afix[row_j] != -1: + mov.${pftype} csub${j}, ${fzero}; +% endif +% endfor +% else: +## Pre-load C and scale by beta so per-row completion is a plain store +% for j, row_j in enumerate(mcx): +% if afix[row_j] != -1: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*row_j*dwidth_i}]; +% endif + mul.${pftype} csub${j}, csub${j}, ${float(beta)}; +% endif +% endfor +% endif + +## Pre-fill double buffer +% if use_cpasync: +## Async fill of chunk 0 +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: +% if n is None: + { + .reg .u64 _bptr; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% endfor + cp.async.commit_group; + cp.async.wait_all; + bar.sync 0; +% else: +## Sync fill of chunk 0 +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: + { + .reg .${pftype} _bv; +% if n is None: + .reg .u64 _bptr; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; + ld.weak.global.cg.${pftype} _bv, [_bptr]; +% else: + ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; +% endif + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + } +% endfor + bar.sync 0; +% endif + +## Main loop over B-chunks (double-buffered) +% for bb in range(len(bchunks)): +<% + buf_cur = bb % 2 + buf_next = (bb + 1) % 2 +%> +% if not loop.last: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[bb + 1]) if i % msplit == cid]: +% if use_cpasync: +% if n is None: + { + .reg .u64 _bptr; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% else: + { + .reg .${pftype} _bv; +% if n is None: + .reg .u64 _bptr; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; + ld.weak.global.cg.${pftype} _bv, [_bptr]; +% else: + ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; +% endif + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + } +% endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif + +% for idx, kx in enumerate(bchunks[bb]): +% if any(A[row_j, kx] for row_j in mcx): + ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}]; +% endif +% for j, row_j in enumerate(mcx): +% if A[row_j, kx] != 0: + fma.rn.${pftype} csub${j}, bv, ${A[row_j, kx]}, csub${j}; +% endif +% endfor +% for j, row_j in enumerate(mcx): +% if kx == alix[row_j]: +% if beta_zero: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; +% endif +% elif preload_c: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; + st.weak.global.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; +% endif + } +% endif +% endif +% endfor +% endfor +% if use_cpasync: +% if not loop.last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor +## End of Main loop over B-chunks + +## Handle zero rows in this cid's group +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta_zero: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _tmp; +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +% endfor +% endif + +$L_END_CID_${cid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako new file mode 100644 index 0000000..ac4d0a6 --- /dev/null +++ b/gimmik/kernels/ptx/bstream.mako @@ -0,0 +1,150 @@ +<%inherit file='base'/> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} csub<${m}>, bv<${len(bix)}>; + .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + +## Batch-load active B columns +% for i, kx in enumerate(bix): +% if n is None: + { + .reg .u64 _bptr; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +% endfor + +% if beta_zero: +## Zero accumulators +% for j in range(m): +% if afix[j] != -1: + mov.${pftype} csub${j}, ${fzero}; +% endif +% endfor +% else: +## Pre-load C and scale by beta so per-row completion is a plain store +% for j in range(m): +% if afix[j] != -1: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*j*dwidth_i}]; +% endif + mul.${pftype} csub${j}, csub${j}, ${float(beta)}; +% endif +% endfor +% endif + +## Main compute +% for kx in bix: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: + fma.rn.${pftype} csub${j}, bv${bix[kx]}, ${jx}, csub${j}; +% endif +% endfor +% for j in range(m): +% if kx == alix[j]: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; +% endif + +% endif +% endfor +% endfor + +% if has_zero_rows: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% for j, jx in enumerate(afix): +% if jx == -1 and beta_zero: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + +% elif jx == -1: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif +% endif +% endfor + } +% endif + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/config/default.json b/gimmik/kernels/ptx/config/default.json new file mode 100644 index 0000000..846a84f --- /dev/null +++ b/gimmik/kernels/ptx/config/default.json @@ -0,0 +1,41 @@ +{ + "schema": 1, + "cc": [7, 0], + "ptx": [7, 0], + "kernels": [ + { + "template": "cstream", + "family": "sparse", + "block": [128, 1, 1], + "width": 1, + "descriptor": "cstream/x128" + }, + { + "template": "bstream", + "family": "sparse", + "block": [128, 1, 1], + "width": 1, + "descriptor": "bstream/x128" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [32, 4, 1], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m4-b24-x32" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [32, 2, 1], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k2-c24-x32" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm100_double.json b/gimmik/kernels/ptx/config/sm100_double.json new file mode 100644 index 0000000..c84ae65 --- /dev/null +++ b/gimmik/kernels/ptx/config/sm100_double.json @@ -0,0 +1,424 @@ +{ + "schema": 1, + "cc": [ + 10, + 0 + ], + "ptx": [ + 8, + 7 + ], + "kernels": [ + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 20, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c20-x32" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "bsz": 32, + "preload_c": true + }, + "descriptor": "bstream-msplit/preload-c/m2-b32-x32" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "bsz": 32, + "preload_c": true + }, + "descriptor": "bstream-msplit/preload-c/m4-b32-x32" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 32, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m1-b32-x32" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 16, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c16-x32" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 32, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "double" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c32-x64" + }, + { + "template": "cstream", + "family": "sparse", + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream/x256" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 32, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c32-x32" + }, + { + "template": "dmma-steal-ws", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, + "block": [ + 192, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 2 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4, + "stealer": 5 + }, + "conditions": { + "field": "n", + "is_not": null + }, + "descriptor": "dmma-steal-ws/nn2-w4" + }, + { + "template": "dmma-astream-msplit-v2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 2, + "warps": 2, + "msplit": 2 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream-msplit/v2/nn2-w2-m2" + }, + { + "template": "dmma-steal-ws", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, + "block": [ + 192, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 1 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4, + "stealer": 5 + }, + "conditions": { + "field": "n", + "is_not": null + }, + "descriptor": "dmma-steal-ws/nn1-w4" + }, + { + "template": "dmma-astream-v2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, + "block": [ + 32, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 4, + "warps": 1 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn4-w1" + }, + { + "template": "dmma-astream-msplit-v2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, + "block": [ + 64, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 4, + "warps": 1, + "msplit": 2 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream-msplit/v2/nn4-w1-m2" + }, + { + "template": "dmma-steal-ws", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, + "block": [ + 192, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 4 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4, + "stealer": 5 + }, + "conditions": { + "field": "n", + "is_not": null + }, + "descriptor": "dmma-steal-ws/nn4-w4" + }, + { + "template": "dmma-astream-msplit-v2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, + "block": [ + 96, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 4, + "warps": 1, + "msplit": 3 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream-msplit/v2/nn4-w1-m3" + }, + { + "template": "dmma-astream-v2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, + "block": [ + 256, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 2, + "warps": 8 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn2-w8" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm100_float.json b/gimmik/kernels/ptx/config/sm100_float.json new file mode 100644 index 0000000..daf15df --- /dev/null +++ b/gimmik/kernels/ptx/config/sm100_float.json @@ -0,0 +1,229 @@ +{ + "schema": 1, + "cc": [ + 10, + 0 + ], + "ptx": [ + 8, + 7 + ], + "kernels": [ + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 32, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c32-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 16, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c16-x64" + }, + { + "template": "cstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream/x128" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c24-x64" + }, + { + "template": "bstream", + "family": "sparse", + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x256" + }, + { + "template": "bstream-msplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "bsz": 32, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/preload-c/m2-b32-x64" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 24, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c24-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 32, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c32-x64" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm80_double.json b/gimmik/kernels/ptx/config/sm80_double.json new file mode 100644 index 0000000..83d9cd9 --- /dev/null +++ b/gimmik/kernels/ptx/config/sm80_double.json @@ -0,0 +1,222 @@ +{ + "schema": 1, + "cc": [ + 8, + 0 + ], + "ptx": [ + 7, + 0 + ], + "kernels": [ + { + "template": "cstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream" + }, + { + "template": "bstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m4-b24-x32" + }, + { + "template": "bstream-msplit-v2", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 2, + "params": { + "bsz": 16 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/m4-b16-x32" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "conditions": { + "all": [ + { + "field": "beta_zero", + "eq": true + }, + { + "field": "m", + "lte": 320 + }, + { + "field": "k_used", + "gte": 64 + } + ] + }, + "descriptor": "bstream-msplit/m1-b32-x64" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k2-c24-x32" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c24-x32" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "csz": 20 + }, + "conditions": { + "field": "k_used", + "gt": 500 + }, + "descriptor": "cstream-ksplit/k4-c20-x32" + }, + { + "template": "cstream-v2", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "k_used", + "lte": 100 + }, + { + "field": "aligne", + "is_null_or_divisible_by": 2 + } + ] + }, + "descriptor": "cstream-v2/x128" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm80_float.json b/gimmik/kernels/ptx/config/sm80_float.json new file mode 100644 index 0000000..23656b2 --- /dev/null +++ b/gimmik/kernels/ptx/config/sm80_float.json @@ -0,0 +1,229 @@ +{ + "schema": 1, + "cc": [ + 8, + 0 + ], + "ptx": [ + 7, + 0 + ], + "kernels": [ + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 32, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c32-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 16, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c16-x64" + }, + { + "template": "cstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream/x128" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c24-x64" + }, + { + "template": "bstream", + "family": "sparse", + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x256" + }, + { + "template": "bstream-msplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "bsz": 32, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/preload-c/m2-b32-x64" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 24, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c24-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 32, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c32-x64" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm90_double.json b/gimmik/kernels/ptx/config/sm90_double.json new file mode 100644 index 0000000..bd135e2 --- /dev/null +++ b/gimmik/kernels/ptx/config/sm90_double.json @@ -0,0 +1,500 @@ +{ + "schema": 1, + "cc": [ + 9, + 0 + ], + "ptx": [ + 8, + 6 + ], + "kernels": [ + { + "template": "cstream-ksplit", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 20 + }, + "descriptor": "cstream-ksplit/k2-c20-x32", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 8, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m8-b24-x32", + "family": "sparse" + }, + { + "template": "cstream-ksplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k4-c24-x32", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m2-b32-x64", + "family": "sparse" + }, + { + "template": "bstream", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x64", + "family": "sparse" + }, + { + "template": "bstream-msplit-v2", + "block": [ + 32, + 4, + 1 + ], + "width": 2, + "params": { + "bsz": 16 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/m4-b16-x32", + "family": "sparse" + }, + { + "template": "cstream-ksplit-v2", + "block": [ + 32, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c24-x32", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m2-b32-x32", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m1-b24-x64", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m4-b32-x32", + "family": "sparse" + }, + { + "template": "cstream-v2", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "k_used", + "lte": 100 + }, + { + "field": "aligne", + "is_null_or_divisible_by": 2 + } + ] + }, + "descriptor": "cstream-v2/x128", + "family": "sparse" + }, + { + "template": "dmma-asmem-v2", + "block": [ + 256, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 1, + "warps": 8 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn1-w8", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } + }, + { + "template": "dmma-astream-v2", + "block": [ + 64, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 2, + "warps": 2 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn2-w2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } + }, + { + "template": "dmma-stride-ws", + "block": [ + 160, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 1, + "iters": 8 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4 + }, + "conditions": { + "field": "n", + "is_not": null + }, + "descriptor": "dmma-stride-ws/nn1-w4-i8", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } + }, + { + "template": "dmma-astream-v2", + "block": [ + 32, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 4, + "warps": 1 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn4-w1", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } + }, + { + "template": "dmma-asmem-v2", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 2, + "warps": 4 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn2-w4", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } + }, + { + "template": "dmma-stride-ws", + "block": [ + 160, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 2, + "iters": 2 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4 + }, + "conditions": { + "field": "n", + "is_not": null + }, + "descriptor": "dmma-stride-ws/nn2-w4-i2", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } + }, + { + "template": "dmma-astream-v2", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "params": { + "nn": 1, + "warps": 4 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn1-w4", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } + }, + { + "template": "dmma-stride-ws", + "block": [ + 160, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 2, + "iters": 8 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4 + }, + "conditions": { + "field": "n", + "is_not": null + }, + "descriptor": "dmma-stride-ws/nn2-w4-i8", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm90_float.json b/gimmik/kernels/ptx/config/sm90_float.json new file mode 100644 index 0000000..ebfd5f3 --- /dev/null +++ b/gimmik/kernels/ptx/config/sm90_float.json @@ -0,0 +1,229 @@ +{ + "schema": 1, + "cc": [ + 9, + 0 + ], + "ptx": [ + 8, + 6 + ], + "kernels": [ + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 32, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c32-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 16, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c16-x64" + }, + { + "template": "cstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream/x128" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c24-x64" + }, + { + "template": "bstream", + "family": "sparse", + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x256" + }, + { + "template": "bstream-msplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "bsz": 32, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/preload-c/m2-b32-x64" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 24, + "preload_c": true + }, + "descriptor": "cstream-ksplit/preload-c/k2-c24-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 32, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/preload-c/k2-c32-x64" + } + ] +} diff --git a/gimmik/kernels/ptx/cstream-ksplit-v2.mako b/gimmik/kernels/ptx/cstream-ksplit-v2.mako new file mode 100644 index 0000000..b10dc3a --- /dev/null +++ b/gimmik/kernels/ptx/cstream-ksplit-v2.mako @@ -0,0 +1,147 @@ +<%inherit file='base'/> + +<% +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(list(range(m)), csz) +cv_per_thread = -(-csz // ksplit) +bv_per_thread = max(len(kbx) for kbx in kparts) +csub_bytes = (ksplit - 1) * csz * blockx * 2 * dwidth_i +%> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, csub_thread; + .reg .${pftype} bv_a<${bv_per_thread}>, bv_b<${bv_per_thread}>; + .reg .${pftype} cv_a<${cv_per_thread}>, cv_b<${cv_per_thread}>; + .reg .${pftype} dotp_a, dotp_b; + .reg .pred p1, p_skip; + .shared .align 16 .b8 _csub[${csub_bytes}]; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${2*dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${2*dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${2*dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; + } + +% for bid, kbx in enumerate(kparts): +## bid = ${bid}: ${len(kbx)} B columns, ksplit=${ksplit} + setp.ne.u32 p_skip, tid_y, ${bid}; + @p_skip bra $L_END_BID_${bid}; + +<% loaded = set() %> + +% for cchunk_i, cchunk in enumerate(cchunks): +## Chunk ${cchunk_i}: partial dot-product +% for row_idx, j in enumerate(cchunk): +<% owner_bid = row_idx % ksplit %> +<% has_dotp = bool(A[j].any()) %> +% for kxi, kx in enumerate(kbx): +% if A[j, kx] != 0 and kx not in loaded: + ld.weak.global.cg.v2.${pftype} {bv_a${kxi}, bv_b${kxi}}, [b_base + ${ldb*kx*dwidth_i}]; +<% loaded.add(kx) %> +% endif +% endfor + mov.${pftype} dotp_a, ${fzero}; + mov.${pftype} dotp_b, ${fzero}; +% for kxi, kx in enumerate(kbx): +% if A[j, kx] != 0: + fma.rn.${pftype} dotp_a, bv_a${kxi}, ${A[j, kx]}, dotp_a; + fma.rn.${pftype} dotp_b, bv_b${kxi}, ${A[j, kx]}, dotp_b; +% endif +% endfor +% if owner_bid == bid: +% if beta_zero or not preload_c: + mov.${pftype} cv_a${row_idx // ksplit}, dotp_a; + mov.${pftype} cv_b${row_idx // ksplit}, dotp_b; +% elif has_dotp: + ld.weak.global.cg.v2.${pftype} {cv_a${row_idx // ksplit}, cv_b${row_idx // ksplit}}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} cv_a${row_idx // ksplit}, cv_a${row_idx // ksplit}, ${float(beta)}, dotp_a; + fma.rn.${pftype} cv_b${row_idx // ksplit}, cv_b${row_idx // ksplit}, ${float(beta)}, dotp_b; +% endif +% else: +<% csub_idx = bid - (1 if bid > owner_bid else 0) %> + st.shared.v2.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * 2 * dwidth_i}], {dotp_a, dotp_b}; +% endif +% endfor + bar.sync 0; + +## Combine phase (owned rows only) +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: +<% has_dotp = bool(A[j].any()) %> +% if not preload_c or beta_zero or has_dotp: + mov.${pftype} dotp_a, cv_a${row_idx // ksplit}; + mov.${pftype} dotp_b, cv_b${row_idx // ksplit}; +% for other_bid in range(ksplit): +% if other_bid != bid: +<% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> + { + .reg .${pftype} _ta, _tb; + ld.shared.v2.${pftype} {_ta, _tb}, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * 2 * dwidth_i}]; + add.${pftype} dotp_a, dotp_a, _ta; + add.${pftype} dotp_b, dotp_b, _tb; + } +% endif +% endfor +% endif +% if beta_zero: + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% elif preload_c and has_dotp: + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% elif preload_c and beta != 1: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _ca, _ca, ${float(beta)}; + mul.${pftype} _cb, _cb, ${float(beta)}; + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% else: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.${pftype} _cb, _cb, ${float(beta)}, dotp_b; + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif + +% endif +% endfor + bar.sync 0; +% endfor + +$L_END_BID_${bid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako new file mode 100644 index 0000000..700d6a3 --- /dev/null +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -0,0 +1,201 @@ +<%inherit file='base'/> + +<% +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(list(range(m)), csz) +cv_per_thread = -(-csz // ksplit) +bv_per_thread = max(len(kbx) for kbx in kparts) +csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, csub_thread; + .reg .${pftype} bv<${bv_per_thread}>, cv<${cv_per_thread}>, dotp; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _csub[${csub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; + } + +% for bid, kbx in enumerate(kparts): +## bid = ${bid}: ${len(kbx)} B columns, ksplit=${ksplit} + setp.ne.u32 p_skip, tid_y, ${bid}; + @p_skip bra $L_END_BID_${bid}; + +<% loaded = set() %> + +% for cchunk_i, cchunk in enumerate(cchunks): +## Chunk ${cchunk_i}: partial dot-product +% for row_idx, j in enumerate(cchunk): +<% owner_bid = row_idx % ksplit %> +<% has_dotp = bool(A[j].any()) %> +% for kxi, kx in enumerate(kbx): +% if A[j, kx] != 0 and kx not in loaded: +% if n is None: + { + .reg .u64 _bptr; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; + ld.weak.global.cg.${pftype} bv${kxi}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +<% loaded.add(kx) %> +% endif +% endfor + mov.${pftype} dotp, ${fzero}; +% for kxi, kx in enumerate(kbx): +% if A[j, kx] != 0: + fma.rn.${pftype} dotp, bv${kxi}, ${A[j, kx]}, dotp; +% endif +% endfor +% if owner_bid == bid: +% if beta_zero or not preload_c: + mov.${pftype} cv${row_idx // ksplit}, dotp; +% elif has_dotp: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} cv${row_idx // ksplit}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} cv${row_idx // ksplit}, [c_base + ${ldc*j*dwidth_i}]; +% endif + fma.rn.${pftype} cv${row_idx // ksplit}, cv${row_idx // ksplit}, ${float(beta)}, dotp; +% endif +% else: +<% csub_idx = bid - (1 if bid > owner_bid else 0) %> + st.shared.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}], dotp; +% endif +% endfor + bar.sync 0; + +## Combine phase (owned rows only) +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: +<% has_dotp = bool(A[j].any()) %> +% if not preload_c or beta_zero or has_dotp: + mov.${pftype} dotp, cv${row_idx // ksplit}; +% for other_bid in range(ksplit): +% if other_bid != bid: +<% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> + { + .reg .${pftype} _tmp; + ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; + add.${pftype} dotp, dotp, _tmp; + } +% endif +% endfor +% endif +% if beta_zero: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% elif preload_c and has_dotp: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + st.weak.global.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% elif preload_c and beta != 1: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + mul.${pftype} _ctmp, _ctmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _ctmp, _ctmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% endif +% endfor + bar.sync 0; +% endfor + +$L_END_BID_${bid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-v2.mako b/gimmik/kernels/ptx/cstream-v2.mako new file mode 100644 index 0000000..9949045 --- /dev/null +++ b/gimmik/kernels/ptx/cstream-v2.mako @@ -0,0 +1,83 @@ +<%inherit file='base'/> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} bv_a<${len(bix)}>, bv_b<${len(bix)}>, dotp_a, dotp_b; + .reg .pred p1; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x, _tid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 _tid_x, %tid.x; + mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${2*dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${2*dwidth_i}, c; + } + +## Batch-load B column pairs +% for i, kx in enumerate(bix): + ld.weak.global.cg.v2.${pftype} {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; +% endfor + +## Main compute: two parallel dot-product streams per thread +% for j in range(m): +% if row_nz[j]: + mov.${pftype} dotp_a, ${fzero}; + mov.${pftype} dotp_b, ${fzero}; +% for kx, jx in row_nz[j]: + fma.rn.${pftype} dotp_a, bv_a${bix[kx]}, ${jx}, dotp_a; + fma.rn.${pftype} dotp_b, bv_b${bix[kx]}, ${jx}, dotp_b; +% endfor +% if beta_zero: + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% else: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.${pftype} _cb, _cb, ${float(beta)}, dotp_b; + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif + +% else: +## Zero row of A +% if beta_zero: + { + .reg .${pftype} _z; + mov.${pftype} _z, ${fzero}; + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_z, _z}; + } +% elif beta != 1: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _ca, _ca, ${float(beta)}; + mul.${pftype} _cb, _cb, ${float(beta)}; + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako new file mode 100644 index 0000000..297d402 --- /dev/null +++ b/gimmik/kernels/ptx/cstream.mako @@ -0,0 +1,134 @@ +<%inherit file='base'/> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} bv<${len(bix)}>, dotp; + .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + +## Batch-load active B columns +% for i, kx in enumerate(bix): +% if n is None: + { + .reg .u64 _bptr; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +% endfor + +## Compute and store each output row +% for j in range(m): +% if row_nz[j]: + mov.${pftype} dotp, ${fzero}; +% for kx, jx in row_nz[j]: + fma.rn.${pftype} dotp, bv${bix[kx]}, ${jx}, dotp; +% endfor +% if beta_zero: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% else: +## Zero row of A +% if beta_zero: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _tmp; +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-asmem-v2.mako b/gimmik/kernels/ptx/dmma-asmem-v2.mako new file mode 100644 index 0000000..9ecfbef --- /dev/null +++ b/gimmik/kernels/ptx/dmma-asmem-v2.mako @@ -0,0 +1,311 @@ +<%inherit file='base'/> + +<% +blockx = a_copy_threads +a_pairs = a_elems // 2 +a_pairs_tail = a_elems % 2 +copy_v2_iters = (a_pairs + blockx - 1) // blockx +bs = bool(block_stealing) +%> + +% if bs: +.shared .align 8 .b64 ${kname}_mbar; +.shared .align 16 .b8 ${kname}_workid[16]; +% endif +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.shared .align 16 .b64 ${kname}_As[${a_elems}]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .${pftype} a_frag_<${a_regs}>; +% if bs: + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; + .reg .pred p_root, p_done, p_have; +% endif +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% for mt in range(m_tiles): + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + +% if bs: + setp.eq.u32 p_root, tid, 0; + mov.u32 mbar_a, ${kname}_mbar; + mov.u32 work_a, ${kname}_workid; + @p_root mbarrier.init.shared::cta.b64 [mbar_a], 1; + bar.sync 0; +% endif + + // Cooperative copy A from .global to .shared + { + .reg .u64 a_glb_base, a_smem_base; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + mov.u64 a_smem_base, ${kname}_As; +% for ci in range(copy_v2_iters): +<% + base_pair = ci * blockx + pairs_this = min(blockx, a_pairs - base_pair) +%> + { + .reg .u32 pidx; + .reg .u64 off64, gaddr, saddr; + .reg .${pftype} v0, v1; +% if loop.last and pairs_this < blockx: + .reg .pred plast; + add.u32 pidx, tid, ${base_pair}; + setp.lt.u32 plast, pidx, ${a_pairs}; + mul.wide.u32 off64, pidx, ${2 * dwidth_i}; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + @plast ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; + @plast st.shared.v2.${pftype} [saddr], {v0, v1}; +% else: + add.u32 pidx, tid, ${base_pair}; + mul.wide.u32 off64, pidx, ${2 * dwidth_i}; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; + st.shared.v2.${pftype} [saddr], {v0, v1}; +% endif + } +% endfor +% if a_pairs_tail: + { + .reg .pred plast; + .reg .u64 gaddr, saddr; + .reg .${pftype} v; + setp.eq.u32 plast, tid, 0; + add.u64 gaddr, a_glb_base, ${(a_elems - 1) * dwidth_i}; + add.u64 saddr, a_smem_base, ${(a_elems - 1) * dwidth_i}; + @plast ld.weak.global.cg.${pftype} v, [gaddr]; + @plast st.shared.${pftype} [saddr], v; + } +% endif + } + bar.sync 0; + + // Lane-only base; lifted out of the optional steal loop + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + +% for mt in range(m_tiles): +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; + } +% endif +% endfor +% endfor + +% if bs: + mov.u32 ctaid, %ctaid.x; +$L_LOOP: +% endif + + { + .reg .u32 cta; +% if bs: + mov.u32 cta, ctaid; +% else: + mov.u32 cta, %ctaid.x; +% endif + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; +% if bs: + @pwarp_exit bra $L_STEAL; +% else: + @pwarp_exit bra $L_EXIT; +% endif + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta_zero: +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor +% else: +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if pm is not None: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{{cpair}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}_{mg}')} + } +% endfor +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +% for kg in range(k_groups): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = f'pbrow_{kg}' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + kg * b_kgroup_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; + } +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}_{kg}')} + } +% endfor +% endfor +% for mt in range(m_tiles): +% for ai in range(a_regs): + ld.shared.${pftype} a_frag_${ai}, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor +% for nt in range(nn): + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; +% endfor +% endfor +% endfor + +% for mt in range(m_tiles): +% for nt in range(nn): +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{{cpair}}};', pm, pred_reg=f'p01s_{nt}_{mt}_{mg}')} + } +% endfor +% endfor +% endfor + +% if bs: +$L_STEAL: + // Root issues async try_cancel + waits; bar.sync orders the workid load + @!p_root bra $L_AFTER_WAIT; + { + .reg .u64 state; + mbarrier.arrive.expect_tx.shared::cta.b64 state, [mbar_a], 16; + clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [work_a], [mbar_a]; +$L_WAIT: + mbarrier.try_wait.shared::cta.b64 p_done, [mbar_a], state, 10000000; + @!p_done bra $L_WAIT; + } +$L_AFTER_WAIT: + bar.sync 0; + + { + .reg .b128 resp; + ld.shared::cta.b128 resp, [work_a]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; + @!p_have bra $L_FIN; + clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; + } + bra.uni $L_LOOP; + +$L_FIN: + bar.sync 0; + @p_root mbarrier.inval.shared::cta.b64 [mbar_a]; +% endif + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-asmem.mako b/gimmik/kernels/ptx/dmma-asmem.mako new file mode 100644 index 0000000..b163a42 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-asmem.mako @@ -0,0 +1,302 @@ +<%inherit file='base'/> + +<% +blockx = a_copy_threads +copy_v1_iters = (a_elems + blockx - 1) // blockx +bs = bool(block_stealing) +%> + +% if bs: +.shared .align 8 .b64 ${kname}_mbar; +.shared .align 16 .b8 ${kname}_workid[16]; +% endif +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.shared .align 16 .b64 ${kname}_As[${a_elems}]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .${pftype} a_frag_<${a_regs}>; +% if bs: + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; + .reg .pred p_root, p_done, p_have; +% endif +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% for mt in range(m_tiles): + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + +% if bs: + setp.eq.u32 p_root, tid, 0; + mov.u32 mbar_a, ${kname}_mbar; + mov.u32 work_a, ${kname}_workid; + @p_root mbarrier.init.shared::cta.b64 [mbar_a], 1; + bar.sync 0; +% endif + + // Cooperative copy A from .global to .shared + { + .reg .u64 a_glb_base, a_smem_base; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + mov.u64 a_smem_base, ${kname}_As; +% for ci in range(copy_v1_iters): +<% + base_elem = ci * blockx + elems_this = min(blockx, a_elems - base_elem) +%> + { + .reg .u32 eidx; + .reg .u64 off64, gaddr, saddr; + .reg .${pftype} v; +% if loop.last and elems_this < blockx: + .reg .pred plast; + add.u32 eidx, tid, ${base_elem}; + setp.lt.u32 plast, eidx, ${a_elems}; + mul.wide.u32 off64, eidx, ${dwidth_i}; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + @plast ld.weak.global.cg.${pftype} v, [gaddr]; + @plast st.shared.${pftype} [saddr], v; +% else: + add.u32 eidx, tid, ${base_elem}; + mul.wide.u32 off64, eidx, ${dwidth_i}; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + ld.weak.global.cg.${pftype} v, [gaddr]; + st.shared.${pftype} [saddr], v; +% endif + } +% endfor + } + bar.sync 0; + + // Lane-only base; lifted out of the optional steal loop + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + +% for mt in range(m_tiles): +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; + } +% endif +% endfor +% endfor + +% if bs: + mov.u32 ctaid, %ctaid.x; +$L_LOOP: +% endif + + { + .reg .u32 cta; +% if bs: + mov.u32 cta, ctaid; +% else: + mov.u32 cta, %ctaid.x; +% endif + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; +% if bs: + @pwarp_exit bra $L_STEAL; +% else: + @pwarp_exit bra $L_EXIT; +% endif + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta_zero: +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor +% else: +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} {c0}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}_{mg}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} {c1}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}_{mg}')} + } +% endfor +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +% for kg in range(k_groups): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = f'pbrow_{kg}' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + kg * b_kgroup_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; + } +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}_{kg}')} + } +% endfor +% endfor +% for mt in range(m_tiles): +% for ai in range(a_regs): + ld.shared.${pftype} a_frag_${ai}, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor +% for nt in range(nn): + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; +% endfor +% endfor +% endfor + +% for mt in range(m_tiles): +% for nt in range(nn): +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], {c0};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}_{mg}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], {c1};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}_{mg}')} + } +% endfor +% endfor +% endfor + +% if bs: +$L_STEAL: + // Root issues async try_cancel + waits; bar.sync orders the workid load + @!p_root bra $L_AFTER_WAIT; + { + .reg .u64 state; + mbarrier.arrive.expect_tx.shared::cta.b64 state, [mbar_a], 16; + clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [work_a], [mbar_a]; +$L_WAIT: + mbarrier.try_wait.shared::cta.b64 p_done, [mbar_a], state, 10000000; + @!p_done bra $L_WAIT; + } +$L_AFTER_WAIT: + bar.sync 0; + + { + .reg .b128 resp; + ld.shared::cta.b128 resp, [work_a]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; + @!p_have bra $L_FIN; + clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; + } + bra.uni $L_LOOP; + +$L_FIN: + bar.sync 0; + @p_root mbarrier.inval.shared::cta.b64 [mbar_a]; +% endif + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako b/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako new file mode 100644 index 0000000..d22704f --- /dev/null +++ b/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako @@ -0,0 +1,258 @@ +<%inherit file='base'/> + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; + +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u32 ctaid_x, n_start_cta, warp_n, warp_m, warp_n_base; + .reg .u64 bdesc_addr, c_ptr; + .reg .u64 ag_thr_base, c_thr_base; + .reg .u32 b_smem, b_thr_base, tma_mbar; + .reg .pred p_tid0, pwarp_exit, p_load_warp, p_warp_lead; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif +% endfor + + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + mov.u32 ctaid_x, %ctaid.x; + mul.lo.u32 n_start_cta, ctaid_x, ${n_per_cta}; + + { + .reg .u32 t; + div.u32 warp_n, warp, ${msplit}; + mad.lo.u32 t, warp_n, ${msplit}, 0; + sub.u32 warp_m, warp, t; + } + + { + .reg .u32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b_smem, dynm_base, ${b_off}; + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + } + + setp.eq.u32 p_tid0, tid, 0; + setp.eq.u32 p_load_warp, warp, 0; + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + @p_tid0 mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_tid0 fence.proxy.async.shared::cta; + bar.sync 0; + + @!p_load_warp bra $L_AFTER_B_TMA; + { + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_smem], [bdesc_addr, {n_start_cta, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_WAIT: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_WAIT; + } +$L_AFTER_B_TMA: + bar.sync 0; + + { + .reg .u32 t; + mul.lo.u32 t, warp_n, ${n_per_warp}; + add.u32 warp_n_base, n_start_cta, t; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*sizeof(f64) + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u32 bcol_local, t, row_off; + mad.lo.u32 bcol_local, warp_n, ${n_per_warp}, r_div4; + mul.lo.u32 t, bcol_local, ${dwidth_i}; + mul.lo.u32 row_off, r_mod4, ${n_per_cta * dwidth_i}; + add.u32 t, t, row_off; + add.u32 b_thr_base, b_smem, t; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for wm in range(msplit): +<% + owned_mts = [mt for mt in range(m_tiles) if mt % msplit == wm] +%> +% if owned_mts: + { + .reg .pred p_this_msplit; + setp.ne.u32 p_this_msplit, warp_m, ${wm}; + @p_this_msplit bra $L_SKIP_MS_${wm}; + } + { + .reg .${pftype} a_frag_<${a_regs}>; +% for nt in range(nn): + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% endfor +% for nt in range(nn): +% for mt in owned_mts: + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor +% endfor +% for mt in owned_mts: +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; + } +% endif +% endfor +% endfor + +% for nt in range(nn): +% for mt in owned_mts: +% if beta_zero: +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor +% else: +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if pm is not None: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{{cpair}}}, [caddr];', pm, pred_reg=f'p01_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} + } +% endfor +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +% for kg in range(k_groups): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = f'pbrow_{kg}' if k_tail else None +%> + { + .reg .u32 baddr; + add.u32 baddr, b_thr_base, ${ki * b_smem_kiter_stride + kg * b_smem_kgroup_stride + nt * b_smem_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; + } +% endif + ${pred_emit(f'ld.shared.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{wm}_{ki}_{nt}_{kg}', indent=' ' * 12)} + } +% endfor +% endfor +% for mt in owned_mts: +% for ai in range(a_regs): + ld.weak.global.${pftype} a_frag_${ai}, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor +% for nt in range(nn): + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; +% endfor +% endfor +% endfor + +% for mt in owned_mts: +% for nt in range(nn): +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{{cpair}}};', pm, pred_reg=f'p01s_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} + } +% endfor +% endfor +% endfor + } +$L_SKIP_MS_${wm}: +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-astream-msplit.mako b/gimmik/kernels/ptx/dmma-astream-msplit.mako new file mode 100644 index 0000000..50e684e --- /dev/null +++ b/gimmik/kernels/ptx/dmma-astream-msplit.mako @@ -0,0 +1,263 @@ +<%inherit file='base'/> + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; + +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u32 ctaid_x, n_start_cta, warp_n, warp_m, warp_n_base; + .reg .u64 bdesc_addr, c_ptr; + .reg .u64 ag_thr_base, c_thr_base; + .reg .u32 b_smem, b_thr_base, tma_mbar; + .reg .pred p_tid0, pwarp_exit, p_load_warp, p_warp_lead; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif +% endfor + + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + mov.u32 ctaid_x, %ctaid.x; + mul.lo.u32 n_start_cta, ctaid_x, ${n_per_cta}; + + { + .reg .u32 t; + div.u32 warp_n, warp, ${msplit}; + mad.lo.u32 t, warp_n, ${msplit}, 0; + sub.u32 warp_m, warp, t; + } + + { + .reg .u32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b_smem, dynm_base, ${b_off}; + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + } + + setp.eq.u32 p_tid0, tid, 0; + setp.eq.u32 p_load_warp, warp, 0; + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + @p_tid0 mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_tid0 fence.proxy.async.shared::cta; + bar.sync 0; + + @!p_load_warp bra $L_AFTER_B_TMA; + { + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_smem], [bdesc_addr, {n_start_cta, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_WAIT: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_WAIT; + } +$L_AFTER_B_TMA: + bar.sync 0; + + { + .reg .u32 t; + mul.lo.u32 t, warp_n, ${n_per_warp}; + add.u32 warp_n_base, n_start_cta, t; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*sizeof(f64) + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u32 bcol_local, t, row_off; + mad.lo.u32 bcol_local, warp_n, ${n_per_warp}, r_div4; + mul.lo.u32 t, bcol_local, ${dwidth_i}; + mul.lo.u32 row_off, r_mod4, ${n_per_cta * dwidth_i}; + add.u32 t, t, row_off; + add.u32 b_thr_base, b_smem, t; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for wm in range(msplit): +<% + owned_mts = [mt for mt in range(m_tiles) if mt % msplit == wm] +%> +% if owned_mts: + { + .reg .pred p_this_msplit; + setp.ne.u32 p_this_msplit, warp_m, ${wm}; + @p_this_msplit bra $L_SKIP_MS_${wm}; + } + { + .reg .${pftype} a_frag_<${a_regs}>; +% for nt in range(nn): + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% endfor +% for nt in range(nn): +% for mt in owned_mts: + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor +% endfor +% for mt in owned_mts: +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; + } +% endif +% endfor +% endfor + +% for nt in range(nn): +% for mt in owned_mts: +% if beta_zero: +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor +% else: +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} {c0}, [caddr];', pm, pvc0, pred_reg=f'p0_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} + ${pred_emit(f'ld.weak.global.cg.{pftype} {c1}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} + } +% endfor +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +% for kg in range(k_groups): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = f'pbrow_{kg}' if k_tail else None +%> + { + .reg .u32 baddr; + add.u32 baddr, b_thr_base, ${ki * b_smem_kiter_stride + kg * b_smem_kgroup_stride + nt * b_smem_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; + } +% endif + ${pred_emit(f'ld.shared.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{wm}_{ki}_{nt}_{kg}', indent=' ' * 12)} + } +% endfor +% endfor +% for mt in owned_mts: +% for ai in range(a_regs): + ld.weak.global.${pftype} a_frag_${ai}, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor +% for nt in range(nn): + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; +% endfor +% endfor +% endfor + +% for mt in owned_mts: +% for nt in range(nn): +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], {c0};', pm, pvc0, pred_reg=f'p0s_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], {c1};', pm, pvc1, pred_reg=f'p1s_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} + } +% endfor +% endfor +% endfor + } +$L_SKIP_MS_${wm}: +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-astream-v2.mako b/gimmik/kernels/ptx/dmma-astream-v2.mako new file mode 100644 index 0000000..395640e --- /dev/null +++ b/gimmik/kernels/ptx/dmma-astream-v2.mako @@ -0,0 +1,203 @@ +<%inherit file='base'/> +/* + dmma-astream-v2 + + Dense FP64 kernel using configurable warp-level DMMA tiles. The tiles of A + are precomputed and put in global memory within this compilation unit. Then + tiles of B and A are streamed from global into registers. This kernel uses + 128-bit loads/stores for C. + */ + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .${pftype} a_frag_<${a_regs}>; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% for mt in range(m_tiles): + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*sizeof(f64) + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(m_tiles): +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; + } +% endif +% endfor +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta_zero: +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor +% else: +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if pm is not None: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{{cpair}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}_{mg}')} + } +% endfor +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +% for kg in range(k_groups): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = f'pbrow_{kg}' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + kg * b_kgroup_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; + } +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}_{kg}')} + } +% endfor +% endfor +% for mt in range(m_tiles): +% for ai in range(a_regs): + ld.weak.global.${pftype} a_frag_${ai}, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor +% for nt in range(nn): + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; +% endfor +% endfor +% endfor + +% for mt in range(m_tiles): +% for nt in range(nn): +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{{cpair}}};', pm, pred_reg=f'p01s_{nt}_{mt}_{mg}')} + } +% endfor +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-astream.mako b/gimmik/kernels/ptx/dmma-astream.mako new file mode 100644 index 0000000..4ab05a9 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-astream.mako @@ -0,0 +1,208 @@ +<%inherit file='base'/> +/* + dmma-astream-v1 + + Dense FP64 kernel using configurable warp-level DMMA tiles. The tiles of A + are precomputed and put in global memory within this compilation unit. Then + tiles of B and A are streamed from global into registers. This kernel uses + scalar loads/stores for C. + */ + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .${pftype} a_frag_<${a_regs}>; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% for mt in range(m_tiles): + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*sizeof(f64) + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(m_tiles): +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; + } +% endif +% endfor +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta_zero: +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor +% else: +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} {c0}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}_{mg}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} {c1}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}_{mg}')} + } +% endfor +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +% for kg in range(k_groups): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = f'pbrow_{kg}' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + kg * b_kgroup_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; + } +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}_{kg}')} + } +% endfor +% endfor +% for mt in range(m_tiles): +% for ai in range(a_regs): + ld.weak.global.${pftype} a_frag_${ai}, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor +% for nt in range(nn): + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; +% endfor +% endfor +% endfor + +% for mt in range(m_tiles): +% for nt in range(nn): +% for mg in range(m_groups): +<% + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], {c0};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}_{mg}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], {c1};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}_{mg}')} + } +% endfor +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-steal-ws.mako b/gimmik/kernels/ptx/dmma-steal-ws.mako new file mode 100644 index 0000000..0e97cd5 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-steal-ws.mako @@ -0,0 +1,432 @@ +<%inherit file='base'/> + +<%def name="producer_init_setup()"> + // Producer warp: initial A bulk-copy + B load for ctaid_x's work + @!p_prod bra.uni $L_AFTER_INIT_B; + { + .reg .b32 n_start0; + .reg .u64 a_glb; + mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + mov.u64 a_glb, ${kname}_Ag; + cvta.to.global.u64 a_glb, a_glb; + @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes + [a_smem], [a_glb], ${8 * 32 * m_tiles * k_tiles}, [tma_mbar]; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes + 8 * 32 * m_tiles * k_tiles}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_INIT_W: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_INIT_W; + .reg .b64 _state2; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; + } +$L_AFTER_INIT_B: + + +<%def name="compute_warp_body()"> + // --- Compute Warps + @!p_compute bra.uni $L_AFTER_COMPUTE; + + // Wait on B + { + .reg .pred p1; +$L_WAIT_BRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BRDY; + } + + // MMA + { + .reg .b32 b_sm_a; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_sm_a, b2_smem, b1_smem, p_ph; + + .reg .b32 a_thr_a; + { + .reg .b32 t; + shl.b32 t, lane, 3; + add.u32 a_thr_a, a_smem, t; + } +% for nt in range(nn): + .reg .b32 b_thr_a_${nt}; + { + .reg .b32 bcol_g, t_off; + add.u32 bcol_g, base_bcol, ${8 * nt}; + shl.b32 t_off, bcol_g, 3; + add.u32 b_thr_a_${nt}, b_sm_a, t_off; + } +% endfor + +% if beta_zero: + // beta=0: skip shared-staging entirely; compute warps store MMA + // outputs straight to global C with N-tail predication. + .reg .u64 c_glob_addr; + ld.param.u64 c_glob_addr, [c_desc]; + cvta.to.global.u64 c_glob_addr, c_glob_addr; +% else: + .reg .b32 c_thr_smem; + { + .reg .b32 t1, ccol_b; + mul.lo.u32 t1, base_crow, ${n_per_cta * dwidth_i}; + shl.b32 ccol_b, base_ccol, 3; + add.u32 c_thr_smem, c_smem, t1; + add.u32 c_thr_smem, c_thr_smem, ccol_b; + } +% endif + + // Zero accumulators +% for mt in range(m_tiles): +% for nt in range(nn): + .reg .${pftype} d_x_${mt}_${nt}, d_y_${mt}_${nt}; + mov.${pftype} d_x_${mt}_${nt}, ${fzero}; + mov.${pftype} d_y_${mt}_${nt}, ${fzero}; +% endfor +% endfor + + .reg .${pftype} a_f; +% for mt in range(m_tiles): +% for kt in range(k_tiles): +<% + k_tail = (k_rem != 0 and loop.last) +%> + { + .reg .b32 a_a; + add.u32 a_a, a_thr_a, ${(32 * kt + 32 * mt * k_tiles) * dwidth_i}; + ld.shared.${pftype} a_f, [a_a]; +% if k_tail: + .reg .pred pbrow_${mt}_${kt}; + { + .reg .b32 brow; + add.u32 brow, base_brow, ${4 * kt}; + setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; + } +% endif +% for nt in range(nn): + { + .reg .b32 b_a, b_row; + .reg .${pftype} b_f; + add.u32 b_row, base_brow, ${4 * kt}; + mul.lo.u32 b_row, b_row, ${n_per_cta * dwidth_i}; + add.u32 b_a, b_thr_a_${nt}, b_row; +% if k_tail: + mov.${pftype} b_f, ${fzero}; + @pbrow_${mt}_${kt} ld.shared.${pftype} b_f, [b_a]; +% else: + ld.shared.${pftype} b_f, [b_a]; +% endif + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor + } +% endfor +% endfor + +% if beta_zero: + .reg .u64 c_thr_glob_base; + { + .reg .u32 thr_col_off, thr_addr_off_lo; + add.u32 thr_col_off, base_ccol, n_start_curr; + mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; + .reg .u64 thr_byte_off; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, ${dwidth_i}; + add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; + } +% for mt in range(m_tiles): +<% + row_tail = pm_runtime(mt) +%> +% if row_tail: + .reg .pred p_row_${mt}; + { + .reg .b32 crow; + add.u32 crow, base_crow, ${8 * mt}; + setp.lt.u32 p_row_${mt}, crow, ${m}; + } +% endif +% for nt in range(nn): + { + .reg .pred p_st; + .reg .u32 g_ccol; + add.u32 g_ccol, base_ccol, ${8 * nt}; + add.u32 g_ccol, g_ccol, n_start_curr; + setp.lt.u32 p_st, g_ccol, ${n}; +% if row_tail: + and.pred p_st, p_st, p_row_${mt}; +% endif + .reg .u64 _c_addr; + add.u64 _c_addr, c_thr_glob_base, ${(8 * mt * ldc + 8 * nt) * dwidth_i}; + @p_st st.weak.global.v2.${pftype} [_c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% else: + // Wait until producer's prev-iter TMA-store of C has drained. + { + .reg .pred p1; +$L_WAIT_CSTORE: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cstored_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CSTORE; + } + + // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows + // are dropped by the C tensor map. +% for mt in range(m_tiles): +% for nt in range(nn): + { + .reg .b32 csaddr; + add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; + st.shared.v2.${pftype} [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% endif + +% if not beta_zero: + bar.sync 1, ${comp_threads}; + fence.proxy.async.shared::cta; + { + .reg .b64 _state; + @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; + } +% endif + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_C: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_C; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + } +$L_AFTER_COMPUTE: + + +<%def name="data_warp_body()"> + // --- Data Movement Warp + @!p_prod bra.uni $L_AFTER_DATA; + { + .reg .b32 n_c_store; + mul.lo.u32 n_c_store, block_idx_x, ${n_per_cta}; + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_D: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_D; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + + // TMA loads of next B + { + mul.lo.u32 n_start_next, block_idx_x, ${n_per_cta}; + .reg .b32 b_next; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_next, b1_smem, b2_smem, p_ph; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_next], [bdesc_addr, {n_start_next, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + @p_warp_lead cp.async.bulk.commit_group; + } + bar.warp.sync 0xffffffff; + +% if not beta_zero: + // TMA reduce+store of C (beta=1 only; beta=0 uses direct global + // stores from compute warps, so the producer does no C work). + { + .reg .pred p1; + .reg .b64 _c_state; +$L_WAIT_CRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CRDY; + @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; + @p_warp_lead cp.async.bulk.commit_group; + @p_warp_lead cp.async.bulk.wait_group 0; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; + } +% endif + + // Wait for next B to be ready, then signal B and C ready + { + .reg .b64 b_state, _bready_state, _c_state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 b_state, [tma_mbar]; +$L_WAIT_TMA: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], b_state, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_TMA; + + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bready_state, [bready_mbar]; + } + } +$L_AFTER_DATA: + + +<%def name="ctrl_warp_body()"> + // --- Controller Warp + @!p_steal bra.uni $L_AFTER_CTRL; + { + .reg .pred p1, p2, p_canc; + .reg .b64 _state; + .reg .b128 resp; + @p_warp_lead fence.proxy.async.shared::cta; + @p_warp_lead clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 + [wid_smem], [steal_mbar]; + @p_warp_lead mbarrier.arrive.expect_tx.shared::cta.b64 + _state, [steal_mbar], 16; + +$L_WAIT_STEAL: + mbarrier.try_wait.parity.shared::cta.b64 p1, [steal_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_STEAL; + + // Signal new work + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_new_mbar]; + + // Query if there's new work + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + selp.b32 work, 1, 0, p_canc; + + // Wait for old work to be used +$L_WAIT_WUSED: + mbarrier.try_wait.parity.shared::cta.b64 p2, [wid_used_mbar], phase, ${mbar_maxwait}; + @!p2 bra.uni $L_WAIT_WUSED; + } +$L_AFTER_CTRL: + + +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; + +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 c_desc) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; + .reg .b32 wid_used_mbar, wid_smem; + .reg .pred p_compute, p_prod, p_steal; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + add.u32 wid_smem, dynm_base, ${wid_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; + add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; + add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; + + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 cdesc_addr, [c_desc]; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + setp.eq.u32 p_steal, warp, ${steal_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 + // can write csmem immediately. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; + @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Compute-warp lane geometry + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + ${producer_init_setup()} + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 phase, 0; + +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + ${compute_warp_body()} + + ${data_warp_body()} + + ${ctrl_warp_body()} + + xor.b32 phase, phase, 1; + bra.uni $L_LOOP; + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-stride-ws.mako b/gimmik/kernels/ptx/dmma-stride-ws.mako new file mode 100644 index 0000000..87b87d6 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-stride-ws.mako @@ -0,0 +1,401 @@ +<%inherit file='base'/> + +<%def name="producer_init_setup()"> + // Producer warp: initial A bulk-copy + first B load + @!p_prod bra.uni $L_AFTER_INIT_B; + { + .reg .b32 n_start0; + .reg .u64 a_glb; + mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + mov.u64 a_glb, ${kname}_Ag; + cvta.to.global.u64 a_glb, a_glb; + @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes + [a_smem], [a_glb], ${a_elems * dwidth_i}, [tma_mbar]; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes + a_elems * dwidth_i}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_INIT_W: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_INIT_W; + .reg .b64 _state2; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; + } +$L_AFTER_INIT_B: + + +<%def name="compute_warp_body()"> + // --- Compute Warps + @!p_compute bra.uni $L_AFTER_COMPUTE; + + // Wait on the current B tile. + { + .reg .pred p1; +$L_WAIT_BRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BRDY; + } + + // MMA + { + .reg .b32 b_sm_a; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_sm_a, b2_smem, b1_smem, p_ph; + + .reg .b32 a_thr_a; + { + .reg .b32 t; + shl.b32 t, lane, 3; + add.u32 a_thr_a, a_smem, t; + } +% for nt in range(nn): + .reg .b32 b_thr_a_${nt}; + { + .reg .b32 bcol_g, t_off; + add.u32 bcol_g, base_bcol, ${tile_n * nt}; + shl.b32 t_off, bcol_g, 3; + add.u32 b_thr_a_${nt}, b_sm_a, t_off; + } +% endfor + +% if beta_zero: + // beta=0: skip shared-staging entirely; compute warps store MMA + // outputs straight to global C with N-tail predication. + .reg .u64 c_glob_addr; + ld.param.u64 c_glob_addr, [c_desc]; + cvta.to.global.u64 c_glob_addr, c_glob_addr; +% else: + .reg .b32 c_thr_smem; + { + .reg .b32 t1, ccol_b; + mul.lo.u32 t1, base_crow, ${n_per_cta * dwidth_i}; + shl.b32 ccol_b, base_ccol, 3; + add.u32 c_thr_smem, c_smem, t1; + add.u32 c_thr_smem, c_thr_smem, ccol_b; + } +% endif + + // Zero accumulators +% for mt in range(m_tiles): +% for nt in range(nn): + .reg .${pftype} d_${mt}_${nt}_<${c_regs}>; +% for ci in range(c_regs): + mov.${pftype} d_${mt}_${nt}_${ci}, ${fzero}; +% endfor +% endfor +% endfor + + .reg .${pftype} a_frag_<${a_regs}>; +% for nt in range(nn): + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% endfor +% for kt in range(k_tiles): +% for nt in range(nn): +% for kg in range(k_groups): +<% + k_tail = (k_rem != 0 and loop.parent.parent.last) + pbrow = f'pbrow_{nt}_{kg}' if k_tail else None +%> + { + .reg .b32 b_a, b_row, b_off; + add.u32 b_row, base_brow, ${tile_k * kt + 4 * kg}; + mul.lo.u32 b_off, b_row, ${n_per_cta * dwidth_i}; + add.u32 b_a, b_thr_a_${nt}, b_off; +% if k_tail: + .reg .pred ${pbrow}; + setp.lt.u32 ${pbrow}, b_row, ${k}; + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; + @${pbrow} ld.shared.${pftype} b_frag_${nt}_${kg}, [b_a]; +% else: + ld.shared.${pftype} b_frag_${nt}_${kg}, [b_a]; +% endif + } +% endfor +% endfor +% for mt in range(m_tiles): +% for ai in range(a_regs): + ld.shared.${pftype} a_frag_${ai}, [a_thr_a + ${(mt * k_tiles + kt) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor +% for nt in range(nn): + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'd_{mt}_{nt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'd_{mt}_{nt}', c_regs)}; +% endfor +% endfor +% endfor + +% if beta_zero: + .reg .u64 c_thr_glob_base; + { + .reg .u32 thr_col_off, thr_addr_off_lo; + add.u32 thr_col_off, base_ccol, n_start_curr; + mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; + .reg .u64 thr_byte_off; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, ${dwidth_i}; + add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; + } +% for mt in range(m_tiles): +% for mg in range(m_groups): +<% + row_tail = pm_runtime(mt, mg) +%> +% if row_tail: + .reg .pred p_row_${mt}_${mg}; + { + .reg .b32 crow; + add.u32 crow, base_crow, ${tile_m * mt + 8 * mg}; + setp.lt.u32 p_row_${mt}_${mg}, crow, ${m}; + } +% endif +% for nt in range(nn): + { + .reg .pred p_st; + .reg .u32 g_ccol; + add.u32 g_ccol, base_ccol, ${tile_n * nt}; + add.u32 g_ccol, g_ccol, n_start_curr; + setp.lt.u32 p_st, g_ccol, ${n}; +% if row_tail: + and.pred p_st, p_st, p_row_${mt}_${mg}; +% endif + .reg .u64 _c_addr; + add.u64 _c_addr, c_thr_glob_base, ${((tile_m * mt + 8 * mg) * ldc + tile_n * nt) * dwidth_i}; + @p_st st.weak.global.v2.${pftype} [_c_addr], {d_${mt}_${nt}_${2*mg}, d_${mt}_${nt}_${2*mg + 1}}; + } +% endfor +% endfor +% endfor +% else: + // Wait until producer's prev-iter TMA-store of C has drained. + { + .reg .pred p1; +$L_WAIT_CSTORE: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cstored_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CSTORE; + } + + // Vector-store accumulator pairs to csmem. M-tail / N-tail OOB rows + // are dropped by the C tensor map. +% for mt in range(m_tiles): +% for mg in range(m_groups): +% for nt in range(nn): + { + .reg .b32 csaddr; + add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + mg * c_mgroup_smem_stride + nt * c_ntile_smem_stride}; + st.shared.v2.${pftype} [csaddr], {d_${mt}_${nt}_${2*mg}, d_${mt}_${nt}_${2*mg + 1}}; + } +% endfor +% endfor +% endfor +% endif + +% if not beta_zero: + bar.sync 1, ${comp_threads}; + fence.proxy.async.shared::cta; + { + .reg .b64 _state; + @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; + } +% endif + + // Match the stealing kernel's "work used" point: retire the current + // phase only after compute-side work for the tile is complete. + { + .reg .b64 _bconsumed_state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bconsumed_state, [bconsumed_mbar]; + } + + } +$L_AFTER_COMPUTE: + + +<%def name="data_warp_body()"> + // --- Data Movement Warp + @!p_prod bra.uni $L_AFTER_DATA; + { + .reg .b32 n_c_store; + mul.lo.u32 n_c_store, block_idx_x, ${n_per_cta}; + + .reg .pred p_next_work; + setp.ne.u32 p_next_work, next_work, 0; + + // Issue the next B load into the alternate buffer before retiring + // the current phase. + .reg .b64 b_state; + @!p_next_work bra.uni $L_SKIP_NEXT_B_ISSUE; + { + mul.lo.u32 n_start_next, next_block_idx_x, ${n_per_cta}; + .reg .b32 b_next; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_next, b1_smem, b2_smem, p_ph; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_next], [bdesc_addr, {n_start_next, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + @p_warp_lead cp.async.bulk.commit_group; + bar.warp.sync 0xffffffff; + mbarrier.arrive.shared::cta.b64 b_state, [tma_mbar]; + } +$L_SKIP_NEXT_B_ISSUE: + +% if not beta_zero: + // TMA reduce+store of C (beta=1 only; beta=0 uses direct global + // stores from compute warps, so the producer does no C work). + { + .reg .pred p1; + .reg .b64 _c_state; +$L_WAIT_CRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CRDY; + @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; + @p_warp_lead cp.async.bulk.commit_group; + @p_warp_lead cp.async.bulk.wait_group 0; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; + } +% endif + + // Wait for the next B load to complete, then mark it ready for the + // next compute iteration. + @!p_next_work bra.uni $L_SKIP_NEXT_B_READY; + { + .reg .b64 _bready_state; + .reg .pred p1; +$L_WAIT_TMA: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], b_state, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_TMA; + +$L_WAIT_BCONSUMED: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bconsumed_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BCONSUMED; + + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bready_state, [bready_mbar]; + } +$L_SKIP_NEXT_B_READY: + } +$L_AFTER_DATA: + + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; + +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 c_desc) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, next_work, iter, next_iter; + .reg .b32 block_idx_x, next_block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, bready_mbar, bconsumed_mbar, cready_mbar, cstored_mbar; + .reg .pred p_compute, p_prod; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 bconsumed_mbar, dynm_base, ${bconsumed_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 cdesc_addr, [c_desc]; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive cstored so compute iter 0 + // can write csmem immediately when beta != 0. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [bconsumed_mbar], ${n_comp_warps}; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Compute-warp lane geometry + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + ${producer_init_setup()} + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 iter, 0; + mov.u32 phase, 0; + + // Bounded grid-stride loop: block_idx_x = ctaid.x + iter*grid_stride. +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + { + .reg .pred p_iter, p_block, p_next; + add.u32 next_iter, iter, 1; + add.u32 next_block_idx_x, block_idx_x, ${grid_stride}; + setp.lt.u32 p_iter, next_iter, ${stride_iters}; + setp.lt.u32 p_block, next_block_idx_x, ${work_blocks}; + and.pred p_next, p_iter, p_block; + selp.b32 next_work, 1, 0, p_next; + } + + ${compute_warp_body()} + + ${data_warp_body()} + + mov.u32 block_idx_x, next_block_idx_x; + mov.u32 work, next_work; + mov.u32 iter, next_iter; + xor.b32 phase, phase, 1; + bra.uni $L_LOOP; + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py new file mode 100644 index 0000000..69e01c6 --- /dev/null +++ b/gimmik/ptx.py @@ -0,0 +1,423 @@ +import numpy as np + +from gimmik.base import MatMul + + +class PTXMatMul(MatMul): + platform = 'ptx' + basemeta = { + 'block': (128, 1, 1), + 'width': 1, + 'shared': 0, + 'dynamic_shared': 0 + } + + # Map explicitly supported CC to minimum PTX version + PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 6), (10, 0): (8, 7), (10, 3): (8, 7), + (12, 0): (8, 7), (12, 1): (8, 7)} + + FZERO = {'float': '0f00000000', 'double': '0d0000000000000000'} + PFTYPE = {'float': 'f32', 'double': 'f64'} + + @classmethod + def is_sparse_suitable(cls, arr, cc): + cc = cc or (0, 0) + nnz = np.count_nonzero(arr) + nuq = len(np.unique(np.abs(arr))) + density = nnz / arr.size + return ((nuq <= 28) or (density <= 0.15)) and cc >= (7, 0) + + @classmethod + def is_dense_suitable(cls, arr, cc): + cc_appropriate = cc in cls.PTX_SM and cc >= (9, 0) + return (arr.dtype == np.float64 and cc_appropriate + and arr.shape[0] <= 128 and arr.shape[1] <= 128) + + @classmethod + def is_suitable(cls, arr, cc): + return cls.is_sparse_suitable(arr, cc) or cls.is_dense_suitable(arr, cc) + + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + smem_info=None): + cc = compute_capability or (0, 0) + smem_info = smem_info or (48*1024, 48*1024) + config = self._platform_config(dtype, cc) + + # When we know the PTX version but there isn't an SM specific config, + # we can overide the default PTX version + if cc in self.PTX_SM: + target_cc = cc + ptx = self.PTX_SM[cc] + else: + target_cc = tuple(config['cc']) + ptx = tuple(config['ptx']) + + cfgs = config['kernels'] + cfg = [k for k in cfgs if self._usable_config(k, dtype, cc, smem_info)] + + for k in cfg: + if prepared := self._get_render_args( + k, dtype, dsize, target_cc, smem_info, ptx + ): + yield prepared + + def render_config(self, kernel_cfg, dtype, dsize, cc, ptx, *, + kname='gimmik_mm', smem_info=None): + smem_info = smem_info or (48*1024, 48*1024) + + if not self._usable_config(kernel_cfg, dtype, cc, smem_info): + return None + + prepared = self._get_render_args( + kernel_cfg, dtype, dsize, cc, smem_info, ptx + ) + if prepared is None: + return None + + tpl, exargs, exmeta = prepared + + args = self._base_template_args(dtype, kname) | exargs + meta = self.basemeta | exmeta + meta['tplname'] = tpl + self._process_meta(meta) + src = self._render_kernel(dtype, tpl, args) + return src, args, meta + + def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, ptx): + tpl = kernel_cfg['template'] + family = kernel_cfg['family'] + block = tuple(kernel_cfg['block']) + width = kernel_cfg['width'] + params = kernel_cfg.get('params', {}) + base_args = { + 'ptx': ptx, + 'cc': cc, + 'smem_info': smem_info, + 'pred_emit': self._pred_emit, + 'pftype': self.PFTYPE[dtype], + 'dwidth_i': dsize, + 'fzero': self.FZERO[dtype], + 'beta_zero': self.beta == 0, + 'mbar_maxwait': hex(10000000), + 'use_cpasync': cc >= (8, 0), + 'width': width, + 'reg_list': self._reg_list, + } + base_meta = { + 'block': block, + 'width': width, + 'desc': kernel_cfg['descriptor'], + } + + match family: + case 'sparse': + cfg = self._sparse_args(tpl, params, block, dtype, dsize, + base_args, base_meta) + case 'dense': + cfg = self._dense_args(kernel_cfg, params, cc, smem_info, + base_args, base_meta) + case 'dense-ws': + cfg = self._dense_ws_args(kernel_cfg, params, cc, smem_info, + base_args, base_meta) + case _: + raise ValueError(f'Unknown PTX template family for {tpl}') + + return cfg + + def _sparse_args(self, tpl, params, block, dtype, dsize, args, meta): + blockx = block[0] + args |= {'has_zero_rows': bool(self.has_zero_rows), + 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) + if self.A[j, kx] != 0] for j in range(self.m)], + 'preload_c': bool(params.get('preload_c', False)), + } + + match tpl: + case 'cstream' | 'bstream': + pass + case 'bstream-msplit' | 'bstream-msplit-v2': + bsz = params['bsz'] + args |= {'msplit': block[1], 'bsz': bsz, 'blockx': blockx} + meta['shared'] = 2*bsz*blockx*dsize*args['width'] + case 'cstream-ksplit' | 'cstream-ksplit-v2': + csz = params['csz'] + args |= {'ksplit': block[1], 'csz': csz, 'blockx': blockx} + meta['shared'] = (block[1] - 1)*csz*blockx*dsize*args['width'] + case _: + args['blockx'] = blockx + return tpl, args, meta + + def _dense_args(self, kernel_cfg, params, cc, smem_info, args, meta): + tpl = kernel_cfg['template'] + nn = params['nn'] + warps = params['warps'] + tile = kernel_cfg['tile'] + width = kernel_cfg['width'] + + setup = self._dense_common(nn, warps, tile, cc, width) + if setup is None: + return None + + args |= setup + if tpl.startswith('dmma-asmem'): + args |= { + 'a_copy_threads': 32 * warps, + 'block_stealing': bool(params.get('block_stealing', False)), + } + meta['grid'] = (-(-self.n // setup['n_per_cta']), 1, 1) + + if (msplit := params.get('msplit')) is None: + return tpl, args, meta + + n_per_cta = setup['n_per_cta'] + k_pad = setup['k_tiles'] * setup['tile_k'] + b_tile_bytes = k_pad * n_per_cta * args['dwidth_i'] + + offsets, dynm_total_bytes = self._dsmem_alloc( + [('b', b_tile_bytes)], ('tma',) + ) + if dynm_total_bytes > smem_info[1]: + return None + + args |= { + 'msplit': msplit, + 'b_tile_bytes': b_tile_bytes, + 'b_smem_kiter_stride': (setup['tile_k'] * n_per_cta + * args['dwidth_i']), + 'b_smem_kgroup_stride': 4 * n_per_cta * args['dwidth_i'], + 'b_smem_ntile_stride': setup['tile_n'] * args['dwidth_i'], + 'blockx_total': 32 * warps * msplit, + } | offsets + meta |= { + 'ws_b_tile': (n_per_cta, k_pad), + 'dynamic_shared': dynm_total_bytes, + } + + return tpl, args, meta + + def _dense_common(self, nn, warps_per_cta, tile, cc, width=None): + tile_m, tile_n, tile_k = tile['m'], tile['n'], tile['k'] + ptx_shape = f'm{tile_m}n{tile_n}k{tile_k}' + + m_groups, k_groups = tile_m // 8, tile_k // 4 + a_regs = m_groups * k_groups + b_regs = k_groups + c_regs = 2 * m_groups + + a = self.A + m, k = a.shape + m_tiles, k_tiles = -(-m // tile_m), -(-k // tile_k) + k_rem = k % tile_k + n_per_warp = tile_n * nn + n_per_cta = warps_per_cta * n_per_warp + + if n_per_cta > self.n: + return None + + if (width == 2 + and (self.aligne is None or self.aligne % 2 + or self.n % n_per_warp)): + return None + + # A in DMMA-fragment layout, packed in PTX A-operand register order. + # This will handle 8x8x4 as well as additional sm90 sizes. + a_pad = np.zeros((m_tiles*tile_m, k_tiles*tile_k), dtype=a.dtype) + a_pad[:m, :k] = a + tile_shape = m_tiles, m_groups, 8, k_tiles, k_groups, 4 + tile_order = 0, 3, 4, 1, 2, 5 + a_tiles = a_pad.reshape(*tile_shape).transpose(*tile_order).ravel() + a_u64 = [f'0x{u:016x}' for u in a_tiles.view(np.uint64)] + + # Predicate-elision flags + n_col_aligned = (self.n is not None and self.n % n_per_warp == 0) + def pm_runtime(mt, mg=0): + return mt*tile_m + 8*(mg + 1) > m + + return { + 'tile_m': tile_m, + 'tile_n': tile_n, + 'tile_k': tile_k, + 'ptx_mma_shape': ptx_shape, + 'm_groups': m_groups, + 'k_groups': k_groups, + 'a_regs': a_regs, + 'b_regs': b_regs, + 'c_regs': c_regs, + 'a_elems': m_tiles * k_tiles * tile_m * tile_k, + 'nn': nn, + 'm_tiles': m_tiles, + 'k_tiles': k_tiles, + 'k_rem': k_rem, + 'a_u64': a_u64, + 'n_per_warp': n_per_warp, + 'n_per_cta': n_per_cta, + 'frag_stride_bytes': 8 * tile_m * tile_k, + 'b_kiter_stride': 8 * tile_k * (self.ldb or 0), + 'b_kgroup_stride': 32 * (self.ldb or 0), + 'b_ntile_stride': 8 * tile_n, + 'c_mtile_stride': 8 * tile_m * (self.ldc or 0), + 'c_mgroup_stride': 64 * (self.ldc or 0), + 'c_ntile_stride': 8 * tile_n, + 'n_col_aligned': n_col_aligned, + 'pm_runtime': pm_runtime, + } + + def _dense_ws_args(self, kernel_cfg, params, cc, smem_info, args, meta): + dynamic_max = smem_info[1] + tpl = kernel_cfg['template'] + nn = params['nn'] + tile = kernel_cfg['tile'] + warp_map = kernel_cfg['warp_map'] + + match tpl: + case 'dmma-steal-ws': + if (tile['m'], tile['n'], tile['k']) != (8, 8, 4): + return None + service_warps = 2 + case 'dmma-stride-ws': + service_warps = 1 + case _: + raise ValueError('Unknown dense warp-specialized template ' + f'{tpl}') + + n_comp_warps = warp_map['compute_count'] + setup = self._dense_common(nn, n_comp_warps, tile, cc) + if setup is None: + return None + + n_per_cta = setup['n_per_cta'] + m_pad = setup['m_tiles'] * setup['tile_m'] + k_pad = setup['k_tiles'] * setup['tile_k'] + b_tile_bytes = 8 * k_pad * n_per_cta + c_tile_bytes = 8 * m_pad * n_per_cta + a_bytes = 8 * setup['a_elems'] + regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), + ('c', c_tile_bytes), ('a', a_bytes)] + ws_setup = { + 'n_comp_warps': n_comp_warps, + 'blockx_total': 32 * (n_comp_warps + service_warps), + 'prod_warp': warp_map['producer'], + 'comp_threads': 32 * n_comp_warps, + 'b_tile_bytes': b_tile_bytes, + 'c_mtile_smem_stride': 8 * setup['tile_m'] * n_per_cta, + 'c_mgroup_smem_stride': 64 * n_per_cta, + 'c_ntile_smem_stride': 8 * setup['tile_n'], + } + + match tpl: + case 'dmma-steal-ws': + regions.append(('wid', 16)) + mbars = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + grid = (-(-self.n // n_per_cta), 1, 1) + ws_setup['steal_warp'] = warp_map['stealer'] + case 'dmma-stride-ws': + stride_iters = params['iters'] + work_blocks = -(-self.n // n_per_cta) + grid_stride = -(-work_blocks // stride_iters) + mbars = ('tma', 'bready', 'bconsumed', 'cready', 'cstored') + grid = (grid_stride, 1, 1) + ws_setup |= { + 'stride_iters': stride_iters, + 'grid_stride': grid_stride, + 'work_blocks': work_blocks, + } + + offsets, dynm_total_bytes = self._dsmem_alloc(regions, mbars) + if dynm_total_bytes > dynamic_max: + return None + + args |= setup | ws_setup | offsets + meta |= { + 'grid': grid, + 'ws_b_tile': (n_per_cta, k_pad), + 'dynamic_shared': dynm_total_bytes, + } + if self.beta != 0: + meta['ws_out_tile'] = (n_per_cta, m_pad) + + return tpl, args, meta + + def _usable_config(self, kernel_cfg, dtype, cc, smem_info): + family = kernel_cfg['family'] + + if family == 'sparse' and not self.is_sparse_suitable(self.A, cc): + return False + elif (family in {'dense', 'dense-ws'} + and (dtype != 'double' or self.n is None + or not self.is_dense_suitable(self.A, cc))): + return False + + condition = kernel_cfg.get('conditions') + if condition is None: + return True + else: + stats = self._matmul_stats(dtype, cc, smem_info) + return self._eval_condition(condition, stats) + + def _platform_config(self, dtype, cc): + cc = cc or (0, 0) + key = f'sm{cc[0]}{cc[1]}_{dtype}' + return self._get_config(key) + + def _matmul_stats(self, dtype, cc, smem_info): + nnz = np.count_nonzero(self.A) + return { + 'dtype': dtype, + 'm': self.m, + 'k': self.k, + 'n': self.n, + 'beta': self.beta, + 'beta_zero': self.beta == 0, + 'aligne': self.aligne, + 'nnz': nnz, + 'density': nnz / self.A.size, + 'unique_abs': len(np.unique(np.abs(self.A))), + 'k_used': len(self.bix), + 'cc': list(cc), + 'smem_static': smem_info[0], + 'smem_dynamic': smem_info[1], + } + + @staticmethod + def _dsmem_alloc(regions, mbars, align=16): + # For a set of regions and mbars and there sizes, work out dynamic + # shared memory pointers offset for a given alignemnt. + out, off = {}, 0 + for name, size in regions: + off = (off + align - 1) & ~(align - 1) + out[f'{name}_off'] = off + off += size + for name in mbars: + out[f'{name}_mbar_off'] = off + off += 8 + total = (off + align - 1) & ~(align - 1) + return out, total + + @staticmethod + def _reg_list(prefix, n): + regs = ', '.join(f'{prefix}_{i}' for i in range(n)) + return f'{{{regs}}}' + + @staticmethod + def _pred_emit(instr, *preds, pred_reg=None, indent=8 * ' '): + # Handle whether an instruction needs a predicate or not + actual = [p for p in preds if p is not None] + if not actual: + return instr + if len(actual) == 1: + return f'@{actual[0]} {instr}' + if pred_reg is None: + raise ValueError('pred_reg required when combining multiple ' + 'predicates') + lines = [f'.reg .pred {pred_reg};', + f'and.pred {pred_reg}, {actual[0]}, {actual[1]};'] + for p in actual[2:]: + lines.append(f'and.pred {pred_reg}, {pred_reg}, {p};') + lines.append(f'@{pred_reg} {instr}') + return f'\n{indent}'.join(lines) + + def _process_meta(self, meta): + if self.n is not None and 'grid' not in meta: + div = meta['block'][0]*meta['width'] + meta['grid'] = (-(-self.n // div), 1, 1) diff --git a/setup.py b/setup.py index f1e94ff..3e41c60 100755 --- a/setup.py +++ b/setup.py @@ -7,8 +7,8 @@ # Python version -if sys.version_info[:2] < (3, 9): - print('GiMMiK requires Python 3.9 or newer') +if sys.version_info[:2] < (3, 10): + print('GiMMiK requires Python 3.10 or newer') sys.exit(-1) # GiMMiK version @@ -22,7 +22,7 @@ # Data package_data = { - 'gimmik': ['kernels/*/*.mako'], + 'gimmik': ['kernels/*/*.mako', 'kernels/ptx/config/*.json'], } # Hard dependencies @@ -34,7 +34,6 @@ # Info classifiers = [ 'License :: OSI Approved :: BSD License', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Topic :: Scientific/Engineering' @@ -56,6 +55,7 @@ setup(name='gimmik', version=version, + python_requires='>=3.10', # Packages packages=['gimmik'],