From 0cd74858b095bae859f14219bba15ce8086057cd Mon Sep 17 00:00:00 2001 From: Will Trojak Date: Tue, 2 Dec 2025 22:13:19 +0000 Subject: [PATCH 01/21] [wip] added ptx generator for bstream --- gimmik/__init__.py | 4 +- gimmik/kernels/ptx/base.mako | 4 + gimmik/kernels/ptx/bstream.mako | 139 ++++++++++++++++++++++++++++++++ gimmik/ptx.py | 63 +++++++++++++++ 4 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 gimmik/kernels/ptx/base.mako create mode 100644 gimmik/kernels/ptx/bstream.mako create mode 100644 gimmik/ptx.py diff --git a/gimmik/__init__.py b/gimmik/__init__.py index b32ebdc..cd21134 100644 --- a/gimmik/__init__.py +++ b/gimmik/__init__.py @@ -8,6 +8,7 @@ from gimmik.hip import HIPMatMul from gimmik.metal import MetalMatMul from gimmik.opencl import OpenCLMatMul +from gimmik.ptx import PTXMatMul def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', @@ -22,7 +23,8 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', 'cuda': CUDAMatMul, 'ispc': ISPCMatMul, 'hip': HIPMatMul, - 'opencl': OpenCLMatMul + 'opencl': OpenCLMatMul, + 'ptx': PTXMatMul } mm = platmap[platform](alpha*mat, beta, None, n, ldb, ldc) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako new file mode 100644 index 0000000..0521b84 --- /dev/null +++ b/gimmik/kernels/ptx/base.mako @@ -0,0 +1,4 @@ +.version 8.7 +.target sm_${cc} +.address_size 64 +${next.body()} \ No newline at end of file diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako new file mode 100644 index 0000000..3ba8e93 --- /dev/null +++ b/gimmik/kernels/ptx/bstream.mako @@ -0,0 +1,139 @@ +<%inherit file='base'/> + +<% +pftype = "f32" if dtype == "float" else "f64" +putype = "u32" if dtype == "float" else "u64" +pbtype = "b32" if dtype == "float" else "b64" +rtype = "f" if dtype == "float" else "fd" +dwidth = "4" if dtype == "float" else "8" +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 n, ldb, ldc; + ld.param.u32 n, [_n]; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n; + mov.u32 n, ${n}; +%endif + .reg .u32 id; + .reg .u64 b, c; + .reg .${pftype} csub<${m}>; + .reg .${pftype} ctmp<${m}>; + .reg .pred p1; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 grid<3>; + mov.u32 grid0, %ntid.x; + mov.u32 grid1, %ctaid.x; + mov.u32 grid2, %tid.x; + mad.lo.u32 id, grid0, grid1, grid2; + } + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .${pftype} bv; + .reg .u32 boff<${len(bix)}>, coff; + .reg .u64 bptr<${len(bix)}>, cptr; +%for kx in bix: +% if n is None: + mul.lo.u32 boff${kx}, ldb, ${kx}; + ${address((f"bptr{kx}", "u64"), ("b", "u64"), dwidth, (f"boff{kx}", "u32"), ("id", "u32"))} +% else: + ${address((f"bptr{kx}", "u64"), ("b", "u64"), dwidth, (f"{ldb*kx}", "u32"), ("id", "u32"))} +% endif + ld.weak.global.cg.${pftype} bv, [bptr${kx}]; + +% for j, jx in enumerate(A[:, kx]): +% if jx != 0 and kx == afix[j]: + mul.${pftype} csub${j}, bv, ${jx}; +% elif jx != 0: + fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; +% endif + +% if kx == alix[j] and beta == 0: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + st.weak.global.cg.${pftype} [cptr], csub${j}; + +% elif kx == alix[j] and beta == 1: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + ld.weak.global.${pftype} ctmp${j}, [cptr]; + add.${pftype} ctmp${j}, ctmp${j}, csub${j}; + st.weak.global.cg.${pftype} [cptr], ctmp${j}; + +% elif kx == alix[j]: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + ld.weak.global.${pftype} ctmp${j}, [cptr]; + fma.rn.${pftype} ctmp${j}, ctmp${j}, ${beta}, csub${j}; + st.weak.global.cg.${pftype} [cptr], csub${j}; +% endif +% endfor +%endfor + } + + { + .reg .u32 coff; + .reg .u64 cptr; + .reg .${pftype} fz; + .reg .${putype} uz; + .reg .${pftype} cin, cout; + mov.${putype} uz, 0; + mov.${pbtype} fz, uz; + +%for j, jx in enumerate(afix): +% if jx == -1 and beta == 0: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + st.weak.global.cg.${pftype} [cptr], fz; + +% elif jx == -1 and beta != 1: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + ld.weak.global.cg.${pftype} cin, [cptr]; + mul.${pftype} cout, cin, ${beta}; + st.weak.globla.cg.${pftype} [cptr], cout; +% endif +%endfor + } + +$L_EXIT: + ret; +} \ No newline at end of file diff --git a/gimmik/ptx.py b/gimmik/ptx.py new file mode 100644 index 0000000..bccdb61 --- /dev/null +++ b/gimmik/ptx.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +from gimmik.base import MatMul + +class PTXSource: + def __init__(self): + self._src = "" + + def __iadd__(self, other): + self._src = f"{self}\n\t{other}" + return self + + def __str__(self): + return self._src + + def __repr__(self): + return self._src + + +class PTXMatMul(MatMul): + platform = 'ptx' + basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, + 'dynamic_shared': 0} + + def _address(self, out, base, size, *offs): + src = PTXSource() + out_type = out[1] + if out_type != base[1]: + raise RuntimeError("out and base must have the same type") + + if offs: + off_type = offs[0][1] + if not all(off[1] == off_type for off in offs): + raise RuntimeError("offsets must all have the same tpye") + + if len(offs) == 1: + off = offs[0] + mad_type = "lo" if out_type == off_type else "wide" + src += f"mad.{mad_type}.{off_type} {out[0]}, {size}, {off[0]}, {base[0]};" + else: + src += f".reg .{off_type} _addrs_acum;" + src += f"add.{off_type} _addrs_acum, {offs[0][0]}, {offs[1][0]};" + for off in offs[2:]: + src += f"add.{off_type} _addrs_acum, _addrs_acum, {off[0]};" + mad_type = "lo" if out_type == off_type else "wide" + src += f"mad.{mad_type}.{off_type} {out[0]}, {size}, _addrs_acum, {base[0]};" + else: + src += f"mov.{out_type} {out[0]}, {base[0]};" + return f"{{{src}\n\t}}" + + + def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + base_args = {'address': lambda o, b, s, *off: self._address(o, b, s, + *off), 'cc': compute_capability} + + # B streaming, C accumulation kernel + args = base_args | {} + yield ('bstream', args, {}) + + def _process_meta(self, meta): + if self.n is not None: + div = meta['block'][0]*meta['width'] + meta['grid'] = (-(-self.n // div), 1, 1) From 626c2f5b4b8d41be691fd45044c81600345a3ece Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Fri, 24 Apr 2026 12:40:34 -0700 Subject: [PATCH 02/21] Addtional sparse and dense work --- gimmik/base.py | 3 +- gimmik/kernels/ptx/base.mako | 2 +- gimmik/kernels/ptx/bstream-msplit.mako | 281 ++++++++++++++++++++++ gimmik/kernels/ptx/bstream.mako | 223 +++++++++-------- gimmik/kernels/ptx/cstream-ksplit.mako | 179 ++++++++++++++ gimmik/kernels/ptx/cstream-w2.mako | 98 ++++++++ gimmik/kernels/ptx/cstream.mako | 157 ++++++++++++ gimmik/kernels/ptx/dense-mma-gAd.mako | 210 ++++++++++++++++ gimmik/kernels/ptx/dense-mma-smem-gA.mako | 264 ++++++++++++++++++++ gimmik/ptx.py | 94 +++++++- 10 files changed, 1410 insertions(+), 101 deletions(-) create mode 100644 gimmik/kernels/ptx/bstream-msplit.mako create mode 100644 gimmik/kernels/ptx/cstream-ksplit.mako create mode 100644 gimmik/kernels/ptx/cstream-w2.mako create mode 100644 gimmik/kernels/ptx/cstream.mako create mode 100644 gimmik/kernels/ptx/dense-mma-gAd.mako create mode 100644 gimmik/kernels/ptx/dense-mma-smem-gA.mako diff --git a/gimmik/base.py b/gimmik/base.py index f547afc..0ecc29a 100644 --- a/gimmik/base.py +++ b/gimmik/base.py @@ -144,7 +144,8 @@ def _render_kernel(self, dtype, tplname, tplargs): src = tpl.render(**tplargs) # At single precision suffix all floating point constants by 'f' - if dtype == 'float': + # (PTX doesn't use an 'f' suffix for FP literals) + if dtype == 'float' and self.platform != 'ptx': src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?', r'\g<0>f', src) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako index 0521b84..e380f1b 100644 --- a/gimmik/kernels/ptx/base.mako +++ b/gimmik/kernels/ptx/base.mako @@ -1,4 +1,4 @@ .version 8.7 -.target sm_${cc} +.target sm_${cc[0]}${cc[1]}${"a" if cc[0] >= 9 else ""} .address_size 64 ${next.body()} \ No newline at end of file diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako new file mode 100644 index 0000000..77e7ce7 --- /dev/null +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -0,0 +1,281 @@ +<%inherit file='base'/> + +<% +pftype = "f32" if dtype == "float" else "f64" +dwidth_i = 4 if dtype == "float" else 8 +fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +has_zero_rows = any(jx == -1 for jx in afix) +mx = partition(A, into=msplit, by='rows') +bix_list = list(bix) +bchunks = chunk(bix_list, bsz) +nchunks = len(bchunks) +m_per_group = max(len(mcx) for mcx in mx) +bsub_bytes = 2 * bsz * blockx * dwidth_i +def bsub_off(buf, idx): + return (buf * bsz + idx) * blockx * dwidth_i +use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, bsub_thread; +% if use_cpasync: + .reg .u32 bsub_sm_thread; +% endif + .reg .${pftype} bv, csub<${m_per_group}>; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _bsub[${bsub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; + } +% if use_cpasync: + { + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; + } +% endif + +% for cid, mcx in enumerate(mx): +## cid = ${cid}, rows ${mcx} + setp.ne.u32 p_skip, tid_y, ${cid}; + @p_skip bra $L_END_CID_${cid}; + +% if use_cpasync: +## Async fill of chunk 0 +% for idx, kx in enumerate(bchunks[0]): +% if idx % msplit == cid: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% endif +% endfor + cp.async.commit_group; + cp.async.wait_all; + bar.sync 0; +% else: +## Sync fill of chunk 0 +% for idx, kx in enumerate(bchunks[0]): +% if idx % msplit == cid: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + .reg .${pftype} _bv; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.cg.${pftype} _bv, [_bptr]; + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + } +% else: + { + .reg .${pftype} _bv; + ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + } +% endif +% endif +% endfor + bar.sync 0; +% endif + +## Main loop over B-chunks (double-buffered) +% for bb in range(nchunks): +<% + buf_cur = bb % 2 + buf_next = (bb + 1) % 2 + is_last = (bb == nchunks - 1) +%> +% if not is_last: +% for idx, kx in enumerate(bchunks[bb + 1]): +% if idx % msplit == cid: +% if use_cpasync: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% else: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + .reg .${pftype} _bv; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.cg.${pftype} _bv, [_bptr]; + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + } +% else: + { + .reg .${pftype} _bv; + ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + } +% endif +% endif +% endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif + +% for idx, kx in enumerate(bchunks[bb]): + ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}]; +% for j, row_j in enumerate(mcx): +<% jx = A[row_j, kx] %> +% if jx != 0 and kx == afix[row_j]: + mul.${pftype} csub${j}, bv, ${jx}; +% elif jx != 0: + fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; +% endif +% if kx == alix[row_j]: +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; +% endif + } +% endif +% endif +% endfor +% endfor +% if use_cpasync: +% if not is_last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor + +## Handle zero rows in this cid's group +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta == 0: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +% endfor +% endif + +$L_END_CID_${cid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index 3ba8e93..465eac3 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -2,10 +2,13 @@ <% pftype = "f32" if dtype == "float" else "f64" -putype = "u32" if dtype == "float" else "u64" -pbtype = "b32" if dtype == "float" else "b64" -rtype = "f" if dtype == "float" else "fd" -dwidth = "4" if dtype == "float" else "8" +dwidth_i = 4 if dtype == "float" else 8 +fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +has_zero_rows = any(jx == -1 for jx in afix) +bix_list = list(bix) +bix_idx = {kx: i for i, kx in enumerate(bix_list)} +preload_c = beta != 0 +need_scale = beta != 0 and beta != 1 %> % if n is None: @@ -15,125 +18,157 @@ dwidth = "4" if dtype == "float" else "8" .param .u64 _c, .param .u32 _ldc) { - .reg .u32 n, ldb, ldc; - ld.param.u32 n, [_n]; + .reg .u32 ldb, ldc; ld.param.u32 ldb, [_ldb]; ld.param.u32 ldc, [_ldc]; % else: .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) { - .reg .u32 n; - mov.u32 n, ${n}; -%endif - .reg .u32 id; - .reg .u64 b, c; - .reg .${pftype} csub<${m}>; - .reg .${pftype} ctmp<${m}>; +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} csub<${m}>, bv<${len(bix_list)}>; .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif ld.param.u64 b, [_b]; ld.param.u64 c, [_c]; { - .reg .u32 grid<3>; - mov.u32 grid0, %ntid.x; - mov.u32 grid1, %ctaid.x; - mov.u32 grid2, %tid.x; - mad.lo.u32 id, grid0, grid1, grid2; + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; } - setp.ge.u32 p1, id, n; + + setp.ge.u32 p1, id, n; @p1 bra $L_EXIT; + cvta.to.global.u64 b, b; cvta.to.global.u64 c, c; { - .reg .${pftype} bv; - .reg .u32 boff<${len(bix)}>, coff; - .reg .u64 bptr<${len(bix)}>, cptr; -%for kx in bix: -% if n is None: - mul.lo.u32 boff${kx}, ldb, ${kx}; - ${address((f"bptr{kx}", "u64"), ("b", "u64"), dwidth, (f"boff{kx}", "u32"), ("id", "u32"))} -% else: - ${address((f"bptr{kx}", "u64"), ("b", "u64"), dwidth, (f"{ldb*kx}", "u32"), ("id", "u32"))} -% endif - ld.weak.global.cg.${pftype} bv, [bptr${kx}]; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } -% for j, jx in enumerate(A[:, kx]): -% if jx != 0 and kx == afix[j]: - mul.${pftype} csub${j}, bv, ${jx}; -% elif jx != 0: - fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; -% endif +## Batch-load active B columns +%for i, kx in enumerate(bix_list): +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +%endfor -% if kx == alix[j] and beta == 0: -% if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} -% else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - st.weak.global.cg.${pftype} [cptr], csub${j}; +% if preload_c: +## Pre-load C so per-row completion is a plain store +% for j in range(m): +% if afix[j] != -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*j*dwidth_i}]; +% endif +% endif +% endfor +% if need_scale: +% for j in range(m): +% if afix[j] != -1: + mul.${pftype} csub${j}, csub${j}, ${float(beta)}; +% endif +% endfor +% endif +% endif -% elif kx == alix[j] and beta == 1: -% if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +## Main compute +%for kx in bix_list: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: +% if preload_c: + fma.rn.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}, csub${j}; +% elif kx == afix[j]: + mul.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}; % else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - ld.weak.global.${pftype} ctmp${j}, [cptr]; - add.${pftype} ctmp${j}, ctmp${j}, csub${j}; - st.weak.global.cg.${pftype} [cptr], ctmp${j}; - -% elif kx == alix[j]: + fma.rn.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}, csub${j}; +% endif +% endif +% if kx == alix[j]: % if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } % else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - ld.weak.global.${pftype} ctmp${j}, [cptr]; - fma.rn.${pftype} ctmp${j}, ctmp${j}, ${beta}, csub${j}; - st.weak.global.cg.${pftype} [cptr], csub${j}; + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; +% endif + % endif % endfor %endfor - } - { - .reg .u32 coff; - .reg .u64 cptr; - .reg .${pftype} fz; - .reg .${putype} uz; - .reg .${pftype} cin, cout; - mov.${putype} uz, 0; - mov.${pbtype} fz, uz; - -%for j, jx in enumerate(afix): -% if jx == -1 and beta == 0: -% if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} -% else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - st.weak.global.cg.${pftype} [cptr], fz; +% if has_zero_rows: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% for j, jx in enumerate(afix): +% if jx == -1 and beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif -% elif jx == -1 and beta != 1: -% if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} -% else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - ld.weak.global.cg.${pftype} cin, [cptr]; - mul.${pftype} cout, cin, ${beta}; - st.weak.globla.cg.${pftype} [cptr], cout; +% elif jx == -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; + } +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif % endif -%endfor - } +% endfor + } +% endif $L_EXIT: ret; -} \ No newline at end of file +} diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako new file mode 100644 index 0000000..06e8a77 --- /dev/null +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -0,0 +1,179 @@ +<%inherit file='base'/> + +<% +pftype = "f32" if dtype == "float" else "f64" +dwidth_i = 4 if dtype == "float" else 8 +fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(list(range(m)), csz) +cv_per_thread = -(-csz // ksplit) +bv_per_thread = max(len(kbx) for kbx in kparts) +csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, csub_thread; + .reg .${pftype} bv<${bv_per_thread}>, cv<${cv_per_thread}>, dotp; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _csub[${csub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; + } + +% for bid, kbx in enumerate(kparts): +## bid = ${bid}: ${len(kbx)} B columns, ksplit=${ksplit} + setp.ne.u32 p_skip, tid_y, ${bid}; + @p_skip bra $L_END_BID_${bid}; + +<% + loaded = set() + kbx_idx = {kx: i for i, kx in enumerate(kbx)} +%> + +% for cchunk_i, cchunk in enumerate(cchunks): +## Chunk ${cchunk_i}: partial dot-product +% for row_idx, j in enumerate(cchunk): +<% + nz = [(kbx_idx[kx], kx, A[j, kx]) for kx in kbx if A[j, kx] != 0] + owner_bid = row_idx % ksplit +%> +% for (kxi, kx, jx) in nz: +% if kx not in loaded: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.nc.${pftype} bv${kxi}, [_bptr]; + } +% else: + ld.global.nc.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +<% loaded.add(kx) %> +% endif +% endfor +% if nz: +% for i, (kxi, kx, jx) in enumerate(nz): +% if i == 0: + mul.${pftype} dotp, bv${kxi}, ${jx}; +% else: + fma.rn.${pftype} dotp, bv${kxi}, ${jx}, dotp; +% endif +% endfor +% else: + mov.${pftype} dotp, ${fzero}; +% endif +% if owner_bid == bid: + mov.${pftype} cv${row_idx // ksplit}, dotp; +% else: +<% csub_idx = bid - (1 if bid > owner_bid else 0) %> + st.shared.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}], dotp; +% endif +% endfor + bar.sync 0; + +## Combine phase (owned rows only) +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: + mov.${pftype} dotp, cv${row_idx // ksplit}; +% for other_bid in range(ksplit): +% if other_bid != bid: +<% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> + { + .reg .${pftype} _tmp; + ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; + add.${pftype} dotp, dotp, _tmp; + } +% endif +% endfor +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% endif +% endfor + bar.sync 0; +% endfor + +$L_END_BID_${bid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako new file mode 100644 index 0000000..150cf57 --- /dev/null +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -0,0 +1,98 @@ +<%inherit file='base'/> + +<% +pftype = "f64" +dwidth_i = 8 +fzero = "0d0000000000000000" +bix_list = list(bix) +bix_pos = {kx: i for i, kx in enumerate(bix_list)} +K_used = len(bix_list) +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m)] +assert dtype == 'double', 'cstream-w2 is double-precision only' +assert n is not None, 'cstream-w2 requires compile-time n' +%> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .f64 bv_a<${K_used}>, bv_b<${K_used}>, dotp_a, dotp_b; + .reg .pred p1; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x, _tid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 _tid_x, %tid.x; + mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, 16, b; + mad.lo.u64 c_base, _id64, 16, c; + } + +## Batch-load B column pairs +%for i, kx in enumerate(bix_list): + ld.global.nc.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; +%endfor + +## Main compute: two parallel dot-product streams per thread +%for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: + mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; + mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; +% else: + fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; + fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; +% endif +% endfor +% if beta == 0: + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% else: + { + .reg .f64 _ca, _cb; + ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; + st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif + +% else: +## Zero row of A +% if beta == 0: + { + .reg .f64 _z; + mov.f64 _z, ${fzero}; + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; + } +% elif beta != 1: + { + .reg .f64 _ca, _cb; + ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.f64 _ca, _ca, ${float(beta)}; + mul.f64 _cb, _cb, ${float(beta)}; + st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +%endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako new file mode 100644 index 0000000..f26abeb --- /dev/null +++ b/gimmik/kernels/ptx/cstream.mako @@ -0,0 +1,157 @@ +<%inherit file='base'/> + +<% +pftype = "f32" if dtype == "float" else "f64" +dwidth_i = 4 if dtype == "float" else 8 +fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +bix_list = list(bix) +bix_pos = {kx: i for i, kx in enumerate(bix_list)} +K_used = len(bix_list) +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m)] +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} bv<${K_used}>, dotp; + .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + +## Batch-load active B columns +%for i, kx in enumerate(bix_list): +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.nc.${pftype} bv${i}, [_bptr]; + } +% else: + ld.global.nc.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +%endfor + +## Compute and store each output row +%for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: + mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; +% else: + fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; +% endif +% endfor +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% else: +## Zero row of A +% if beta == 0: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +%endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako new file mode 100644 index 0000000..dcb8463 --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -0,0 +1,210 @@ +<%inherit file='base'/> + +<%! +import struct +import math +%> + +<% +assert dtype == "double" +assert n is not None and ldb is not None and ldc is not None + +M, K_ = A.shape +assert K_ == k +M_PAD = -(-M // 8) * 8 +M_TILES = M_PAD // 8 +K_REM = k % 4 +K_PAD = k if K_REM == 0 else k + (4 - K_REM) +K_ITERS = K_PAD // 4 + +# A in fragment-layout (32 contiguous elements per fragment) +a_u64 = [] +for m_tile in range(M_TILES): + for k_iter in range(K_ITERS): + for lane in range(32): + r_div4 = lane // 4 + r_mod4 = lane % 4 + i = m_tile * 8 + r_div4 + j = k_iter * 4 + r_mod4 + v = float(A[i, j]) if (i < M and j < k) else 0.0 + u = struct.unpack(' + +.global .align 16 .b64 ${kname}_Ag[${A_ELEMS}] = { + ${', '.join(a_u64)} +}; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .f64 a_frag; +% for nt in range(NN): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; + .reg .f64 b_frag_${nt}; + .reg .f64 c0_${nt}_<${M_TILES}>, c1_${nt}_<${M_TILES}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${N_PER_CTA}; + mul.lo.u32 warp_n_base, warp, ${N_PER_WARP}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(NN): + add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endfor + + // A global thread base: &Ag[0] (generic -> global) + lane*8 + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(M_TILES): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${M}; + } +% endfor + +% for nt in range(NN): +% for mt in range(M_TILES): +% if beta == 0: + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; +% else: + { + .reg .u64 caddr; + .reg .pred p0, p1; + add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; + and.pred p0, pm_${mt}, pvalid_c0col_${nt}; + and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; + @p0 ld.global.f64 c0_${nt}_${mt}, [caddr]; + @p1 ld.global.f64 c1_${nt}_${mt}, [caddr + 8]; + } +% endif +% endfor +% endfor + +% for ki in range(K_ITERS): +% for nt in range(NN): + { + .reg .u64 baddr; + .reg .pred pb_load; + add.u64 baddr, b_thr_base, ${ki * B_KITER_STRIDE + nt * B_NTILE_STRIDE}; +% if K_REM != 0 and ki == K_ITERS - 1: + { + .reg .u32 brow; + .reg .pred pbrow; + add.u32 brow, r_mod4, ${ki * 4}; + setp.lt.u32 pbrow, brow, ${k}; + and.pred pb_load, pbrow, pvalid_bcol_${nt}; + } +% else: + and.pred pb_load, pvalid_bcol_${nt}, pvalid_bcol_${nt}; +% endif + mov.f64 b_frag_${nt}, 0d0000000000000000; + @pb_load ld.global.nc.f64 b_frag_${nt}, [baddr]; + } +% endfor +% for mt in range(M_TILES): + ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * K_ITERS + ki) * FRAG_STRIDE_BYTES}]; +% for nt in range(NN): + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(NN): +% for mt in range(M_TILES): + { + .reg .u64 caddr; + .reg .pred p0, p1; + add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; + and.pred p0, pm_${mt}, pvalid_c0col_${nt}; + and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + @p0 st.global.f64 [caddr], c0_${nt}_${mt}; + @p1 st.global.f64 [caddr + 8], c1_${nt}_${mt}; + } +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako new file mode 100644 index 0000000..d395b2e --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -0,0 +1,264 @@ +<%inherit file='base'/> + +<%! +import struct +import math +%> + +<% +assert dtype == "double" +assert n is not None and ldb is not None and ldc is not None + +M, K_ = A.shape +assert K_ == k +M_PAD = -(-M // 8) * 8 +M_TILES = M_PAD // 8 +K_REM = k % 4 +K_PAD = k if K_REM == 0 else k + (4 - K_REM) +K_ITERS = K_PAD // 4 + +# A in fragment-layout (same as dense-mma-smem-nn) +a_u64 = [] +for m_tile in range(M_TILES): + for k_iter in range(K_ITERS): + for lane in range(32): + r_div4 = lane // 4 + r_mod4 = lane % 4 + i = m_tile * 8 + r_div4 + j = k_iter * 4 + r_mod4 + v = float(A[i, j]) if (i < M and j < k) else 0.0 + u = struct.unpack(' 2*BLOCKX elements per copy iter +A_PAIRS = A_ELEMS // 2 # number of f64x2 pairs +A_PAIRS_TAIL = A_ELEMS % 2 # 0 if even, 1 if odd +COPY_V2_ITERS = math.ceil(A_PAIRS / BLOCKX) + +FRAG_STRIDE_BYTES = 32 * 8 +B_KITER_STRIDE = 4 * ldb * 8 +B_NTILE_STRIDE = 8 * 8 +C_MTILE_STRIDE = 8 * ldc * 8 +C_NTILE_STRIDE = 8 * 8 +%> + +.global .align 16 .b64 ${kname}_Ag[${A_ELEMS}] = { + ${', '.join(a_u64)} +}; +.shared .align 16 .b64 ${kname}_As[${A_ELEMS}]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .f64 a_frag; +% for nt in range(NN): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; + .reg .f64 b_frag_${nt}; + .reg .f64 c0_${nt}_<${M_TILES}>, c1_${nt}_<${M_TILES}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + // ---- Cooperative copy A from .global to .shared using v2 loads ---- + { + .reg .u64 a_glb_base, a_smem_base; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + mov.u64 a_smem_base, ${kname}_As; +% for ci in range(COPY_V2_ITERS): +<% + base_pair = ci * BLOCKX + is_last = ci == COPY_V2_ITERS - 1 + pairs_this = min(BLOCKX, A_PAIRS - base_pair) +%> + { + .reg .u32 pidx; + .reg .u64 off64, gaddr, saddr; + .reg .f64 v0, v1; +% if is_last and pairs_this < BLOCKX: + .reg .pred plast; + add.u32 pidx, tid, ${base_pair}; + setp.lt.u32 plast, pidx, ${A_PAIRS}; + mul.wide.u32 off64, pidx, 16; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + @plast ld.global.nc.v2.f64 {v0, v1}, [gaddr]; + @plast st.shared.v2.f64 [saddr], {v0, v1}; +% else: + add.u32 pidx, tid, ${base_pair}; + mul.wide.u32 off64, pidx, 16; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + ld.global.nc.v2.f64 {v0, v1}, [gaddr]; + st.shared.v2.f64 [saddr], {v0, v1}; +% endif + } +% endfor +% if A_PAIRS_TAIL: + // Odd element at the very end (rare; A_ELEMS odd) + { + .reg .pred plast; + .reg .u64 gaddr, saddr; + .reg .f64 v; + setp.eq.u32 plast, tid, 0; + add.u64 gaddr, a_glb_base, ${(A_ELEMS-1) * 8}; + add.u64 saddr, a_smem_base, ${(A_ELEMS-1) * 8}; + @plast ld.global.nc.f64 v, [gaddr]; + @plast st.shared.f64 [saddr], v; + } +% endif + } + bar.sync 0; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${N_PER_CTA}; + mul.lo.u32 warp_n_base, warp, ${N_PER_WARP}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(NN): + add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endfor + + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(M_TILES): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${M}; + } +% endfor + +% for nt in range(NN): +% for mt in range(M_TILES): +% if beta == 0: + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; +% else: + { + .reg .u64 caddr; + .reg .pred p0, p1; + add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; + and.pred p0, pm_${mt}, pvalid_c0col_${nt}; + and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; + @p0 ld.global.f64 c0_${nt}_${mt}, [caddr]; + @p1 ld.global.f64 c1_${nt}_${mt}, [caddr + 8]; + } +% endif +% endfor +% endfor + +% for ki in range(K_ITERS): +% for nt in range(NN): + { + .reg .u64 baddr; + .reg .pred pb_load; + add.u64 baddr, b_thr_base, ${ki * B_KITER_STRIDE + nt * B_NTILE_STRIDE}; +% if K_REM != 0 and ki == K_ITERS - 1: + { + .reg .u32 brow; + .reg .pred pbrow; + add.u32 brow, r_mod4, ${ki * 4}; + setp.lt.u32 pbrow, brow, ${k}; + and.pred pb_load, pbrow, pvalid_bcol_${nt}; + } +% else: + and.pred pb_load, pvalid_bcol_${nt}, pvalid_bcol_${nt}; +% endif + mov.f64 b_frag_${nt}, 0d0000000000000000; + @pb_load ld.global.nc.f64 b_frag_${nt}, [baddr]; + } +% endfor +% for mt in range(M_TILES): + ld.shared.f64 a_frag, [as_thr_base + ${(mt * K_ITERS + ki) * FRAG_STRIDE_BYTES}]; +% for nt in range(NN): + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(NN): +% for mt in range(M_TILES): + { + .reg .u64 caddr; + .reg .pred p0, p1; + add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; + and.pred p0, pm_${mt}, pvalid_c0col_${nt}; + and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + @p0 st.global.f64 [caddr], c0_${nt}_${mt}; + @p1 st.global.f64 [caddr + 8], c1_${nt}_${mt}; + } +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py index bccdb61..dd3b259 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,7 +1,10 @@ # -*- coding: utf-8 -*- +import numpy as np + from gimmik.base import MatMul + class PTXSource: def __init__(self): self._src = "" @@ -48,16 +51,97 @@ def _address(self, out, base, size, *offs): src += f"mov.{out_type} {out[0]}, {base[0]};" return f"{{{src}\n\t}}" - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): base_args = {'address': lambda o, b, s, *off: self._address(o, b, s, *off), 'cc': compute_capability} - # B streaming, C accumulation kernel - args = base_args | {} - yield ('bstream', args, {}) + # Matrix-property gates + arr = self.A + nnz = int(np.count_nonzero(arr)) + nuq = int(len(np.unique(np.abs(arr)))) + density = nnz / arr.size + sparse_suitable = (nuq <= 28) or (density <= 0.15) + + cc = compute_capability or (0, 0) + dense_suitable = ( + dtype == 'double' + and cc >= (9, 0) + and self.n is not None + and self.m <= 128 + and self.k <= 128 + ) + + if sparse_suitable: + yield ('cstream', base_args | {}, {}) + + yield ('bstream', base_args | {}, {}) + + ms, bsz, blkx = 4, 24, 32 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} + yield ('bstream-msplit', args, meta) + + ms, bsz, blkx = 1, 16, 128 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} + yield ('bstream-msplit', args, meta) + + ks, csz, blkx = 2, 24, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} + yield ('cstream-ksplit', args, meta) + + K_used = len(self.bix) + if K_used > 500: + ks, csz, blkx = 4, 20, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize} + yield ('cstream-ksplit', args, meta) + + if (dtype == 'double' and self.n is not None and self.n % 2 == 0 + and K_used <= 100 + and (self.aligne is None or self.aligne % 2 == 0)): + blkx = 128 + args = base_args | {'blockx': blkx} + meta = {'block': (blkx, 1, 1), 'width': 2} + yield ('cstream-w2', args, meta) + + if dense_suitable: + # Dense DMMA m8n8k4 templates. Yields a small cover of the nn × w + # space that empirically spans the autotune winners seen on tet + # p=3,4 at N=500k. The PyFR wrapper's _benchmark picks the fastest. + for tpl in ('dense-mma-smem-gA', 'dense-mma-gAd'): + for nn in (1, 2, 4): + for w in (2, 4, 8): + blkx = 32 * w + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + args = base_args | {'warps_per_cta': w, 'nn': nn} + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + } + yield (tpl, args, meta) + + # Extra fine-grained nn for shapes where a specific nn usually + # wins (p3/tet/m132, p4/tet/m132). + for tpl in ('dense-mma-smem-gA', 'dense-mma-gAd'): + for nn in (6,): + for w in (1, 4): + blkx = 32 * w + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + args = base_args | {'warps_per_cta': w, 'nn': nn} + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + } + yield (tpl, args, meta) def _process_meta(self, meta): - if self.n is not None: + if self.n is not None and 'grid' not in meta: div = meta['block'][0]*meta['width'] meta['grid'] = (-(-self.n // div), 1, 1) From bbbb8ef94f7803d5449cbcc3c51a1409d8efa630 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Wed, 29 Apr 2026 06:49:21 -0700 Subject: [PATCH 03/21] Dense and sparse optimisation --- gimmik/kernels/ptx/dense-mma-gAd.mako | 134 +++++------ gimmik/kernels/ptx/dense-mma-smem-gA.mako | 257 ++++++++++++--------- gimmik/ptx.py | 259 ++++++++++++---------- 3 files changed, 348 insertions(+), 302 deletions(-) diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index dcb8463..7fc6572 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,50 +1,11 @@ <%inherit file='base'/> -<%! -import struct -import math -%> - <% assert dtype == "double" assert n is not None and ldb is not None and ldc is not None - -M, K_ = A.shape -assert K_ == k -M_PAD = -(-M // 8) * 8 -M_TILES = M_PAD // 8 -K_REM = k % 4 -K_PAD = k if K_REM == 0 else k + (4 - K_REM) -K_ITERS = K_PAD // 4 - -# A in fragment-layout (32 contiguous elements per fragment) -a_u64 = [] -for m_tile in range(M_TILES): - for k_iter in range(K_ITERS): - for lane in range(32): - r_div4 = lane // 4 - r_mod4 = lane % 4 - i = m_tile * 8 + r_div4 - j = k_iter * 4 + r_mod4 - v = float(A[i, j]) if (i < M and j < k) else 0.0 - u = struct.unpack(' -.global .align 16 .b64 ${kname}_Ag[${A_ELEMS}] = { +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; @@ -57,11 +18,13 @@ C_NTILE_STRIDE = 8 * 8 .reg .u64 ag_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; .reg .f64 a_frag; -% for nt in range(NN): +% for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif .reg .f64 b_frag_${nt}; - .reg .f64 c0_${nt}_<${M_TILES}>, c1_${nt}_<${M_TILES}>; + .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -78,14 +41,14 @@ C_NTILE_STRIDE = 8 * 8 { .reg .u32 cta; mov.u32 cta, %ctaid.x; - mul.lo.u32 cta, cta, ${N_PER_CTA}; - mul.lo.u32 warp_n_base, warp, ${N_PER_WARP}; + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; add.u32 warp_n_base, warp_n_base, cta; } setp.ge.u32 pwarp_exit, warp_n_base, ${n}; @pwarp_exit bra $L_EXIT; -% for nt in range(NN): +% for nt in range(nn): add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { @@ -95,12 +58,14 @@ C_NTILE_STRIDE = 8 * 8 add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } +% if not n_col_aligned: setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif % endfor - // A global thread base: &Ag[0] (generic -> global) + lane*8 + // A thread base: &Ag[0] + lane*8 { .reg .u64 t64, a_glb_base, lane64; mov.u64 a_glb_base, ${kname}_Ag; @@ -128,60 +93,71 @@ C_NTILE_STRIDE = 8 * 8 add.u64 c_thr_base, c_ptr, t64; } -% for mt in range(M_TILES): +% for mt in range(m_tiles): +% if pm_runtime(mt): .reg .pred pm_${mt}; { .reg .u32 crow; add.u32 crow, r_div4, ${mt * 8}; - setp.lt.u32 pm_${mt}, crow, ${M}; + setp.lt.u32 pm_${mt}, crow, ${m}; } +% endif % endfor -% for nt in range(NN): -% for mt in range(M_TILES): +% for nt in range(nn): +% for mt in range(m_tiles): % if beta == 0: mov.f64 c0_${nt}_${mt}, 0d0000000000000000; mov.f64 c1_${nt}_${mt}, 0d0000000000000000; % else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> { .reg .u64 caddr; - .reg .pred p0, p1; - add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; - and.pred p0, pm_${mt}, pvalid_c0col_${nt}; - and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: mov.f64 c0_${nt}_${mt}, 0d0000000000000000; mov.f64 c1_${nt}_${mt}, 0d0000000000000000; - @p0 ld.global.f64 c0_${nt}_${mt}, [caddr]; - @p1 ld.global.f64 c1_${nt}_${mt}, [caddr + 8]; +% endif + ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} } % endif % endfor % endfor -% for ki in range(K_ITERS): -% for nt in range(NN): +% for ki in range(k_iters): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and ki == k_iters - 1) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> { .reg .u64 baddr; - .reg .pred pb_load; - add.u64 baddr, b_thr_base, ${ki * B_KITER_STRIDE + nt * B_NTILE_STRIDE}; -% if K_REM != 0 and ki == K_ITERS - 1: + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.f64 b_frag_${nt}, 0d0000000000000000; +% endif +% if k_tail: + .reg .pred pbrow; { .reg .u32 brow; - .reg .pred pbrow; add.u32 brow, r_mod4, ${ki * 4}; setp.lt.u32 pbrow, brow, ${k}; - and.pred pb_load, pbrow, pvalid_bcol_${nt}; } -% else: - and.pred pb_load, pvalid_bcol_${nt}, pvalid_bcol_${nt}; % endif - mov.f64 b_frag_${nt}, 0d0000000000000000; - @pb_load ld.global.nc.f64 b_frag_${nt}, [baddr]; + ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} } % endfor -% for mt in range(M_TILES): - ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * K_ITERS + ki) * FRAG_STRIDE_BYTES}]; -% for nt in range(NN): +% for mt in range(m_tiles): + ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {c0_${nt}_${mt}, c1_${nt}_${mt}}, {a_frag}, @@ -191,16 +167,18 @@ C_NTILE_STRIDE = 8 * 8 % endfor % endfor -% for nt in range(NN): -% for mt in range(M_TILES): +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> { .reg .u64 caddr; - .reg .pred p0, p1; - add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; - and.pred p0, pm_${mt}, pvalid_c0col_${nt}; - and.pred p1, pm_${mt}, pvalid_c1col_${nt}; - @p0 st.global.f64 [caddr], c0_${nt}_${mt}; - @p1 st.global.f64 [caddr + 8], c1_${nt}_${mt}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } % endfor % endfor diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index d395b2e..8451e06 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -1,57 +1,24 @@ <%inherit file='base'/> -<%! -import struct -import math -%> - <% assert dtype == "double" assert n is not None and ldb is not None and ldc is not None - -M, K_ = A.shape -assert K_ == k -M_PAD = -(-M // 8) * 8 -M_TILES = M_PAD // 8 -K_REM = k % 4 -K_PAD = k if K_REM == 0 else k + (4 - K_REM) -K_ITERS = K_PAD // 4 - -# A in fragment-layout (same as dense-mma-smem-nn) -a_u64 = [] -for m_tile in range(M_TILES): - for k_iter in range(K_ITERS): - for lane in range(32): - r_div4 = lane // 4 - r_mod4 = lane % 4 - i = m_tile * 8 + r_div4 - j = k_iter * 4 + r_mod4 - v = float(A[i, j]) if (i < M and j < k) else 0.0 - u = struct.unpack(' 2*BLOCKX elements per copy iter -A_PAIRS = A_ELEMS // 2 # number of f64x2 pairs -A_PAIRS_TAIL = A_ELEMS % 2 # 0 if even, 1 if odd -COPY_V2_ITERS = math.ceil(A_PAIRS / BLOCKX) - -FRAG_STRIDE_BYTES = 32 * 8 -B_KITER_STRIDE = 4 * ldb * 8 -B_NTILE_STRIDE = 8 * 8 -C_MTILE_STRIDE = 8 * ldc * 8 -C_NTILE_STRIDE = 8 * 8 +# Cooperative-copy params (gA-only) +blockx = 32 * warps_per_cta +a_pairs = a_elems // 2 +a_pairs_tail = a_elems % 2 +copy_v2_iters = (a_pairs + blockx - 1) // blockx +bs = bool(context.get('block_stealing', False)) %> -.global .align 16 .b64 ${kname}_Ag[${A_ELEMS}] = { +% if bs: +.shared .align 8 .b64 ${kname}_mbar; +.shared .align 16 .b8 ${kname}_workid[16]; +% endif +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; -.shared .align 16 .b64 ${kname}_As[${A_ELEMS}]; +.shared .align 16 .b64 ${kname}_As[${a_elems}]; .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) @@ -62,11 +29,18 @@ C_NTILE_STRIDE = 8 * 8 .reg .u64 as_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; .reg .f64 a_frag; -% for nt in range(NN): +% if bs: + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; + .reg .pred p_root, p_done, p_have; +% endif +% for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif .reg .f64 b_frag_${nt}; - .reg .f64 c0_${nt}_<${M_TILES}>, c1_${nt}_<${M_TILES}>; + .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -80,26 +54,34 @@ C_NTILE_STRIDE = 8 * 8 shr.u32 r_div4, lane, 2; and.b32 r_mod4, lane, 3; - // ---- Cooperative copy A from .global to .shared using v2 loads ---- +% if bs: + setp.eq.u32 p_root, tid, 0; + mov.u32 mbar_a, ${kname}_mbar; + mov.u32 work_a, ${kname}_workid; + @p_root mbarrier.init.shared::cta.b64 [mbar_a], 1; + bar.sync 0; +% endif + + // Cooperative copy A from .global to .shared via v2 loads { .reg .u64 a_glb_base, a_smem_base; mov.u64 a_glb_base, ${kname}_Ag; cvta.to.global.u64 a_glb_base, a_glb_base; mov.u64 a_smem_base, ${kname}_As; -% for ci in range(COPY_V2_ITERS): +% for ci in range(copy_v2_iters): <% - base_pair = ci * BLOCKX - is_last = ci == COPY_V2_ITERS - 1 - pairs_this = min(BLOCKX, A_PAIRS - base_pair) + base_pair = ci * blockx + is_last = ci == copy_v2_iters - 1 + pairs_this = min(blockx, a_pairs - base_pair) %> { .reg .u32 pidx; .reg .u64 off64, gaddr, saddr; .reg .f64 v0, v1; -% if is_last and pairs_this < BLOCKX: +% if is_last and pairs_this < blockx: .reg .pred plast; add.u32 pidx, tid, ${base_pair}; - setp.lt.u32 plast, pidx, ${A_PAIRS}; + setp.lt.u32 plast, pidx, ${a_pairs}; mul.wide.u32 off64, pidx, 16; add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; @@ -115,15 +97,15 @@ C_NTILE_STRIDE = 8 * 8 % endif } % endfor -% if A_PAIRS_TAIL: - // Odd element at the very end (rare; A_ELEMS odd) +% if a_pairs_tail: + // Tail element (only when a_elems is odd) { .reg .pred plast; .reg .u64 gaddr, saddr; .reg .f64 v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(A_ELEMS-1) * 8}; - add.u64 saddr, a_smem_base, ${(A_ELEMS-1) * 8}; + add.u64 gaddr, a_glb_base, ${(a_elems-1) * 8}; + add.u64 saddr, a_smem_base, ${(a_elems-1) * 8}; @plast ld.global.nc.f64 v, [gaddr]; @plast st.shared.f64 [saddr], v; } @@ -131,17 +113,50 @@ C_NTILE_STRIDE = 8 * 8 } bar.sync 0; + // Lane-only base; lifted out of the optional steal loop + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + +% for mt in range(m_tiles): +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% if bs: + mov.u32 ctaid, %ctaid.x; +$L_LOOP: +% endif + { .reg .u32 cta; +% if bs: + mov.u32 cta, ctaid; +% else: mov.u32 cta, %ctaid.x; - mul.lo.u32 cta, cta, ${N_PER_CTA}; - mul.lo.u32 warp_n_base, warp, ${N_PER_WARP}; +% endif + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; add.u32 warp_n_base, warp_n_base, cta; } setp.ge.u32 pwarp_exit, warp_n_base, ${n}; +% if bs: + @pwarp_exit bra $L_STEAL; +% else: @pwarp_exit bra $L_EXIT; +% endif -% for nt in range(NN): +% for nt in range(nn): add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { @@ -151,19 +166,13 @@ C_NTILE_STRIDE = 8 * 8 add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } +% if not n_col_aligned: setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif % endfor - { - .reg .u64 t64, a_smem_base, lane64; - mov.u64 a_smem_base, ${kname}_As; - cvt.u64.u32 lane64, lane; - shl.b64 t64, lane64, 3; - add.u64 as_thr_base, a_smem_base, t64; - } - { .reg .u64 t64, bcol64; mul.wide.u32 t64, r_mod4, ${ldb}; @@ -182,60 +191,60 @@ C_NTILE_STRIDE = 8 * 8 add.u64 c_thr_base, c_ptr, t64; } -% for mt in range(M_TILES): - .reg .pred pm_${mt}; - { - .reg .u32 crow; - add.u32 crow, r_div4, ${mt * 8}; - setp.lt.u32 pm_${mt}, crow, ${M}; - } -% endfor - -% for nt in range(NN): -% for mt in range(M_TILES): +% for nt in range(nn): +% for mt in range(m_tiles): % if beta == 0: mov.f64 c0_${nt}_${mt}, 0d0000000000000000; mov.f64 c1_${nt}_${mt}, 0d0000000000000000; % else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> { .reg .u64 caddr; - .reg .pred p0, p1; - add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; - and.pred p0, pm_${mt}, pvalid_c0col_${nt}; - and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: mov.f64 c0_${nt}_${mt}, 0d0000000000000000; mov.f64 c1_${nt}_${mt}, 0d0000000000000000; - @p0 ld.global.f64 c0_${nt}_${mt}, [caddr]; - @p1 ld.global.f64 c1_${nt}_${mt}, [caddr + 8]; +% endif + ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} } % endif % endfor % endfor -% for ki in range(K_ITERS): -% for nt in range(NN): +% for ki in range(k_iters): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and ki == k_iters - 1) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> { .reg .u64 baddr; - .reg .pred pb_load; - add.u64 baddr, b_thr_base, ${ki * B_KITER_STRIDE + nt * B_NTILE_STRIDE}; -% if K_REM != 0 and ki == K_ITERS - 1: + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.f64 b_frag_${nt}, 0d0000000000000000; +% endif +% if k_tail: + .reg .pred pbrow; { .reg .u32 brow; - .reg .pred pbrow; add.u32 brow, r_mod4, ${ki * 4}; setp.lt.u32 pbrow, brow, ${k}; - and.pred pb_load, pbrow, pvalid_bcol_${nt}; } -% else: - and.pred pb_load, pvalid_bcol_${nt}, pvalid_bcol_${nt}; % endif - mov.f64 b_frag_${nt}, 0d0000000000000000; - @pb_load ld.global.nc.f64 b_frag_${nt}, [baddr]; + ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} } % endfor -% for mt in range(M_TILES): - ld.shared.f64 a_frag, [as_thr_base + ${(mt * K_ITERS + ki) * FRAG_STRIDE_BYTES}]; -% for nt in range(NN): +% for mt in range(m_tiles): + ld.shared.f64 a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {c0_${nt}_${mt}, c1_${nt}_${mt}}, {a_frag}, @@ -245,20 +254,52 @@ C_NTILE_STRIDE = 8 * 8 % endfor % endfor -% for nt in range(NN): -% for mt in range(M_TILES): +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> { .reg .u64 caddr; - .reg .pred p0, p1; - add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; - and.pred p0, pm_${mt}, pvalid_c0col_${nt}; - and.pred p1, pm_${mt}, pvalid_c1col_${nt}; - @p0 st.global.f64 [caddr], c0_${nt}_${mt}; - @p1 st.global.f64 [caddr + 8], c1_${nt}_${mt}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } % endfor % endfor +% if bs: +$L_STEAL: + // Root issues async try_cancel + waits; bar.sync orders the workid load + @!p_root bra $L_AFTER_WAIT; + { + .reg .u64 state; + mbarrier.arrive.expect_tx.shared::cta.b64 state, [mbar_a], 16; + clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [work_a], [mbar_a]; +$L_WAIT: + mbarrier.try_wait.shared::cta.b64 p_done, [mbar_a], state, 10000000; + @!p_done bra $L_WAIT; + } +$L_AFTER_WAIT: + bar.sync 0; + + { + .reg .b128 resp; + ld.shared::cta.b128 resp, [work_a]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; + @!p_have bra $L_FIN; + // 1D grid: extract just x + clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; + } + bra.uni $L_LOOP; + +$L_FIN: + bar.sync 0; + @p_root mbarrier.inval.shared::cta.b64 [mbar_a]; +% endif + $L_EXIT: ret; } diff --git a/gimmik/ptx.py b/gimmik/ptx.py index dd3b259..d2c6894 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,145 +1,172 @@ # -*- coding: utf-8 -*- +import struct + import numpy as np from gimmik.base import MatMul -class PTXSource: - def __init__(self): - self._src = "" - - def __iadd__(self, other): - self._src = f"{self}\n\t{other}" - return self - - def __str__(self): - return self._src - - def __repr__(self): - return self._src - - class PTXMatMul(MatMul): platform = 'ptx' basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - def _address(self, out, base, size, *offs): - src = PTXSource() - out_type = out[1] - if out_type != base[1]: - raise RuntimeError("out and base must have the same type") - - if offs: - off_type = offs[0][1] - if not all(off[1] == off_type for off in offs): - raise RuntimeError("offsets must all have the same tpye") - - if len(offs) == 1: - off = offs[0] - mad_type = "lo" if out_type == off_type else "wide" - src += f"mad.{mad_type}.{off_type} {out[0]}, {size}, {off[0]}, {base[0]};" - else: - src += f".reg .{off_type} _addrs_acum;" - src += f"add.{off_type} _addrs_acum, {offs[0][0]}, {offs[1][0]};" - for off in offs[2:]: - src += f"add.{off_type} _addrs_acum, _addrs_acum, {off[0]};" - mad_type = "lo" if out_type == off_type else "wide" - src += f"mad.{mad_type}.{off_type} {out[0]}, {size}, _addrs_acum, {base[0]};" - else: - src += f"mov.{out_type} {out[0]}, {base[0]};" - return f"{{{src}\n\t}}" - - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): - base_args = {'address': lambda o, b, s, *off: self._address(o, b, s, - *off), 'cc': compute_capability} - - # Matrix-property gates + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + trim_a=False): + base_args = {'cc': compute_capability, + 'pred_emit': self._pred_emit, + 'trim_a': bool(trim_a) and dtype == 'double'} + + yield from self._sparse_kernel_generators(dtype, dsize, base_args) + yield from self._dense_kernel_generators(dtype, dsize, base_args) + + def _sparse_kernel_generators(self, dtype, dsize, base_args): arr = self.A nnz = int(np.count_nonzero(arr)) nuq = int(len(np.unique(np.abs(arr)))) density = nnz / arr.size - sparse_suitable = (nuq <= 28) or (density <= 0.15) - - cc = compute_capability or (0, 0) - dense_suitable = ( - dtype == 'double' - and cc >= (9, 0) - and self.n is not None - and self.m <= 128 - and self.k <= 128 - ) + if not ((nuq <= 28) or (density <= 0.15)): + return - if sparse_suitable: - yield ('cstream', base_args | {}, {}) + # B loading, C streaming kernel + yield ('cstream', base_args | {}, {'desc': 'cstream'}) - yield ('bstream', base_args | {}, {}) + # B streaming, C accumulation kernel + yield ('bstream', base_args | {}, {'desc': 'bstream'}) - ms, bsz, blkx = 4, 24, 32 - args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} - yield ('bstream-msplit', args, meta) + # Four-way m-split B streaming, C accumulation kernel + ms, bsz, blkx = 4, 24, 32 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + yield ('bstream-msplit', args, meta) - ms, bsz, blkx = 1, 16, 128 + # Single-warp LDGSTS variant for medium-M beta=0 large-K cases + if self.beta == 0 and self.m <= 320 and len(self.bix) >= 64: + ms, bsz, blkx = 1, 32, 64 args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} + meta = {'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} yield ('bstream-msplit', args, meta) - ks, csz, blkx = 2, 24, 32 + # Two-way k-split B loading, C streaming kernel + ks, csz, blkx = 2, 24, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + yield ('cstream-ksplit', args, meta) + + # Four-way k-split for large K + K_used = len(self.bix) + if K_used > 500: + ks, csz, blkx = 4, 20, 32 args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} + meta = {'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} yield ('cstream-ksplit', args, meta) - K_used = len(self.bix) - if K_used > 500: - ks, csz, blkx = 4, 20, 32 - args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), - 'shared': (ks - 1)*csz*blkx*dsize} - yield ('cstream-ksplit', args, meta) - - if (dtype == 'double' and self.n is not None and self.n % 2 == 0 - and K_used <= 100 - and (self.aligne is None or self.aligne % 2 == 0)): - blkx = 128 - args = base_args | {'blockx': blkx} - meta = {'block': (blkx, 1, 1), 'width': 2} - yield ('cstream-w2', args, meta) - - if dense_suitable: - # Dense DMMA m8n8k4 templates. Yields a small cover of the nn × w - # space that empirically spans the autotune winners seen on tet - # p=3,4 at N=500k. The PyFR wrapper's _benchmark picks the fastest. - for tpl in ('dense-mma-smem-gA', 'dense-mma-gAd'): - for nn in (1, 2, 4): - for w in (2, 4, 8): - blkx = 32 * w - n_per_cta = 8 * nn * w - if n_per_cta > self.n: - continue - args = base_args | {'warps_per_cta': w, 'nn': nn} - meta = { - 'block': (blkx, 1, 1), - 'grid': (-(-self.n // n_per_cta), 1, 1), - } - yield (tpl, args, meta) - - # Extra fine-grained nn for shapes where a specific nn usually - # wins (p3/tet/m132, p4/tet/m132). - for tpl in ('dense-mma-smem-gA', 'dense-mma-gAd'): - for nn in (6,): - for w in (1, 4): - blkx = 32 * w - n_per_cta = 8 * nn * w - if n_per_cta > self.n: - continue - args = base_args | {'warps_per_cta': w, 'nn': nn} - meta = { - 'block': (blkx, 1, 1), - 'grid': (-(-self.n // n_per_cta), 1, 1), - } - yield (tpl, args, meta) + # Width-2 vector cstream for fp64 small-K + if (dtype == 'double' and self.n is not None and self.n % 2 == 0 + and K_used <= 100 + and (self.aligne is None or self.aligne % 2 == 0)): + blkx = 128 + args = base_args | {'blockx': blkx} + meta = {'block': (blkx, 1, 1), 'width': 2, + 'desc': f'cstream-w2/x{blkx}'} + yield ('cstream-w2', args, meta) + + def _dense_kernel_generators(self, dtype, dsize, base_args): + cc = base_args['cc'] or (0, 0) + if not (dtype == 'double' and cc >= (9, 0) and self.n is not None + and self.m <= 128 and self.k <= 128): + return + + # Dense DMMA m8n8k4; block stealing default on sm_100+ for gA + bs_default = cc >= (10, 0) + dense_configs = [ + ('dense-mma-smem-gA', 1, 8), + ('dense-mma-smem-gA', 2, 4), + ('dense-mma-smem-gA', 4, 4), + ('dense-mma-gAd', 2, 2), + ('dense-mma-gAd', 4, 2), + ] + for tpl, nn, w in dense_configs: + blkx = 32 * w + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + bs = (tpl == 'dense-mma-smem-gA') and bs_default + setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + args = (base_args | {'warps_per_cta': w, 'nn': nn, + 'block_stealing': bs} | setup) + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'desc': f'{tpl}/nn{nn}-w{w}{"-bs" if bs else ""}', + } + yield (tpl, args, meta) + + def _dense_mma_setup(self, *, nn, warps_per_cta): + a = self.A + m, k = a.shape + m_tiles = -(-m // 8) + k_rem = k % 4 + k_iters = (k + (4 - k_rem if k_rem else 0)) // 4 + + # A in fragment layout: lane l -> A[m_tile*8 + l/4][k_iter*4 + l%4] + a_u64 = [] + for m_tile in range(m_tiles): + for k_iter in range(k_iters): + for lane in range(32): + i = m_tile * 8 + lane // 4 + j = k_iter * 4 + lane % 4 + v = float(a[i, j]) if (i < m and j < k) else 0.0 + u = struct.unpack(' m + + return { + 'm_tiles': m_tiles, + 'k_rem': k_rem, 'k_iters': k_iters, + 'a_u64': a_u64, + 'n_per_warp': n_per_warp, 'n_per_cta': n_per_cta, + 'a_elems': a_elems, + 'frag_stride_bytes': 32 * 8, + 'b_kiter_stride': 4 * (self.ldb or 0) * 8, + 'b_ntile_stride': 8 * 8, + 'c_mtile_stride': 8 * (self.ldc or 0) * 8, + 'c_ntile_stride': 8 * 8, + 'n_col_aligned': n_col_aligned, + 'pm_runtime': pm_runtime, + } + + @staticmethod + def _pred_emit(instr, *preds, pred_reg=None, indent=' ' * 8): + actual = [p for p in preds if p is not None] + if not actual: + return instr + if len(actual) == 1: + return f'@{actual[0]} {instr}' + if pred_reg is None: + raise ValueError('pred_reg required when combining multiple ' + 'predicates') + lines = [f'.reg .pred {pred_reg};', + f'and.pred {pred_reg}, {actual[0]}, {actual[1]};'] + for p in actual[2:]: + lines.append(f'and.pred {pred_reg}, {pred_reg}, {p};') + lines.append(f'@{pred_reg} {instr}') + return f'\n{indent}'.join(lines) def _process_meta(self, meta): if self.n is not None and 'grid' not in meta: From 393b4095a3985d6de32551b5d6daa0de4cd312c4 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Mon, 11 May 2026 05:35:20 -0700 Subject: [PATCH 04/21] Added warp specialised dense kernel --- gimmik/kernels/ptx/dense-mma-ws.mako | 422 +++++++++++++++++++++++++++ gimmik/ptx.py | 120 +++++++- 2 files changed, 535 insertions(+), 7 deletions(-) create mode 100644 gimmik/kernels/ptx/dense-mma-ws.mako diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako new file mode 100644 index 0000000..514837f --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -0,0 +1,422 @@ +<%inherit file='base'/> +<% +assert dtype == "double" +assert n is not None and ldb is not None and ldc is not None +mbar_maxwait = '0x989680' +%> + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; +.const .align 64 .b8 ${kname}_bdesc[128]; +.const .align 64 .b8 ${kname}_cdesc[128]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; + .reg .b32 wid_used_mbar, wid_smem; + .reg .pred p_compute, p_prod, p_steal; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + add.u32 wid_smem, dynm_base, ${wid_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; + add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; + add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; + + cvta.const.u64 bdesc_addr, ${kname}_bdesc; + cvta.const.u64 cdesc_addr, ${kname}_cdesc; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + setp.eq.u32 p_steal, warp, ${steal_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 + // can write csmem immediately. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; + @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Cooperative copy A: .global -> a_smem (ld.global.nc.v2.f64) + { + .reg .u64 a_glb_base; + .reg .b32 pidx; + .reg .f64 av0, av1; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; +% for ci in range(copy_v2_iters): +<% + base_pair = ci * blockx_total + is_last = ci == copy_v2_iters - 1 + pairs_this = min(blockx_total, a_pairs - base_pair) + needs_guard = is_last and pairs_this < blockx_total +%> + { + .reg .u64 ofs64, gaddr; + .reg .b32 saddr; + add.u32 pidx, tid, ${base_pair}; +% if needs_guard: + .reg .pred p_load; + setp.lt.u32 p_load, pidx, ${a_pairs}; +% endif + mul.wide.u32 ofs64, pidx, 16; + add.u64 gaddr, a_glb_base, ofs64; + cvt.u32.u64 saddr, ofs64; + add.u32 saddr, saddr, a_smem; +% if needs_guard: + @p_load ld.global.nc.v2.f64 {av0, av1}, [gaddr]; + @p_load st.shared.v2.f64 [saddr], {av0, av1}; +% else: + ld.global.nc.v2.f64 {av0, av1}, [gaddr]; + st.shared.v2.f64 [saddr], {av0, av1}; +% endif + } +% endfor +% if a_pairs_tail: + { + .reg .pred p_tail; + .reg .u64 gaddr; + .reg .b32 saddr; + .reg .f64 v; + setp.eq.u32 p_tail, tid, 0; + add.u64 gaddr, a_glb_base, ${(a_elems - 1) * 8}; + mov.u32 saddr, ${(a_elems - 1) * 8}; + add.u32 saddr, saddr, a_smem; + @p_tail ld.global.nc.f64 v, [gaddr]; + @p_tail st.shared.f64 [saddr], v; + } +% endif + } + bar.sync 0; + + // Compute-warp lane geometry (cheap; all warps execute uniformly) + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + // Producer warp: initial B load for ctaid_x's work + @!p_prod bra.uni $L_AFTER_INIT_B; + { + .reg .b32 n_start0; + mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_INIT_W: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_INIT_W; + .reg .b64 _state2; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; + } +$L_AFTER_INIT_B: + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 phase, 0; + +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + // --- Compute Warps + @!p_compute bra.uni $L_AFTER_COMPUTE; + + // Wait on B + { + .reg .pred p1; +$L_WAIT_BRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BRDY; + } + + // MMA + { + .reg .b32 b_sm_a; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_sm_a, b2_smem, b1_smem, p_ph; + + .reg .b32 a_thr_a; + { + .reg .b32 t; + shl.b32 t, lane, 3; + add.u32 a_thr_a, a_smem, t; + } +% for nt in range(nn): + .reg .b32 b_thr_a_${nt}; + { + .reg .b32 bcol_g, t_off; + add.u32 bcol_g, base_bcol, ${8 * nt}; + shl.b32 t_off, bcol_g, 3; + add.u32 b_thr_a_${nt}, b_sm_a, t_off; + } +% endfor + + .reg .b32 c_thr_smem; + { + .reg .b32 t1, ccol_b; + mul.lo.u32 t1, base_crow, ${n_per_cta * 8}; + shl.b32 ccol_b, base_ccol, 3; + add.u32 c_thr_smem, c_smem, t1; + add.u32 c_thr_smem, c_thr_smem, ccol_b; + } + + // Zero accumulators +% for mt in range(m_tiles): +% for nt in range(nn): + .reg .f64 d_x_${mt}_${nt}, d_y_${mt}_${nt}; + mov.f64 d_x_${mt}_${nt}, 0d0000000000000000; + mov.f64 d_y_${mt}_${nt}, 0d0000000000000000; +% endfor +% endfor + + .reg .f64 a_f; +% for mt in range(m_tiles): +% for kt in range(k_iters): +<% + k_tail = (k_rem != 0 and kt == k_iters - 1) +%> + { + .reg .b32 a_a; + add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * 8}; + ld.shared.f64 a_f, [a_a]; +% if k_tail: + .reg .pred pbrow_${mt}_${kt}; + { + .reg .b32 brow; + add.u32 brow, base_brow, ${4 * kt}; + setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; + } +% endif +% for nt in range(nn): + { + .reg .b32 b_a, b_row; + .reg .f64 b_f; + add.u32 b_row, base_brow, ${4 * kt}; + mul.lo.u32 b_row, b_row, ${n_per_cta * 8}; + add.u32 b_a, b_thr_a_${nt}, b_row; +% if k_tail: + mov.f64 b_f, 0d0000000000000000; + @pbrow_${mt}_${kt} ld.shared.f64 b_f, [b_a]; +% else: + ld.shared.f64 b_f, [b_a]; +% endif + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor + } +% endfor +% endfor + + // Wait until producer's prev-iter TMA-store of C has drained. + { + .reg .pred p1; +$L_WAIT_CSTORE: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cstored_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CSTORE; + } + + // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows + // are dropped by the C tensor map. +% for mt in range(m_tiles): +% for nt in range(nn): + { + .reg .b32 csaddr; + add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; + st.shared.v2.f64 [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor + + bar.sync 1, ${comp_threads}; + fence.proxy.async.shared::cta; + { + .reg .b64 _state; + @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; + } + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_C: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_C; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + } +$L_AFTER_COMPUTE: + + // --- Data Movement Warp + @!p_prod bra.uni $L_AFTER_DATA; + { + .reg .b32 n_c_store; + mul.lo.u32 n_c_store, block_idx_x, ${n_per_cta}; + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_D: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_D; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + + // TMA loads of next B + { + mul.lo.u32 n_start_next, block_idx_x, ${n_per_cta}; + .reg .b32 b_next; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_next, b1_smem, b2_smem, p_ph; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_next], [bdesc_addr, {n_start_next, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + @p_warp_lead cp.async.bulk.commit_group; + } + bar.warp.sync 0xffffffff; + + // TMA store/reduce+store of a C + { + .reg .pred p1; + .reg .b64 _c_state; +$L_WAIT_CRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CRDY; +% if beta == 0: + @p_warp_lead cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; +% else: + @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; +% endif + @p_warp_lead cp.async.bulk.commit_group; + @p_warp_lead cp.async.bulk.wait_group 0; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; + } + + // Wait for next B to be ready, then signal B and C ready + { + .reg .b64 b_state, _bready_state, _c_state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 b_state, [tma_mbar]; +$L_WAIT_TMA: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], b_state, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_TMA; + + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bready_state, [bready_mbar]; + } + } +$L_AFTER_DATA: + + // --- Controller Warp + @!p_steal bra.uni $L_AFTER_CTRL; + { + .reg .pred p1, p2, p_canc; + .reg .b64 _state; + .reg .b128 resp; + @p_warp_lead fence.proxy.async.shared::cta; + @p_warp_lead clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 + [wid_smem], [steal_mbar]; + @p_warp_lead mbarrier.arrive.expect_tx.shared::cta.b64 + _state, [steal_mbar], 16; + +$L_WAIT_STEAL: + mbarrier.try_wait.parity.shared::cta.b64 p1, [steal_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_STEAL; + + // Signal new work + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_new_mbar]; + + // Query if there's new work + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + selp.b32 work, 1, 0, p_canc; + + // Wait for old work to be used +$L_WAIT_WUSED: + mbarrier.try_wait.parity.shared::cta.b64 p2, [wid_used_mbar], phase, ${mbar_maxwait}; + @!p2 bra.uni $L_WAIT_WUSED; + } +$L_AFTER_CTRL: + + xor.b32 phase, phase, 1; + bra.uni $L_LOOP; + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py index d2c6894..bf32a62 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -12,6 +12,28 @@ class PTXMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} + @staticmethod + def is_sparse_suitable(arr): + nnz = int(np.count_nonzero(arr)) + nuq = int(len(np.unique(np.abs(arr)))) + density = nnz / arr.size + return (nuq <= 28) or (density <= 0.15) + + @staticmethod + def is_dense_suitable(arr, dtype, cc): + """True if A's shape and the target arch support the dense DMMA + template family. Does NOT check runtime args (n, ldb, ldc); those + are validated when the generator runs.""" + return (np.dtype(dtype) == np.float64 + and cc is not None and cc >= (9, 0) + and arr.shape[0] <= 128 and arr.shape[1] <= 128) + + @classmethod + def is_suitable(cls, arr, dtype, cc): + """True if either sparse or dense templates are applicable.""" + return (cls.is_sparse_suitable(arr) + or cls.is_dense_suitable(arr, dtype, cc)) + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, trim_a=False): base_args = {'cc': compute_capability, @@ -22,11 +44,7 @@ def _kernel_generators(self, dtype, dsize, *, compute_capability=None, yield from self._dense_kernel_generators(dtype, dsize, base_args) def _sparse_kernel_generators(self, dtype, dsize, base_args): - arr = self.A - nnz = int(np.count_nonzero(arr)) - nuq = int(len(np.unique(np.abs(arr)))) - density = nnz / arr.size - if not ((nuq <= 28) or (density <= 0.15)): + if not self.is_sparse_suitable(self.A): return # B loading, C streaming kernel @@ -80,8 +98,8 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): def _dense_kernel_generators(self, dtype, dsize, base_args): cc = base_args['cc'] or (0, 0) - if not (dtype == 'double' and cc >= (9, 0) and self.n is not None - and self.m <= 128 and self.k <= 128): + if not (self.is_dense_suitable(self.A, dtype, cc) + and self.n is not None): return # Dense DMMA m8n8k4; block stealing default on sm_100+ for gA @@ -109,6 +127,94 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): } yield (tpl, args, meta) + # Warp-specialised dense DMMA with TMA B-load + TMA C-store. + if cc >= (10, 0): + yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) + + def _dense_ws_kernel_generators(self, dtype, dsize, base_args): + m_pad = -(-self.m // 8) * 8 + k_pad = -(-self.k // 4) * 4 + # (nn, w_compute) -- block has w_compute + 2 warps (producer, stealer) + ws_configs = [(1, 4), (2, 4), (4, 4)] + for nn, w in ws_configs: + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + blkx = 32 * (w + 2) + setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + ws_layout = self._dense_ws_layout( + n_comp_warps=w, n_per_cta=n_per_cta, + m_pad=m_pad, k_pad=k_pad, a_elems=setup['a_elems'] + ) + # sm_100 supports up to 228 KiB shared per CTA with the + # set_shared_size opt-in. Reserve some headroom for L1 carveout. + if ws_layout['dynm_total_bytes'] > 200 * 1024: + continue + args = (base_args + | {'warps_per_cta': w, 'nn': nn} + | setup | ws_layout) + yield ('dense-mma-ws', args, { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'desc': f'dense-mma-ws/nn{nn}-w{w}', + 'ws_tensor_map': True, + 'ws_n_per_cta': n_per_cta, + 'ws_k_pad': k_pad, + 'ws_m_pad': m_pad, + 'dynamic_shared': ws_layout['dynm_total_bytes'], + }) + + @staticmethod + def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): + """Render-time constants for the dense-mma-ws template: warp roles, + cooperative-copy iteration counts, smem-tile sizes, mbar timeout, + and dynamic-shared byte offsets for each buffer.""" + n_total_warps = n_comp_warps + 2 + blockx_total = 32 * n_total_warps + a_pairs = a_elems // 2 + a_pairs_tail = a_elems % 2 + + b_tile_bytes = k_pad * n_per_cta * 8 + c_tile_bytes = m_pad * n_per_cta * 8 + a_bytes = a_elems * 8 + + smem_size = {'b1': b_tile_bytes, 'b2': b_tile_bytes, 'c': c_tile_bytes, + 'a': a_bytes, 'wid': 16} + smem_off, off = {}, 0 + for k, v in smem_size.items(): + off = (off + 15) & ~15 + smem_off[f'{k}_off'] = off + off += v + + mbar_names = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + for k in mbar_names: + smem_off[f'{k}_mbar_off'] = off + off += 8 + + # Pad total to 16-byte multiple + dynm_total_bytes = (off + 15) & ~15 + + params = {'n_comp_warps': n_comp_warps, + 'blockx_total': blockx_total, + 'prod_warp': n_comp_warps, + 'steal_warp': n_comp_warps + 1, + 'comp_threads': 32 * n_comp_warps, + 'a_pairs': a_pairs, + 'a_pairs_tail': a_pairs_tail, + 'copy_v2_iters': -(-a_pairs // blockx_total), + 'm_pad': m_pad, + 'k_pad': k_pad, + 'b_tile_doubles': k_pad * n_per_cta, + 'b_tile_bytes': b_tile_bytes, + 'c_tile_doubles': m_pad * n_per_cta, + 'c_mtile_smem_stride': 8 * n_per_cta * 8, + 'c_ntile_smem_stride': 8 * 8, + 'dynm_total_bytes': dynm_total_bytes, + } + params |= smem_off + return params + def _dense_mma_setup(self, *, nn, warps_per_cta): a = self.A m, k = a.shape From 67d1bebd516e29b7d3b70460919056a6e534ab0a Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Wed, 13 May 2026 09:27:38 -0700 Subject: [PATCH 05/21] Performance tuning and cleanup --- gimmik/kernels/ptx/bstream-msplit.mako | 12 +- gimmik/kernels/ptx/bstream.mako | 38 ++- gimmik/kernels/ptx/cstream-ksplit.mako | 6 +- gimmik/kernels/ptx/cstream-w2.mako | 36 ++- gimmik/kernels/ptx/cstream.mako | 36 +-- gimmik/kernels/ptx/dense-mma-gAd.mako | 13 +- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 4 +- gimmik/kernels/ptx/dense-mma-ws.mako | 334 ++++++++++++---------- gimmik/ptx.py | 58 ++-- 9 files changed, 271 insertions(+), 266 deletions(-) diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako index 77e7ce7..0af5091 100644 --- a/gimmik/kernels/ptx/bstream-msplit.mako +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -1,14 +1,13 @@ <%inherit file='base'/> <% -pftype = "f32" if dtype == "float" else "f64" -dwidth_i = 4 if dtype == "float" else 8 -fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' has_zero_rows = any(jx == -1 for jx in afix) mx = partition(A, into=msplit, by='rows') bix_list = list(bix) bchunks = chunk(bix_list, bsz) -nchunks = len(bchunks) m_per_group = max(len(mcx) for mcx in mx) bsub_bytes = 2 * bsz * blockx * dwidth_i def bsub_off(buf, idx): @@ -135,11 +134,11 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) % endif ## Main loop over B-chunks (double-buffered) -% for bb in range(nchunks): +% for bb in range(len(bchunks)): <% buf_cur = bb % 2 buf_next = (bb + 1) % 2 - is_last = (bb == nchunks - 1) + is_last = (bb == len(bchunks) - 1) %> % if not is_last: % for idx, kx in enumerate(bchunks[bb + 1]): @@ -232,6 +231,7 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) % endif bar.sync 0; % endfor +## End of Main loop over B-chunks ## Handle zero rows in this cid's group % if has_zero_rows: diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index 465eac3..f58e9b3 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -1,14 +1,12 @@ <%inherit file='base'/> <% -pftype = "f32" if dtype == "float" else "f64" -dwidth_i = 4 if dtype == "float" else 8 -fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' has_zero_rows = any(jx == -1 for jx in afix) bix_list = list(bix) -bix_idx = {kx: i for i, kx in enumerate(bix_list)} -preload_c = beta != 0 -need_scale = beta != 0 and beta != 1 +bix_pos = {kx: i for i, kx in enumerate(bix_list)} %> % if n is None: @@ -61,7 +59,7 @@ need_scale = beta != 0 and beta != 1 } ## Batch-load active B columns -%for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix_list): % if n is None: { .reg .u32 _boff; @@ -73,9 +71,9 @@ need_scale = beta != 0 and beta != 1 % else: ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; % endif -%endfor +% endfor -% if preload_c: +% if beta != 0: ## Pre-load C so per-row completion is a plain store % for j in range(m): % if afix[j] != -1: @@ -92,7 +90,7 @@ need_scale = beta != 0 and beta != 1 % endif % endif % endfor -% if need_scale: +% if beta != 0 and beta != 1: % for j in range(m): % if afix[j] != -1: mul.${pftype} csub${j}, csub${j}, ${float(beta)}; @@ -102,15 +100,15 @@ need_scale = beta != 0 and beta != 1 % endif ## Main compute -%for kx in bix_list: -% for j, jx in enumerate(A[:, kx]): -% if jx != 0: -% if preload_c: - fma.rn.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}, csub${j}; +% for kx in bix_list: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: +% if preload_c: + fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; % elif kx == afix[j]: - mul.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}; + mul.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}; % else: - fma.rn.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}, csub${j}; + fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; % endif % endif % if kx == alix[j]: @@ -126,9 +124,9 @@ need_scale = beta != 0 and beta != 1 st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; % endif -% endif -% endfor -%endfor +% endif +% endfor +% endfor % if has_zero_rows: { diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako index 06e8a77..1ba2491 100644 --- a/gimmik/kernels/ptx/cstream-ksplit.mako +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -1,9 +1,9 @@ <%inherit file='base'/> <% -pftype = "f32" if dtype == "float" else "f64" -dwidth_i = 4 if dtype == "float" else 8 -fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' kparts = partition(A, ksplit, by='cols') cchunks = chunk(list(range(m)), csz) cv_per_thread = -(-csz // ksplit) diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako index 150cf57..c82ebab 100644 --- a/gimmik/kernels/ptx/cstream-w2.mako +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -1,15 +1,13 @@ <%inherit file='base'/> <% -pftype = "f64" +pftype = 'f64' dwidth_i = 8 -fzero = "0d0000000000000000" +fzero = '0d0000000000000000' bix_list = list(bix) bix_pos = {kx: i for i, kx in enumerate(bix_list)} -K_used = len(bix_list) -row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m)] -assert dtype == 'double', 'cstream-w2 is double-precision only' -assert n is not None, 'cstream-w2 requires compile-time n' +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] + for j in range(m)] %> .visible .entry ${kname}(.param .u64 _b, @@ -17,7 +15,7 @@ assert n is not None, 'cstream-w2 requires compile-time n' { .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .f64 bv_a<${K_used}>, bv_b<${K_used}>, dotp_a, dotp_b; + .reg .f64 bv_a<${len(bix_list)}>, bv_b<${len(bix_list)}>, dotp_a, dotp_b; .reg .pred p1; mov.u32 n, ${-(-n // 2)}; @@ -45,22 +43,22 @@ assert n is not None, 'cstream-w2 requires compile-time n' } ## Batch-load B column pairs -%for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix_list): ld.global.nc.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; -%endfor +% endfor ## Main compute: two parallel dot-product streams per thread -%for j in range(m): -% if row_nz[j]: -% for i_nz, (kx, jx) in enumerate(row_nz[j]): -% if i_nz == 0: +% for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; -% endif -% endfor +% endif +% endfor % if beta == 0: st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; % else: @@ -73,7 +71,7 @@ assert n is not None, 'cstream-w2 requires compile-time n' } % endif -% else: +% else: ## Zero row of A % if beta == 0: { @@ -90,8 +88,8 @@ assert n is not None, 'cstream-w2 requires compile-time n' st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; } % endif -% endif -%endfor +% endif +% endfor $L_EXIT: ret; diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako index f26abeb..ec46934 100644 --- a/gimmik/kernels/ptx/cstream.mako +++ b/gimmik/kernels/ptx/cstream.mako @@ -1,13 +1,13 @@ <%inherit file='base'/> <% -pftype = "f32" if dtype == "float" else "f64" -dwidth_i = 4 if dtype == "float" else 8 -fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' bix_list = list(bix) bix_pos = {kx: i for i, kx in enumerate(bix_list)} -K_used = len(bix_list) -row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m)] +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] + for j in range(m)] %> % if n is None: @@ -27,7 +27,7 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) % endif .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .${pftype} bv<${K_used}>, dotp; + .reg .${pftype} bv<${len(bix_list)}>, dotp; .reg .pred p1; % if n is None: @@ -60,7 +60,7 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) } ## Batch-load active B columns -%for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix_list): % if n is None: { .reg .u32 _boff; @@ -72,18 +72,18 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) % else: ld.global.nc.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; % endif -%endfor +% endfor ## Compute and store each output row -%for j in range(m): -% if row_nz[j]: -% for i_nz, (kx, jx) in enumerate(row_nz[j]): -% if i_nz == 0: +% for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; -% endif -% endfor +% endif +% endfor % if beta == 0: % if n is None: { @@ -115,7 +115,7 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) } % endif -% else: +% else: ## Zero row of A % if beta == 0: { @@ -149,8 +149,8 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) % endif } % endif -% endif -%endfor +% endif +% endfor $L_EXIT: ret; diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index 7fc6572..ce8066d 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,8 +1,7 @@ <%inherit file='base'/> <% -assert dtype == "double" -assert n is not None and ldb is not None and ldc is not None +fzero = '0d0000000000000000' %> .global .align 16 .b64 ${kname}_Ag[${a_elems}] = { @@ -107,8 +106,8 @@ assert n is not None and ldb is not None and ldc is not None % for nt in range(nn): % for mt in range(m_tiles): % if beta == 0: - mov.f64 c0_${nt}_${mt}, 0d0000000000000000; - mov.f64 c1_${nt}_${mt}, 0d0000000000000000; + mov.f64 c0_${nt}_${mt}, ${fzero}; + mov.f64 c1_${nt}_${mt}, ${fzero}; % else: <% pm = f'pm_{mt}' if pm_runtime(mt) else None @@ -120,8 +119,8 @@ assert n is not None and ldb is not None and ldc is not None .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; % if needs_zero_init: - mov.f64 c0_${nt}_${mt}, 0d0000000000000000; - mov.f64 c1_${nt}_${mt}, 0d0000000000000000; + mov.f64 c0_${nt}_${mt}, ${fzero}; + mov.f64 c1_${nt}_${mt}, ${fzero}; % endif ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} @@ -142,7 +141,7 @@ assert n is not None and ldb is not None and ldc is not None .reg .u64 baddr; add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; % if needs_zero: - mov.f64 b_frag_${nt}, 0d0000000000000000; + mov.f64 b_frag_${nt}, ${fzero}; % endif % if k_tail: .reg .pred pbrow; diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index 8451e06..ec2f013 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -1,14 +1,12 @@ <%inherit file='base'/> <% -assert dtype == "double" -assert n is not None and ldb is not None and ldc is not None # Cooperative-copy params (gA-only) blockx = 32 * warps_per_cta a_pairs = a_elems // 2 a_pairs_tail = a_elems % 2 copy_v2_iters = (a_pairs + blockx - 1) // blockx -bs = bool(context.get('block_stealing', False)) +bs = bool(block_stealing) %> % if bs: diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index 514837f..e4b576a 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -1,158 +1,24 @@ <%inherit file='base'/> <% -assert dtype == "double" -assert n is not None and ldb is not None and ldc is not None mbar_maxwait = '0x989680' +direct_store = (beta == 0) %> -.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { - ${', '.join(a_u64)} -}; -.extern .shared .align 128 .b8 ${kname}_dynm[]; -.const .align 64 .b8 ${kname}_bdesc[128]; -.const .align 64 .b8 ${kname}_cdesc[128]; - -.visible .entry ${kname}(.param .u64 _b, - .param .u64 _c) -.maxntid ${blockx_total}, 1, 1 -{ - .reg .b32 tid, warp, lane, phase, ctaid_x; - .reg .b32 base_brow, base_bcol, base_crow, base_ccol; - .reg .b32 work, block_idx_x, n_start_curr, n_start_next; - .reg .u64 bdesc_addr, cdesc_addr; - .reg .b32 a_smem, b1_smem, b2_smem, c_smem; - .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; - .reg .b32 wid_used_mbar, wid_smem; - .reg .pred p_compute, p_prod, p_steal; - .reg .pred p_warp_lead; - .reg .pred p_done; - .reg .pred p_tid0; - - mov.u32 tid, %tid.x; - shr.u32 warp, tid, 5; - and.b32 lane, tid, 31; - mov.u32 ctaid_x, %ctaid.x; - - .reg .b32 dynm_base; - mov.u32 dynm_base, ${kname}_dynm; - add.u32 b1_smem, dynm_base, ${b1_off}; - add.u32 b2_smem, dynm_base, ${b2_off}; - add.u32 c_smem, dynm_base, ${c_off}; - add.u32 a_smem, dynm_base, ${a_off}; - add.u32 wid_smem, dynm_base, ${wid_off}; - - add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; - add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; - add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; - add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; - add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; - add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; - add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; - - cvta.const.u64 bdesc_addr, ${kname}_bdesc; - cvta.const.u64 cdesc_addr, ${kname}_cdesc; - - setp.eq.u32 p_tid0, tid, 0; - - setp.lt.u32 p_compute, warp, ${n_comp_warps}; - setp.eq.u32 p_prod, warp, ${prod_warp}; - setp.eq.u32 p_steal, warp, ${steal_warp}; - - { - .reg .b32 _elect_lane; - elect.sync _elect_lane|p_warp_lead, 0xffffffff; - } - - // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 - // can write csmem immediately. - { - .reg .pred p_init; - setp.eq.u32 p_init, tid, 0; - .reg .b64 _state; - @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; - @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; - @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; - @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; - @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; - @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; - @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; - @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; - @p_init fence.proxy.async.shared::cta; - } - bar.sync 0; - - // Cooperative copy A: .global -> a_smem (ld.global.nc.v2.f64) - { - .reg .u64 a_glb_base; - .reg .b32 pidx; - .reg .f64 av0, av1; - mov.u64 a_glb_base, ${kname}_Ag; - cvta.to.global.u64 a_glb_base, a_glb_base; -% for ci in range(copy_v2_iters): -<% - base_pair = ci * blockx_total - is_last = ci == copy_v2_iters - 1 - pairs_this = min(blockx_total, a_pairs - base_pair) - needs_guard = is_last and pairs_this < blockx_total -%> - { - .reg .u64 ofs64, gaddr; - .reg .b32 saddr; - add.u32 pidx, tid, ${base_pair}; -% if needs_guard: - .reg .pred p_load; - setp.lt.u32 p_load, pidx, ${a_pairs}; -% endif - mul.wide.u32 ofs64, pidx, 16; - add.u64 gaddr, a_glb_base, ofs64; - cvt.u32.u64 saddr, ofs64; - add.u32 saddr, saddr, a_smem; -% if needs_guard: - @p_load ld.global.nc.v2.f64 {av0, av1}, [gaddr]; - @p_load st.shared.v2.f64 [saddr], {av0, av1}; -% else: - ld.global.nc.v2.f64 {av0, av1}, [gaddr]; - st.shared.v2.f64 [saddr], {av0, av1}; -% endif - } -% endfor -% if a_pairs_tail: - { - .reg .pred p_tail; - .reg .u64 gaddr; - .reg .b32 saddr; - .reg .f64 v; - setp.eq.u32 p_tail, tid, 0; - add.u64 gaddr, a_glb_base, ${(a_elems - 1) * 8}; - mov.u32 saddr, ${(a_elems - 1) * 8}; - add.u32 saddr, saddr, a_smem; - @p_tail ld.global.nc.f64 v, [gaddr]; - @p_tail st.shared.f64 [saddr], v; - } -% endif - } - bar.sync 0; - - // Compute-warp lane geometry (cheap; all warps execute uniformly) - { - .reg .b32 t, w_n_base; - and.b32 base_brow, lane, 3; - shr.u32 base_crow, lane, 2; - mul.lo.u32 w_n_base, warp, ${n_per_warp}; - add.u32 base_bcol, base_crow, w_n_base; - shl.b32 t, base_brow, 1; - add.u32 base_ccol, t, w_n_base; - } - - // Producer warp: initial B load for ctaid_x's work +<%def name="producer_init_setup()"> + // Producer warp: initial A bulk-copy + B load for ctaid_x's work @!p_prod bra.uni $L_AFTER_INIT_B; { .reg .b32 n_start0; + .reg .u64 a_glb; mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + mov.u64 a_glb, ${kname}_Ag; + cvta.to.global.u64 a_glb, a_glb; + @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes + [a_smem], [a_glb], ${a_elems * 8}, [tma_mbar]; @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 - [tma_mbar], ${b_tile_bytes}; + [tma_mbar], ${b_tile_bytes + a_elems * 8}; bar.warp.sync 0xffffffff; .reg .b64 state; .reg .pred p1; @@ -164,17 +30,9 @@ $L_TMA_INIT_W: @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; } $L_AFTER_INIT_B: + - mov.u32 block_idx_x, ctaid_x; - mov.u32 work, 1; - mov.u32 phase, 0; - -$L_LOOP: - setp.eq.u32 p_done, work, 0; - @p_done bra.uni $L_EXIT; - - mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; - +<%def name="compute_warp_body()"> // --- Compute Warps @!p_compute bra.uni $L_AFTER_COMPUTE; @@ -209,6 +67,13 @@ $L_WAIT_BRDY: } % endfor +% if direct_store: + // direct_store: skip shared-staging entirely; compute warps store + // MMA outputs straight to global C with N-tail predication. + .reg .u64 c_glob_addr; + ld.param.u64 c_glob_addr, [_c]; + cvta.to.global.u64 c_glob_addr, c_glob_addr; +% else: .reg .b32 c_thr_smem; { .reg .b32 t1, ccol_b; @@ -217,6 +82,7 @@ $L_WAIT_BRDY: add.u32 c_thr_smem, c_smem, t1; add.u32 c_thr_smem, c_thr_smem, ccol_b; } +% endif // Zero accumulators % for mt in range(m_tiles): @@ -267,6 +133,45 @@ $L_WAIT_BRDY: % endfor % endfor +% if direct_store: + .reg .u64 c_thr_glob_base; + { + .reg .u32 thr_col_off, thr_addr_off_lo; + add.u32 thr_col_off, base_ccol, n_start_curr; + mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; + .reg .u64 thr_byte_off; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, 8; + add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; + } +% for mt in range(m_tiles): +<% + row_tail = (m_pad > m) and ((mt + 1) * 8 > m) +%> +% if row_tail: + .reg .pred p_row_${mt}; + { + .reg .b32 crow; + add.u32 crow, base_crow, ${8 * mt}; + setp.lt.u32 p_row_${mt}, crow, ${m}; + } +% endif +% for nt in range(nn): + { + .reg .pred p_st; + .reg .u32 g_ccol; + add.u32 g_ccol, base_ccol, ${8 * nt}; + add.u32 g_ccol, g_ccol, n_start_curr; + setp.lt.u32 p_st, g_ccol, ${n}; +% if row_tail: + and.pred p_st, p_st, p_row_${mt}; +% endif + .reg .u64 c_addr; + add.u64 c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * 8}; + @p_st st.global.v2.f64 [c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% else: // Wait until producer's prev-iter TMA-store of C has drained. { .reg .pred p1; @@ -286,13 +191,16 @@ $L_WAIT_CSTORE: } % endfor % endfor +% endif +% if not direct_store: bar.sync 1, ${comp_threads}; fence.proxy.async.shared::cta; { .reg .b64 _state; @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; } +% endif // Wait for new work and unpack { @@ -312,7 +220,9 @@ $L_WAIT_WNEW_C: } } $L_AFTER_COMPUTE: + +<%def name="data_warp_body()"> // --- Data Movement Warp @!p_prod bra.uni $L_AFTER_DATA; { @@ -350,24 +260,22 @@ $L_WAIT_WNEW_D: } bar.warp.sync 0xffffffff; - // TMA store/reduce+store of a C +% if not direct_store: + // TMA reduce+store of C (beta=1 only; beta=0 uses direct global + // stores from compute warps, so the producer does no C work). { .reg .pred p1; .reg .b64 _c_state; $L_WAIT_CRDY: mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; @!p1 bra.uni $L_WAIT_CRDY; -% if beta == 0: - @p_warp_lead cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group - [cdesc_addr, {n_c_store, 0}], [c_smem]; -% else: @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group [cdesc_addr, {n_c_store, 0}], [c_smem]; -% endif @p_warp_lead cp.async.bulk.commit_group; @p_warp_lead cp.async.bulk.wait_group 0; @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; } +% endif // Wait for next B to be ready, then signal B and C ready { @@ -382,7 +290,9 @@ $L_WAIT_TMA: } } $L_AFTER_DATA: + +<%def name="ctrl_warp_body()"> // --- Controller Warp @!p_steal bra.uni $L_AFTER_CTRL; { @@ -413,6 +323,112 @@ $L_WAIT_WUSED: @!p2 bra.uni $L_WAIT_WUSED; } $L_AFTER_CTRL: + + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; +.const .align 64 .b8 ${kname}_bdesc[128]; +.const .align 64 .b8 ${kname}_cdesc[128]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; + .reg .b32 wid_used_mbar, wid_smem; + .reg .pred p_compute, p_prod, p_steal; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + add.u32 wid_smem, dynm_base, ${wid_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; + add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; + add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; + + cvta.const.u64 bdesc_addr, ${kname}_bdesc; + cvta.const.u64 cdesc_addr, ${kname}_cdesc; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + setp.eq.u32 p_steal, warp, ${steal_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 + // can write csmem immediately. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; + @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Compute-warp lane geometry (cheap; all warps execute uniformly) + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + ${producer_init_setup()} + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 phase, 0; + +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + ${compute_warp_body()} + + ${data_warp_body()} + + ${ctrl_warp_body()} xor.b32 phase, phase, 1; bra.uni $L_LOOP; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index bf32a62..ad429e8 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -19,26 +19,21 @@ def is_sparse_suitable(arr): density = nnz / arr.size return (nuq <= 28) or (density <= 0.15) + # Shape/arch gate for dense DMMA; n/ldb/ldc are validated at generate time @staticmethod def is_dense_suitable(arr, dtype, cc): - """True if A's shape and the target arch support the dense DMMA - template family. Does NOT check runtime args (n, ldb, ldc); those - are validated when the generator runs.""" return (np.dtype(dtype) == np.float64 and cc is not None and cc >= (9, 0) and arr.shape[0] <= 128 and arr.shape[1] <= 128) @classmethod def is_suitable(cls, arr, dtype, cc): - """True if either sparse or dense templates are applicable.""" return (cls.is_sparse_suitable(arr) or cls.is_dense_suitable(arr, dtype, cc)) - def _kernel_generators(self, dtype, dsize, *, compute_capability=None, - trim_a=False): + def _kernel_generators(self, dtype, dsize, *, compute_capability=None): base_args = {'cc': compute_capability, - 'pred_emit': self._pred_emit, - 'trim_a': bool(trim_a) and dtype == 'double'} + 'pred_emit': self._pred_emit} yield from self._sparse_kernel_generators(dtype, dsize, base_args) yield from self._dense_kernel_generators(dtype, dsize, base_args) @@ -48,10 +43,10 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): return # B loading, C streaming kernel - yield ('cstream', base_args | {}, {'desc': 'cstream'}) + yield ('cstream', base_args, {'desc': 'cstream'}) # B streaming, C accumulation kernel - yield ('bstream', base_args | {}, {'desc': 'bstream'}) + yield ('bstream', base_args, {'desc': 'bstream'}) # Four-way m-split B streaming, C accumulation kernel ms, bsz, blkx = 4, 24, 32 @@ -102,15 +97,22 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): and self.n is not None): return - # Dense DMMA m8n8k4; block stealing default on sm_100+ for gA + # Some kernels can optional steal blocks bs_default = cc >= (10, 0) - dense_configs = [ - ('dense-mma-smem-gA', 1, 8), - ('dense-mma-smem-gA', 2, 4), - ('dense-mma-smem-gA', 4, 4), - ('dense-mma-gAd', 2, 2), - ('dense-mma-gAd', 4, 2), - ] + + if cc >= (10, 0): + # Warp specialised is uniformly better on sm_100+, so no need to JIT + # other versions + dense_configs = [('dense-mma-smem-gA', 4, 4)] + else: + dense_configs = [ + ('dense-mma-smem-gA', 1, 8), + ('dense-mma-smem-gA', 2, 4), + ('dense-mma-smem-gA', 4, 4), + ('dense-mma-gAd', 2, 2), + ('dense-mma-gAd', 4, 2), + ] + for tpl, nn, w in dense_configs: blkx = 32 * w n_per_cta = 8 * nn * w @@ -127,13 +129,14 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): } yield (tpl, args, meta) - # Warp-specialised dense DMMA with TMA B-load + TMA C-store. + # Warp-specialised dense DMMA if cc >= (10, 0): yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) def _dense_ws_kernel_generators(self, dtype, dsize, base_args): m_pad = -(-self.m // 8) * 8 k_pad = -(-self.k // 4) * 4 + # (nn, w_compute) -- block has w_compute + 2 warps (producer, stealer) ws_configs = [(1, 4), (2, 4), (4, 4)] for nn, w in ws_configs: @@ -146,14 +149,14 @@ def _dense_ws_kernel_generators(self, dtype, dsize, base_args): n_comp_warps=w, n_per_cta=n_per_cta, m_pad=m_pad, k_pad=k_pad, a_elems=setup['a_elems'] ) - # sm_100 supports up to 228 KiB shared per CTA with the - # set_shared_size opt-in. Reserve some headroom for L1 carveout. + if ws_layout['dynm_total_bytes'] > 200 * 1024: continue + args = (base_args | {'warps_per_cta': w, 'nn': nn} | setup | ws_layout) - yield ('dense-mma-ws', args, { + meta = { 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), 'desc': f'dense-mma-ws/nn{nn}-w{w}', @@ -162,17 +165,13 @@ def _dense_ws_kernel_generators(self, dtype, dsize, base_args): 'ws_k_pad': k_pad, 'ws_m_pad': m_pad, 'dynamic_shared': ws_layout['dynm_total_bytes'], - }) + } + yield ('dense-mma-ws', args, meta) @staticmethod def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): - """Render-time constants for the dense-mma-ws template: warp roles, - cooperative-copy iteration counts, smem-tile sizes, mbar timeout, - and dynamic-shared byte offsets for each buffer.""" n_total_warps = n_comp_warps + 2 blockx_total = 32 * n_total_warps - a_pairs = a_elems // 2 - a_pairs_tail = a_elems % 2 b_tile_bytes = k_pad * n_per_cta * 8 c_tile_bytes = m_pad * n_per_cta * 8 @@ -200,9 +199,6 @@ def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): 'prod_warp': n_comp_warps, 'steal_warp': n_comp_warps + 1, 'comp_threads': 32 * n_comp_warps, - 'a_pairs': a_pairs, - 'a_pairs_tail': a_pairs_tail, - 'copy_v2_iters': -(-a_pairs // blockx_total), 'm_pad': m_pad, 'k_pad': k_pad, 'b_tile_doubles': k_pad * n_per_cta, From e2a818bb9234d5326deda37035635dd6c0ae3129 Mon Sep 17 00:00:00 2001 From: Will Trojak Date: Fri, 15 May 2026 13:19:38 +0100 Subject: [PATCH 06/21] Whitespace --- gimmik/kernels/ptx/base.mako | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako index e380f1b..b64ddc1 100644 --- a/gimmik/kernels/ptx/base.mako +++ b/gimmik/kernels/ptx/base.mako @@ -1,4 +1,4 @@ .version 8.7 .target sm_${cc[0]}${cc[1]}${"a" if cc[0] >= 9 else ""} .address_size 64 -${next.body()} \ No newline at end of file +${next.body()} From 7d7299a48bc486883f9b3e933ce7321d9e0e2dc2 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Tue, 19 May 2026 11:31:30 -0700 Subject: [PATCH 07/21] Cleanups, formating and addressign comments --- gimmik/kernels/ptx/base.mako | 4 +- gimmik/kernels/ptx/bstream-msplit.mako | 306 ++++++++++------------ gimmik/kernels/ptx/bstream.mako | 157 +++++------ gimmik/kernels/ptx/cstream-ksplit.mako | 157 ++++++----- gimmik/kernels/ptx/cstream-w2.mako | 80 +++--- gimmik/kernels/ptx/cstream.mako | 166 ++++++------ gimmik/kernels/ptx/dense-mma-gAd.mako | 86 +++--- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 113 ++++---- gimmik/kernels/ptx/dense-mma-ws.mako | 114 ++++---- gimmik/ptx.py | 217 +++++++-------- 10 files changed, 667 insertions(+), 733 deletions(-) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako index b64ddc1..71eb414 100644 --- a/gimmik/kernels/ptx/base.mako +++ b/gimmik/kernels/ptx/base.mako @@ -1,4 +1,4 @@ -.version 8.7 -.target sm_${cc[0]}${cc[1]}${"a" if cc[0] >= 9 else ""} +.version 8.6 +.target sm_${cc[0]}${cc[1]}${'a' if cc[0] >= 9 else ''} .address_size 64 ${next.body()} diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako index 0af5091..530b19f 100644 --- a/gimmik/kernels/ptx/bstream-msplit.mako +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -1,12 +1,7 @@ <%inherit file='base'/> <% -pftype = 'f32' if dtype == 'float' else 'f64' -dwidth_i = 4 if dtype == 'float' else 8 -fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' -has_zero_rows = any(jx == -1 for jx in afix) mx = partition(A, into=msplit, by='rows') -bix_list = list(bix) bchunks = chunk(bix_list, bsz) m_per_group = max(len(mcx) for mcx in mx) bsub_bytes = 2 * bsz * blockx * dwidth_i @@ -48,11 +43,11 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) ld.param.u64 c, [_c]; { - .reg .u32 _ctaid_x; - mov.u32 _ctaid_x, %ctaid.x; - mov.u32 tid_x, %tid.x; - mov.u32 tid_y, %tid.y; - mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; } setp.ge.u32 p1, id, n; @@ -62,23 +57,23 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, ${dwidth_i}, b; - mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; } { - .reg .u64 _tx_off; - mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; - mov.u64 bsub_thread, _bsub; - add.u64 bsub_thread, bsub_thread, _tx_off; + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; } % if use_cpasync: { - .reg .u64 _sm64; - cvta.to.shared.u64 _sm64, bsub_thread; - cvt.u32.u64 bsub_sm_thread, _sm64; + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; } % endif @@ -87,191 +82,176 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) setp.ne.u32 p_skip, tid_y, ${cid}; @p_skip bra $L_END_CID_${cid}; -% if use_cpasync: +% if use_cpasync: ## Async fill of chunk 0 -% for idx, kx in enumerate(bchunks[0]): -% if idx % msplit == cid: -% if n is None: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; } -% else: +% else: cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; -% endif -% endif +% endif % endfor cp.async.commit_group; cp.async.wait_all; bar.sync 0; -% else: +% else: ## Sync fill of chunk 0 -% for idx, kx in enumerate(bchunks[0]): -% if idx % msplit == cid: -% if n is None: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: { - .reg .u32 _boff; - .reg .u64 _bptr; - .reg .${pftype} _bv; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.global.cg.${pftype} _bv, [_bptr]; - st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + .reg .${pftype} _bv; +% if n is None: + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} _bv, [_bptr]; +% else: + ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; +% endif + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; } -% else: - { - .reg .${pftype} _bv; - ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; - st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; - } -% endif -% endif % endfor bar.sync 0; -% endif +% endif ## Main loop over B-chunks (double-buffered) -% for bb in range(len(bchunks)): +% for bb in range(len(bchunks)): <% buf_cur = bb % 2 buf_next = (bb + 1) % 2 - is_last = (bb == len(bchunks) - 1) %> -% if not is_last: -% for idx, kx in enumerate(bchunks[bb + 1]): -% if idx % msplit == cid: -% if use_cpasync: -% if n is None: +% if not loop.last: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[bb + 1]) if i % msplit == cid]: +% if use_cpasync: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; } -% else: +% else: cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; -% endif -% else: -% if n is None: - { - .reg .u32 _boff; - .reg .u64 _bptr; - .reg .${pftype} _bv; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.global.cg.${pftype} _bv, [_bptr]; - st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; - } -% else: +% endif +% else: { - .reg .${pftype} _bv; - ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; - st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + .reg .${pftype} _bv; +% if n is None: + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} _bv, [_bptr]; +% else: + ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; +% endif + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; } -% endif -% endif -% endif -% endfor -% if use_cpasync: - cp.async.commit_group; -% endif % endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif -% for idx, kx in enumerate(bchunks[bb]): +% for idx, kx in enumerate(bchunks[bb]): ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}]; -% for j, row_j in enumerate(mcx): -<% jx = A[row_j, kx] %> -% if jx != 0 and kx == afix[row_j]: +% for j, row_j in enumerate(mcx): +<% jx = A[row_j, kx] %> +% if jx != 0 and kx == afix[row_j]: mul.${pftype} csub${j}, bv, ${jx}; -% elif jx != 0: +% elif jx != 0: fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; -% endif -% if kx == alix[row_j]: -% if beta == 0: -% if n is None: +% endif +% if kx == alix[row_j]: +% if beta_zero: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], csub${j}; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; } -% else: +% else: st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; -% endif -% else: +% endif +% else: { - .reg .${pftype} _ctmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _ctmp, [_cptr]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; - st.global.${pftype} [_cptr], _ctmp; -% else: - ld.global.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; - st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; -% endif + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; +% endif } -% endif -% endif -% endfor -% endfor -% if use_cpasync: -% if not is_last: - cp.async.wait_all; +% endif % endif -% endif - bar.sync 0; +% endfor % endfor +% if use_cpasync: +% if not loop.last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor ## End of Main loop over B-chunks ## Handle zero rows in this cid's group -% if has_zero_rows: -% for row_j in mcx: -% if afix[row_j] == -1: -% if beta == 0: +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta_zero: { - .reg .${pftype} _tmp; - mov.${pftype} _tmp, ${fzero}; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], _tmp; -% else: - st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; -% endif + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif } -% elif beta != 1: +% elif beta != 1: { - .reg .${pftype} _tmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _tmp, [_cptr]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [_cptr], _tmp; -% else: - ld.global.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; -% endif + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _tmp; +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif } -% endif -% endif -% endfor -% endif +% endif +% endif +% endfor +% endif $L_END_CID_${cid}: % endfor diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index f58e9b3..24b0acb 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -1,14 +1,5 @@ <%inherit file='base'/> -<% -pftype = 'f32' if dtype == 'float' else 'f64' -dwidth_i = 4 if dtype == 'float' else 8 -fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' -has_zero_rows = any(jx == -1 for jx in afix) -bix_list = list(bix) -bix_pos = {kx: i for i, kx in enumerate(bix_list)} -%> - % if n is None: .visible .entry ${kname}(.param .u32 _n, .param .u64 _b, @@ -38,11 +29,11 @@ bix_pos = {kx: i for i, kx in enumerate(bix_list)} ld.param.u64 c, [_c]; { - .reg .u32 _grd<3>; - mov.u32 _grd0, %ntid.x; - mov.u32 _grd1, %ctaid.x; - mov.u32 _grd2, %tid.x; - mad.lo.u32 id, _grd0, _grd1, _grd2; + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; } setp.ge.u32 p1, id, n; @@ -52,117 +43,113 @@ bix_pos = {kx: i for i, kx in enumerate(bix_list)} cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, ${dwidth_i}, b; - mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; } ## Batch-load active B columns % for i, kx in enumerate(bix_list): -% if n is None: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; } -% else: +% else: ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; -% endif +% endif % endfor -% if beta != 0: +% if not beta_zero: ## Pre-load C so per-row completion is a plain store % for j in range(m): % if afix[j] != -1: -% if n is None: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; } -% else: +% else: ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*j*dwidth_i}]; -% endif +% endif % endif % endfor -% if beta != 0 and beta != 1: % for j in range(m): % if afix[j] != -1: mul.${pftype} csub${j}, csub${j}, ${float(beta)}; % endif % endfor % endif -% endif ## Main compute % for kx in bix_list: -% for j, jx in enumerate(A[:, kx]): -% if jx != 0: -% if preload_c: - fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; -% elif kx == afix[j]: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: +% if beta_zero and kx == afix[j]: mul.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; -% endif % endif -% if kx == alix[j]: -% if n is None: +% endif +% if kx == alix[j]: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], csub${j}; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; } -% else: +% else: st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; -% endif +% endif -% endif -% endfor +% endif +% endfor % endfor % if has_zero_rows: { - .reg .${pftype} _tmp; - mov.${pftype} _tmp, ${fzero}; + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; % for j, jx in enumerate(afix): -% if jx == -1 and beta == 0: -% if n is None: - { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], _tmp; - } -% else: - st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; -% endif +% if jx == -1 and beta_zero: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif -% elif jx == -1: -% if n is None: - { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _tmp, [_cptr]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [_cptr], _tmp; - } -% else: - ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; -% endif +% elif jx == -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; % endif +% endif % endfor } % endif diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako index 1ba2491..5d704de 100644 --- a/gimmik/kernels/ptx/cstream-ksplit.mako +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -1,9 +1,6 @@ <%inherit file='base'/> <% -pftype = 'f32' if dtype == 'float' else 'f64' -dwidth_i = 4 if dtype == 'float' else 8 -fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' kparts = partition(A, ksplit, by='cols') cchunks = chunk(list(range(m)), csz) cv_per_thread = -(-csz // ksplit) @@ -41,11 +38,11 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i ld.param.u64 c, [_c]; { - .reg .u32 _ctaid_x; - mov.u32 _ctaid_x, %ctaid.x; - mov.u32 tid_x, %tid.x; - mov.u32 tid_y, %tid.y; - mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; } setp.ge.u32 p1, id, n; @@ -55,17 +52,17 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, ${dwidth_i}, b; - mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; } { - .reg .u64 _tx_off; - mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; - mov.u64 csub_thread, _csub; - add.u64 csub_thread, csub_thread, _tx_off; + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; } % for bid, kbx in enumerate(kparts): @@ -78,98 +75,98 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i kbx_idx = {kx: i for i, kx in enumerate(kbx)} %> -% for cchunk_i, cchunk in enumerate(cchunks): +% for cchunk_i, cchunk in enumerate(cchunks): ## Chunk ${cchunk_i}: partial dot-product -% for row_idx, j in enumerate(cchunk): +% for row_idx, j in enumerate(cchunk): <% nz = [(kbx_idx[kx], kx, A[j, kx]) for kx in kbx if A[j, kx] != 0] owner_bid = row_idx % ksplit %> -% for (kxi, kx, jx) in nz: -% if kx not in loaded: -% if n is None: +% for (kxi, kx, jx) in nz: +% if kx not in loaded: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.global.nc.${pftype} bv${kxi}, [_bptr]; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${kxi}, [_bptr]; } -% else: - ld.global.nc.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; -% endif +% else: + ld.weak.global.cg.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; +% endif <% loaded.add(kx) %> -% endif -% endfor -% if nz: -% for i, (kxi, kx, jx) in enumerate(nz): -% if i == 0: +% endif +% endfor +% if nz: +% for kxi, kx, jx in nz: +% if loop.first: mul.${pftype} dotp, bv${kxi}, ${jx}; -% else: +% else: fma.rn.${pftype} dotp, bv${kxi}, ${jx}, dotp; -% endif -% endfor -% else: +% endif +% endfor +% else: mov.${pftype} dotp, ${fzero}; -% endif -% if owner_bid == bid: +% endif +% if owner_bid == bid: mov.${pftype} cv${row_idx // ksplit}, dotp; -% else: +% else: <% csub_idx = bid - (1 if bid > owner_bid else 0) %> st.shared.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}], dotp; -% endif -% endfor +% endif +% endfor bar.sync 0; ## Combine phase (owned rows only) -% for row_idx, j in enumerate(cchunk): -% if row_idx % ksplit == bid: +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: mov.${pftype} dotp, cv${row_idx // ksplit}; -% for other_bid in range(ksplit): -% if other_bid != bid: +% for other_bid in range(ksplit): +% if other_bid != bid: <% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> { - .reg .${pftype} _tmp; - ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; - add.${pftype} dotp, dotp, _tmp; + .reg .${pftype} _tmp; + ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; + add.${pftype} dotp, dotp, _tmp; } -% endif -% endfor -% if beta == 0: -% if n is None: +% endif +% endfor +% if beta_zero: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], dotp; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; } -% else: +% else: st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; -% endif -% else: +% endif +% else: { - .reg .${pftype} _ctmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _ctmp, [_cptr]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; - st.global.${pftype} [_cptr], _ctmp; -% else: - ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; - st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; -% endif + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif } -% endif +% endif -% endif -% endfor - bar.sync 0; +% endif % endfor + bar.sync 0; +% endfor $L_END_BID_${bid}: % endfor diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako index c82ebab..e6b4d75 100644 --- a/gimmik/kernels/ptx/cstream-w2.mako +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -1,15 +1,5 @@ <%inherit file='base'/> -<% -pftype = 'f64' -dwidth_i = 8 -fzero = '0d0000000000000000' -bix_list = list(bix) -bix_pos = {kx: i for i, kx in enumerate(bix_list)} -row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] - for j in range(m)] -%> - .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) { @@ -23,10 +13,10 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] ld.param.u64 c, [_c]; { - .reg .u32 _ctaid_x, _tid_x; - mov.u32 _ctaid_x, %ctaid.x; - mov.u32 _tid_x, %tid.x; - mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; + .reg .u32 _ctaid_x, _tid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 _tid_x, %tid.x; + mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; } setp.ge.u32 p1, id, n; @@ -36,59 +26,59 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, 16, b; - mad.lo.u64 c_base, _id64, 16, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, 16, b; + mad.lo.u64 c_base, _id64, 16, c; } ## Batch-load B column pairs % for i, kx in enumerate(bix_list): - ld.global.nc.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; + ld.weak.global.cg.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; % endfor ## Main compute: two parallel dot-product streams per thread % for j in range(m): -% if row_nz[j]: -% for i_nz, (kx, jx) in enumerate(row_nz[j]): -% if i_nz == 0: +% if row_nz[j]: +% for kx, jx in row_nz[j]: +% if loop.first: mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; -% endif -% endfor -% if beta == 0: +% endif +% endfor +% if beta_zero: st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; -% else: +% else: { - .reg .f64 _ca, _cb; - ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; - fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; - fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; - st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + .reg .f64 _ca, _cb; + ld.weak.global.cg.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; + st.weak.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; } -% endif +% endif -% else: +% else: ## Zero row of A -% if beta == 0: +% if beta_zero: { - .reg .f64 _z; - mov.f64 _z, ${fzero}; - st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; + .reg .f64 _z; + mov.f64 _z, ${fzero}; + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; } -% elif beta != 1: +% elif beta != 1: { - .reg .f64 _ca, _cb; - ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; - mul.f64 _ca, _ca, ${float(beta)}; - mul.f64 _cb, _cb, ${float(beta)}; - st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + .reg .f64 _ca, _cb; + ld.weak.global.cg.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.f64 _ca, _ca, ${float(beta)}; + mul.f64 _cb, _cb, ${float(beta)}; + st.weak.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; } -% endif % endif +% endif % endfor $L_EXIT: diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako index ec46934..726fe46 100644 --- a/gimmik/kernels/ptx/cstream.mako +++ b/gimmik/kernels/ptx/cstream.mako @@ -1,15 +1,5 @@ <%inherit file='base'/> -<% -pftype = 'f32' if dtype == 'float' else 'f64' -dwidth_i = 4 if dtype == 'float' else 8 -fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' -bix_list = list(bix) -bix_pos = {kx: i for i, kx in enumerate(bix_list)} -row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] - for j in range(m)] -%> - % if n is None: .visible .entry ${kname}(.param .u32 _n, .param .u64 _b, @@ -39,11 +29,11 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] ld.param.u64 c, [_c]; { - .reg .u32 _grd<3>; - mov.u32 _grd0, %ntid.x; - mov.u32 _grd1, %ctaid.x; - mov.u32 _grd2, %tid.x; - mad.lo.u32 id, _grd0, _grd1, _grd2; + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; } setp.ge.u32 p1, id, n; @@ -53,103 +43,103 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, ${dwidth_i}, b; - mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; } ## Batch-load active B columns % for i, kx in enumerate(bix_list): -% if n is None: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.global.nc.${pftype} bv${i}, [_bptr]; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; } -% else: - ld.global.nc.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; -% endif +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif % endfor ## Compute and store each output row % for j in range(m): -% if row_nz[j]: -% for i_nz, (kx, jx) in enumerate(row_nz[j]): -% if i_nz == 0: +% if row_nz[j]: +% for kx, jx in row_nz[j]: +% if loop.first: mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; -% endif -% endfor -% if beta == 0: -% if n is None: +% endif +% endfor +% if beta_zero: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], dotp; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; } -% else: +% else: st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; -% endif -% else: +% endif +% else: { - .reg .${pftype} _ctmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _ctmp, [_cptr]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; - st.global.${pftype} [_cptr], _ctmp; -% else: - ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; - st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; -% endif + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif } -% endif +% endif -% else: +% else: ## Zero row of A -% if beta == 0: +% if beta_zero: { - .reg .${pftype} _tmp; - mov.${pftype} _tmp, ${fzero}; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], _tmp; -% else: - st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; -% endif + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif } -% elif beta != 1: +% elif beta != 1: { - .reg .${pftype} _tmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _tmp, [_cptr]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [_cptr], _tmp; -% else: - ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; -% endif + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _tmp; +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif } -% endif % endif +% endif % endfor $L_EXIT: diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index ce8066d..8933e51 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,9 +1,5 @@ <%inherit file='base'/> -<% -fzero = '0d0000000000000000' -%> - .global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; @@ -16,14 +12,14 @@ fzero = '0d0000000000000000' .reg .u32 warp_n_base; .reg .u64 ag_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .f64 a_frag; + .reg .${pftype} a_frag; % for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; -% if not n_col_aligned: +% if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; -% endif - .reg .f64 b_frag_${nt}; - .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endif + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -57,11 +53,11 @@ fzero = '0d0000000000000000' add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } -% if not n_col_aligned: +% if not n_col_aligned: setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; -% endif +% endif % endfor // A thread base: &Ag[0] + lane*8 @@ -93,22 +89,22 @@ fzero = '0d0000000000000000' } % for mt in range(m_tiles): -% if pm_runtime(mt): +% if pm_runtime(mt): .reg .pred pm_${mt}; { .reg .u32 crow; add.u32 crow, r_div4, ${mt * 8}; setp.lt.u32 pm_${mt}, crow, ${m}; } -% endif +% endif % endfor % for nt in range(nn): -% for mt in range(m_tiles): -% if beta == 0: - mov.f64 c0_${nt}_${mt}, ${fzero}; - mov.f64 c1_${nt}_${mt}, ${fzero}; -% else: +% for mt in range(m_tiles): +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None @@ -118,56 +114,56 @@ fzero = '0d0000000000000000' { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.f64 c0_${nt}_${mt}, ${fzero}; - mov.f64 c1_${nt}_${mt}, ${fzero}; -% endif - ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} - ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} } -% endif -% endfor +% endif +% endfor % endfor % for ki in range(k_iters): -% for nt in range(nn): +% for nt in range(nn): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and ki == k_iters - 1) + k_tail = (k_rem != 0 and loop.parent.last) needs_zero = pvb is not None or k_tail pbrow = 'pbrow' if k_tail else None %> { .reg .u64 baddr; add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; -% if needs_zero: - mov.f64 b_frag_${nt}, ${fzero}; -% endif -% if k_tail: +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: .reg .pred pbrow; { .reg .u32 brow; add.u32 brow, r_mod4, ${ki * 4}; setp.lt.u32 pbrow, brow, ${k}; } -% endif - ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} } -% endfor -% for mt in range(m_tiles): - ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; -% for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 +% endfor +% for mt in range(m_tiles): + ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {c0_${nt}_${mt}, c1_${nt}_${mt}}, {a_frag}, {b_frag_${nt}}, {c0_${nt}_${mt}, c1_${nt}_${mt}}; -% endfor -% endfor +% endfor +% endfor % endfor % for nt in range(nn): -% for mt in range(m_tiles): +% for mt in range(m_tiles): <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None @@ -176,10 +172,10 @@ fzero = '0d0000000000000000' { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} - ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } -% endfor +% endfor % endfor $L_EXIT: diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index ec2f013..d1b72a8 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -26,7 +26,7 @@ bs = bool(block_stealing) .reg .u32 warp_n_base; .reg .u64 as_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .f64 a_frag; + .reg .${pftype} a_frag; % if bs: .reg .u32 ctaid; .reg .u32 mbar_a, work_a; @@ -34,11 +34,11 @@ bs = bool(block_stealing) % endif % for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; -% if not n_col_aligned: +% if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; -% endif - .reg .f64 b_frag_${nt}; - .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endif + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -69,30 +69,29 @@ bs = bool(block_stealing) % for ci in range(copy_v2_iters): <% base_pair = ci * blockx - is_last = ci == copy_v2_iters - 1 pairs_this = min(blockx, a_pairs - base_pair) %> { .reg .u32 pidx; .reg .u64 off64, gaddr, saddr; - .reg .f64 v0, v1; -% if is_last and pairs_this < blockx: + .reg .${pftype} v0, v1; +% if loop.last and pairs_this < blockx: .reg .pred plast; add.u32 pidx, tid, ${base_pair}; setp.lt.u32 plast, pidx, ${a_pairs}; - mul.wide.u32 off64, pidx, 16; + mul.wide.u32 off64, pidx, ${2 * dwidth_i}; add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; - @plast ld.global.nc.v2.f64 {v0, v1}, [gaddr]; - @plast st.shared.v2.f64 [saddr], {v0, v1}; -% else: + @plast ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; + @plast st.shared.v2.${pftype} [saddr], {v0, v1}; +% else: add.u32 pidx, tid, ${base_pair}; - mul.wide.u32 off64, pidx, 16; + mul.wide.u32 off64, pidx, ${2 * dwidth_i}; add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; - ld.global.nc.v2.f64 {v0, v1}, [gaddr]; - st.shared.v2.f64 [saddr], {v0, v1}; -% endif + ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; + st.shared.v2.${pftype} [saddr], {v0, v1}; +% endif } % endfor % if a_pairs_tail: @@ -100,12 +99,12 @@ bs = bool(block_stealing) { .reg .pred plast; .reg .u64 gaddr, saddr; - .reg .f64 v; + .reg .${pftype} v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(a_elems-1) * 8}; - add.u64 saddr, a_smem_base, ${(a_elems-1) * 8}; - @plast ld.global.nc.f64 v, [gaddr]; - @plast st.shared.f64 [saddr], v; + add.u64 gaddr, a_glb_base, ${(a_elems-1) * dwidth_i}; + add.u64 saddr, a_smem_base, ${(a_elems-1) * dwidth_i}; + @plast ld.weak.global.cg.${pftype} v, [gaddr]; + @plast st.shared.${pftype} [saddr], v; } % endif } @@ -121,14 +120,14 @@ bs = bool(block_stealing) } % for mt in range(m_tiles): -% if pm_runtime(mt): +% if pm_runtime(mt): .reg .pred pm_${mt}; { .reg .u32 crow; add.u32 crow, r_div4, ${mt * 8}; setp.lt.u32 pm_${mt}, crow, ${m}; } -% endif +% endif % endfor % if bs: @@ -164,11 +163,11 @@ $L_LOOP: add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } -% if not n_col_aligned: +% if not n_col_aligned: setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; -% endif +% endif % endfor { @@ -190,11 +189,11 @@ $L_LOOP: } % for nt in range(nn): -% for mt in range(m_tiles): -% if beta == 0: - mov.f64 c0_${nt}_${mt}, 0d0000000000000000; - mov.f64 c1_${nt}_${mt}, 0d0000000000000000; -% else: +% for mt in range(m_tiles): +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None @@ -204,56 +203,56 @@ $L_LOOP: { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.f64 c0_${nt}_${mt}, 0d0000000000000000; - mov.f64 c1_${nt}_${mt}, 0d0000000000000000; -% endif - ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} - ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} } -% endif -% endfor +% endif +% endfor % endfor % for ki in range(k_iters): -% for nt in range(nn): +% for nt in range(nn): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and ki == k_iters - 1) + k_tail = (k_rem != 0 and loop.parent.last) needs_zero = pvb is not None or k_tail pbrow = 'pbrow' if k_tail else None %> { .reg .u64 baddr; add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; -% if needs_zero: - mov.f64 b_frag_${nt}, 0d0000000000000000; -% endif -% if k_tail: +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: .reg .pred pbrow; { .reg .u32 brow; add.u32 brow, r_mod4, ${ki * 4}; setp.lt.u32 pbrow, brow, ${k}; } -% endif - ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} } -% endfor -% for mt in range(m_tiles): - ld.shared.f64 a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; -% for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 +% endfor +% for mt in range(m_tiles): + ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {c0_${nt}_${mt}, c1_${nt}_${mt}}, {a_frag}, {b_frag_${nt}}, {c0_${nt}_${mt}, c1_${nt}_${mt}}; -% endfor -% endfor +% endfor +% endfor % endfor % for nt in range(nn): -% for mt in range(m_tiles): +% for mt in range(m_tiles): <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None @@ -262,10 +261,10 @@ $L_LOOP: { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} - ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } -% endfor +% endfor % endfor % if bs: diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index e4b576a..a7c9f88 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -1,8 +1,4 @@ <%inherit file='base'/> -<% -mbar_maxwait = '0x989680' -direct_store = (beta == 0) -%> <%def name="producer_init_setup()"> // Producer warp: initial A bulk-copy + B load for ctaid_x's work @@ -67,17 +63,17 @@ $L_WAIT_BRDY: } % endfor -% if direct_store: - // direct_store: skip shared-staging entirely; compute warps store - // MMA outputs straight to global C with N-tail predication. +% if beta_zero: + // beta=0: skip shared-staging entirely; compute warps store MMA + // outputs straight to global C with N-tail predication. .reg .u64 c_glob_addr; - ld.param.u64 c_glob_addr, [_c]; + ld.param.u64 c_glob_addr, [c_desc]; cvta.to.global.u64 c_glob_addr, c_glob_addr; % else: .reg .b32 c_thr_smem; { .reg .b32 t1, ccol_b; - mul.lo.u32 t1, base_crow, ${n_per_cta * 8}; + mul.lo.u32 t1, base_crow, ${n_per_cta * dwidth_i}; shl.b32 ccol_b, base_ccol, 3; add.u32 c_thr_smem, c_smem, t1; add.u32 c_thr_smem, c_thr_smem, ccol_b; @@ -86,91 +82,91 @@ $L_WAIT_BRDY: // Zero accumulators % for mt in range(m_tiles): -% for nt in range(nn): - .reg .f64 d_x_${mt}_${nt}, d_y_${mt}_${nt}; - mov.f64 d_x_${mt}_${nt}, 0d0000000000000000; - mov.f64 d_y_${mt}_${nt}, 0d0000000000000000; -% endfor +% for nt in range(nn): + .reg .${pftype} d_x_${mt}_${nt}, d_y_${mt}_${nt}; + mov.${pftype} d_x_${mt}_${nt}, ${fzero}; + mov.${pftype} d_y_${mt}_${nt}, ${fzero}; +% endfor % endfor - .reg .f64 a_f; + .reg .${pftype} a_f; % for mt in range(m_tiles): -% for kt in range(k_iters): +% for kt in range(k_iters): <% - k_tail = (k_rem != 0 and kt == k_iters - 1) + k_tail = (k_rem != 0 and loop.last) %> { .reg .b32 a_a; - add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * 8}; - ld.shared.f64 a_f, [a_a]; -% if k_tail: + add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * dwidth_i}; + ld.shared.${pftype} a_f, [a_a]; +% if k_tail: .reg .pred pbrow_${mt}_${kt}; { .reg .b32 brow; add.u32 brow, base_brow, ${4 * kt}; setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; } -% endif -% for nt in range(nn): +% endif +% for nt in range(nn): { .reg .b32 b_a, b_row; - .reg .f64 b_f; + .reg .${pftype} b_f; add.u32 b_row, base_brow, ${4 * kt}; - mul.lo.u32 b_row, b_row, ${n_per_cta * 8}; + mul.lo.u32 b_row, b_row, ${n_per_cta * dwidth_i}; add.u32 b_a, b_thr_a_${nt}, b_row; -% if k_tail: - mov.f64 b_f, 0d0000000000000000; - @pbrow_${mt}_${kt} ld.shared.f64 b_f, [b_a]; -% else: - ld.shared.f64 b_f, [b_a]; -% endif - mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 +% if k_tail: + mov.${pftype} b_f, ${fzero}; + @pbrow_${mt}_${kt} ld.shared.${pftype} b_f, [b_a]; +% else: + ld.shared.${pftype} b_f, [b_a]; +% endif + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; } -% endfor +% endfor } -% endfor +% endfor % endfor -% if direct_store: +% if beta_zero: .reg .u64 c_thr_glob_base; { .reg .u32 thr_col_off, thr_addr_off_lo; add.u32 thr_col_off, base_ccol, n_start_curr; mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; .reg .u64 thr_byte_off; - mul.wide.u32 thr_byte_off, thr_addr_off_lo, 8; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, ${dwidth_i}; add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; } -% for mt in range(m_tiles): +% for mt in range(m_tiles): <% row_tail = (m_pad > m) and ((mt + 1) * 8 > m) %> -% if row_tail: +% if row_tail: .reg .pred p_row_${mt}; { .reg .b32 crow; add.u32 crow, base_crow, ${8 * mt}; setp.lt.u32 p_row_${mt}, crow, ${m}; } -% endif -% for nt in range(nn): +% endif +% for nt in range(nn): { .reg .pred p_st; .reg .u32 g_ccol; add.u32 g_ccol, base_ccol, ${8 * nt}; add.u32 g_ccol, g_ccol, n_start_curr; setp.lt.u32 p_st, g_ccol, ${n}; -% if row_tail: +% if row_tail: and.pred p_st, p_st, p_row_${mt}; -% endif - .reg .u64 c_addr; - add.u64 c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * 8}; - @p_st st.global.v2.f64 [c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; +% endif + .reg .u64 _c_addr; + add.u64 _c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * dwidth_i}; + @p_st st.weak.global.v2.${pftype} [_c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; } -% endfor -% endfor +% endfor +% endfor % else: // Wait until producer's prev-iter TMA-store of C has drained. { @@ -182,18 +178,18 @@ $L_WAIT_CSTORE: // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows // are dropped by the C tensor map. -% for mt in range(m_tiles): -% for nt in range(nn): +% for mt in range(m_tiles): +% for nt in range(nn): { .reg .b32 csaddr; add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; - st.shared.v2.f64 [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + st.shared.v2.${pftype} [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; } -% endfor -% endfor +% endfor +% endfor % endif -% if not direct_store: +% if not beta_zero: bar.sync 1, ${comp_threads}; fence.proxy.async.shared::cta; { @@ -260,7 +256,7 @@ $L_WAIT_WNEW_D: } bar.warp.sync 0xffffffff; -% if not direct_store: +% if not beta_zero: // TMA reduce+store of C (beta=1 only; beta=0 uses direct global // stores from compute warps, so the producer does no C work). { @@ -329,11 +325,9 @@ $L_AFTER_CTRL: ${', '.join(a_u64)} }; .extern .shared .align 128 .b8 ${kname}_dynm[]; -.const .align 64 .b8 ${kname}_bdesc[128]; -.const .align 64 .b8 ${kname}_cdesc[128]; -.visible .entry ${kname}(.param .u64 _b, - .param .u64 _c) +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 c_desc) .maxntid ${blockx_total}, 1, 1 { .reg .b32 tid, warp, lane, phase, ctaid_x; @@ -369,8 +363,8 @@ $L_AFTER_CTRL: add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; - cvta.const.u64 bdesc_addr, ${kname}_bdesc; - cvta.const.u64 cdesc_addr, ${kname}_cdesc; + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 cdesc_addr, [c_desc]; setp.eq.u32 p_tid0, tid, 0; @@ -401,7 +395,7 @@ $L_AFTER_CTRL: } bar.sync 0; - // Compute-warp lane geometry (cheap; all warps execute uniformly) + // Compute-warp lane geometry { .reg .b32 t, w_n_base; and.b32 base_brow, lane, 3; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index ad429e8..1f46384 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -12,35 +12,53 @@ class PTXMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - @staticmethod - def is_sparse_suitable(arr): + DENSE_SMEM_MAX = 200*1024 + PTX_SM = {(8, 0), (9, 0), (10, 0), (10, 3), (12, 0), (12, 1)} + + @classmethod + def is_sparse_suitable(cls, arr, cc): nnz = int(np.count_nonzero(arr)) nuq = int(len(np.unique(np.abs(arr)))) density = nnz / arr.size - return (nuq <= 28) or (density <= 0.15) + return ((nuq <= 28) or (density <= 0.15)) and cc in cls.PTX_SM - # Shape/arch gate for dense DMMA; n/ldb/ldc are validated at generate time - @staticmethod - def is_dense_suitable(arr, dtype, cc): - return (np.dtype(dtype) == np.float64 - and cc is not None and cc >= (9, 0) + @classmethod + def is_dense_suitable(cls, arr, cc): + cc_appropriate = cc in cls.PTX_SM and cc >= (9, 0) + return (arr.dtype == np.float64 and cc_appropriate and arr.shape[0] <= 128 and arr.shape[1] <= 128) @classmethod - def is_suitable(cls, arr, dtype, cc): - return (cls.is_sparse_suitable(arr) - or cls.is_dense_suitable(arr, dtype, cc)) + def is_suitable(cls, arr, cc): + return cls.is_sparse_suitable(arr, cc) or cls.is_dense_suitable(arr, cc) def _kernel_generators(self, dtype, dsize, *, compute_capability=None): - base_args = {'cc': compute_capability, - 'pred_emit': self._pred_emit} - - yield from self._sparse_kernel_generators(dtype, dsize, base_args) - yield from self._dense_kernel_generators(dtype, dsize, base_args) + cc = compute_capability or (0, 0) + base_args = {'cc': cc, + 'pred_emit': self._pred_emit, + 'pftype': 'f32' if dtype == 'float' else 'f64', + 'dwidth_i': 4 if dtype == 'float' else 8, + 'fzero': ('0f00000000' if dtype == 'float' + else '0d0000000000000000'), + 'beta_zero': self.beta == 0, + 'mbar_maxwait': '0x989680' + } + + if self.is_sparse_suitable(self.A, cc): + yield from self._sparse_kernel_generators(dtype, dsize, base_args) + + if self.is_dense_suitable(self.A, cc): + yield from self._dense_kernel_generators(dtype, dsize, base_args) def _sparse_kernel_generators(self, dtype, dsize, base_args): - if not self.is_sparse_suitable(self.A): - return + # Sparse-shared template constants + base_args = base_args | { + 'bix_list': list(self.bix), + 'bix_pos': self.bix, + 'has_zero_rows': bool(self.has_zero_rows), + 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) + if self.A[j, kx] != 0] for j in range(self.m)], + } # B loading, C streaming kernel yield ('cstream', base_args, {'desc': 'cstream'}) @@ -93,145 +111,124 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): def _dense_kernel_generators(self, dtype, dsize, base_args): cc = base_args['cc'] or (0, 0) - if not (self.is_dense_suitable(self.A, dtype, cc) - and self.n is not None): - return - # Some kernels can optional steal blocks - bs_default = cc >= (10, 0) - - if cc >= (10, 0): - # Warp specialised is uniformly better on sm_100+, so no need to JIT - # other versions + # Block stealing requires sm_100+ + block_steal = cc >= (10, 0) + if block_steal: dense_configs = [('dense-mma-smem-gA', 4, 4)] else: dense_configs = [ ('dense-mma-smem-gA', 1, 8), ('dense-mma-smem-gA', 2, 4), ('dense-mma-smem-gA', 4, 4), - ('dense-mma-gAd', 2, 2), - ('dense-mma-gAd', 4, 2), + ('dense-mma-gAd', 2, 2), + ('dense-mma-gAd', 4, 2), ] for tpl, nn, w in dense_configs: blkx = 32 * w - n_per_cta = 8 * nn * w - if n_per_cta > self.n: + if (n_per_cta := 8 * nn * w) > self.n: continue - bs = (tpl == 'dense-mma-smem-gA') and bs_default setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) args = (base_args | {'warps_per_cta': w, 'nn': nn, - 'block_stealing': bs} | setup) + 'block_stealing': block_steal} | setup) meta = { 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), - 'desc': f'{tpl}/nn{nn}-w{w}{"-bs" if bs else ""}', + 'desc': f'{tpl}/nn{nn}-w{w}{'-bs' if block_steal else ''}', } yield (tpl, args, meta) - # Warp-specialised dense DMMA - if cc >= (10, 0): + # Warp-specialised dense DMMA, required block stealing + if block_steal: yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) def _dense_ws_kernel_generators(self, dtype, dsize, base_args): - m_pad = -(-self.m // 8) * 8 - k_pad = -(-self.k // 4) * 4 - - # (nn, w_compute) -- block has w_compute + 2 warps (producer, stealer) + # (nn, compute) -- block has compute + 2 warps (producer, stealer) ws_configs = [(1, 4), (2, 4), (4, 4)] for nn, w in ws_configs: - n_per_cta = 8 * nn * w - if n_per_cta > self.n: + if (n_per_cta := 8 * nn * w) > self.n: continue - blkx = 32 * (w + 2) + setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) - ws_layout = self._dense_ws_layout( - n_comp_warps=w, n_per_cta=n_per_cta, - m_pad=m_pad, k_pad=k_pad, a_elems=setup['a_elems'] - ) + ws_setup = self._dense_ws_setup(setup, n_comp_warps=w) - if ws_layout['dynm_total_bytes'] > 200 * 1024: + if ws_setup['dynm_total_bytes'] > self.DENSE_SMEM_MAX: continue - args = (base_args - | {'warps_per_cta': w, 'nn': nn} - | setup | ws_layout) + args = base_args | {'nn': nn} | setup | ws_setup meta = { - 'block': (blkx, 1, 1), + 'block': (32 * (w + 2), 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), 'desc': f'dense-mma-ws/nn{nn}-w{w}', - 'ws_tensor_map': True, - 'ws_n_per_cta': n_per_cta, - 'ws_k_pad': k_pad, - 'ws_m_pad': m_pad, - 'dynamic_shared': ws_layout['dynm_total_bytes'], + 'ws_b_tile': (n_per_cta, setup['k_pad']), + 'dynamic_shared': ws_setup['dynm_total_bytes'], } + if self.beta != 0: + meta |= {'ws_out_tile': (n_per_cta, setup['m_pad'])} yield ('dense-mma-ws', args, meta) @staticmethod - def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): - n_total_warps = n_comp_warps + 2 - blockx_total = 32 * n_total_warps - - b_tile_bytes = k_pad * n_per_cta * 8 - c_tile_bytes = m_pad * n_per_cta * 8 - a_bytes = a_elems * 8 - - smem_size = {'b1': b_tile_bytes, 'b2': b_tile_bytes, 'c': c_tile_bytes, - 'a': a_bytes, 'wid': 16} - smem_off, off = {}, 0 - for k, v in smem_size.items(): - off = (off + 15) & ~15 - smem_off[f'{k}_off'] = off - off += v - - mbar_names = ('tma', 'bready', 'cready', 'cstored', - 'steal', 'wid_new', 'wid_used') - for k in mbar_names: - smem_off[f'{k}_mbar_off'] = off + def _dsmem_alloc(regions, mbars, align=16): + out, off = {}, 0 + for name, size in regions: + off = (off + align - 1) & ~(align - 1) + out[f'{name}_off'] = off + off += size + for name in mbars: + out[f'{name}_mbar_off'] = off off += 8 + total = (off + align - 1) & ~(align - 1) + return out, total - # Pad total to 16-byte multiple - dynm_total_bytes = (off + 15) & ~15 - - params = {'n_comp_warps': n_comp_warps, - 'blockx_total': blockx_total, - 'prod_warp': n_comp_warps, - 'steal_warp': n_comp_warps + 1, - 'comp_threads': 32 * n_comp_warps, - 'm_pad': m_pad, - 'k_pad': k_pad, - 'b_tile_doubles': k_pad * n_per_cta, - 'b_tile_bytes': b_tile_bytes, - 'c_tile_doubles': m_pad * n_per_cta, - 'c_mtile_smem_stride': 8 * n_per_cta * 8, - 'c_ntile_smem_stride': 8 * 8, - 'dynm_total_bytes': dynm_total_bytes, - } - params |= smem_off - return params + @classmethod + def _dense_ws_setup(cls, setup, *, n_comp_warps): + n_per_cta = setup['n_per_cta'] + b_tile_bytes = setup['k_pad'] * n_per_cta * 8 + c_tile_bytes = setup['m_pad'] * n_per_cta * 8 + a_bytes = setup['a_elems'] * 8 + + regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), + ('c', c_tile_bytes), ('a', a_bytes), ('wid', 16)] + mbars = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + offsets, dynm_total_bytes = cls._dsmem_alloc(regions, mbars) + + return offsets | { + 'n_comp_warps': n_comp_warps, + 'blockx_total': 32 * (n_comp_warps + 2), + 'prod_warp': n_comp_warps, + 'steal_warp': n_comp_warps + 1, + 'comp_threads': 32 * n_comp_warps, + 'b_tile_bytes': b_tile_bytes, + 'c_mtile_smem_stride': 8 * n_per_cta * 8, + 'c_ntile_smem_stride': 8 * 8, + 'dynm_total_bytes': dynm_total_bytes, + } def _dense_mma_setup(self, *, nn, warps_per_cta): a = self.A m, k = a.shape m_tiles = -(-m // 8) - k_rem = k % 4 - k_iters = (k + (4 - k_rem if k_rem else 0)) // 4 + k_iters = -(-k // 4) + k_rem = k % 4 - # A in fragment layout: lane l -> A[m_tile*8 + l/4][k_iter*4 + l%4] + # A in fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] + # DMMA tiles are 8x8x4 so this loads a 8x4 tile, flattens it and + # and packs it as uint64 as a more robust way of storing the values in + # the template a_u64 = [] - for m_tile in range(m_tiles): - for k_iter in range(k_iters): + for mt in range(m_tiles): + for kt in range(k_iters): for lane in range(32): - i = m_tile * 8 + lane // 4 - j = k_iter * 4 + lane % 4 + i = mt * 8 + lane // 4 + j = kt * 4 + lane % 4 v = float(a[i, j]) if (i < m and j < k) else 0.0 - u = struct.unpack(' Date: Thu, 21 May 2026 06:26:38 -0700 Subject: [PATCH 08/21] General cleanups and moved smem to pyfr --- gimmik/cuda.py | 7 +++ gimmik/kernels/ptx/bstream-msplit.mako | 2 +- gimmik/kernels/ptx/bstream.mako | 10 ++--- gimmik/kernels/ptx/cstream-w2.mako | 12 +++--- gimmik/kernels/ptx/cstream.mako | 8 ++-- gimmik/kernels/ptx/dense-mma-gAd.mako | 4 +- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 4 +- gimmik/kernels/ptx/dense-mma-ws.mako | 4 +- gimmik/ptx.py | 52 +++++++++++------------ 9 files changed, 53 insertions(+), 50 deletions(-) diff --git a/gimmik/cuda.py b/gimmik/cuda.py index b18c509..e40179c 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -8,6 +8,13 @@ class CUDAMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} + @staticmethod + def is_suitable(arr): + nnz = np.count_nonzero(arr) + nuq = len(np.unique(np.abs(arr))) + density = nnz / arr.size + return (nuq <= 28) or (density <= 0.15) + def _kernel_generators(self, dtype, dsize, *, compute_capability=None): # B loading, C streaming kernel yield ('cstream', {}, {}) diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako index 530b19f..2ef85e9 100644 --- a/gimmik/kernels/ptx/bstream-msplit.mako +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -2,7 +2,7 @@ <% mx = partition(A, into=msplit, by='rows') -bchunks = chunk(bix_list, bsz) +bchunks = chunk(bix, bsz) m_per_group = max(len(mcx) for mcx in mx) bsub_bytes = 2 * bsz * blockx * dwidth_i def bsub_off(buf, idx): diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index 24b0acb..45eb1a7 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -17,7 +17,7 @@ % endif .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .${pftype} csub<${m}>, bv<${len(bix_list)}>; + .reg .${pftype} csub<${m}>, bv<${len(bix)}>; .reg .pred p1; % if n is None: @@ -50,7 +50,7 @@ } ## Batch-load active B columns -% for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix): % if n is None: { .reg .u32 _boff; @@ -89,13 +89,13 @@ % endif ## Main compute -% for kx in bix_list: +% for kx in bix: % for j, jx in enumerate(A[:, kx]): % if jx != 0: % if beta_zero and kx == afix[j]: - mul.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}; + mul.${pftype} csub${j}, bv${bix[kx]}, ${jx}; % else: - fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; + fma.rn.${pftype} csub${j}, bv${bix[kx]}, ${jx}, csub${j}; % endif % endif % if kx == alix[j]: diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako index e6b4d75..ce7301d 100644 --- a/gimmik/kernels/ptx/cstream-w2.mako +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -5,7 +5,7 @@ { .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .f64 bv_a<${len(bix_list)}>, bv_b<${len(bix_list)}>, dotp_a, dotp_b; + .reg .f64 bv_a<${len(bix)}>, bv_b<${len(bix)}>, dotp_a, dotp_b; .reg .pred p1; mov.u32 n, ${-(-n // 2)}; @@ -33,7 +33,7 @@ } ## Batch-load B column pairs -% for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix): ld.weak.global.cg.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; % endfor @@ -42,11 +42,11 @@ % if row_nz[j]: % for kx, jx in row_nz[j]: % if loop.first: - mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; - mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; + mul.f64 dotp_a, bv_a${bix[kx]}, ${jx}; + mul.f64 dotp_b, bv_b${bix[kx]}, ${jx}; % else: - fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; - fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; + fma.rn.f64 dotp_a, bv_a${bix[kx]}, ${jx}, dotp_a; + fma.rn.f64 dotp_b, bv_b${bix[kx]}, ${jx}, dotp_b; % endif % endfor % if beta_zero: diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako index 726fe46..9ce4c4d 100644 --- a/gimmik/kernels/ptx/cstream.mako +++ b/gimmik/kernels/ptx/cstream.mako @@ -17,7 +17,7 @@ % endif .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .${pftype} bv<${len(bix_list)}>, dotp; + .reg .${pftype} bv<${len(bix)}>, dotp; .reg .pred p1; % if n is None: @@ -50,7 +50,7 @@ } ## Batch-load active B columns -% for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix): % if n is None: { .reg .u32 _boff; @@ -69,9 +69,9 @@ % if row_nz[j]: % for kx, jx in row_nz[j]: % if loop.first: - mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; + mul.${pftype} dotp, bv${bix[kx]}, ${jx}; % else: - fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; + fma.rn.${pftype} dotp, bv${bix[kx]}, ${jx}, dotp; % endif % endfor % if beta_zero: diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index 8933e51..0996a17 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -125,7 +125,7 @@ % endfor % endfor -% for ki in range(k_iters): +% for ki in range(k_tiles): % for nt in range(nn): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None @@ -151,7 +151,7 @@ } % endfor % for mt in range(m_tiles): - ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; + ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; % for nt in range(nn): mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {c0_${nt}_${mt}, c1_${nt}_${mt}}, diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index d1b72a8..4516831 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -214,7 +214,7 @@ $L_LOOP: % endfor % endfor -% for ki in range(k_iters): +% for ki in range(k_tiles): % for nt in range(nn): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None @@ -240,7 +240,7 @@ $L_LOOP: } % endfor % for mt in range(m_tiles): - ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; + ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; % for nt in range(nn): mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {c0_${nt}_${mt}, c1_${nt}_${mt}}, diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index a7c9f88..de4314f 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -91,13 +91,13 @@ $L_WAIT_BRDY: .reg .${pftype} a_f; % for mt in range(m_tiles): -% for kt in range(k_iters): +% for kt in range(k_tiles): <% k_tail = (k_rem != 0 and loop.last) %> { .reg .b32 a_a; - add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * dwidth_i}; + add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_tiles) * dwidth_i}; ld.shared.${pftype} a_f, [a_a]; % if k_tail: .reg .pred pbrow_${mt}_${kt}; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index 1f46384..a9c049d 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -import struct import numpy as np @@ -12,13 +11,12 @@ class PTXMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - DENSE_SMEM_MAX = 200*1024 PTX_SM = {(8, 0), (9, 0), (10, 0), (10, 3), (12, 0), (12, 1)} @classmethod def is_sparse_suitable(cls, arr, cc): - nnz = int(np.count_nonzero(arr)) - nuq = int(len(np.unique(np.abs(arr)))) + nnz = np.count_nonzero(arr) + nuq = len(np.unique(np.abs(arr))) density = nnz / arr.size return ((nuq <= 28) or (density <= 0.15)) and cc in cls.PTX_SM @@ -32,9 +30,12 @@ def is_dense_suitable(cls, arr, cc): def is_suitable(cls, arr, cc): return cls.is_sparse_suitable(arr, cc) or cls.is_dense_suitable(arr, cc) - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + smem_info=None): cc = compute_capability or (0, 0) + smem_info = smem_info or (48*1024, 48*1024) base_args = {'cc': cc, + 'smem_info': smem_info, 'pred_emit': self._pred_emit, 'pftype': 'f32' if dtype == 'float' else 'f64', 'dwidth_i': 4 if dtype == 'float' else 8, @@ -53,8 +54,6 @@ def _kernel_generators(self, dtype, dsize, *, compute_capability=None): def _sparse_kernel_generators(self, dtype, dsize, base_args): # Sparse-shared template constants base_args = base_args | { - 'bix_list': list(self.bix), - 'bix_pos': self.bix, 'has_zero_rows': bool(self.has_zero_rows), 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) if self.A[j, kx] != 0] for j in range(self.m)], @@ -126,10 +125,10 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): ] for tpl, nn, w in dense_configs: - blkx = 32 * w if (n_per_cta := 8 * nn * w) > self.n: continue setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + blkx = 32 * w args = (base_args | {'warps_per_cta': w, 'nn': nn, 'block_stealing': block_steal} | setup) meta = { @@ -144,6 +143,8 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) def _dense_ws_kernel_generators(self, dtype, dsize, base_args): + static_max, dynamic_max = base_args['smem_info'] + # (nn, compute) -- block has compute + 2 warps (producer, stealer) ws_configs = [(1, 4), (2, 4), (4, 4)] for nn, w in ws_configs: @@ -153,12 +154,13 @@ def _dense_ws_kernel_generators(self, dtype, dsize, base_args): setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) ws_setup = self._dense_ws_setup(setup, n_comp_warps=w) - if ws_setup['dynm_total_bytes'] > self.DENSE_SMEM_MAX: + if ws_setup['dynm_total_bytes'] > dynamic_max: continue + blkx = 32 * (w + 2) args = base_args | {'nn': nn} | setup | ws_setup meta = { - 'block': (32 * (w + 2), 1, 1), + 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), 'desc': f'dense-mma-ws/nn{nn}-w{w}', 'ws_b_tile': (n_per_cta, setup['k_pad']), @@ -209,23 +211,17 @@ def _dense_ws_setup(cls, setup, *, n_comp_warps): def _dense_mma_setup(self, *, nn, warps_per_cta): a = self.A m, k = a.shape - m_tiles = -(-m // 8) - k_iters = -(-k // 4) + m_tiles = (m + 7) // 8 + k_tiles = (k + 3) // 4 k_rem = k % 4 - # A in fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] - # DMMA tiles are 8x8x4 so this loads a 8x4 tile, flattens it and - # and packs it as uint64 as a more robust way of storing the values in - # the template - a_u64 = [] - for mt in range(m_tiles): - for kt in range(k_iters): - for lane in range(32): - i = mt * 8 + lane // 4 - j = kt * 4 + lane % 4 - v = float(a[i, j]) if (i < m and j < k) else 0.0 - u, = struct.unpack(' A[mt*8 + l//4][kt*4 + l%4] + # i.e. an (m_tiles, k_tiles) grid of row-major 8x4 tiles, packed as + # uint64 + a_pad = np.zeros((m_tiles*8, k_tiles*4), dtype=np.float64) + a_pad[:m, :k] = a + tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).transpose(0, 2, 1, 3) + a_u64 = [f'0x{u:016x}' for u in tiles.view(np.uint64).ravel()] n_per_warp = 8 * nn n_per_cta = warps_per_cta * n_per_warp @@ -237,12 +233,12 @@ def pm_runtime(mt): return { 'm_tiles': m_tiles, - 'k_iters': k_iters, + 'k_tiles': k_tiles, 'k_rem': k_rem, 'm_pad': m_tiles * 8, - 'k_pad': k_iters * 4, + 'k_pad': k_tiles * 4, 'a_u64': a_u64, - 'a_elems': m_tiles * k_iters * 32, + 'a_elems': m_tiles * k_tiles * 32, 'n_per_warp': n_per_warp, 'n_per_cta': n_per_cta, 'frag_stride_bytes': 32 * 8, From 0e86053d3de9c70cfbcf5b6c7c85d3912aea771e Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Thu, 21 May 2026 09:26:39 -0700 Subject: [PATCH 09/21] Fixed missing import --- gimmik/cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gimmik/cuda.py b/gimmik/cuda.py index e40179c..8afc755 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +import numpy as np + from gimmik.base import MatMul From 1f62b5f4aa09a5b77973b34ad64ad9251bb55135 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Thu, 21 May 2026 09:30:40 -0700 Subject: [PATCH 10/21] Fixed additional args --- gimmik/cuda.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gimmik/cuda.py b/gimmik/cuda.py index 8afc755..9e1da43 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -17,7 +17,8 @@ def is_suitable(arr): density = nnz / arr.size return (nuq <= 28) or (density <= 0.15) - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + **kwargs): # B loading, C streaming kernel yield ('cstream', {}, {}) From 79f41cb9a689da9d2f89c8d85143911e4eccf0cf Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Fri, 22 May 2026 07:15:58 -0700 Subject: [PATCH 11/21] Cleanup and added PTX Version to handle older drivers. --- gimmik/kernels/ptx/base.mako | 2 +- gimmik/kernels/ptx/dense-mma-gAd.mako | 2 +- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 14 +-- gimmik/kernels/ptx/dense-mma-ws.mako | 6 +- gimmik/ptx.py | 104 +++++++++++++--------- 5 files changed, 75 insertions(+), 53 deletions(-) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako index 71eb414..dbd8433 100644 --- a/gimmik/kernels/ptx/base.mako +++ b/gimmik/kernels/ptx/base.mako @@ -1,4 +1,4 @@ -.version 8.6 +.version ${ptx[0]}.${ptx[1]} .target sm_${cc[0]}${cc[1]}${'a' if cc[0] >= 9 else ''} .address_size 64 ${next.body()} diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index 0996a17..3df43c0 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,6 +1,6 @@ <%inherit file='base'/> -.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { ${', '.join(a_u64)} }; diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index 4516831..9a88b64 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -3,8 +3,8 @@ <% # Cooperative-copy params (gA-only) blockx = 32 * warps_per_cta -a_pairs = a_elems // 2 -a_pairs_tail = a_elems % 2 +a_pairs = m_tiles*k_tiles*32 // 2 +a_pairs_tail = m_tiles*k_tiles*32 % 2 copy_v2_iters = (a_pairs + blockx - 1) // blockx bs = bool(block_stealing) %> @@ -13,10 +13,10 @@ bs = bool(block_stealing) .shared .align 8 .b64 ${kname}_mbar; .shared .align 16 .b8 ${kname}_workid[16]; % endif -.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { ${', '.join(a_u64)} }; -.shared .align 16 .b64 ${kname}_As[${a_elems}]; +.shared .align 16 .b64 ${kname}_As[${m_tiles*k_tiles*32}]; .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) @@ -95,14 +95,14 @@ bs = bool(block_stealing) } % endfor % if a_pairs_tail: - // Tail element (only when a_elems is odd) + // Tail element (only when m_tiles*k_tiles*32 is odd) { .reg .pred plast; .reg .u64 gaddr, saddr; .reg .${pftype} v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(a_elems-1) * dwidth_i}; - add.u64 saddr, a_smem_base, ${(a_elems-1) * dwidth_i}; + add.u64 gaddr, a_glb_base, ${(m_tiles*k_tiles*32-1) * dwidth_i}; + add.u64 saddr, a_smem_base, ${(m_tiles*k_tiles*32-1) * dwidth_i}; @plast ld.weak.global.cg.${pftype} v, [gaddr]; @plast st.shared.${pftype} [saddr], v; } diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index de4314f..e151372 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -10,11 +10,11 @@ mov.u64 a_glb, ${kname}_Ag; cvta.to.global.u64 a_glb, a_glb; @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes - [a_smem], [a_glb], ${a_elems * 8}, [tma_mbar]; + [a_smem], [a_glb], ${m_tiles*k_tiles*32 * 8}, [tma_mbar]; @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 - [tma_mbar], ${b_tile_bytes + a_elems * 8}; + [tma_mbar], ${b_tile_bytes + m_tiles*k_tiles*32 * 8}; bar.warp.sync 0xffffffff; .reg .b64 state; .reg .pred p1; @@ -321,7 +321,7 @@ $L_WAIT_WUSED: $L_AFTER_CTRL: -.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { ${', '.join(a_u64)} }; .extern .shared .align 128 .b8 ${kname}_dynm[]; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index a9c049d..f7be8f4 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- - - import numpy as np from gimmik.base import MatMul @@ -8,10 +5,16 @@ class PTXMatMul(MatMul): platform = 'ptx' - basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, - 'dynamic_shared': 0} + basemeta = { + 'block': (128, 1, 1), + 'width': 1, + 'shared': 0, + 'dynamic_shared': 0 + } - PTX_SM = {(8, 0), (9, 0), (10, 0), (10, 3), (12, 0), (12, 1)} + # Map Supported CC -> Minimum PTX version + PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 0), (10, 0): (8, 7), (10, 3): (8, 7), + (12, 0): (8, 7), (12, 1): (8, 7)} @classmethod def is_sparse_suitable(cls, arr, cc): @@ -33,17 +36,20 @@ def is_suitable(cls, arr, cc): def _kernel_generators(self, dtype, dsize, *, compute_capability=None, smem_info=None): cc = compute_capability or (0, 0) + ptx = self.PTX_SM.get(cc, (0, 0)) smem_info = smem_info or (48*1024, 48*1024) - base_args = {'cc': cc, - 'smem_info': smem_info, - 'pred_emit': self._pred_emit, - 'pftype': 'f32' if dtype == 'float' else 'f64', - 'dwidth_i': 4 if dtype == 'float' else 8, - 'fzero': ('0f00000000' if dtype == 'float' - else '0d0000000000000000'), - 'beta_zero': self.beta == 0, - 'mbar_maxwait': '0x989680' - } + base_args = { + 'ptx': ptx, + 'cc': cc, + 'smem_info': smem_info, + 'pred_emit': self._pred_emit, + 'pftype': 'f32' if dtype == 'float' else 'f64', + 'dwidth_i': 4 if dtype == 'float' else 8, + 'fzero': ('0f00000000' if dtype == 'float' + else '0d0000000000000000'), + 'beta_zero': self.beta == 0, + 'mbar_maxwait': '0x989680', + } if self.is_sparse_suitable(self.A, cc): yield from self._sparse_kernel_generators(dtype, dsize, base_args) @@ -68,24 +74,32 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): # Four-way m-split B streaming, C accumulation kernel ms, bsz, blkx = 4, 24, 32 args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize, - 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + meta = { + 'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}', + } yield ('bstream-msplit', args, meta) # Single-warp LDGSTS variant for medium-M beta=0 large-K cases if self.beta == 0 and self.m <= 320 and len(self.bix) >= 64: ms, bsz, blkx = 1, 32, 64 args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), - 'shared': 2*bsz*blkx*dsize, - 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + meta = { + 'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}', + } yield ('bstream-msplit', args, meta) # Two-way k-split B loading, C streaming kernel ks, csz, blkx = 2, 24, 32 args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize, - 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + meta = { + 'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}', + } yield ('cstream-ksplit', args, meta) # Four-way k-split for large K @@ -93,9 +107,11 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): if K_used > 500: ks, csz, blkx = 4, 20, 32 args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), - 'shared': (ks - 1)*csz*blkx*dsize, - 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + meta = { + 'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}', + } yield ('cstream-ksplit', args, meta) # Width-2 vector cstream for fp64 small-K @@ -104,8 +120,11 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): and (self.aligne is None or self.aligne % 2 == 0)): blkx = 128 args = base_args | {'blockx': blkx} - meta = {'block': (blkx, 1, 1), 'width': 2, - 'desc': f'cstream-w2/x{blkx}'} + meta = { + 'block': (blkx, 1, 1), + 'width': 2, + 'desc': f'cstream-w2/x{blkx}', + } yield ('cstream-w2', args, meta) def _dense_kernel_generators(self, dtype, dsize, base_args): @@ -127,10 +146,9 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): for tpl, nn, w in dense_configs: if (n_per_cta := 8 * nn * w) > self.n: continue - setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + setup = self._dense_mma_setup(nn, w, block_steal) blkx = 32 * w - args = (base_args | {'warps_per_cta': w, 'nn': nn, - 'block_stealing': block_steal} | setup) + args = base_args | setup meta = { 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), @@ -151,14 +169,14 @@ def _dense_ws_kernel_generators(self, dtype, dsize, base_args): if (n_per_cta := 8 * nn * w) > self.n: continue - setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) - ws_setup = self._dense_ws_setup(setup, n_comp_warps=w) + setup = self._dense_mma_setup(nn, w, True) + ws_setup = self._dense_ws_setup(setup, w) if ws_setup['dynm_total_bytes'] > dynamic_max: continue blkx = 32 * (w + 2) - args = base_args | {'nn': nn} | setup | ws_setup + args = base_args | setup | ws_setup meta = { 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), @@ -184,11 +202,11 @@ def _dsmem_alloc(regions, mbars, align=16): return out, total @classmethod - def _dense_ws_setup(cls, setup, *, n_comp_warps): + def _dense_ws_setup(cls, setup, n_comp_warps): n_per_cta = setup['n_per_cta'] b_tile_bytes = setup['k_pad'] * n_per_cta * 8 c_tile_bytes = setup['m_pad'] * n_per_cta * 8 - a_bytes = setup['a_elems'] * 8 + a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), ('c', c_tile_bytes), ('a', a_bytes), ('wid', 16)] @@ -196,7 +214,7 @@ def _dense_ws_setup(cls, setup, *, n_comp_warps): 'steal', 'wid_new', 'wid_used') offsets, dynm_total_bytes = cls._dsmem_alloc(regions, mbars) - return offsets | { + args = { 'n_comp_warps': n_comp_warps, 'blockx_total': 32 * (n_comp_warps + 2), 'prod_warp': n_comp_warps, @@ -208,7 +226,9 @@ def _dense_ws_setup(cls, setup, *, n_comp_warps): 'dynm_total_bytes': dynm_total_bytes, } - def _dense_mma_setup(self, *, nn, warps_per_cta): + return offsets | args + + def _dense_mma_setup(self, nn, warps_per_cta, block_steal): a = self.A m, k = a.shape m_tiles = (m + 7) // 8 @@ -218,9 +238,9 @@ def _dense_mma_setup(self, *, nn, warps_per_cta): # A in DMMA-fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] # i.e. an (m_tiles, k_tiles) grid of row-major 8x4 tiles, packed as # uint64 - a_pad = np.zeros((m_tiles*8, k_tiles*4), dtype=np.float64) + a_pad = np.zeros((m_tiles*8, k_tiles*4)) a_pad[:m, :k] = a - tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).transpose(0, 2, 1, 3) + tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).swapaxes(1, 2) a_u64 = [f'0x{u:016x}' for u in tiles.view(np.uint64).ravel()] n_per_warp = 8 * nn @@ -232,13 +252,14 @@ def pm_runtime(mt): return (mt + 1) * 8 > m return { + 'warps_per_cta': warps_per_cta, + 'nn': nn, 'm_tiles': m_tiles, 'k_tiles': k_tiles, 'k_rem': k_rem, 'm_pad': m_tiles * 8, 'k_pad': k_tiles * 4, 'a_u64': a_u64, - 'a_elems': m_tiles * k_tiles * 32, 'n_per_warp': n_per_warp, 'n_per_cta': n_per_cta, 'frag_stride_bytes': 32 * 8, @@ -248,6 +269,7 @@ def pm_runtime(mt): 'c_ntile_stride': 8 * 8, 'n_col_aligned': n_col_aligned, 'pm_runtime': pm_runtime, + 'block_stealing': block_steal, } @staticmethod From 7b59ca4c81c517cb9ea2c4722a409839c01efba1 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Wed, 27 May 2026 07:11:50 -0700 Subject: [PATCH 12/21] Further cleanup --- gimmik/kernels/ptx/bstream-msplit.mako | 52 ++++++++++------------- gimmik/kernels/ptx/bstream.mako | 41 +++++++----------- gimmik/kernels/ptx/cstream-ksplit.mako | 43 ++++++------------- gimmik/kernels/ptx/cstream-w2.mako | 7 +-- gimmik/kernels/ptx/cstream.mako | 25 +++-------- gimmik/kernels/ptx/dense-mma-gAd.mako | 10 ++--- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 16 +++---- gimmik/kernels/ptx/dense-mma-ws.mako | 10 ++--- gimmik/ptx.py | 4 +- 9 files changed, 80 insertions(+), 128 deletions(-) diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako index 2ef85e9..c98d357 100644 --- a/gimmik/kernels/ptx/bstream-msplit.mako +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -7,7 +7,6 @@ m_per_group = max(len(mcx) for mcx in mx) bsub_bytes = 2 * bsz * blockx * dwidth_i def bsub_off(buf, idx): return (buf * bsz + idx) * blockx * dwidth_i -use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) %> % if n is None: @@ -82,15 +81,21 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) setp.ne.u32 p_skip, tid_y, ${cid}; @p_skip bra $L_END_CID_${cid}; +## Zero accumulators +% for j, row_j in enumerate(mcx): +% if afix[row_j] != -1: + mov.${pftype} csub${j}, ${fzero}; +% endif +% endfor + +## Pre-fill double buffer % if use_cpasync: ## Async fill of chunk 0 % for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: % if n is None: { - .reg .u32 _boff; .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; } % else: @@ -106,10 +111,8 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) { .reg .${pftype} _bv; % if n is None: - .reg .u32 _boff; .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; ld.weak.global.cg.${pftype} _bv, [_bptr]; % else: ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; @@ -131,10 +134,8 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) % if use_cpasync: % if n is None: { - .reg .u32 _boff; .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; } % else: @@ -144,10 +145,8 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) { .reg .${pftype} _bv; % if n is None: - .reg .u32 _boff; .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; ld.weak.global.cg.${pftype} _bv, [_bptr]; % else: ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; @@ -162,22 +161,21 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) % endif % for idx, kx in enumerate(bchunks[bb]): +% if any(A[row_j, kx] for row_j in mcx): ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}]; +% endif % for j, row_j in enumerate(mcx): -<% jx = A[row_j, kx] %> -% if jx != 0 and kx == afix[row_j]: - mul.${pftype} csub${j}, bv, ${jx}; -% elif jx != 0: - fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; +% if A[row_j, kx] != 0: + fma.rn.${pftype} csub${j}, bv, ${A[row_j, kx]}, csub${j}; % endif +% endfor +% for j, row_j in enumerate(mcx): % if kx == alix[row_j]: % if beta_zero: % if n is None: { - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; st.weak.global.cg.${pftype} [_cptr], csub${j}; } % else: @@ -187,10 +185,8 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) { .reg .${pftype} _ctmp; % if n is None: - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; ld.weak.global.cg.${pftype} _ctmp, [_cptr]; fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; st.weak.global.${pftype} [_cptr], _ctmp; @@ -222,10 +218,8 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) .reg .${pftype} _tmp; mov.${pftype} _tmp, ${fzero}; % if n is None: - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; st.weak.global.cg.${pftype} [_cptr], _tmp; % else: st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; @@ -235,10 +229,8 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) { .reg .${pftype} _tmp; % if n is None: - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; ld.weak.global.cg.${pftype} _tmp, [_cptr]; mul.${pftype} _tmp, _tmp, ${float(beta)}; st.weak.global.${pftype} [_cptr], _tmp; diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index 45eb1a7..ac4d0a6 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -53,10 +53,8 @@ % for i, kx in enumerate(bix): % if n is None: { - .reg .u32 _boff; .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; ld.weak.global.cg.${pftype} bv${i}, [_bptr]; } % else: @@ -64,25 +62,26 @@ % endif % endfor -% if not beta_zero: -## Pre-load C so per-row completion is a plain store +% if beta_zero: +## Zero accumulators +% for j in range(m): +% if afix[j] != -1: + mov.${pftype} csub${j}, ${fzero}; +% endif +% endfor +% else: +## Pre-load C and scale by beta so per-row completion is a plain store % for j in range(m): % if afix[j] != -1: % if n is None: { - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; ld.weak.global.cg.${pftype} csub${j}, [_cptr]; } % else: ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*j*dwidth_i}]; % endif -% endif -% endfor -% for j in range(m): -% if afix[j] != -1: mul.${pftype} csub${j}, csub${j}, ${float(beta)}; % endif % endfor @@ -92,19 +91,15 @@ % for kx in bix: % for j, jx in enumerate(A[:, kx]): % if jx != 0: -% if beta_zero and kx == afix[j]: - mul.${pftype} csub${j}, bv${bix[kx]}, ${jx}; -% else: fma.rn.${pftype} csub${j}, bv${bix[kx]}, ${jx}, csub${j}; -% endif % endif +% endfor +% for j in range(m): % if kx == alix[j]: % if n is None: { - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; st.weak.global.cg.${pftype} [_cptr], csub${j}; } % else: @@ -123,10 +118,8 @@ % if jx == -1 and beta_zero: % if n is None: { - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; st.weak.global.cg.${pftype} [_cptr], _tmp; } % else: @@ -136,10 +129,8 @@ % elif jx == -1: % if n is None: { - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; ld.weak.global.cg.${pftype} _tmp, [_cptr]; mul.${pftype} _tmp, _tmp, ${float(beta)}; st.weak.global.cg.${pftype} [_cptr], _tmp; diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako index 5d704de..8ec1bd2 100644 --- a/gimmik/kernels/ptx/cstream-ksplit.mako +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -70,45 +70,32 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i setp.ne.u32 p_skip, tid_y, ${bid}; @p_skip bra $L_END_BID_${bid}; -<% - loaded = set() - kbx_idx = {kx: i for i, kx in enumerate(kbx)} -%> +<% loaded = set() %> % for cchunk_i, cchunk in enumerate(cchunks): ## Chunk ${cchunk_i}: partial dot-product % for row_idx, j in enumerate(cchunk): -<% - nz = [(kbx_idx[kx], kx, A[j, kx]) for kx in kbx if A[j, kx] != 0] - owner_bid = row_idx % ksplit -%> -% for (kxi, kx, jx) in nz: -% if kx not in loaded: +<% owner_bid = row_idx % ksplit %> +% for kxi, kx in enumerate(kbx): +% if A[j, kx] != 0 and kx not in loaded: % if n is None: { - .reg .u32 _boff; .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; ld.weak.global.cg.${pftype} bv${kxi}, [_bptr]; } % else: ld.weak.global.cg.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; % endif -<% loaded.add(kx) %> +<% loaded.add(kx) %> % endif % endfor -% if nz: -% for kxi, kx, jx in nz: -% if loop.first: - mul.${pftype} dotp, bv${kxi}, ${jx}; -% else: - fma.rn.${pftype} dotp, bv${kxi}, ${jx}, dotp; -% endif -% endfor -% else: mov.${pftype} dotp, ${fzero}; -% endif +% for kxi, kx in enumerate(kbx): +% if A[j, kx] != 0: + fma.rn.${pftype} dotp, bv${kxi}, ${A[j, kx]}, dotp; +% endif +% endfor % if owner_bid == bid: mov.${pftype} cv${row_idx // ksplit}, dotp; % else: @@ -135,10 +122,8 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i % if beta_zero: % if n is None: { - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; st.weak.global.cg.${pftype} [_cptr], dotp; } % else: @@ -148,10 +133,8 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i { .reg .${pftype} _ctmp; % if n is None: - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; ld.weak.global.cg.${pftype} _ctmp, [_cptr]; fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; st.weak.global.${pftype} [_cptr], _ctmp; diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako index ce7301d..fbb2a0d 100644 --- a/gimmik/kernels/ptx/cstream-w2.mako +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -40,14 +40,11 @@ ## Main compute: two parallel dot-product streams per thread % for j in range(m): % if row_nz[j]: + mov.f64 dotp_a, ${fzero}; + mov.f64 dotp_b, ${fzero}; % for kx, jx in row_nz[j]: -% if loop.first: - mul.f64 dotp_a, bv_a${bix[kx]}, ${jx}; - mul.f64 dotp_b, bv_b${bix[kx]}, ${jx}; -% else: fma.rn.f64 dotp_a, bv_a${bix[kx]}, ${jx}, dotp_a; fma.rn.f64 dotp_b, bv_b${bix[kx]}, ${jx}, dotp_b; -% endif % endfor % if beta_zero: st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako index 9ce4c4d..297d402 100644 --- a/gimmik/kernels/ptx/cstream.mako +++ b/gimmik/kernels/ptx/cstream.mako @@ -53,10 +53,8 @@ % for i, kx in enumerate(bix): % if n is None: { - .reg .u32 _boff; .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + mad.wide.u32 _bptr, ldb, ${kx * dwidth_i}, b_base; ld.weak.global.cg.${pftype} bv${i}, [_bptr]; } % else: @@ -67,20 +65,15 @@ ## Compute and store each output row % for j in range(m): % if row_nz[j]: + mov.${pftype} dotp, ${fzero}; % for kx, jx in row_nz[j]: -% if loop.first: - mul.${pftype} dotp, bv${bix[kx]}, ${jx}; -% else: fma.rn.${pftype} dotp, bv${bix[kx]}, ${jx}, dotp; -% endif % endfor % if beta_zero: % if n is None: { - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; st.weak.global.cg.${pftype} [_cptr], dotp; } % else: @@ -90,10 +83,8 @@ { .reg .${pftype} _ctmp; % if n is None: - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; ld.weak.global.cg.${pftype} _ctmp, [_cptr]; fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; st.weak.global.${pftype} [_cptr], _ctmp; @@ -112,10 +103,8 @@ .reg .${pftype} _tmp; mov.${pftype} _tmp, ${fzero}; % if n is None: - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; st.weak.global.cg.${pftype} [_cptr], _tmp; % else: st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; @@ -125,10 +114,8 @@ { .reg .${pftype} _tmp; % if n is None: - .reg .u32 _coff; .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; ld.weak.global.cg.${pftype} _tmp, [_cptr]; mul.${pftype} _tmp, _tmp, ${float(beta)}; st.weak.global.${pftype} [_cptr], _tmp; diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index 3df43c0..437f80b 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,6 +1,6 @@ <%inherit file='base'/> -.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { ${', '.join(a_u64)} }; @@ -44,12 +44,12 @@ @pwarp_exit bra $L_EXIT; % for nt in range(nn): - add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { .reg .u32 t; shl.b32 t, r_mod4, 1; - add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } @@ -93,7 +93,7 @@ .reg .pred pm_${mt}; { .reg .u32 crow; - add.u32 crow, r_div4, ${mt * 8}; + add.u32 crow, r_div4, ${8 * mt}; setp.lt.u32 pm_${mt}, crow, ${m}; } % endif @@ -143,7 +143,7 @@ .reg .pred pbrow; { .reg .u32 brow; - add.u32 brow, r_mod4, ${ki * 4}; + add.u32 brow, r_mod4, ${4 * ki}; setp.lt.u32 pbrow, brow, ${k}; } % endif diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index 9a88b64..92265c8 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -13,10 +13,10 @@ bs = bool(block_stealing) .shared .align 8 .b64 ${kname}_mbar; .shared .align 16 .b8 ${kname}_workid[16]; % endif -.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { ${', '.join(a_u64)} }; -.shared .align 16 .b64 ${kname}_As[${m_tiles*k_tiles*32}]; +.shared .align 16 .b64 ${kname}_As[${32 * m_tiles * k_tiles}]; .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) @@ -101,8 +101,8 @@ bs = bool(block_stealing) .reg .u64 gaddr, saddr; .reg .${pftype} v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(m_tiles*k_tiles*32-1) * dwidth_i}; - add.u64 saddr, a_smem_base, ${(m_tiles*k_tiles*32-1) * dwidth_i}; + add.u64 gaddr, a_glb_base, ${(32 * m_tiles * k_tiles - 1) * dwidth_i}; + add.u64 saddr, a_smem_base, ${(32 * m_tiles * k_tiles - 1) * dwidth_i}; @plast ld.weak.global.cg.${pftype} v, [gaddr]; @plast st.shared.${pftype} [saddr], v; } @@ -124,7 +124,7 @@ bs = bool(block_stealing) .reg .pred pm_${mt}; { .reg .u32 crow; - add.u32 crow, r_div4, ${mt * 8}; + add.u32 crow, r_div4, ${8 * mt}; setp.lt.u32 pm_${mt}, crow, ${m}; } % endif @@ -154,12 +154,12 @@ $L_LOOP: % endif % for nt in range(nn): - add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { .reg .u32 t; shl.b32 t, r_mod4, 1; - add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } @@ -232,7 +232,7 @@ $L_LOOP: .reg .pred pbrow; { .reg .u32 brow; - add.u32 brow, r_mod4, ${ki * 4}; + add.u32 brow, r_mod4, ${4 * ki}; setp.lt.u32 pbrow, brow, ${k}; } % endif diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index e151372..982b1b8 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -10,11 +10,11 @@ mov.u64 a_glb, ${kname}_Ag; cvta.to.global.u64 a_glb, a_glb; @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes - [a_smem], [a_glb], ${m_tiles*k_tiles*32 * 8}, [tma_mbar]; + [a_smem], [a_glb], ${8 * 32 * m_tiles * k_tiles}, [tma_mbar]; @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 - [tma_mbar], ${b_tile_bytes + m_tiles*k_tiles*32 * 8}; + [tma_mbar], ${b_tile_bytes + 8 * 32 * m_tiles * k_tiles}; bar.warp.sync 0xffffffff; .reg .b64 state; .reg .pred p1; @@ -97,7 +97,7 @@ $L_WAIT_BRDY: %> { .reg .b32 a_a; - add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_tiles) * dwidth_i}; + add.u32 a_a, a_thr_a, ${(32 * kt + 32 * mt * k_tiles) * dwidth_i}; ld.shared.${pftype} a_f, [a_a]; % if k_tail: .reg .pred pbrow_${mt}_${kt}; @@ -162,7 +162,7 @@ $L_WAIT_BRDY: and.pred p_st, p_st, p_row_${mt}; % endif .reg .u64 _c_addr; - add.u64 _c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * dwidth_i}; + add.u64 _c_addr, c_thr_glob_base, ${(8 * mt * ldc + 8 * nt) * dwidth_i}; @p_st st.weak.global.v2.${pftype} [_c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; } % endfor @@ -321,7 +321,7 @@ $L_WAIT_WUSED: $L_AFTER_CTRL: -.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { ${', '.join(a_u64)} }; .extern .shared .align 128 .b8 ${kname}_dynm[]; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index f7be8f4..bc260ee 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -49,12 +49,14 @@ def _kernel_generators(self, dtype, dsize, *, compute_capability=None, else '0d0000000000000000'), 'beta_zero': self.beta == 0, 'mbar_maxwait': '0x989680', + 'use_cpasync': cc >= (8, 0), } if self.is_sparse_suitable(self.A, cc): yield from self._sparse_kernel_generators(dtype, dsize, base_args) - if self.is_dense_suitable(self.A, cc): + # Dense kernels bake n/ldb/ldc as compile-time constants + if self.n is not None and self.is_dense_suitable(self.A, cc): yield from self._dense_kernel_generators(dtype, dsize, base_args) def _sparse_kernel_generators(self, dtype, dsize, base_args): From 577208a26f429fcff2e6c35824679cddb22dad08 Mon Sep 17 00:00:00 2001 From: Will Trojak Date: Tue, 2 Jun 2026 09:53:52 +0000 Subject: [PATCH 13/21] Moved the PTX backend to use tuned kernel profiles for different CCs --- gimmik/base.py | 19 +- gimmik/kernels/ptx/config/default.json | 6 + gimmik/kernels/ptx/config/sm100.json | 260 ++++++++++ gimmik/kernels/ptx/config/sm80.json | 81 ++++ gimmik/kernels/ptx/config/sm90.json | 362 ++++++++++++++ gimmik/kernels/ptx/dmma-asmem-v1.mako | 288 +++++++++++ ...se-mma-smem-gA.mako => dmma-asmem-v2.mako} | 19 +- ...ense-mma-gAd.mako => dmma-astream-v1.mako} | 4 +- gimmik/kernels/ptx/dmma-astream-v2.mako | 177 +++++++ .../{dense-mma-ws.mako => dmma-steal-ws.mako} | 0 gimmik/ptx.py | 450 +++++++++++------- setup.py | 2 +- 12 files changed, 1466 insertions(+), 202 deletions(-) create mode 100644 gimmik/kernels/ptx/config/default.json create mode 100644 gimmik/kernels/ptx/config/sm100.json create mode 100644 gimmik/kernels/ptx/config/sm80.json create mode 100644 gimmik/kernels/ptx/config/sm90.json create mode 100644 gimmik/kernels/ptx/dmma-asmem-v1.mako rename gimmik/kernels/ptx/{dense-mma-smem-gA.mako => dmma-asmem-v2.mako} (89%) rename gimmik/kernels/ptx/{dense-mma-gAd.mako => dmma-astream-v1.mako} (99%) create mode 100644 gimmik/kernels/ptx/dmma-astream-v2.mako rename gimmik/kernels/ptx/{dense-mma-ws.mako => dmma-steal-ws.mako} (100%) diff --git a/gimmik/base.py b/gimmik/base.py index 0ecc29a..9db8ee9 100644 --- a/gimmik/base.py +++ b/gimmik/base.py @@ -103,14 +103,7 @@ def kernels(self, dtype, kname='gimmik_mm', **kwargs): raise ValueError('Invalid floating point data type') # Common template arguments - baseargs = { - 'dtype': dtype, 'kname': kname, - 'A': self.A, 'beta': self.beta, 'width': 1, - 'm': self.m, 'n': self.n, 'k': self.k, - 'ldb': self.ldb, 'ldc': self.ldc, - 'afix': self.afix, 'alix': self.alix, 'bix': self.bix, - 'dot': _dot, 'partition': _partition, 'chunk': _chunk - } + baseargs = self._base_template_args(dtype, kname) # Incrementally generate and render the kernels gen = self._kernel_generators(dtype, dsize, **kwargs) @@ -136,6 +129,16 @@ def kernels(self, dtype, kname='gimmik_mm', **kwargs): except StopIteration: pass + def _base_template_args(self, dtype, kname): + return { + 'dtype': dtype, 'kname': kname, + 'A': self.A, 'beta': self.beta, 'width': 1, + 'm': self.m, 'n': self.n, 'k': self.k, + 'ldb': self.ldb, 'ldc': self.ldc, + 'afix': self.afix, 'alix': self.alix, 'bix': self.bix, + 'dot': _dot, 'partition': _partition, 'chunk': _chunk + } + def _process_meta(self, meta): pass diff --git a/gimmik/kernels/ptx/config/default.json b/gimmik/kernels/ptx/config/default.json new file mode 100644 index 0000000..13cd31d --- /dev/null +++ b/gimmik/kernels/ptx/config/default.json @@ -0,0 +1,6 @@ +{ + "schema": 1, + "cc": [0, 0], + "ptx": [0, 0], + "kernels": [] +} diff --git a/gimmik/kernels/ptx/config/sm100.json b/gimmik/kernels/ptx/config/sm100.json new file mode 100644 index 0000000..04eb4d7 --- /dev/null +++ b/gimmik/kernels/ptx/config/sm100.json @@ -0,0 +1,260 @@ +{ + "schema": 1, + "cc": [ + 10, + 0 + ], + "ptx": [ + 8, + 7 + ], + "kernels": [ + { + "template": "cstream", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream" + }, + { + "template": "bstream", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m4-b24-x32" + }, + { + "template": "bstream-msplit", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "conditions": { + "all": [ + { + "field": "beta_zero", + "eq": true + }, + { + "field": "m", + "lte": 320 + }, + { + "field": "k_used", + "gte": 64 + } + ] + }, + "descriptor": "bstream-msplit/m1-b32-x64" + }, + { + "template": "cstream-ksplit", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k2-c24-x32" + }, + { + "template": "cstream-ksplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "csz": 20 + }, + "conditions": { + "field": "k_used", + "gt": 500 + }, + "descriptor": "cstream-ksplit/k4-c20-x32" + }, + { + "template": "cstream-w2", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "k_used", + "lte": 100 + }, + { + "field": "aligne", + "is_null_or_divisible_by": 2 + } + ] + }, + "descriptor": "cstream-w2/x128" + }, + { + "template": "dmma-asmem", + "vector_width": 1, + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 4, + "warps": 4, + "block_stealing": true + }, + "descriptor": "dmma-asmem/v1/nn4-w4-bs", + "conditions": { + "field": "n", + "is_not": null + } + }, + { + "template": "dmma-asmem", + "vector_width": 2, + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 4, + "warps": 4, + "block_stealing": true + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn4-w4-bs" + }, + { + "template": "dmma-steal-ws", + "block": [ + 192, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 1 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4, + "stealer": 5 + }, + "descriptor": "dmma-steal-ws/nn1-w4", + "conditions": { + "field": "n", + "is_not": null + } + }, + { + "template": "dmma-steal-ws", + "block": [ + 192, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 2 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4, + "stealer": 5 + }, + "descriptor": "dmma-steal-ws/nn2-w4", + "conditions": { + "field": "n", + "is_not": null + } + }, + { + "template": "dmma-steal-ws", + "block": [ + 192, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 4 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4, + "stealer": 5 + }, + "descriptor": "dmma-steal-ws/nn4-w4", + "conditions": { + "field": "n", + "is_not": null + } + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm80.json b/gimmik/kernels/ptx/config/sm80.json new file mode 100644 index 0000000..8c607ef --- /dev/null +++ b/gimmik/kernels/ptx/config/sm80.json @@ -0,0 +1,81 @@ +{ + "schema": 1, + "cc": [8, 0], + "ptx": [7, 0], + "kernels": [ + { + "template": "cstream", + "block": [128, 1, 1], + "width": 1, + "descriptor": "cstream" + }, + { + "template": "bstream", + "block": [128, 1, 1], + "width": 1, + "descriptor": "bstream" + }, + { + "template": "bstream-msplit", + "block": [32, 4, 1], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m4-b24-x32" + }, + { + "template": "bstream-msplit", + "block": [64, 1, 1], + "width": 1, + "params": { + "bsz": 32 + }, + "conditions": { + "all": [ + {"field": "beta_zero", "eq": true}, + {"field": "m", "lte": 320}, + {"field": "k_used", "gte": 64} + ] + }, + "descriptor": "bstream-msplit/m1-b32-x64" + }, + { + "template": "cstream-ksplit", + "block": [32, 2, 1], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k2-c24-x32" + }, + { + "template": "cstream-ksplit", + "block": [32, 4, 1], + "width": 1, + "params": { + "csz": 20 + }, + "conditions": { + "field": "k_used", + "gt": 500 + }, + "descriptor": "cstream-ksplit/k4-c20-x32" + }, + { + "template": "cstream-w2", + "block": [128, 1, 1], + "width": 2, + "conditions": { + "all": [ + {"field": "dtype", "eq": "double"}, + {"field": "n", "is_not": null}, + {"field": "n", "divisible_by": 2}, + {"field": "k_used", "lte": 100}, + {"field": "aligne", "is_null_or_divisible_by": 2} + ] + }, + "descriptor": "cstream-w2/x128" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm90.json b/gimmik/kernels/ptx/config/sm90.json new file mode 100644 index 0000000..3b3c64d --- /dev/null +++ b/gimmik/kernels/ptx/config/sm90.json @@ -0,0 +1,362 @@ +{ + "schema": 1, + "cc": [ + 9, + 0 + ], + "ptx": [ + 8, + 0 + ], + "kernels": [ + { + "template": "cstream-ksplit", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 20 + }, + "descriptor": "cstream-ksplit/k2-c20-x32" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 8, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m8-b24-x32" + }, + { + "template": "cstream-ksplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k4-c24-x32" + }, + { + "template": "bstream-msplit", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m2-b32-x64" + }, + { + "template": "bstream", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x64" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m2-b32-x32" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m1-b24-x32" + }, + { + "template": "cstream-ksplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "csz": 16 + }, + "descriptor": "cstream-ksplit/k4-c16-x32" + }, + { + "template": "dmma-asmem", + "vector_width": 2, + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 1, + "warps": 8 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn1-w8" + }, + { + "template": "dmma-astream", + "vector_width": 2, + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 2, + "warps": 2 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn2-w2" + }, + { + "template": "dmma-astream", + "vector_width": 2, + "block": [ + 32, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 4, + "warps": 1 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn4-w1" + }, + { + "template": "dmma-asmem", + "vector_width": 2, + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 2, + "warps": 4 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn2-w4" + }, + { + "template": "dmma-astream", + "vector_width": 2, + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 1, + "warps": 8 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn1-w8" + }, + { + "template": "dmma-asmem", + "vector_width": 2, + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 2, + "warps": 8 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn2-w8" + }, + { + "template": "dmma-asmem", + "vector_width": 2, + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 4, + "warps": 4 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn4-w4" + }, + { + "template": "dmma-astream", + "vector_width": 2, + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 1, + "warps": 4 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream/v2/nn1-w4" + } + ] +} diff --git a/gimmik/kernels/ptx/dmma-asmem-v1.mako b/gimmik/kernels/ptx/dmma-asmem-v1.mako new file mode 100644 index 0000000..6938881 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-asmem-v1.mako @@ -0,0 +1,288 @@ +<%inherit file='base'/> + +<% +# Cooperative-copy params (gA-only) +blockx = 32 * warps_per_cta +a_elems = m_tiles*k_tiles*32 +copy_v1_iters = (a_elems + blockx - 1) // blockx +bs = bool(block_stealing) +%> + +% if bs: +.shared .align 8 .b64 ${kname}_mbar; +.shared .align 16 .b8 ${kname}_workid[16]; +% endif +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { + ${', '.join(a_u64)} +}; +.shared .align 16 .b64 ${kname}_As[${32 * m_tiles * k_tiles}]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .${pftype} a_frag; +% if bs: + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; + .reg .pred p_root, p_done, p_have; +% endif +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + +% if bs: + setp.eq.u32 p_root, tid, 0; + mov.u32 mbar_a, ${kname}_mbar; + mov.u32 work_a, ${kname}_workid; + @p_root mbarrier.init.shared::cta.b64 [mbar_a], 1; + bar.sync 0; +% endif + + // Cooperative copy A from .global to .shared + { + .reg .u64 a_glb_base, a_smem_base; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + mov.u64 a_smem_base, ${kname}_As; +% for ci in range(copy_v1_iters): +<% + base_elem = ci * blockx + elems_this = min(blockx, a_elems - base_elem) +%> + { + .reg .u32 eidx; + .reg .u64 off64, gaddr, saddr; + .reg .${pftype} v; +% if loop.last and elems_this < blockx: + .reg .pred plast; + add.u32 eidx, tid, ${base_elem}; + setp.lt.u32 plast, eidx, ${a_elems}; + mul.wide.u32 off64, eidx, ${dwidth_i}; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + @plast ld.weak.global.cg.${pftype} v, [gaddr]; + @plast st.shared.${pftype} [saddr], v; +% else: + add.u32 eidx, tid, ${base_elem}; + mul.wide.u32 off64, eidx, ${dwidth_i}; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + ld.weak.global.cg.${pftype} v, [gaddr]; + st.shared.${pftype} [saddr], v; +% endif + } +% endfor + } + bar.sync 0; + + // Lane-only base; lifted out of the optional steal loop + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + +% for mt in range(m_tiles): +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${8 * mt}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% if bs: + mov.u32 ctaid, %ctaid.x; +$L_LOOP: +% endif + + { + .reg .u32 cta; +% if bs: + mov.u32 cta, ctaid; +% else: + mov.u32 cta, %ctaid.x; +% endif + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; +% if bs: + @pwarp_exit bra $L_STEAL; +% else: + @pwarp_exit bra $L_EXIT; +% endif + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} + } +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: + .reg .pred pbrow; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${4 * ki}; + setp.lt.u32 pbrow, brow, ${k}; + } +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} + } +% endfor +% for mt in range(m_tiles): + ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + } +% endfor +% endfor + +% if bs: +$L_STEAL: + // Root issues async try_cancel + waits; bar.sync orders the workid load + @!p_root bra $L_AFTER_WAIT; + { + .reg .u64 state; + mbarrier.arrive.expect_tx.shared::cta.b64 state, [mbar_a], 16; + clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [work_a], [mbar_a]; +$L_WAIT: + mbarrier.try_wait.shared::cta.b64 p_done, [mbar_a], state, 10000000; + @!p_done bra $L_WAIT; + } +$L_AFTER_WAIT: + bar.sync 0; + + { + .reg .b128 resp; + ld.shared::cta.b128 resp, [work_a]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; + @!p_have bra $L_FIN; + // 1D grid: extract just x + clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; + } + bra.uni $L_LOOP; + +$L_FIN: + bar.sync 0; + @p_root mbarrier.inval.shared::cta.b64 [mbar_a]; +% endif + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dmma-asmem-v2.mako similarity index 89% rename from gimmik/kernels/ptx/dense-mma-smem-gA.mako rename to gimmik/kernels/ptx/dmma-asmem-v2.mako index 92265c8..51462cf 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dmma-asmem-v2.mako @@ -3,8 +3,9 @@ <% # Cooperative-copy params (gA-only) blockx = 32 * warps_per_cta -a_pairs = m_tiles*k_tiles*32 // 2 -a_pairs_tail = m_tiles*k_tiles*32 % 2 +a_elems = m_tiles*k_tiles*32 +a_pairs = a_elems // 2 +a_pairs_tail = a_elems % 2 copy_v2_iters = (a_pairs + blockx - 1) // blockx bs = bool(block_stealing) %> @@ -60,7 +61,7 @@ bs = bool(block_stealing) bar.sync 0; % endif - // Cooperative copy A from .global to .shared via v2 loads + // Cooperative copy A from .global to .shared { .reg .u64 a_glb_base, a_smem_base; mov.u64 a_glb_base, ${kname}_Ag; @@ -196,9 +197,7 @@ $L_LOOP: % else: <% pm = f'pm_{mt}' if pm_runtime(mt) else None - pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None - pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None - needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None + needs_zero_init = pm is not None %> { .reg .u64 caddr; @@ -207,8 +206,7 @@ $L_LOOP: mov.${pftype} c0_${nt}_${mt}, ${fzero}; mov.${pftype} c1_${nt}_${mt}, ${fzero}; % endif - ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} - ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{c0_{nt}_{mt}, c1_{nt}_{mt}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}')} } % endif % endfor @@ -255,14 +253,11 @@ $L_LOOP: % for mt in range(m_tiles): <% pm = f'pm_{mt}' if pm_runtime(mt) else None - pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None - pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None %> { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} - ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{c0_{nt}_{mt}, c1_{nt}_{mt}}};', pm, pred_reg=f'p01s_{nt}_{mt}')} } % endfor % endfor diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dmma-astream-v1.mako similarity index 99% rename from gimmik/kernels/ptx/dense-mma-gAd.mako rename to gimmik/kernels/ptx/dmma-astream-v1.mako index 437f80b..48d41f3 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dmma-astream-v1.mako @@ -162,8 +162,8 @@ % endfor % endfor -% for nt in range(nn): -% for mt in range(m_tiles): +% for mt in range(m_tiles): +% for nt in range(nn): <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None diff --git a/gimmik/kernels/ptx/dmma-astream-v2.mako b/gimmik/kernels/ptx/dmma-astream-v2.mako new file mode 100644 index 0000000..16711d5 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-astream-v2.mako @@ -0,0 +1,177 @@ +<%inherit file='base'/> + +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { + ${', '.join(a_u64)} +}; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .${pftype} a_frag; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*8 + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(m_tiles): +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${8 * mt}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% for nt in range(nn): +% for mt in range(m_tiles): +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + needs_zero_init = pm is not None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{c0_{nt}_{mt}, c1_{nt}_{mt}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}')} + } +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> + { + .reg .u64 baddr; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: + .reg .pred pbrow; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${4 * ki}; + setp.lt.u32 pbrow, brow, ${k}; + } +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} + } +% endfor +% for mt in range(m_tiles): + ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for mt in range(m_tiles): +% for nt in range(nn): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{c0_{nt}_{mt}, c1_{nt}_{mt}}};', pm, pred_reg=f'p01s_{nt}_{mt}')} + } +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dmma-steal-ws.mako similarity index 100% rename from gimmik/kernels/ptx/dense-mma-ws.mako rename to gimmik/kernels/ptx/dmma-steal-ws.mako diff --git a/gimmik/ptx.py b/gimmik/ptx.py index bc260ee..e26560b 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,3 +1,6 @@ +import json +import pkgutil + import numpy as np from gimmik.base import MatMul @@ -16,6 +19,21 @@ class PTXMatMul(MatMul): PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 0), (10, 0): (8, 7), (10, 3): (8, 7), (12, 0): (8, 7), (12, 1): (8, 7)} + PTX_TEMPLATE_FAMILY = { + 'cstream': 'sparse', + 'bstream': 'sparse', + 'bstream-msplit': 'sparse', + 'cstream-ksplit': 'sparse', + 'cstream-w2': 'sparse', + 'dmma-astream': 'dense', + 'dmma-asmem': 'dense', + 'dmma-steal-ws': 'dense', + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._config_cache = {} + @classmethod def is_sparse_suitable(cls, arr, cc): nnz = np.count_nonzero(arr) @@ -36,201 +54,216 @@ def is_suitable(cls, arr, cc): def _kernel_generators(self, dtype, dsize, *, compute_capability=None, smem_info=None): cc = compute_capability or (0, 0) - ptx = self.PTX_SM.get(cc, (0, 0)) smem_info = smem_info or (48*1024, 48*1024) + config = self._cc_config(cc) + + for kernel_cfg in config['kernels']: + if not self._usable_config(kernel_cfg, dtype, cc, smem_info): + continue + + prepared = self._get_render_args( + kernel_cfg, dtype, dsize, cc, smem_info, tuple(config['ptx']) + ) + if prepared is not None: + yield prepared + + def render_config(self, kernel_cfg, dtype, dsize, *, kname='gimmik_mm', + compute_capability=None, smem_info=None, config=None): + cc = compute_capability or (0, 0) + smem_info = smem_info or (48*1024, 48*1024) + config = config or self._cc_config(cc) + + if not self._usable_config(kernel_cfg, dtype, cc, smem_info): + return None + + prepared = self._get_render_args( + kernel_cfg, dtype, dsize, cc, smem_info, tuple(config['ptx']) + ) + if prepared is None: + return None + tplname, exargs, exmeta = prepared + + args = self._base_template_args(dtype, kname) | exargs + meta = self.basemeta | exmeta + meta['tplname'] = tplname + self._process_meta(meta) + src = self._render_kernel(dtype, tplname, args) + return src, args, meta + + def _cc_config(self, cc): + cc = cc or (0, 0) + if cc not in self._config_cache: + cfgname = f'sm{cc[0]}{cc[1]}.json' + paths = [f'kernels/ptx/config/{cfgname}', + 'kernels/ptx/config/default.json'] + + cfg = None + for path in paths: + try: + cfgdir = pkgutil.get_data('gimmik', path) + cfg = json.loads(cfgdir.decode('utf-8')) + break + except FileNotFoundError: + continue + except json.JSONDecodeError as e: + raise ValueError(f'{path}: invalid JSON: {exc}') from e + + if cfg is None: + raise ValueError('PTX default kernel config is missing') + self._config_cache[cc] = cfg + return self._config_cache[cc] + + def _matmul_stats(self, dtype, cc, smem_info): + nnz = int(np.count_nonzero(self.A)) + return { + 'dtype': dtype, + 'm': self.m, + 'k': self.k, + 'n': self.n, + 'beta': self.beta, + 'beta_zero': self.beta == 0, + 'aligne': self.aligne, + 'nnz': nnz, + 'density': nnz / self.A.size, + 'unique_abs': int(len(np.unique(np.abs(self.A)))), + 'k_used': len(self.bix), + 'cc': list(cc), + 'smem_static': smem_info[0], + 'smem_dynamic': smem_info[1], + } + + def _eval_condition(self, condition, stats): + if 'all' in condition: + return all(self._eval_condition(c, stats) for c in condition['all']) + if 'any' in condition: + return any(self._eval_condition(c, stats) for c in condition['any']) + if 'not' in condition: + return not self._eval_condition(condition['not'], stats) + + value = stats[condition['field']] + op = next(k for k in condition if k != 'field') + expected = condition[op] + + return { + 'eq': lambda: value == expected, + 'ne': lambda: value != expected, + 'lt': lambda: value is not None and value < expected, + 'lte': lambda: value is not None and value <= expected, + 'gt': lambda: value is not None and value > expected, + 'gte': lambda: value is not None and value >= expected, + 'in': lambda: value in expected, + 'is_null': lambda: value is None, + 'is_not': lambda: value is not None, + 'divisible_by': lambda: value is not None and value % expected == 0, + 'is_null_or_divisible_by': lambda: (value is None + or value % expected == 0), + }[op]() + + def _usable_config(self, kernel_cfg, dtype, cc, smem_info): + tpl = kernel_cfg['template'] + family = self.PTX_TEMPLATE_FAMILY[tpl] + + if family == 'sparse' and not self.is_sparse_suitable(self.A, cc): + return False + elif (family == 'dense' + and (self.n is None or not self.is_dense_suitable(self.A, cc))): + return False + + condition = kernel_cfg.get('conditions') + if condition is None: + return True + else: + stats = self._matmul_stats(dtype, cc, smem_info) + return self._eval_condition(condition, stats) + + def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, + ptx): + tpl = kernel_cfg['template'] + block = tuple(kernel_cfg['block']) + width = kernel_cfg['width'] + params = kernel_cfg.get('params', {}) base_args = { 'ptx': ptx, 'cc': cc, 'smem_info': smem_info, 'pred_emit': self._pred_emit, 'pftype': 'f32' if dtype == 'float' else 'f64', - 'dwidth_i': 4 if dtype == 'float' else 8, + 'dwidth_i': dsize, 'fzero': ('0f00000000' if dtype == 'float' else '0d0000000000000000'), 'beta_zero': self.beta == 0, 'mbar_maxwait': '0x989680', 'use_cpasync': cc >= (8, 0), + 'width': width, } - - if self.is_sparse_suitable(self.A, cc): - yield from self._sparse_kernel_generators(dtype, dsize, base_args) - - # Dense kernels bake n/ldb/ldc as compile-time constants - if self.n is not None and self.is_dense_suitable(self.A, cc): - yield from self._dense_kernel_generators(dtype, dsize, base_args) - - def _sparse_kernel_generators(self, dtype, dsize, base_args): - # Sparse-shared template constants - base_args = base_args | { - 'has_zero_rows': bool(self.has_zero_rows), - 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) - if self.A[j, kx] != 0] for j in range(self.m)], + base_meta = { + 'block': block, + 'width': width, + 'desc': kernel_cfg['descriptor'], } - # B loading, C streaming kernel - yield ('cstream', base_args, {'desc': 'cstream'}) - - # B streaming, C accumulation kernel - yield ('bstream', base_args, {'desc': 'bstream'}) - - # Four-way m-split B streaming, C accumulation kernel - ms, bsz, blkx = 4, 24, 32 - args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = { - 'block': (blkx, ms, 1), - 'shared': 2*bsz*blkx*dsize, - 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}', - } - yield ('bstream-msplit', args, meta) - - # Single-warp LDGSTS variant for medium-M beta=0 large-K cases - if self.beta == 0 and self.m <= 320 and len(self.bix) >= 64: - ms, bsz, blkx = 1, 32, 64 - args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = { - 'block': (blkx, ms, 1), - 'shared': 2*bsz*blkx*dsize, - 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}', - } - yield ('bstream-msplit', args, meta) - - # Two-way k-split B loading, C streaming kernel - ks, csz, blkx = 2, 24, 32 - args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = { - 'block': (blkx, ks, 1), - 'shared': (ks - 1)*csz*blkx*dsize, - 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}', - } - yield ('cstream-ksplit', args, meta) - - # Four-way k-split for large K - K_used = len(self.bix) - if K_used > 500: - ks, csz, blkx = 4, 20, 32 - args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = { - 'block': (blkx, ks, 1), - 'shared': (ks - 1)*csz*blkx*dsize, - 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}', - } - yield ('cstream-ksplit', args, meta) - - # Width-2 vector cstream for fp64 small-K - if (dtype == 'double' and self.n is not None and self.n % 2 == 0 - and K_used <= 100 - and (self.aligne is None or self.aligne % 2 == 0)): - blkx = 128 - args = base_args | {'blockx': blkx} - meta = { - 'block': (blkx, 1, 1), - 'width': 2, - 'desc': f'cstream-w2/x{blkx}', - } - yield ('cstream-w2', args, meta) - - def _dense_kernel_generators(self, dtype, dsize, base_args): - cc = base_args['cc'] or (0, 0) - - # Block stealing requires sm_100+ - block_steal = cc >= (10, 0) - if block_steal: - dense_configs = [('dense-mma-smem-gA', 4, 4)] + if self.PTX_TEMPLATE_FAMILY[tpl] == 'sparse': + cfg = self._sparse_args(tpl, params, block, dtype, dsize, + base_args, base_meta) + elif self.PTX_TEMPLATE_FAMILY[tpl] == 'dense': + if tpl.endswith('ws'): + cfg = self._dense_ws_args(kernel_cfg, params, smem_info, + base_args, base_meta) + else: + cfg = self._dense_args(kernel_cfg, params, base_args, + base_meta) else: - dense_configs = [ - ('dense-mma-smem-gA', 1, 8), - ('dense-mma-smem-gA', 2, 4), - ('dense-mma-smem-gA', 4, 4), - ('dense-mma-gAd', 2, 2), - ('dense-mma-gAd', 4, 2), - ] - - for tpl, nn, w in dense_configs: - if (n_per_cta := 8 * nn * w) > self.n: - continue - setup = self._dense_mma_setup(nn, w, block_steal) - blkx = 32 * w - args = base_args | setup - meta = { - 'block': (blkx, 1, 1), - 'grid': (-(-self.n // n_per_cta), 1, 1), - 'desc': f'{tpl}/nn{nn}-w{w}{'-bs' if block_steal else ''}', - } - yield (tpl, args, meta) - - # Warp-specialised dense DMMA, required block stealing - if block_steal: - yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) - - def _dense_ws_kernel_generators(self, dtype, dsize, base_args): - static_max, dynamic_max = base_args['smem_info'] - - # (nn, compute) -- block has compute + 2 warps (producer, stealer) - ws_configs = [(1, 4), (2, 4), (4, 4)] - for nn, w in ws_configs: - if (n_per_cta := 8 * nn * w) > self.n: - continue - - setup = self._dense_mma_setup(nn, w, True) - ws_setup = self._dense_ws_setup(setup, w) - - if ws_setup['dynm_total_bytes'] > dynamic_max: - continue - - blkx = 32 * (w + 2) - args = base_args | setup | ws_setup - meta = { - 'block': (blkx, 1, 1), - 'grid': (-(-self.n // n_per_cta), 1, 1), - 'desc': f'dense-mma-ws/nn{nn}-w{w}', - 'ws_b_tile': (n_per_cta, setup['k_pad']), - 'dynamic_shared': ws_setup['dynm_total_bytes'], - } - if self.beta != 0: - meta |= {'ws_out_tile': (n_per_cta, setup['m_pad'])} - yield ('dense-mma-ws', args, meta) - - @staticmethod - def _dsmem_alloc(regions, mbars, align=16): - out, off = {}, 0 - for name, size in regions: - off = (off + align - 1) & ~(align - 1) - out[f'{name}_off'] = off - off += size - for name in mbars: - out[f'{name}_mbar_off'] = off - off += 8 - total = (off + align - 1) & ~(align - 1) - return out, total - - @classmethod - def _dense_ws_setup(cls, setup, n_comp_warps): - n_per_cta = setup['n_per_cta'] - b_tile_bytes = setup['k_pad'] * n_per_cta * 8 - c_tile_bytes = setup['m_pad'] * n_per_cta * 8 - a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 - - regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), - ('c', c_tile_bytes), ('a', a_bytes), ('wid', 16)] - mbars = ('tma', 'bready', 'cready', 'cstored', - 'steal', 'wid_new', 'wid_used') - offsets, dynm_total_bytes = cls._dsmem_alloc(regions, mbars) - - args = { - 'n_comp_warps': n_comp_warps, - 'blockx_total': 32 * (n_comp_warps + 2), - 'prod_warp': n_comp_warps, - 'steal_warp': n_comp_warps + 1, - 'comp_threads': 32 * n_comp_warps, - 'b_tile_bytes': b_tile_bytes, - 'c_mtile_smem_stride': 8 * n_per_cta * 8, - 'c_ntile_smem_stride': 8 * 8, - 'dynm_total_bytes': dynm_total_bytes, - } - - return offsets | args - - def _dense_mma_setup(self, nn, warps_per_cta, block_steal): + raise ValueError(f'Unknown PTX template family for {tpl}') + + return cfg + + def _sparse_args(self, tpl, params, block, dtype, dsize, args, + meta): + blockx = block[0] + args |= {'has_zero_rows': bool(self.has_zero_rows), + 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) + if self.A[j, kx] != 0] for j in range(self.m)], + } + + match tpl: + case 'cstream' | 'bstream': + pass + case 'bstream-msplit': + msplit = block[1] + bsz = params['bsz'] + args |= {'msplit': msplit, 'bsz': bsz, 'blockx': blockx} + meta['shared'] = 2*bsz*blockx*dsize + case 'cstream-ksplit': + ksplit = block[1] + csz = params['csz'] + args |= {'ksplit': ksplit, 'csz': csz, 'blockx': blockx} + meta['shared'] = (ksplit - 1)*csz*blockx*dsize + case _: + args['blockx'] = blockx + return tpl, args, meta + + def _dense_args(self, kernel_cfg, params, args, meta): + nn = params['nn'] + warps = params['warps'] + n_per_cta = 8 * nn * warps + if n_per_cta > self.n: + return None + + vector_width = kernel_cfg['vector_width'] + if (vector_width == 2 + and (self.aligne is None or self.aligne % 2 + or self.n % (8 * nn))): + return None + + block_steal = bool(params.get('block_stealing', False)) + setup = self._dense_common(nn, warps, block_steal) + tpl = f"{kernel_cfg['template']}-v{vector_width}" + args |= setup + meta['grid'] = (-(-self.n // n_per_cta), 1, 1) + + return tpl, args, meta + + def _dense_common(self, nn, warps_per_cta, block_steal): a = self.A m, k = a.shape m_tiles = (m + 7) // 8 @@ -274,6 +307,65 @@ def pm_runtime(mt): 'block_stealing': block_steal, } + def _dense_ws_args(self, kernel_cfg, params, smem_info, args, meta): + dynamic_max = smem_info[1] + nn = params['nn'] + warp_map = kernel_cfg['warp_map'] + n_comp_warps = warp_map['compute_count'] + n_per_cta = 8 * nn * n_comp_warps + if n_per_cta > self.n: + return None + + setup = self._dense_common(nn, n_comp_warps, True) + + # Warp Specialism Setup + b_tile_bytes = setup['k_pad'] * n_per_cta * 8 + c_tile_bytes = setup['m_pad'] * n_per_cta * 8 + a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 + + regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), + ('c', c_tile_bytes), ('a', a_bytes), ('wid', 16)] + mbars = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + offsets, dynm_total_bytes = self._dsmem_alloc(regions, mbars) + ws_setup = { + 'n_comp_warps': n_comp_warps, + 'blockx_total': 32 * (n_comp_warps + 2), + 'prod_warp': warp_map['producer'], + 'steal_warp': warp_map['stealer'], + 'comp_threads': 32 * n_comp_warps, + 'b_tile_bytes': b_tile_bytes, + 'c_mtile_smem_stride': 8 * n_per_cta * 8, + 'c_ntile_smem_stride': 8 * 8, + 'dynm_total_bytes': dynm_total_bytes, + } + + if ws_setup['dynm_total_bytes'] > dynamic_max: + return None + + args |= setup | ws_setup | offsets + meta |= { + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'ws_b_tile': (n_per_cta, setup['k_pad']), + 'dynamic_shared': ws_setup['dynm_total_bytes'], + } + if self.beta != 0: + meta['ws_out_tile'] = (n_per_cta, setup['m_pad']) + return kernel_cfg['template'], args, meta + + @staticmethod + def _dsmem_alloc(regions, mbars, align=16): + out, off = {}, 0 + for name, size in regions: + off = (off + align - 1) & ~(align - 1) + out[f'{name}_off'] = off + off += size + for name in mbars: + out[f'{name}_mbar_off'] = off + off += 8 + total = (off + align - 1) & ~(align - 1) + return out, total + @staticmethod def _pred_emit(instr, *preds, pred_reg=None, indent=' ' * 8): actual = [p for p in preds if p is not None] diff --git a/setup.py b/setup.py index f1e94ff..2d4545b 100755 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ # Data package_data = { - 'gimmik': ['kernels/*/*.mako'], + 'gimmik': ['kernels/*/*.mako', 'kernels/ptx/config/*.json'], } # Hard dependencies From e7d42524b4f05eaccb0a845d3b659d052b3d61a4 Mon Sep 17 00:00:00 2001 From: Will Trojak Date: Tue, 2 Jun 2026 17:21:03 +0000 Subject: [PATCH 14/21] Added grid stridded kernel and cleanup --- gimmik/kernels/ptx/config/sm90.json | 132 ++++---- gimmik/kernels/ptx/dmma-asmem-v1.mako | 70 ++--- gimmik/kernels/ptx/dmma-asmem-v2.mako | 74 ++--- gimmik/kernels/ptx/dmma-astream-v1.mako | 54 ++-- gimmik/kernels/ptx/dmma-astream-v2.mako | 54 ++-- gimmik/kernels/ptx/dmma-steal-ws.mako | 40 +-- gimmik/kernels/ptx/dmma-stride-ws.mako | 395 ++++++++++++++++++++++++ gimmik/ptx.py | 65 +++- 8 files changed, 655 insertions(+), 229 deletions(-) create mode 100644 gimmik/kernels/ptx/dmma-stride-ws.mako diff --git a/gimmik/kernels/ptx/config/sm90.json b/gimmik/kernels/ptx/config/sm90.json index 3b3c64d..ac1a0a6 100644 --- a/gimmik/kernels/ptx/config/sm90.json +++ b/gimmik/kernels/ptx/config/sm90.json @@ -6,7 +6,7 @@ ], "ptx": [ 8, - 0 + 6 ], "kernels": [ { @@ -87,7 +87,7 @@ { "template": "bstream-msplit", "block": [ - 32, + 64, 1, 1 ], @@ -95,10 +95,10 @@ "params": { "bsz": 24 }, - "descriptor": "bstream-msplit/m1-b24-x32" + "descriptor": "bstream-msplit/m1-b24-x64" }, { - "template": "cstream-ksplit", + "template": "bstream-msplit", "block": [ 32, 4, @@ -106,9 +106,9 @@ ], "width": 1, "params": { - "csz": 16 + "bsz": 32 }, - "descriptor": "cstream-ksplit/k4-c16-x32" + "descriptor": "bstream-msplit/m4-b32-x32" }, { "template": "dmma-asmem", @@ -173,48 +173,40 @@ "descriptor": "dmma-astream/v2/nn2-w2" }, { - "template": "dmma-astream", - "vector_width": 2, + "template": "dmma-stride-ws", "block": [ - 32, + 160, 1, 1 ], "width": 1, "params": { - "nn": 4, - "warps": 1 + "nn": 1, + "iters": 8 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4 }, "conditions": { - "all": [ - { - "field": "n", - "is_not": null - }, - { - "field": "aligne", - "is_not": null - }, - { - "field": "aligne", - "divisible_by": 2 - } - ] + "field": "n", + "is_not": null }, - "descriptor": "dmma-astream/v2/nn4-w1" + "descriptor": "dmma-stride-ws/nn1-w4-i8" }, { - "template": "dmma-asmem", + "template": "dmma-astream", "vector_width": 2, "block": [ - 128, + 32, 1, 1 ], "width": 1, "params": { - "nn": 2, - "warps": 4 + "nn": 4, + "warps": 1 }, "conditions": { "all": [ @@ -232,20 +224,20 @@ } ] }, - "descriptor": "dmma-asmem/v2/nn2-w4" + "descriptor": "dmma-astream/v2/nn4-w1" }, { - "template": "dmma-astream", + "template": "dmma-asmem", "vector_width": 2, "block": [ - 256, + 128, 1, 1 ], "width": 1, "params": { - "nn": 1, - "warps": 8 + "nn": 2, + "warps": 4 }, "conditions": { "all": [ @@ -263,41 +255,33 @@ } ] }, - "descriptor": "dmma-astream/v2/nn1-w8" + "descriptor": "dmma-asmem/v2/nn2-w4" }, { - "template": "dmma-asmem", - "vector_width": 2, + "template": "dmma-stride-ws", "block": [ - 256, + 160, 1, 1 ], "width": 1, "params": { "nn": 2, - "warps": 8 + "iters": 2 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4 }, "conditions": { - "all": [ - { - "field": "n", - "is_not": null - }, - { - "field": "aligne", - "is_not": null - }, - { - "field": "aligne", - "divisible_by": 2 - } - ] + "field": "n", + "is_not": null }, - "descriptor": "dmma-asmem/v2/nn2-w8" + "descriptor": "dmma-stride-ws/nn2-w4-i2" }, { - "template": "dmma-asmem", + "template": "dmma-astream", "vector_width": 2, "block": [ 128, @@ -306,7 +290,7 @@ ], "width": 1, "params": { - "nn": 4, + "nn": 1, "warps": 4 }, "conditions": { @@ -325,38 +309,30 @@ } ] }, - "descriptor": "dmma-asmem/v2/nn4-w4" + "descriptor": "dmma-astream/v2/nn1-w4" }, { - "template": "dmma-astream", - "vector_width": 2, + "template": "dmma-stride-ws", "block": [ - 128, + 160, 1, 1 ], "width": 1, "params": { - "nn": 1, - "warps": 4 + "nn": 2, + "iters": 8 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4 }, "conditions": { - "all": [ - { - "field": "n", - "is_not": null - }, - { - "field": "aligne", - "is_not": null - }, - { - "field": "aligne", - "divisible_by": 2 - } - ] + "field": "n", + "is_not": null }, - "descriptor": "dmma-astream/v2/nn1-w4" + "descriptor": "dmma-stride-ws/nn2-w4-i8" } ] } diff --git a/gimmik/kernels/ptx/dmma-asmem-v1.mako b/gimmik/kernels/ptx/dmma-asmem-v1.mako index 6938881..e34ef28 100644 --- a/gimmik/kernels/ptx/dmma-asmem-v1.mako +++ b/gimmik/kernels/ptx/dmma-asmem-v1.mako @@ -20,24 +20,24 @@ bs = bool(block_stealing) .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) { - .reg .u32 tid, warp, lane, r_mod4, r_div4; - .reg .u64 b_ptr, c_ptr; - .reg .u32 warp_n_base; - .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .${pftype} a_frag; + .reg .${pftype} a_frag; % if bs: - .reg .u32 ctaid; - .reg .u32 mbar_a, work_a; + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; .reg .pred p_root, p_done, p_have; % endif % for nt in range(nn): - .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; % if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; % endif - .reg .${pftype} b_frag_${nt}; - .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -62,7 +62,7 @@ bs = bool(block_stealing) // Cooperative copy A from .global to .shared { .reg .u64 a_glb_base, a_smem_base; - mov.u64 a_glb_base, ${kname}_Ag; + mov.u64 a_glb_base, ${kname}_Ag; cvta.to.global.u64 a_glb_base, a_glb_base; mov.u64 a_smem_base, ${kname}_As; % for ci in range(copy_v1_iters): @@ -79,17 +79,17 @@ bs = bool(block_stealing) add.u32 eidx, tid, ${base_elem}; setp.lt.u32 plast, eidx, ${a_elems}; mul.wide.u32 off64, eidx, ${dwidth_i}; - add.u64 gaddr, a_glb_base, off64; + add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; @plast ld.weak.global.cg.${pftype} v, [gaddr]; - @plast st.shared.${pftype} [saddr], v; + @plast st.shared.${pftype} [saddr], v; % else: add.u32 eidx, tid, ${base_elem}; mul.wide.u32 off64, eidx, ${dwidth_i}; - add.u64 gaddr, a_glb_base, off64; + add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; ld.weak.global.cg.${pftype} v, [gaddr]; - st.shared.${pftype} [saddr], v; + st.shared.${pftype} [saddr], v; % endif } % endfor @@ -99,10 +99,10 @@ bs = bool(block_stealing) // Lane-only base; lifted out of the optional steal loop { .reg .u64 t64, a_smem_base, lane64; - mov.u64 a_smem_base, ${kname}_As; - cvt.u64.u32 lane64, lane; - shl.b64 t64, lane64, 3; - add.u64 as_thr_base, a_smem_base, t64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; } % for mt in range(m_tiles): @@ -124,13 +124,13 @@ $L_LOOP: { .reg .u32 cta; % if bs: - mov.u32 cta, ctaid; + mov.u32 cta, ctaid; % else: - mov.u32 cta, %ctaid.x; + mov.u32 cta, %ctaid.x; % endif mul.lo.u32 cta, cta, ${n_per_cta}; mul.lo.u32 warp_n_base, warp, ${n_per_warp}; - add.u32 warp_n_base, warp_n_base, cta; + add.u32 warp_n_base, warp_n_base, cta; } setp.ge.u32 pwarp_exit, warp_n_base, ${n}; % if bs: @@ -150,7 +150,7 @@ $L_LOOP: add.u32 c_col1_${nt}, c_col0_${nt}, 1; } % if not n_col_aligned: - setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; % endif @@ -159,19 +159,19 @@ $L_LOOP: { .reg .u64 t64, bcol64; mul.wide.u32 t64, r_mod4, ${ldb}; - cvt.u64.u32 bcol64, b_col_0; - add.u64 t64, t64, bcol64; - shl.b64 t64, t64, 3; - add.u64 b_thr_base, b_ptr, t64; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; } { .reg .u64 t64, ccol64; mul.wide.u32 t64, r_div4, ${ldc}; - cvt.u64.u32 ccol64, c_col0_0; - add.u64 t64, t64, ccol64; - shl.b64 t64, t64, 3; - add.u64 c_thr_base, c_ptr, t64; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; } % for nt in range(nn): @@ -188,10 +188,10 @@ $L_LOOP: %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; % if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; % endif ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} @@ -246,7 +246,7 @@ $L_LOOP: %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } diff --git a/gimmik/kernels/ptx/dmma-asmem-v2.mako b/gimmik/kernels/ptx/dmma-asmem-v2.mako index 51462cf..98f491d 100644 --- a/gimmik/kernels/ptx/dmma-asmem-v2.mako +++ b/gimmik/kernels/ptx/dmma-asmem-v2.mako @@ -22,24 +22,24 @@ bs = bool(block_stealing) .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) { - .reg .u32 tid, warp, lane, r_mod4, r_div4; - .reg .u64 b_ptr, c_ptr; - .reg .u32 warp_n_base; - .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .${pftype} a_frag; + .reg .${pftype} a_frag; % if bs: - .reg .u32 ctaid; - .reg .u32 mbar_a, work_a; + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; .reg .pred p_root, p_done, p_have; % endif % for nt in range(nn): - .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; % if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; % endif - .reg .${pftype} b_frag_${nt}; - .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -64,7 +64,7 @@ bs = bool(block_stealing) // Cooperative copy A from .global to .shared { .reg .u64 a_glb_base, a_smem_base; - mov.u64 a_glb_base, ${kname}_Ag; + mov.u64 a_glb_base, ${kname}_Ag; cvta.to.global.u64 a_glb_base, a_glb_base; mov.u64 a_smem_base, ${kname}_As; % for ci in range(copy_v2_iters): @@ -81,17 +81,17 @@ bs = bool(block_stealing) add.u32 pidx, tid, ${base_pair}; setp.lt.u32 plast, pidx, ${a_pairs}; mul.wide.u32 off64, pidx, ${2 * dwidth_i}; - add.u64 gaddr, a_glb_base, off64; + add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; @plast ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; - @plast st.shared.v2.${pftype} [saddr], {v0, v1}; + @plast st.shared.v2.${pftype} [saddr], {v0, v1}; % else: add.u32 pidx, tid, ${base_pair}; mul.wide.u32 off64, pidx, ${2 * dwidth_i}; - add.u64 gaddr, a_glb_base, off64; + add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; - st.shared.v2.${pftype} [saddr], {v0, v1}; + st.shared.v2.${pftype} [saddr], {v0, v1}; % endif } % endfor @@ -102,10 +102,10 @@ bs = bool(block_stealing) .reg .u64 gaddr, saddr; .reg .${pftype} v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(32 * m_tiles * k_tiles - 1) * dwidth_i}; + add.u64 gaddr, a_glb_base, ${(32 * m_tiles * k_tiles - 1) * dwidth_i}; add.u64 saddr, a_smem_base, ${(32 * m_tiles * k_tiles - 1) * dwidth_i}; @plast ld.weak.global.cg.${pftype} v, [gaddr]; - @plast st.shared.${pftype} [saddr], v; + @plast st.shared.${pftype} [saddr], v; } % endif } @@ -114,10 +114,10 @@ bs = bool(block_stealing) // Lane-only base; lifted out of the optional steal loop { .reg .u64 t64, a_smem_base, lane64; - mov.u64 a_smem_base, ${kname}_As; - cvt.u64.u32 lane64, lane; - shl.b64 t64, lane64, 3; - add.u64 as_thr_base, a_smem_base, t64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; } % for mt in range(m_tiles): @@ -139,13 +139,13 @@ $L_LOOP: { .reg .u32 cta; % if bs: - mov.u32 cta, ctaid; + mov.u32 cta, ctaid; % else: - mov.u32 cta, %ctaid.x; + mov.u32 cta, %ctaid.x; % endif mul.lo.u32 cta, cta, ${n_per_cta}; mul.lo.u32 warp_n_base, warp, ${n_per_warp}; - add.u32 warp_n_base, warp_n_base, cta; + add.u32 warp_n_base, warp_n_base, cta; } setp.ge.u32 pwarp_exit, warp_n_base, ${n}; % if bs: @@ -165,7 +165,7 @@ $L_LOOP: add.u32 c_col1_${nt}, c_col0_${nt}, 1; } % if not n_col_aligned: - setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; % endif @@ -174,19 +174,19 @@ $L_LOOP: { .reg .u64 t64, bcol64; mul.wide.u32 t64, r_mod4, ${ldb}; - cvt.u64.u32 bcol64, b_col_0; - add.u64 t64, t64, bcol64; - shl.b64 t64, t64, 3; - add.u64 b_thr_base, b_ptr, t64; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; } { .reg .u64 t64, ccol64; mul.wide.u32 t64, r_div4, ${ldc}; - cvt.u64.u32 ccol64, c_col0_0; - add.u64 t64, t64, ccol64; - shl.b64 t64, t64, 3; - add.u64 c_thr_base, c_ptr, t64; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; } % for nt in range(nn): @@ -201,10 +201,10 @@ $L_LOOP: %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; % if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; % endif ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{c0_{nt}_{mt}, c1_{nt}_{mt}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}')} } @@ -256,7 +256,7 @@ $L_LOOP: %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{c0_{nt}_{mt}, c1_{nt}_{mt}}};', pm, pred_reg=f'p01s_{nt}_{mt}')} } % endfor diff --git a/gimmik/kernels/ptx/dmma-astream-v1.mako b/gimmik/kernels/ptx/dmma-astream-v1.mako index 48d41f3..455496a 100644 --- a/gimmik/kernels/ptx/dmma-astream-v1.mako +++ b/gimmik/kernels/ptx/dmma-astream-v1.mako @@ -7,19 +7,19 @@ .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) { - .reg .u32 tid, warp, lane, r_mod4, r_div4; - .reg .u64 b_ptr, c_ptr; - .reg .u32 warp_n_base; - .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .${pftype} a_frag; + .reg .${pftype} a_frag; % for nt in range(nn): - .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; % if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; % endif - .reg .${pftype} b_frag_${nt}; - .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -35,10 +35,10 @@ { .reg .u32 cta; - mov.u32 cta, %ctaid.x; + mov.u32 cta, %ctaid.x; mul.lo.u32 cta, cta, ${n_per_cta}; mul.lo.u32 warp_n_base, warp, ${n_per_warp}; - add.u32 warp_n_base, warp_n_base, cta; + add.u32 warp_n_base, warp_n_base, cta; } setp.ge.u32 pwarp_exit, warp_n_base, ${n}; @pwarp_exit bra $L_EXIT; @@ -54,7 +54,7 @@ add.u32 c_col1_${nt}, c_col0_${nt}, 1; } % if not n_col_aligned: - setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; % endif @@ -63,29 +63,29 @@ // A thread base: &Ag[0] + lane*8 { .reg .u64 t64, a_glb_base, lane64; - mov.u64 a_glb_base, ${kname}_Ag; + mov.u64 a_glb_base, ${kname}_Ag; cvta.to.global.u64 a_glb_base, a_glb_base; - cvt.u64.u32 lane64, lane; - shl.b64 t64, lane64, 3; - add.u64 ag_thr_base, a_glb_base, t64; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; } { .reg .u64 t64, bcol64; mul.wide.u32 t64, r_mod4, ${ldb}; - cvt.u64.u32 bcol64, b_col_0; - add.u64 t64, t64, bcol64; - shl.b64 t64, t64, 3; - add.u64 b_thr_base, b_ptr, t64; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; } { .reg .u64 t64, ccol64; mul.wide.u32 t64, r_div4, ${ldc}; - cvt.u64.u32 ccol64, c_col0_0; - add.u64 t64, t64, ccol64; - shl.b64 t64, t64, 3; - add.u64 c_thr_base, c_ptr, t64; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; } % for mt in range(m_tiles): @@ -113,10 +113,10 @@ %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; % if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; % endif ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} @@ -171,7 +171,7 @@ %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } diff --git a/gimmik/kernels/ptx/dmma-astream-v2.mako b/gimmik/kernels/ptx/dmma-astream-v2.mako index 16711d5..4700fd0 100644 --- a/gimmik/kernels/ptx/dmma-astream-v2.mako +++ b/gimmik/kernels/ptx/dmma-astream-v2.mako @@ -7,19 +7,19 @@ .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) { - .reg .u32 tid, warp, lane, r_mod4, r_div4; - .reg .u64 b_ptr, c_ptr; - .reg .u32 warp_n_base; - .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .${pftype} a_frag; + .reg .${pftype} a_frag; % for nt in range(nn): - .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; % if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; % endif - .reg .${pftype} b_frag_${nt}; - .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -35,10 +35,10 @@ { .reg .u32 cta; - mov.u32 cta, %ctaid.x; + mov.u32 cta, %ctaid.x; mul.lo.u32 cta, cta, ${n_per_cta}; mul.lo.u32 warp_n_base, warp, ${n_per_warp}; - add.u32 warp_n_base, warp_n_base, cta; + add.u32 warp_n_base, warp_n_base, cta; } setp.ge.u32 pwarp_exit, warp_n_base, ${n}; @pwarp_exit bra $L_EXIT; @@ -54,7 +54,7 @@ add.u32 c_col1_${nt}, c_col0_${nt}, 1; } % if not n_col_aligned: - setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; % endif @@ -63,29 +63,29 @@ // A thread base: &Ag[0] + lane*8 { .reg .u64 t64, a_glb_base, lane64; - mov.u64 a_glb_base, ${kname}_Ag; + mov.u64 a_glb_base, ${kname}_Ag; cvta.to.global.u64 a_glb_base, a_glb_base; - cvt.u64.u32 lane64, lane; - shl.b64 t64, lane64, 3; - add.u64 ag_thr_base, a_glb_base, t64; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; } { .reg .u64 t64, bcol64; mul.wide.u32 t64, r_mod4, ${ldb}; - cvt.u64.u32 bcol64, b_col_0; - add.u64 t64, t64, bcol64; - shl.b64 t64, t64, 3; - add.u64 b_thr_base, b_ptr, t64; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; } { .reg .u64 t64, ccol64; mul.wide.u32 t64, r_div4, ${ldc}; - cvt.u64.u32 ccol64, c_col0_0; - add.u64 t64, t64, ccol64; - shl.b64 t64, t64, 3; - add.u64 c_thr_base, c_ptr, t64; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; } % for mt in range(m_tiles): @@ -111,10 +111,10 @@ %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; % if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; % endif ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{c0_{nt}_{mt}, c1_{nt}_{mt}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}')} } @@ -166,7 +166,7 @@ %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{c0_{nt}_{mt}, c1_{nt}_{mt}}};', pm, pred_reg=f'p01s_{nt}_{mt}')} } % endfor diff --git a/gimmik/kernels/ptx/dmma-steal-ws.mako b/gimmik/kernels/ptx/dmma-steal-ws.mako index 982b1b8..21ad486 100644 --- a/gimmik/kernels/ptx/dmma-steal-ws.mako +++ b/gimmik/kernels/ptx/dmma-steal-ws.mako @@ -7,7 +7,7 @@ .reg .b32 n_start0; .reg .u64 a_glb; mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; - mov.u64 a_glb, ${kname}_Ag; + mov.u64 a_glb, ${kname}_Ag; cvta.to.global.u64 a_glb, a_glb; @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [a_smem], [a_glb], ${8 * 32 * m_tiles * k_tiles}, [tma_mbar]; @@ -57,9 +57,9 @@ $L_WAIT_BRDY: .reg .b32 b_thr_a_${nt}; { .reg .b32 bcol_g, t_off; - add.u32 bcol_g, base_bcol, ${8 * nt}; - shl.b32 t_off, bcol_g, 3; - add.u32 b_thr_a_${nt}, b_sm_a, t_off; + add.u32 bcol_g, base_bcol, ${8 * nt}; + shl.b32 t_off, bcol_g, 3; + add.u32 b_thr_a_${nt}, b_sm_a, t_off; } % endfor @@ -73,10 +73,10 @@ $L_WAIT_BRDY: .reg .b32 c_thr_smem; { .reg .b32 t1, ccol_b; - mul.lo.u32 t1, base_crow, ${n_per_cta * dwidth_i}; - shl.b32 ccol_b, base_ccol, 3; - add.u32 c_thr_smem, c_smem, t1; - add.u32 c_thr_smem, c_thr_smem, ccol_b; + mul.lo.u32 t1, base_crow, ${n_per_cta * dwidth_i}; + shl.b32 ccol_b, base_ccol, 3; + add.u32 c_thr_smem, c_smem, t1; + add.u32 c_thr_smem, c_thr_smem, ccol_b; } % endif @@ -97,13 +97,13 @@ $L_WAIT_BRDY: %> { .reg .b32 a_a; - add.u32 a_a, a_thr_a, ${(32 * kt + 32 * mt * k_tiles) * dwidth_i}; + add.u32 a_a, a_thr_a, ${(32 * kt + 32 * mt * k_tiles) * dwidth_i}; ld.shared.${pftype} a_f, [a_a]; % if k_tail: .reg .pred pbrow_${mt}_${kt}; { .reg .b32 brow; - add.u32 brow, base_brow, ${4 * kt}; + add.u32 brow, base_brow, ${4 * kt}; setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; } % endif @@ -111,9 +111,9 @@ $L_WAIT_BRDY: { .reg .b32 b_a, b_row; .reg .${pftype} b_f; - add.u32 b_row, base_brow, ${4 * kt}; - mul.lo.u32 b_row, b_row, ${n_per_cta * dwidth_i}; - add.u32 b_a, b_thr_a_${nt}, b_row; + add.u32 b_row, base_brow, ${4 * kt}; + mul.lo.u32 b_row, b_row, ${n_per_cta * dwidth_i}; + add.u32 b_a, b_thr_a_${nt}, b_row; % if k_tail: mov.${pftype} b_f, ${fzero}; @pbrow_${mt}_${kt} ld.shared.${pftype} b_f, [b_a]; @@ -369,8 +369,8 @@ $L_AFTER_CTRL: setp.eq.u32 p_tid0, tid, 0; setp.lt.u32 p_compute, warp, ${n_comp_warps}; - setp.eq.u32 p_prod, warp, ${prod_warp}; - setp.eq.u32 p_steal, warp, ${steal_warp}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + setp.eq.u32 p_steal, warp, ${steal_warp}; { .reg .b32 _elect_lane; @@ -398,12 +398,12 @@ $L_AFTER_CTRL: // Compute-warp lane geometry { .reg .b32 t, w_n_base; - and.b32 base_brow, lane, 3; - shr.u32 base_crow, lane, 2; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; mul.lo.u32 w_n_base, warp, ${n_per_warp}; - add.u32 base_bcol, base_crow, w_n_base; - shl.b32 t, base_brow, 1; - add.u32 base_ccol, t, w_n_base; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; } ${producer_init_setup()} diff --git a/gimmik/kernels/ptx/dmma-stride-ws.mako b/gimmik/kernels/ptx/dmma-stride-ws.mako new file mode 100644 index 0000000..0a34b44 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-stride-ws.mako @@ -0,0 +1,395 @@ +<%inherit file='base'/> + +<%def name="producer_init_setup()"> + // Producer warp: initial A bulk-copy + first B load + @!p_prod bra.uni $L_AFTER_INIT_B; + { + .reg .b32 n_start0; + .reg .u64 a_glb; + mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + mov.u64 a_glb, ${kname}_Ag; + cvta.to.global.u64 a_glb, a_glb; + @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes + [a_smem], [a_glb], ${8 * 32 * m_tiles * k_tiles}, [tma_mbar]; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes + 8 * 32 * m_tiles * k_tiles}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_INIT_W: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_INIT_W; + .reg .b64 _state2; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; + } +$L_AFTER_INIT_B: + + +<%def name="compute_warp_body()"> + // --- Compute Warps + @!p_compute bra.uni $L_AFTER_COMPUTE; + + // Wait on the current B tile. + { + .reg .pred p1; +$L_WAIT_BRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BRDY; + } + + // MMA + { + .reg .b32 b_sm_a; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_sm_a, b2_smem, b1_smem, p_ph; + + .reg .b32 a_thr_a; + { + .reg .b32 t; + shl.b32 t, lane, 3; + add.u32 a_thr_a, a_smem, t; + } +% for nt in range(nn): + .reg .b32 b_thr_a_${nt}; + { + .reg .b32 bcol_g, t_off; + add.u32 bcol_g, base_bcol, ${8 * nt}; + shl.b32 t_off, bcol_g, 3; + add.u32 b_thr_a_${nt}, b_sm_a, t_off; + } +% endfor + +% if beta_zero: + // beta=0: skip shared-staging entirely; compute warps store MMA + // outputs straight to global C with N-tail predication. + .reg .u64 c_glob_addr; + ld.param.u64 c_glob_addr, [c_desc]; + cvta.to.global.u64 c_glob_addr, c_glob_addr; +% else: + .reg .b32 c_thr_smem; + { + .reg .b32 t1, ccol_b; + mul.lo.u32 t1, base_crow, ${n_per_cta * dwidth_i}; + shl.b32 ccol_b, base_ccol, 3; + add.u32 c_thr_smem, c_smem, t1; + add.u32 c_thr_smem, c_thr_smem, ccol_b; + } +% endif + + // Zero accumulators +% for mt in range(m_tiles): +% for nt in range(nn): + .reg .${pftype} d_x_${mt}_${nt}, d_y_${mt}_${nt}; + mov.${pftype} d_x_${mt}_${nt}, ${fzero}; + mov.${pftype} d_y_${mt}_${nt}, ${fzero}; +% endfor +% endfor + + .reg .${pftype} a_f; +% for mt in range(m_tiles): +% for kt in range(k_tiles): +<% + k_tail = (k_rem != 0 and loop.last) +%> + { + .reg .b32 a_a; + add.u32 a_a, a_thr_a, ${(32 * kt + 32 * mt * k_tiles) * dwidth_i}; + ld.shared.${pftype} a_f, [a_a]; +% if k_tail: + .reg .pred pbrow_${mt}_${kt}; + { + .reg .b32 brow; + add.u32 brow, base_brow, ${4 * kt}; + setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; + } +% endif +% for nt in range(nn): + { + .reg .b32 b_a, b_row; + .reg .${pftype} b_f; + add.u32 b_row, base_brow, ${4 * kt}; + mul.lo.u32 b_row, b_row, ${n_per_cta * dwidth_i}; + add.u32 b_a, b_thr_a_${nt}, b_row; +% if k_tail: + mov.${pftype} b_f, ${fzero}; + @pbrow_${mt}_${kt} ld.shared.${pftype} b_f, [b_a]; +% else: + ld.shared.${pftype} b_f, [b_a]; +% endif + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor + } +% endfor +% endfor + +% if beta_zero: + .reg .u64 c_thr_glob_base; + { + .reg .u32 thr_col_off, thr_addr_off_lo; + add.u32 thr_col_off, base_ccol, n_start_curr; + mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; + .reg .u64 thr_byte_off; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, ${dwidth_i}; + add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; + } +% for mt in range(m_tiles): +<% + row_tail = (m_pad > m) and ((mt + 1) * 8 > m) +%> +% if row_tail: + .reg .pred p_row_${mt}; + { + .reg .b32 crow; + add.u32 crow, base_crow, ${8 * mt}; + setp.lt.u32 p_row_${mt}, crow, ${m}; + } +% endif +% for nt in range(nn): + { + .reg .pred p_st; + .reg .u32 g_ccol; + add.u32 g_ccol, base_ccol, ${8 * nt}; + add.u32 g_ccol, g_ccol, n_start_curr; + setp.lt.u32 p_st, g_ccol, ${n}; +% if row_tail: + and.pred p_st, p_st, p_row_${mt}; +% endif + .reg .u64 _c_addr; + add.u64 _c_addr, c_thr_glob_base, ${(8 * mt * ldc + 8 * nt) * dwidth_i}; + @p_st st.weak.global.v2.${pftype} [_c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% else: + // Wait until producer's prev-iter TMA-store of C has drained. + { + .reg .pred p1; +$L_WAIT_CSTORE: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cstored_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CSTORE; + } + + // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows + // are dropped by the C tensor map. +% for mt in range(m_tiles): +% for nt in range(nn): + { + .reg .b32 csaddr; + add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; + st.shared.v2.${pftype} [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% endif + +% if not beta_zero: + bar.sync 1, ${comp_threads}; + fence.proxy.async.shared::cta; + { + .reg .b64 _state; + @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; + } +% endif + + // Match the stealing kernel's "work used" point: retire the current + // phase only after compute-side work for the tile is complete. + { + .reg .b64 _bconsumed_state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bconsumed_state, [bconsumed_mbar]; + } + + } +$L_AFTER_COMPUTE: + + +<%def name="data_warp_body()"> + // --- Data Movement Warp + @!p_prod bra.uni $L_AFTER_DATA; + { + .reg .b32 n_c_store; + mul.lo.u32 n_c_store, block_idx_x, ${n_per_cta}; + + .reg .pred p_next_work; + setp.ne.u32 p_next_work, next_work, 0; + + // Issue the next B load into the alternate buffer before retiring + // the current phase. + .reg .b64 b_state; + @!p_next_work bra.uni $L_SKIP_NEXT_B_ISSUE; + { + mul.lo.u32 n_start_next, next_block_idx_x, ${n_per_cta}; + .reg .b32 b_next; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_next, b1_smem, b2_smem, p_ph; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_next], [bdesc_addr, {n_start_next, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + @p_warp_lead cp.async.bulk.commit_group; + bar.warp.sync 0xffffffff; + mbarrier.arrive.shared::cta.b64 b_state, [tma_mbar]; + } +$L_SKIP_NEXT_B_ISSUE: + +% if not beta_zero: + // TMA reduce+store of C (beta=1 only; beta=0 uses direct global + // stores from compute warps, so the producer does no C work). + { + .reg .pred p1; + .reg .b64 _c_state; +$L_WAIT_CRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CRDY; + @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; + @p_warp_lead cp.async.bulk.commit_group; + @p_warp_lead cp.async.bulk.wait_group 0; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; + } +% endif + + // Wait for the next B load to complete, then mark it ready for the + // next compute iteration. + @!p_next_work bra.uni $L_SKIP_NEXT_B_READY; + { + .reg .b64 _bready_state; + .reg .pred p1; +$L_WAIT_TMA: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], b_state, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_TMA; + +$L_WAIT_BCONSUMED: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bconsumed_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BCONSUMED; + + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bready_state, [bready_mbar]; + } +$L_SKIP_NEXT_B_READY: + } +$L_AFTER_DATA: + + +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; + +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 c_desc) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, next_work, iter, next_iter; + .reg .b32 block_idx_x, next_block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, bready_mbar, bconsumed_mbar, cready_mbar, cstored_mbar; + .reg .pred p_compute, p_prod; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 bconsumed_mbar, dynm_base, ${bconsumed_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 cdesc_addr, [c_desc]; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive cstored so compute iter 0 + // can write csmem immediately when beta != 0. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [bconsumed_mbar], ${n_comp_warps}; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Compute-warp lane geometry + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + ${producer_init_setup()} + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 iter, 0; + mov.u32 phase, 0; + + // Bounded grid-stride loop: block_idx_x = ctaid.x + iter*grid_stride. +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + { + .reg .pred p_iter, p_block, p_next; + add.u32 next_iter, iter, 1; + add.u32 next_block_idx_x, block_idx_x, ${grid_stride}; + setp.lt.u32 p_iter, next_iter, ${stride_iters}; + setp.lt.u32 p_block, next_block_idx_x, ${work_blocks}; + and.pred p_next, p_iter, p_block; + selp.b32 next_work, 1, 0, p_next; + } + + ${compute_warp_body()} + + ${data_warp_body()} + + mov.u32 block_idx_x, next_block_idx_x; + mov.u32 work, next_work; + mov.u32 iter, next_iter; + xor.b32 phase, phase, 1; + bra.uni $L_LOOP; + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py index e26560b..00fc16a 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -16,7 +16,7 @@ class PTXMatMul(MatMul): } # Map Supported CC -> Minimum PTX version - PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 0), (10, 0): (8, 7), (10, 3): (8, 7), + PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 6), (10, 0): (8, 7), (10, 3): (8, 7), (12, 0): (8, 7), (12, 1): (8, 7)} PTX_TEMPLATE_FAMILY = { @@ -28,6 +28,7 @@ class PTXMatMul(MatMul): 'dmma-astream': 'dense', 'dmma-asmem': 'dense', 'dmma-steal-ws': 'dense', + 'dmma-stride-ws': 'dense', } def __init__(self, *args, **kwargs): @@ -206,9 +207,12 @@ def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, cfg = self._sparse_args(tpl, params, block, dtype, dsize, base_args, base_meta) elif self.PTX_TEMPLATE_FAMILY[tpl] == 'dense': - if tpl.endswith('ws'): - cfg = self._dense_ws_args(kernel_cfg, params, smem_info, - base_args, base_meta) + if tpl == 'dmma-steal-ws': + cfg = self._dense_steal_ws_args(kernel_cfg, params, smem_info, + base_args, base_meta) + elif tpl == 'dmma-stride-ws': + cfg = self._dense_stride_ws_args(kernel_cfg, params, smem_info, + base_args, base_meta) else: cfg = self._dense_args(kernel_cfg, params, base_args, base_meta) @@ -307,7 +311,7 @@ def pm_runtime(mt): 'block_stealing': block_steal, } - def _dense_ws_args(self, kernel_cfg, params, smem_info, args, meta): + def _dense_steal_ws_args(self, kernel_cfg, params, smem_info, args, meta): dynamic_max = smem_info[1] nn = params['nn'] warp_map = kernel_cfg['warp_map'] @@ -353,6 +357,57 @@ def _dense_ws_args(self, kernel_cfg, params, smem_info, args, meta): meta['ws_out_tile'] = (n_per_cta, setup['m_pad']) return kernel_cfg['template'], args, meta + def _dense_stride_ws_args(self, kernel_cfg, params, smem_info, args, meta): + dynamic_max = smem_info[1] + nn = params['nn'] + stride_iters = params['iters'] + warp_map = kernel_cfg['warp_map'] + n_comp_warps = warp_map['compute_count'] + n_per_cta = 8 * nn * n_comp_warps + if n_per_cta > self.n: + return None + + setup = self._dense_common(nn, n_comp_warps, False) + + # Warp Specialism Setup + b_tile_bytes = setup['k_pad'] * n_per_cta * 8 + c_tile_bytes = setup['m_pad'] * n_per_cta * 8 + a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 + + regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), + ('c', c_tile_bytes), ('a', a_bytes)] + mbars = ('tma', 'bready', 'bconsumed', 'cready', 'cstored') + offsets, dynm_total_bytes = self._dsmem_alloc(regions, mbars) + + work_blocks = -(-self.n // n_per_cta) + grid_stride = -(-work_blocks // stride_iters) + ws_setup = { + 'n_comp_warps': n_comp_warps, + 'blockx_total': 32 * (n_comp_warps + 1), + 'prod_warp': warp_map['producer'], + 'comp_threads': 32 * n_comp_warps, + 'b_tile_bytes': b_tile_bytes, + 'c_mtile_smem_stride': 8 * n_per_cta * 8, + 'c_ntile_smem_stride': 8 * 8, + 'stride_iters': stride_iters, + 'grid_stride': grid_stride, + 'work_blocks': work_blocks, + 'dynm_total_bytes': dynm_total_bytes, + } + + if ws_setup['dynm_total_bytes'] > dynamic_max: + return None + + args |= setup | ws_setup | offsets + meta |= { + 'grid': (grid_stride, 1, 1), + 'ws_b_tile': (n_per_cta, setup['k_pad']), + 'dynamic_shared': ws_setup['dynm_total_bytes'], + } + if self.beta != 0: + meta['ws_out_tile'] = (n_per_cta, setup['m_pad']) + return kernel_cfg['template'], args, meta + @staticmethod def _dsmem_alloc(regions, mbars, align=16): out, off = {}, 0 From d4e12160134926d2a13de2d5afcdd3c473c25b62 Mon Sep 17 00:00:00 2001 From: Will Trojak Date: Wed, 3 Jun 2026 10:25:15 +0000 Subject: [PATCH 15/21] msplit dmma --- .../kernels/ptx/dmma-astream-msplit-v1.mako | 248 ++++++++++++++++++ .../kernels/ptx/dmma-astream-msplit-v2.mako | 242 +++++++++++++++++ gimmik/ptx.py | 63 +++++ 3 files changed, 553 insertions(+) create mode 100644 gimmik/kernels/ptx/dmma-astream-msplit-v1.mako create mode 100644 gimmik/kernels/ptx/dmma-astream-msplit-v2.mako diff --git a/gimmik/kernels/ptx/dmma-astream-msplit-v1.mako b/gimmik/kernels/ptx/dmma-astream-msplit-v1.mako new file mode 100644 index 0000000..1158d48 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-astream-msplit-v1.mako @@ -0,0 +1,248 @@ +<%inherit file='base'/> + +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; + +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u32 ctaid_x, n_start_cta, warp_n, warp_m, warp_n_base; + .reg .u64 bdesc_addr, c_ptr; + .reg .u64 ag_thr_base, c_thr_base; + .reg .u32 b_smem, b_thr_base, tma_mbar; + .reg .pred p_tid0, pwarp_exit, p_load_warp, p_warp_lead; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif +% endfor + + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + mov.u32 ctaid_x, %ctaid.x; + mul.lo.u32 n_start_cta, ctaid_x, ${n_per_cta}; + + { + .reg .u32 t; + div.u32 warp_n, warp, ${msplit}; + mad.lo.u32 t, warp_n, ${msplit}, 0; + sub.u32 warp_m, warp, t; + } + + { + .reg .u32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b_smem, dynm_base, ${b_off}; + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + } + + setp.eq.u32 p_tid0, tid, 0; + setp.eq.u32 p_load_warp, warp, 0; + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + @p_tid0 mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_tid0 fence.proxy.async.shared::cta; + bar.sync 0; + + @!p_load_warp bra $L_AFTER_B_TMA; + { + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_smem], [bdesc_addr, {n_start_cta, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_WAIT: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_WAIT; + } +$L_AFTER_B_TMA: + bar.sync 0; + + { + .reg .u32 t; + mul.lo.u32 t, warp_n, ${n_per_warp}; + add.u32 warp_n_base, n_start_cta, t; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*8 + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u32 bcol_local, t, row_off; + mad.lo.u32 bcol_local, warp_n, ${n_per_warp}, r_div4; + mul.lo.u32 t, bcol_local, ${dwidth_i}; + mul.lo.u32 row_off, r_mod4, ${n_per_cta * dwidth_i}; + add.u32 t, t, row_off; + add.u32 b_thr_base, b_smem, t; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for wm in range(msplit): +<% + owned_mts = [mt for mt in range(m_tiles) if mt % msplit == wm] +%> +% if owned_mts: + { + .reg .pred p_this_msplit; + setp.ne.u32 p_this_msplit, warp_m, ${wm}; + @p_this_msplit bra $L_SKIP_MS_${wm}; + } + { + .reg .${pftype} a_frag; +% for nt in range(nn): + .reg .${pftype} b_frag_${nt}; +% endfor +% for nt in range(nn): +% for mt in owned_mts: + .reg .${pftype} c0_${nt}_${mt}, c1_${nt}_${mt}; +% endfor +% endfor +% for mt in owned_mts: +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${8 * mt}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% for nt in range(nn): +% for mt in owned_mts: +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{wm}_{nt}_{mt}', indent=' ' * 12)} + ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{wm}_{nt}_{mt}', indent=' ' * 12)} + } +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> + { + .reg .u32 baddr; + add.u32 baddr, b_thr_base, ${ki * b_smem_kiter_stride + nt * b_smem_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: + .reg .pred pbrow; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${4 * ki}; + setp.lt.u32 pbrow, brow, ${k}; + } +% endif + ${pred_emit(f'ld.shared.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{wm}_{ki}_{nt}', indent=' ' * 12)} + } +% endfor +% for mt in owned_mts: + ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for mt in owned_mts: +% for nt in range(nn): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{wm}_{nt}_{mt}', indent=' ' * 12)} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{wm}_{nt}_{mt}', indent=' ' * 12)} + } +% endfor +% endfor + } +$L_SKIP_MS_${wm}: +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako b/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako new file mode 100644 index 0000000..4b41f67 --- /dev/null +++ b/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako @@ -0,0 +1,242 @@ +<%inherit file='base'/> + +.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; + +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u32 ctaid_x, n_start_cta, warp_n, warp_m, warp_n_base; + .reg .u64 bdesc_addr, c_ptr; + .reg .u64 ag_thr_base, c_thr_base; + .reg .u32 b_smem, b_thr_base, tma_mbar; + .reg .pred p_tid0, pwarp_exit, p_load_warp, p_warp_lead; +% for nt in range(nn): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif +% endfor + + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + mov.u32 ctaid_x, %ctaid.x; + mul.lo.u32 n_start_cta, ctaid_x, ${n_per_cta}; + + { + .reg .u32 t; + div.u32 warp_n, warp, ${msplit}; + mad.lo.u32 t, warp_n, ${msplit}, 0; + sub.u32 warp_m, warp, t; + } + + { + .reg .u32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b_smem, dynm_base, ${b_off}; + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + } + + setp.eq.u32 p_tid0, tid, 0; + setp.eq.u32 p_load_warp, warp, 0; + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + @p_tid0 mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_tid0 fence.proxy.async.shared::cta; + bar.sync 0; + + @!p_load_warp bra $L_AFTER_B_TMA; + { + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_smem], [bdesc_addr, {n_start_cta, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_WAIT: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_WAIT; + } +$L_AFTER_B_TMA: + bar.sync 0; + + { + .reg .u32 t; + mul.lo.u32 t, warp_n, ${n_per_warp}; + add.u32 warp_n_base, n_start_cta, t; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(nn): + add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } +% if not n_col_aligned: + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif +% endfor + + // A thread base: &Ag[0] + lane*8 + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u32 bcol_local, t, row_off; + mad.lo.u32 bcol_local, warp_n, ${n_per_warp}, r_div4; + mul.lo.u32 t, bcol_local, ${dwidth_i}; + mul.lo.u32 row_off, r_mod4, ${n_per_cta * dwidth_i}; + add.u32 t, t, row_off; + add.u32 b_thr_base, b_smem, t; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for wm in range(msplit): +<% + owned_mts = [mt for mt in range(m_tiles) if mt % msplit == wm] +%> +% if owned_mts: + { + .reg .pred p_this_msplit; + setp.ne.u32 p_this_msplit, warp_m, ${wm}; + @p_this_msplit bra $L_SKIP_MS_${wm}; + } + { + .reg .${pftype} a_frag; +% for nt in range(nn): + .reg .${pftype} b_frag_${nt}; +% endfor +% for nt in range(nn): +% for mt in owned_mts: + .reg .${pftype} c0_${nt}_${mt}, c1_${nt}_${mt}; +% endfor +% endfor +% for mt in owned_mts: +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${8 * mt}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% for nt in range(nn): +% for mt in owned_mts: +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + needs_zero_init = pm is not None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{c0_{nt}_{mt}, c1_{nt}_{mt}}}, [caddr];', pm, pred_reg=f'p01_{wm}_{nt}_{mt}', indent=' ' * 12)} + } +% endif +% endfor +% endfor + +% for ki in range(k_tiles): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and loop.parent.last) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> + { + .reg .u32 baddr; + add.u32 baddr, b_thr_base, ${ki * b_smem_kiter_stride + nt * b_smem_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: + .reg .pred pbrow; + { + .reg .u32 brow; + add.u32 brow, r_mod4, ${4 * ki}; + setp.lt.u32 pbrow, brow, ${k}; + } +% endif + ${pred_emit(f'ld.shared.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{wm}_{ki}_{nt}', indent=' ' * 12)} + } +% endfor +% for mt in owned_mts: + ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for mt in owned_mts: +% for nt in range(nn): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None +%> + { + .reg .u64 caddr; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{c0_{nt}_{mt}, c1_{nt}_{mt}}};', pm, pred_reg=f'p01s_{wm}_{nt}_{mt}', indent=' ' * 12)} + } +% endfor +% endfor + } +$L_SKIP_MS_${wm}: +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py index 00fc16a..50d06ed 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -26,6 +26,7 @@ class PTXMatMul(MatMul): 'cstream-ksplit': 'sparse', 'cstream-w2': 'sparse', 'dmma-astream': 'dense', + 'dmma-astream-msplit': 'dense', 'dmma-asmem': 'dense', 'dmma-steal-ws': 'dense', 'dmma-stride-ws': 'dense', @@ -213,6 +214,10 @@ def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, elif tpl == 'dmma-stride-ws': cfg = self._dense_stride_ws_args(kernel_cfg, params, smem_info, base_args, base_meta) + elif tpl == 'dmma-astream-msplit': + cfg = self._dense_astream_msplit_args(kernel_cfg, params, + smem_info, base_args, + base_meta) else: cfg = self._dense_args(kernel_cfg, params, base_args, base_meta) @@ -408,6 +413,64 @@ def _dense_stride_ws_args(self, kernel_cfg, params, smem_info, args, meta): meta['ws_out_tile'] = (n_per_cta, setup['m_pad']) return kernel_cfg['template'], args, meta + def _dense_astream_msplit_args(self, kernel_cfg, params, smem_info, args, + meta): + dynamic_max = smem_info[1] + nn = params.get('nn') + warps = params.get('warps') + msplit = params.get('msplit') + vector_width = kernel_cfg.get('vector_width') + block = tuple(kernel_cfg['block']) + width = kernel_cfg['width'] + + for name, val in (('nn', nn), ('warps', warps), ('msplit', msplit)): + if not isinstance(val, int) or val <= 0: + raise ValueError(f'dmma-astream-msplit params.{name} must be ' + 'a positive integer') + + if block != (32 * warps * msplit, 1, 1) or width != 1: + raise ValueError('dmma-astream-msplit block/width mismatch') + + n_per_cta = 8 * nn * warps + if n_per_cta > self.n: + return None + + if vector_width not in {1, 2}: + raise ValueError('dmma-astream-msplit vector_width must be 1 or 2') + + if (vector_width == 2 + and (self.aligne is None or self.aligne % 2 + or self.n % (8 * nn))): + return None + + setup = self._dense_common(nn, warps, False) + + b_tile_bytes = setup['k_pad'] * n_per_cta * args['dwidth_i'] + regions = [('b', b_tile_bytes)] + offsets, dynm_total_bytes = self._dsmem_alloc(regions, ('tma',)) + msplit_setup = { + 'msplit': msplit, + 'm_tiles_per_group': -(-setup['m_tiles'] // msplit), + 'b_tile_bytes': b_tile_bytes, + 'b_smem_kiter_stride': 4 * n_per_cta * args['dwidth_i'], + 'b_smem_ntile_stride': 8 * args['dwidth_i'], + 'blockx_total': 32 * warps * msplit, + 'dynm_total_bytes': dynm_total_bytes, + } + + if msplit_setup['dynm_total_bytes'] > dynamic_max: + return None + + args |= setup | msplit_setup | offsets + meta |= { + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'ws_b_tile': (n_per_cta, setup['k_pad']), + 'dynamic_shared': msplit_setup['dynm_total_bytes'], + } + + tpl = f"{kernel_cfg['template']}-v{vector_width}" + return tpl, args, meta + @staticmethod def _dsmem_alloc(regions, mbars, align=16): out, off = {}, 0 From b9ac47cb65387fc8386053e70bce7fd8fff1bba2 Mon Sep 17 00:00:00 2001 From: Will Trojak Date: Wed, 3 Jun 2026 16:43:09 +0000 Subject: [PATCH 16/21] Refactored arg setup --- gimmik/ptx.py | 179 +++++++++++++++++++------------------------------- 1 file changed, 67 insertions(+), 112 deletions(-) diff --git a/gimmik/ptx.py b/gimmik/ptx.py index 50d06ed..75f9697 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -208,12 +208,9 @@ def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, cfg = self._sparse_args(tpl, params, block, dtype, dsize, base_args, base_meta) elif self.PTX_TEMPLATE_FAMILY[tpl] == 'dense': - if tpl == 'dmma-steal-ws': - cfg = self._dense_steal_ws_args(kernel_cfg, params, smem_info, - base_args, base_meta) - elif tpl == 'dmma-stride-ws': - cfg = self._dense_stride_ws_args(kernel_cfg, params, smem_info, - base_args, base_meta) + if tpl in {'dmma-steal-ws', 'dmma-stride-ws'}: + cfg = self._dense_ws_args(kernel_cfg, params, smem_info, + base_args, base_meta) elif tpl == 'dmma-astream-msplit': cfg = self._dense_astream_msplit_args(kernel_cfg, params, smem_info, base_args, @@ -254,30 +251,38 @@ def _sparse_args(self, tpl, params, block, dtype, dsize, args, def _dense_args(self, kernel_cfg, params, args, meta): nn = params['nn'] warps = params['warps'] - n_per_cta = 8 * nn * warps - if n_per_cta > self.n: - return None - vector_width = kernel_cfg['vector_width'] - if (vector_width == 2 - and (self.aligne is None or self.aligne % 2 - or self.n % (8 * nn))): + + setup = self._dense_common( + nn, warps, bool(params.get('block_stealing', False)), + vector_width + ) + if setup is None: return None - block_steal = bool(params.get('block_stealing', False)) - setup = self._dense_common(nn, warps, block_steal) tpl = f"{kernel_cfg['template']}-v{vector_width}" args |= setup - meta['grid'] = (-(-self.n // n_per_cta), 1, 1) + meta['grid'] = (-(-self.n // setup['n_per_cta']), 1, 1) return tpl, args, meta - def _dense_common(self, nn, warps_per_cta, block_steal): + def _dense_common(self, nn, warps_per_cta, block_steal, + vector_width=None): a = self.A m, k = a.shape m_tiles = (m + 7) // 8 k_tiles = (k + 3) // 4 k_rem = k % 4 + n_per_warp = 8 * nn + n_per_cta = warps_per_cta * n_per_warp + + if n_per_cta > self.n: + return None + + if (vector_width == 2 + and (self.aligne is None or self.aligne % 2 + or self.n % n_per_warp)): + return None # A in DMMA-fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] # i.e. an (m_tiles, k_tiles) grid of row-major 8x4 tiles, packed as @@ -287,9 +292,6 @@ def _dense_common(self, nn, warps_per_cta, block_steal): tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).swapaxes(1, 2) a_u64 = [f'0x{u:016x}' for u in tiles.view(np.uint64).ravel()] - n_per_warp = 8 * nn - n_per_cta = warps_per_cta * n_per_warp - # Predicate-elision flags n_col_aligned = (self.n is not None and self.n % n_per_warp == 0) def pm_runtime(mt): @@ -316,102 +318,75 @@ def pm_runtime(mt): 'block_stealing': block_steal, } - def _dense_steal_ws_args(self, kernel_cfg, params, smem_info, args, meta): + def _dense_ws_args(self, kernel_cfg, params, smem_info, args, meta): dynamic_max = smem_info[1] + tpl = kernel_cfg['template'] nn = params['nn'] warp_map = kernel_cfg['warp_map'] - n_comp_warps = warp_map['compute_count'] - n_per_cta = 8 * nn * n_comp_warps - if n_per_cta > self.n: - return None - - setup = self._dense_common(nn, n_comp_warps, True) - - # Warp Specialism Setup - b_tile_bytes = setup['k_pad'] * n_per_cta * 8 - c_tile_bytes = setup['m_pad'] * n_per_cta * 8 - a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 - regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), - ('c', c_tile_bytes), ('a', a_bytes), ('wid', 16)] - mbars = ('tma', 'bready', 'cready', 'cstored', - 'steal', 'wid_new', 'wid_used') - offsets, dynm_total_bytes = self._dsmem_alloc(regions, mbars) - ws_setup = { - 'n_comp_warps': n_comp_warps, - 'blockx_total': 32 * (n_comp_warps + 2), - 'prod_warp': warp_map['producer'], - 'steal_warp': warp_map['stealer'], - 'comp_threads': 32 * n_comp_warps, - 'b_tile_bytes': b_tile_bytes, - 'c_mtile_smem_stride': 8 * n_per_cta * 8, - 'c_ntile_smem_stride': 8 * 8, - 'dynm_total_bytes': dynm_total_bytes, - } - - if ws_setup['dynm_total_bytes'] > dynamic_max: - return None - - args |= setup | ws_setup | offsets - meta |= { - 'grid': (-(-self.n // n_per_cta), 1, 1), - 'ws_b_tile': (n_per_cta, setup['k_pad']), - 'dynamic_shared': ws_setup['dynm_total_bytes'], - } - if self.beta != 0: - meta['ws_out_tile'] = (n_per_cta, setup['m_pad']) - return kernel_cfg['template'], args, meta + match tpl: + case 'dmma-steal-ws': + block_steal, service_warps = True, 2 + case 'dmma-stride-ws': + block_steal, service_warps = False, 1 + case _: + raise ValueError(f'Unknown dense warp-specialized template ' + f'{tpl}') - def _dense_stride_ws_args(self, kernel_cfg, params, smem_info, args, meta): - dynamic_max = smem_info[1] - nn = params['nn'] - stride_iters = params['iters'] - warp_map = kernel_cfg['warp_map'] n_comp_warps = warp_map['compute_count'] - n_per_cta = 8 * nn * n_comp_warps - if n_per_cta > self.n: + setup = self._dense_common(nn, n_comp_warps, block_steal) + if setup is None: return None - setup = self._dense_common(nn, n_comp_warps, False) - - # Warp Specialism Setup + n_per_cta = setup['n_per_cta'] b_tile_bytes = setup['k_pad'] * n_per_cta * 8 c_tile_bytes = setup['m_pad'] * n_per_cta * 8 a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 - regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), ('c', c_tile_bytes), ('a', a_bytes)] - mbars = ('tma', 'bready', 'bconsumed', 'cready', 'cstored') - offsets, dynm_total_bytes = self._dsmem_alloc(regions, mbars) - - work_blocks = -(-self.n // n_per_cta) - grid_stride = -(-work_blocks // stride_iters) ws_setup = { 'n_comp_warps': n_comp_warps, - 'blockx_total': 32 * (n_comp_warps + 1), + 'blockx_total': 32 * (n_comp_warps + service_warps), 'prod_warp': warp_map['producer'], 'comp_threads': 32 * n_comp_warps, 'b_tile_bytes': b_tile_bytes, 'c_mtile_smem_stride': 8 * n_per_cta * 8, 'c_ntile_smem_stride': 8 * 8, - 'stride_iters': stride_iters, - 'grid_stride': grid_stride, - 'work_blocks': work_blocks, - 'dynm_total_bytes': dynm_total_bytes, } - if ws_setup['dynm_total_bytes'] > dynamic_max: + match tpl: + case 'dmma-steal-ws': + regions.append(('wid', 16)) + mbars = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + grid = (-(-self.n // n_per_cta), 1, 1) + ws_setup['steal_warp'] = warp_map['stealer'] + case 'dmma-stride-ws': + stride_iters = params['iters'] + work_blocks = -(-self.n // n_per_cta) + grid_stride = -(-work_blocks // stride_iters) + mbars = ('tma', 'bready', 'bconsumed', 'cready', 'cstored') + grid = (grid_stride, 1, 1) + ws_setup |= { + 'stride_iters': stride_iters, + 'grid_stride': grid_stride, + 'work_blocks': work_blocks, + } + + offsets, dynm_total_bytes = self._dsmem_alloc(regions, mbars) + if dynm_total_bytes > dynamic_max: return None args |= setup | ws_setup | offsets meta |= { - 'grid': (grid_stride, 1, 1), + 'grid': grid, 'ws_b_tile': (n_per_cta, setup['k_pad']), - 'dynamic_shared': ws_setup['dynm_total_bytes'], + 'dynamic_shared': dynm_total_bytes, } if self.beta != 0: meta['ws_out_tile'] = (n_per_cta, setup['m_pad']) - return kernel_cfg['template'], args, meta + + return tpl, args, meta def _dense_astream_msplit_args(self, kernel_cfg, params, smem_info, args, meta): @@ -420,34 +395,14 @@ def _dense_astream_msplit_args(self, kernel_cfg, params, smem_info, args, warps = params.get('warps') msplit = params.get('msplit') vector_width = kernel_cfg.get('vector_width') - block = tuple(kernel_cfg['block']) - width = kernel_cfg['width'] - - for name, val in (('nn', nn), ('warps', warps), ('msplit', msplit)): - if not isinstance(val, int) or val <= 0: - raise ValueError(f'dmma-astream-msplit params.{name} must be ' - 'a positive integer') - - if block != (32 * warps * msplit, 1, 1) or width != 1: - raise ValueError('dmma-astream-msplit block/width mismatch') - - n_per_cta = 8 * nn * warps - if n_per_cta > self.n: + setup = self._dense_common(nn, warps, False, vector_width) + if setup is None: return None - if vector_width not in {1, 2}: - raise ValueError('dmma-astream-msplit vector_width must be 1 or 2') - - if (vector_width == 2 - and (self.aligne is None or self.aligne % 2 - or self.n % (8 * nn))): - return None - - setup = self._dense_common(nn, warps, False) + n_per_cta = setup['n_per_cta'] b_tile_bytes = setup['k_pad'] * n_per_cta * args['dwidth_i'] regions = [('b', b_tile_bytes)] - offsets, dynm_total_bytes = self._dsmem_alloc(regions, ('tma',)) msplit_setup = { 'msplit': msplit, 'm_tiles_per_group': -(-setup['m_tiles'] // msplit), @@ -455,17 +410,17 @@ def _dense_astream_msplit_args(self, kernel_cfg, params, smem_info, args, 'b_smem_kiter_stride': 4 * n_per_cta * args['dwidth_i'], 'b_smem_ntile_stride': 8 * args['dwidth_i'], 'blockx_total': 32 * warps * msplit, - 'dynm_total_bytes': dynm_total_bytes, } - if msplit_setup['dynm_total_bytes'] > dynamic_max: + offsets, dynm_total_bytes = self._dsmem_alloc(regions, ('tma',)) + if dynm_total_bytes > dynamic_max: return None args |= setup | msplit_setup | offsets meta |= { 'grid': (-(-self.n // n_per_cta), 1, 1), 'ws_b_tile': (n_per_cta, setup['k_pad']), - 'dynamic_shared': msplit_setup['dynm_total_bytes'], + 'dynamic_shared': dynm_total_bytes, } tpl = f"{kernel_cfg['template']}-v{vector_width}" From 66e3796128b78051cd1e05a3b65a3521a61d117f Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Wed, 3 Jun 2026 07:34:21 -0700 Subject: [PATCH 17/21] Updated Blackwell profile --- gimmik/kernels/ptx/config/sm100.json | 244 +++++++++++++++++---------- 1 file changed, 156 insertions(+), 88 deletions(-) diff --git a/gimmik/kernels/ptx/config/sm100.json b/gimmik/kernels/ptx/config/sm100.json index 04eb4d7..3f97a6f 100644 --- a/gimmik/kernels/ptx/config/sm100.json +++ b/gimmik/kernels/ptx/config/sm100.json @@ -10,14 +10,30 @@ ], "kernels": [ { - "template": "cstream", + "template": "cstream-ksplit", "block": [ - 128, - 1, + 32, + 2, 1 ], "width": 1, - "descriptor": "cstream" + "params": { + "csz": 20 + }, + "descriptor": "cstream-ksplit/k2-c20-x32" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m4-b32-x32" }, { "template": "bstream", @@ -27,10 +43,10 @@ 1 ], "width": 1, - "descriptor": "bstream" + "descriptor": "bstream/x128" }, { - "template": "bstream-msplit", + "template": "cstream-ksplit", "block": [ 32, 4, @@ -38,136 +54,180 @@ ], "width": 1, "params": { - "bsz": 24 + "csz": 20 }, - "descriptor": "bstream-msplit/m4-b24-x32" + "descriptor": "cstream-ksplit/k4-c20-x32" }, { "template": "bstream-msplit", "block": [ - 64, - 1, + 32, + 8, 1 ], "width": 1, "params": { "bsz": 32 }, - "conditions": { - "all": [ - { - "field": "beta_zero", - "eq": true - }, - { - "field": "m", - "lte": 320 - }, - { - "field": "k_used", - "gte": 64 - } - ] - }, - "descriptor": "bstream-msplit/m1-b32-x64" + "descriptor": "bstream-msplit/m8-b32-x32" }, { - "template": "cstream-ksplit", + "template": "bstream-msplit", "block": [ 32, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m1-b32-x32" + }, + { + "template": "bstream-msplit", + "block": [ + 64, 2, 1 ], "width": 1, "params": { - "csz": 24 + "bsz": 32 }, - "descriptor": "cstream-ksplit/k2-c24-x32" + "descriptor": "bstream-msplit/m2-b32-x64" }, { - "template": "cstream-ksplit", + "template": "cstream", "block": [ - 32, - 4, + 256, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream/x256" + }, + { + "template": "dmma-steal-ws", + "block": [ + 192, + 1, 1 ], "width": 1, "params": { - "csz": 20 + "nn": 2 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4, + "stealer": 5 }, "conditions": { - "field": "k_used", - "gt": 500 + "field": "n", + "is_not": null }, - "descriptor": "cstream-ksplit/k4-c20-x32" + "descriptor": "dmma-steal-ws/nn2-w4" }, { - "template": "cstream-w2", + "template": "dmma-steal-ws", "block": [ - 128, + 192, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 1 + }, + "warp_map": { + "compute_start": 0, + "compute_count": 4, + "producer": 4, + "stealer": 5 + }, + "conditions": { + "field": "n", + "is_not": null + }, + "descriptor": "dmma-steal-ws/nn1-w4" + }, + { + "template": "dmma-astream", + "vector_width": 2, + "block": [ + 32, 1, 1 ], - "width": 2, + "width": 1, + "params": { + "nn": 4, + "warps": 1 + }, "conditions": { "all": [ - { - "field": "dtype", - "eq": "double" - }, { "field": "n", "is_not": null }, { - "field": "n", - "divisible_by": 2 - }, - { - "field": "k_used", - "lte": 100 + "field": "aligne", + "is_not": null }, { "field": "aligne", - "is_null_or_divisible_by": 2 + "divisible_by": 2 } ] }, - "descriptor": "cstream-w2/x128" + "descriptor": "dmma-astream/v2/nn4-w1" }, { "template": "dmma-asmem", - "vector_width": 1, + "vector_width": 2, "block": [ - 128, + 256, 1, 1 ], "width": 1, "params": { - "nn": 4, - "warps": 4, + "nn": 1, + "warps": 8, "block_stealing": true }, - "descriptor": "dmma-asmem/v1/nn4-w4-bs", "conditions": { - "field": "n", - "is_not": null - } + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn1-w8-bs" }, { - "template": "dmma-asmem", + "template": "dmma-astream", "vector_width": 2, "block": [ - 128, + 64, 1, 1 ], "width": 1, "params": { - "nn": 4, - "warps": 4, - "block_stealing": true + "nn": 2, + "warps": 2 }, "conditions": { "all": [ @@ -185,53 +245,61 @@ } ] }, - "descriptor": "dmma-asmem/v2/nn4-w4-bs" + "descriptor": "dmma-astream/v2/nn2-w2" }, { "template": "dmma-steal-ws", "block": [ - 192, + 128, 1, 1 ], "width": 1, "params": { - "nn": 1 + "nn": 2 }, "warp_map": { "compute_start": 0, - "compute_count": 4, - "producer": 4, - "stealer": 5 + "compute_count": 2, + "producer": 2, + "stealer": 3 }, - "descriptor": "dmma-steal-ws/nn1-w4", "conditions": { "field": "n", "is_not": null - } + }, + "descriptor": "dmma-steal-ws/nn2-w2" }, { - "template": "dmma-steal-ws", + "template": "dmma-asmem", + "vector_width": 2, "block": [ - 192, + 256, 1, 1 ], "width": 1, "params": { - "nn": 2 - }, - "warp_map": { - "compute_start": 0, - "compute_count": 4, - "producer": 4, - "stealer": 5 + "nn": 1, + "warps": 8 }, - "descriptor": "dmma-steal-ws/nn2-w4", "conditions": { - "field": "n", - "is_not": null - } + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-asmem/v2/nn1-w8" }, { "template": "dmma-steal-ws", @@ -250,11 +318,11 @@ "producer": 4, "stealer": 5 }, - "descriptor": "dmma-steal-ws/nn4-w4", "conditions": { "field": "n", "is_not": null - } + }, + "descriptor": "dmma-steal-ws/nn4-w4" } ] } From 010d13cc28d8e06d19c755e36c2a52454b899d7a Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Wed, 3 Jun 2026 09:43:30 -0700 Subject: [PATCH 18/21] updated sm100 config --- gimmik/kernels/ptx/config/sm100.json | 136 ++++++++++++++------------- 1 file changed, 73 insertions(+), 63 deletions(-) diff --git a/gimmik/kernels/ptx/config/sm100.json b/gimmik/kernels/ptx/config/sm100.json index 3f97a6f..51a1d4d 100644 --- a/gimmik/kernels/ptx/config/sm100.json +++ b/gimmik/kernels/ptx/config/sm100.json @@ -38,12 +38,12 @@ { "template": "bstream", "block": [ - 128, + 64, 1, 1 ], "width": 1, - "descriptor": "bstream/x128" + "descriptor": "bstream/x64" }, { "template": "cstream-ksplit", @@ -100,12 +100,12 @@ { "template": "cstream", "block": [ - 256, + 128, 1, 1 ], "width": 1, - "descriptor": "cstream/x256" + "descriptor": "cstream/x128" }, { "template": "dmma-steal-ws", @@ -130,6 +130,38 @@ }, "descriptor": "dmma-steal-ws/nn2-w4" }, + { + "template": "dmma-astream-msplit", + "vector_width": 2, + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "params": { + "nn": 2, + "warps": 2, + "msplit": 2 + }, + "conditions": { + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "dmma-astream-msplit/v2/nn2-w2-m2" + }, { "template": "dmma-steal-ws", "block": [ @@ -185,39 +217,7 @@ "descriptor": "dmma-astream/v2/nn4-w1" }, { - "template": "dmma-asmem", - "vector_width": 2, - "block": [ - 256, - 1, - 1 - ], - "width": 1, - "params": { - "nn": 1, - "warps": 8, - "block_stealing": true - }, - "conditions": { - "all": [ - { - "field": "n", - "is_not": null - }, - { - "field": "aligne", - "is_not": null - }, - { - "field": "aligne", - "divisible_by": 2 - } - ] - }, - "descriptor": "dmma-asmem/v2/nn1-w8-bs" - }, - { - "template": "dmma-astream", + "template": "dmma-astream-msplit", "vector_width": 2, "block": [ 64, @@ -226,8 +226,9 @@ ], "width": 1, "params": { - "nn": 2, - "warps": 2 + "nn": 4, + "warps": 1, + "msplit": 2 }, "conditions": { "all": [ @@ -245,43 +246,44 @@ } ] }, - "descriptor": "dmma-astream/v2/nn2-w2" + "descriptor": "dmma-astream-msplit/v2/nn4-w1-m2" }, { "template": "dmma-steal-ws", "block": [ - 128, + 192, 1, 1 ], "width": 1, "params": { - "nn": 2 + "nn": 4 }, "warp_map": { "compute_start": 0, - "compute_count": 2, - "producer": 2, - "stealer": 3 + "compute_count": 4, + "producer": 4, + "stealer": 5 }, "conditions": { "field": "n", "is_not": null }, - "descriptor": "dmma-steal-ws/nn2-w2" + "descriptor": "dmma-steal-ws/nn4-w4" }, { - "template": "dmma-asmem", + "template": "dmma-astream-msplit", "vector_width": 2, "block": [ - 256, + 96, 1, 1 ], "width": 1, "params": { - "nn": 1, - "warps": 8 + "nn": 4, + "warps": 1, + "msplit": 3 }, "conditions": { "all": [ @@ -299,30 +301,38 @@ } ] }, - "descriptor": "dmma-asmem/v2/nn1-w8" + "descriptor": "dmma-astream-msplit/v2/nn4-w1-m3" }, { - "template": "dmma-steal-ws", + "template": "dmma-astream", + "vector_width": 2, "block": [ - 192, + 256, 1, 1 ], "width": 1, "params": { - "nn": 4 - }, - "warp_map": { - "compute_start": 0, - "compute_count": 4, - "producer": 4, - "stealer": 5 + "nn": 2, + "warps": 8 }, "conditions": { - "field": "n", - "is_not": null + "all": [ + { + "field": "n", + "is_not": null + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] }, - "descriptor": "dmma-steal-ws/nn4-w4" + "descriptor": "dmma-astream/v2/nn2-w8" } ] } From 1e6e55d002bc7ceadc4f15fbb15a3c2b4b0328dd Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Mon, 15 Jun 2026 05:35:44 -0700 Subject: [PATCH 19/21] Cleanups and added default config. --- gimmik/kernels/ptx/config/default.json | 41 +- gimmik/kernels/ptx/config/sm100.json | 88 +++- gimmik/kernels/ptx/config/sm80.json | 7 + gimmik/kernels/ptx/config/sm90.json | 88 +++- gimmik/kernels/ptx/dmma-asmem-v1.mako | 114 +++-- gimmik/kernels/ptx/dmma-asmem-v2.mako | 118 ++--- .../kernels/ptx/dmma-astream-msplit-v1.mako | 103 +++-- .../kernels/ptx/dmma-astream-msplit-v2.mako | 102 ++-- gimmik/kernels/ptx/dmma-astream-v1.mako | 113 +++-- gimmik/kernels/ptx/dmma-astream-v2.mako | 112 +++-- gimmik/kernels/ptx/dmma-steal-ws.mako | 2 +- gimmik/kernels/ptx/dmma-stride-ws.mako | 114 ++--- gimmik/ptx.py | 437 +++++++++--------- setup.py | 6 +- 14 files changed, 865 insertions(+), 580 deletions(-) diff --git a/gimmik/kernels/ptx/config/default.json b/gimmik/kernels/ptx/config/default.json index 13cd31d..846a84f 100644 --- a/gimmik/kernels/ptx/config/default.json +++ b/gimmik/kernels/ptx/config/default.json @@ -1,6 +1,41 @@ { "schema": 1, - "cc": [0, 0], - "ptx": [0, 0], - "kernels": [] + "cc": [7, 0], + "ptx": [7, 0], + "kernels": [ + { + "template": "cstream", + "family": "sparse", + "block": [128, 1, 1], + "width": 1, + "descriptor": "cstream/x128" + }, + { + "template": "bstream", + "family": "sparse", + "block": [128, 1, 1], + "width": 1, + "descriptor": "bstream/x128" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [32, 4, 1], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m4-b24-x32" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [32, 2, 1], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k2-c24-x32" + } + ] } diff --git a/gimmik/kernels/ptx/config/sm100.json b/gimmik/kernels/ptx/config/sm100.json index 51a1d4d..d7dfe1b 100644 --- a/gimmik/kernels/ptx/config/sm100.json +++ b/gimmik/kernels/ptx/config/sm100.json @@ -20,7 +20,8 @@ "params": { "csz": 20 }, - "descriptor": "cstream-ksplit/k2-c20-x32" + "descriptor": "cstream-ksplit/k2-c20-x32", + "family": "sparse" }, { "template": "bstream-msplit", @@ -33,7 +34,8 @@ "params": { "bsz": 32 }, - "descriptor": "bstream-msplit/m4-b32-x32" + "descriptor": "bstream-msplit/m4-b32-x32", + "family": "sparse" }, { "template": "bstream", @@ -43,7 +45,8 @@ 1 ], "width": 1, - "descriptor": "bstream/x64" + "descriptor": "bstream/x64", + "family": "sparse" }, { "template": "cstream-ksplit", @@ -56,7 +59,8 @@ "params": { "csz": 20 }, - "descriptor": "cstream-ksplit/k4-c20-x32" + "descriptor": "cstream-ksplit/k4-c20-x32", + "family": "sparse" }, { "template": "bstream-msplit", @@ -69,7 +73,8 @@ "params": { "bsz": 32 }, - "descriptor": "bstream-msplit/m8-b32-x32" + "descriptor": "bstream-msplit/m8-b32-x32", + "family": "sparse" }, { "template": "bstream-msplit", @@ -82,7 +87,8 @@ "params": { "bsz": 32 }, - "descriptor": "bstream-msplit/m1-b32-x32" + "descriptor": "bstream-msplit/m1-b32-x32", + "family": "sparse" }, { "template": "bstream-msplit", @@ -95,7 +101,8 @@ "params": { "bsz": 32 }, - "descriptor": "bstream-msplit/m2-b32-x64" + "descriptor": "bstream-msplit/m2-b32-x64", + "family": "sparse" }, { "template": "cstream", @@ -105,7 +112,8 @@ 1 ], "width": 1, - "descriptor": "cstream/x128" + "descriptor": "cstream/x128", + "family": "sparse" }, { "template": "dmma-steal-ws", @@ -128,7 +136,13 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-steal-ws/nn2-w4" + "descriptor": "dmma-steal-ws/nn2-w4", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-astream-msplit", @@ -160,7 +174,13 @@ } ] }, - "descriptor": "dmma-astream-msplit/v2/nn2-w2-m2" + "descriptor": "dmma-astream-msplit/v2/nn2-w2-m2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-steal-ws", @@ -183,7 +203,13 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-steal-ws/nn1-w4" + "descriptor": "dmma-steal-ws/nn1-w4", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-astream", @@ -214,7 +240,13 @@ } ] }, - "descriptor": "dmma-astream/v2/nn4-w1" + "descriptor": "dmma-astream/v2/nn4-w1", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-astream-msplit", @@ -246,7 +278,13 @@ } ] }, - "descriptor": "dmma-astream-msplit/v2/nn4-w1-m2" + "descriptor": "dmma-astream-msplit/v2/nn4-w1-m2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-steal-ws", @@ -269,7 +307,13 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-steal-ws/nn4-w4" + "descriptor": "dmma-steal-ws/nn4-w4", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-astream-msplit", @@ -301,7 +345,13 @@ } ] }, - "descriptor": "dmma-astream-msplit/v2/nn4-w1-m3" + "descriptor": "dmma-astream-msplit/v2/nn4-w1-m3", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-astream", @@ -332,7 +382,13 @@ } ] }, - "descriptor": "dmma-astream/v2/nn2-w8" + "descriptor": "dmma-astream/v2/nn2-w8", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } } ] } diff --git a/gimmik/kernels/ptx/config/sm80.json b/gimmik/kernels/ptx/config/sm80.json index 8c607ef..08464d3 100644 --- a/gimmik/kernels/ptx/config/sm80.json +++ b/gimmik/kernels/ptx/config/sm80.json @@ -5,18 +5,21 @@ "kernels": [ { "template": "cstream", + "family": "sparse", "block": [128, 1, 1], "width": 1, "descriptor": "cstream" }, { "template": "bstream", + "family": "sparse", "block": [128, 1, 1], "width": 1, "descriptor": "bstream" }, { "template": "bstream-msplit", + "family": "sparse", "block": [32, 4, 1], "width": 1, "params": { @@ -26,6 +29,7 @@ }, { "template": "bstream-msplit", + "family": "sparse", "block": [64, 1, 1], "width": 1, "params": { @@ -42,6 +46,7 @@ }, { "template": "cstream-ksplit", + "family": "sparse", "block": [32, 2, 1], "width": 1, "params": { @@ -51,6 +56,7 @@ }, { "template": "cstream-ksplit", + "family": "sparse", "block": [32, 4, 1], "width": 1, "params": { @@ -64,6 +70,7 @@ }, { "template": "cstream-w2", + "family": "sparse", "block": [128, 1, 1], "width": 2, "conditions": { diff --git a/gimmik/kernels/ptx/config/sm90.json b/gimmik/kernels/ptx/config/sm90.json index ac1a0a6..2bfe5eb 100644 --- a/gimmik/kernels/ptx/config/sm90.json +++ b/gimmik/kernels/ptx/config/sm90.json @@ -20,7 +20,8 @@ "params": { "csz": 20 }, - "descriptor": "cstream-ksplit/k2-c20-x32" + "descriptor": "cstream-ksplit/k2-c20-x32", + "family": "sparse" }, { "template": "bstream-msplit", @@ -33,7 +34,8 @@ "params": { "bsz": 24 }, - "descriptor": "bstream-msplit/m8-b24-x32" + "descriptor": "bstream-msplit/m8-b24-x32", + "family": "sparse" }, { "template": "cstream-ksplit", @@ -46,7 +48,8 @@ "params": { "csz": 24 }, - "descriptor": "cstream-ksplit/k4-c24-x32" + "descriptor": "cstream-ksplit/k4-c24-x32", + "family": "sparse" }, { "template": "bstream-msplit", @@ -59,7 +62,8 @@ "params": { "bsz": 32 }, - "descriptor": "bstream-msplit/m2-b32-x64" + "descriptor": "bstream-msplit/m2-b32-x64", + "family": "sparse" }, { "template": "bstream", @@ -69,7 +73,8 @@ 1 ], "width": 1, - "descriptor": "bstream/x64" + "descriptor": "bstream/x64", + "family": "sparse" }, { "template": "bstream-msplit", @@ -82,7 +87,8 @@ "params": { "bsz": 32 }, - "descriptor": "bstream-msplit/m2-b32-x32" + "descriptor": "bstream-msplit/m2-b32-x32", + "family": "sparse" }, { "template": "bstream-msplit", @@ -95,7 +101,8 @@ "params": { "bsz": 24 }, - "descriptor": "bstream-msplit/m1-b24-x64" + "descriptor": "bstream-msplit/m1-b24-x64", + "family": "sparse" }, { "template": "bstream-msplit", @@ -108,7 +115,8 @@ "params": { "bsz": 32 }, - "descriptor": "bstream-msplit/m4-b32-x32" + "descriptor": "bstream-msplit/m4-b32-x32", + "family": "sparse" }, { "template": "dmma-asmem", @@ -139,7 +147,13 @@ } ] }, - "descriptor": "dmma-asmem/v2/nn1-w8" + "descriptor": "dmma-asmem/v2/nn1-w8", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-astream", @@ -170,7 +184,13 @@ } ] }, - "descriptor": "dmma-astream/v2/nn2-w2" + "descriptor": "dmma-astream/v2/nn2-w2", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-stride-ws", @@ -193,7 +213,13 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-stride-ws/nn1-w4-i8" + "descriptor": "dmma-stride-ws/nn1-w4-i8", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-astream", @@ -224,7 +250,13 @@ } ] }, - "descriptor": "dmma-astream/v2/nn4-w1" + "descriptor": "dmma-astream/v2/nn4-w1", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-asmem", @@ -255,7 +287,13 @@ } ] }, - "descriptor": "dmma-asmem/v2/nn2-w4" + "descriptor": "dmma-asmem/v2/nn2-w4", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-stride-ws", @@ -278,7 +316,13 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-stride-ws/nn2-w4-i2" + "descriptor": "dmma-stride-ws/nn2-w4-i2", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-astream", @@ -309,7 +353,13 @@ } ] }, - "descriptor": "dmma-astream/v2/nn1-w4" + "descriptor": "dmma-astream/v2/nn1-w4", + "family": "dense", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } }, { "template": "dmma-stride-ws", @@ -332,7 +382,13 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-stride-ws/nn2-w4-i8" + "descriptor": "dmma-stride-ws/nn2-w4-i8", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + } } ] } diff --git a/gimmik/kernels/ptx/dmma-asmem-v1.mako b/gimmik/kernels/ptx/dmma-asmem-v1.mako index e34ef28..b163a42 100644 --- a/gimmik/kernels/ptx/dmma-asmem-v1.mako +++ b/gimmik/kernels/ptx/dmma-asmem-v1.mako @@ -1,9 +1,7 @@ <%inherit file='base'/> <% -# Cooperative-copy params (gA-only) -blockx = 32 * warps_per_cta -a_elems = m_tiles*k_tiles*32 +blockx = a_copy_threads copy_v1_iters = (a_elems + blockx - 1) // blockx bs = bool(block_stealing) %> @@ -12,10 +10,10 @@ bs = bool(block_stealing) .shared .align 8 .b64 ${kname}_mbar; .shared .align 16 .b8 ${kname}_workid[16]; % endif -.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; -.shared .align 16 .b64 ${kname}_As[${32 * m_tiles * k_tiles}]; +.shared .align 16 .b64 ${kname}_As[${a_elems}]; .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) @@ -25,7 +23,7 @@ bs = bool(block_stealing) .reg .u32 warp_n_base; .reg .u64 as_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .${pftype} a_frag; + .reg .${pftype} a_frag_<${a_regs}>; % if bs: .reg .u32 ctaid; .reg .u32 mbar_a, work_a; @@ -36,8 +34,10 @@ bs = bool(block_stealing) % if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; % endif - .reg .${pftype} b_frag_${nt}; - .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% for mt in range(m_tiles): + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor % endfor ld.param.u64 b_ptr, [_b]; @@ -106,14 +106,16 @@ bs = bool(block_stealing) } % for mt in range(m_tiles): -% if pm_runtime(mt): - .reg .pred pm_${mt}; +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; { .reg .u32 crow; - add.u32 crow, r_div4, ${8 * mt}; - setp.lt.u32 pm_${mt}, crow, ${m}; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; } -% endif +% endif +% endfor % endfor % if bs: @@ -140,12 +142,12 @@ $L_LOOP: % endif % for nt in range(nn): - add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { .reg .u32 t; shl.b32 t, r_mod4, 1; - add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } @@ -177,79 +179,92 @@ $L_LOOP: % for nt in range(nn): % for mt in range(m_tiles): % if beta_zero: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor % else: +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; -% endif - ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} - ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} {c0}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}_{mg}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} {c1}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}_{mg}')} } +% endfor % endif % endfor % endfor % for ki in range(k_tiles): % for nt in range(nn): +% for kg in range(k_groups): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and loop.parent.last) + k_tail = (k_rem != 0 and loop.parent.parent.last) needs_zero = pvb is not None or k_tail - pbrow = 'pbrow' if k_tail else None + pbrow = f'pbrow_{kg}' if k_tail else None %> { .reg .u64 baddr; - add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; -% if needs_zero: - mov.${pftype} b_frag_${nt}, ${fzero}; -% endif -% if k_tail: - .reg .pred pbrow; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + kg * b_kgroup_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; { .reg .u32 brow; - add.u32 brow, r_mod4, ${4 * ki}; - setp.lt.u32 pbrow, brow, ${k}; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; } -% endif - ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}_{kg}')} } +% endfor % endfor % for mt in range(m_tiles): - ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for ai in range(a_regs): + ld.shared.${pftype} a_frag_${ai}, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor % for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} - {c0_${nt}_${mt}, c1_${nt}_${mt}}, - {a_frag}, - {b_frag_${nt}}, - {c0_${nt}_${mt}, c1_${nt}_${mt}}; + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; % endfor % endfor % endfor -% for nt in range(nn): -% for mt in range(m_tiles): +% for mt in range(m_tiles): +% for nt in range(nn): +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} - ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], {c0};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}_{mg}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], {c1};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}_{mg}')} } +% endfor % endfor % endfor @@ -273,7 +288,6 @@ $L_AFTER_WAIT: ld.shared::cta.b128 resp, [work_a]; clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; @!p_have bra $L_FIN; - // 1D grid: extract just x clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; } bra.uni $L_LOOP; diff --git a/gimmik/kernels/ptx/dmma-asmem-v2.mako b/gimmik/kernels/ptx/dmma-asmem-v2.mako index 98f491d..9ecfbef 100644 --- a/gimmik/kernels/ptx/dmma-asmem-v2.mako +++ b/gimmik/kernels/ptx/dmma-asmem-v2.mako @@ -1,9 +1,7 @@ <%inherit file='base'/> <% -# Cooperative-copy params (gA-only) -blockx = 32 * warps_per_cta -a_elems = m_tiles*k_tiles*32 +blockx = a_copy_threads a_pairs = a_elems // 2 a_pairs_tail = a_elems % 2 copy_v2_iters = (a_pairs + blockx - 1) // blockx @@ -14,10 +12,10 @@ bs = bool(block_stealing) .shared .align 8 .b64 ${kname}_mbar; .shared .align 16 .b8 ${kname}_workid[16]; % endif -.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; -.shared .align 16 .b64 ${kname}_As[${32 * m_tiles * k_tiles}]; +.shared .align 16 .b64 ${kname}_As[${a_elems}]; .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) @@ -27,7 +25,7 @@ bs = bool(block_stealing) .reg .u32 warp_n_base; .reg .u64 as_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .${pftype} a_frag; + .reg .${pftype} a_frag_<${a_regs}>; % if bs: .reg .u32 ctaid; .reg .u32 mbar_a, work_a; @@ -38,8 +36,10 @@ bs = bool(block_stealing) % if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; % endif - .reg .${pftype} b_frag_${nt}; - .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% for mt in range(m_tiles): + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor % endfor ld.param.u64 b_ptr, [_b]; @@ -96,14 +96,13 @@ bs = bool(block_stealing) } % endfor % if a_pairs_tail: - // Tail element (only when m_tiles*k_tiles*32 is odd) { .reg .pred plast; .reg .u64 gaddr, saddr; .reg .${pftype} v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(32 * m_tiles * k_tiles - 1) * dwidth_i}; - add.u64 saddr, a_smem_base, ${(32 * m_tiles * k_tiles - 1) * dwidth_i}; + add.u64 gaddr, a_glb_base, ${(a_elems - 1) * dwidth_i}; + add.u64 saddr, a_smem_base, ${(a_elems - 1) * dwidth_i}; @plast ld.weak.global.cg.${pftype} v, [gaddr]; @plast st.shared.${pftype} [saddr], v; } @@ -121,14 +120,16 @@ bs = bool(block_stealing) } % for mt in range(m_tiles): -% if pm_runtime(mt): - .reg .pred pm_${mt}; +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; { .reg .u32 crow; - add.u32 crow, r_div4, ${8 * mt}; - setp.lt.u32 pm_${mt}, crow, ${m}; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; } -% endif +% endif +% endfor % endfor % if bs: @@ -155,12 +156,12 @@ $L_LOOP: % endif % for nt in range(nn): - add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { .reg .u32 t; shl.b32 t, r_mod4, 1; - add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } @@ -192,73 +193,87 @@ $L_LOOP: % for nt in range(nn): % for mt in range(m_tiles): % if beta_zero: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor % else: +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None - needs_zero_init = pm is not None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; -% endif - ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{c0_{nt}_{mt}, c1_{nt}_{mt}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}')} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if pm is not None: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{{cpair}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}_{mg}')} } +% endfor % endif % endfor % endfor % for ki in range(k_tiles): % for nt in range(nn): +% for kg in range(k_groups): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and loop.parent.last) + k_tail = (k_rem != 0 and loop.parent.parent.last) needs_zero = pvb is not None or k_tail - pbrow = 'pbrow' if k_tail else None + pbrow = f'pbrow_{kg}' if k_tail else None %> { .reg .u64 baddr; - add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; -% if needs_zero: - mov.${pftype} b_frag_${nt}, ${fzero}; -% endif -% if k_tail: - .reg .pred pbrow; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + kg * b_kgroup_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; { .reg .u32 brow; - add.u32 brow, r_mod4, ${4 * ki}; - setp.lt.u32 pbrow, brow, ${k}; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; } -% endif - ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}_{kg}')} } +% endfor % endfor % for mt in range(m_tiles): - ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for ai in range(a_regs): + ld.shared.${pftype} a_frag_${ai}, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor % for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} - {c0_${nt}_${mt}, c1_${nt}_${mt}}, - {a_frag}, - {b_frag_${nt}}, - {c0_${nt}_${mt}, c1_${nt}_${mt}}; + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; % endfor % endfor % endfor -% for nt in range(nn): -% for mt in range(m_tiles): +% for mt in range(m_tiles): +% for nt in range(nn): +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{c0_{nt}_{mt}, c1_{nt}_{mt}}};', pm, pred_reg=f'p01s_{nt}_{mt}')} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{{cpair}}};', pm, pred_reg=f'p01s_{nt}_{mt}_{mg}')} } +% endfor % endfor % endfor @@ -282,7 +297,6 @@ $L_AFTER_WAIT: ld.shared::cta.b128 resp, [work_a]; clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; @!p_have bra $L_FIN; - // 1D grid: extract just x clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; } bra.uni $L_LOOP; diff --git a/gimmik/kernels/ptx/dmma-astream-msplit-v1.mako b/gimmik/kernels/ptx/dmma-astream-msplit-v1.mako index 1158d48..50e684e 100644 --- a/gimmik/kernels/ptx/dmma-astream-msplit-v1.mako +++ b/gimmik/kernels/ptx/dmma-astream-msplit-v1.mako @@ -1,6 +1,6 @@ <%inherit file='base'/> -.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; .extern .shared .align 128 .b8 ${kname}_dynm[]; @@ -85,12 +85,12 @@ $L_AFTER_B_TMA: @pwarp_exit bra $L_EXIT; % for nt in range(nn): - add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { .reg .u32 t; shl.b32 t, r_mod4, 1; - add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } @@ -101,7 +101,7 @@ $L_AFTER_B_TMA: % endif % endfor - // A thread base: &Ag[0] + lane*8 + // A thread base: &Ag[0] + lane*sizeof(f64) { .reg .u64 t64, a_glb_base, lane64; mov.u64 a_glb_base, ${kname}_Ag; @@ -140,102 +140,117 @@ $L_AFTER_B_TMA: @p_this_msplit bra $L_SKIP_MS_${wm}; } { - .reg .${pftype} a_frag; + .reg .${pftype} a_frag_<${a_regs}>; % for nt in range(nn): - .reg .${pftype} b_frag_${nt}; + .reg .${pftype} b_frag_${nt}_<${b_regs}>; % endfor % for nt in range(nn): % for mt in owned_mts: - .reg .${pftype} c0_${nt}_${mt}, c1_${nt}_${mt}; + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; % endfor % endfor % for mt in owned_mts: -% if pm_runtime(mt): - .reg .pred pm_${mt}; +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; { .reg .u32 crow; - add.u32 crow, r_div4, ${8 * mt}; - setp.lt.u32 pm_${mt}, crow, ${m}; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; } -% endif +% endif +% endfor % endfor % for nt in range(nn): % for mt in owned_mts: % if beta_zero: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor % else: +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; -% endif - ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{wm}_{nt}_{mt}', indent=' ' * 12)} - ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{wm}_{nt}_{mt}', indent=' ' * 12)} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} {c0}, [caddr];', pm, pvc0, pred_reg=f'p0_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} + ${pred_emit(f'ld.weak.global.cg.{pftype} {c1}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} } +% endfor % endif % endfor % endfor % for ki in range(k_tiles): % for nt in range(nn): +% for kg in range(k_groups): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and loop.parent.last) + k_tail = (k_rem != 0 and loop.parent.parent.last) needs_zero = pvb is not None or k_tail - pbrow = 'pbrow' if k_tail else None + pbrow = f'pbrow_{kg}' if k_tail else None %> { .reg .u32 baddr; - add.u32 baddr, b_thr_base, ${ki * b_smem_kiter_stride + nt * b_smem_ntile_stride}; -% if needs_zero: - mov.${pftype} b_frag_${nt}, ${fzero}; -% endif -% if k_tail: - .reg .pred pbrow; + add.u32 baddr, b_thr_base, ${ki * b_smem_kiter_stride + kg * b_smem_kgroup_stride + nt * b_smem_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; { .reg .u32 brow; - add.u32 brow, r_mod4, ${4 * ki}; - setp.lt.u32 pbrow, brow, ${k}; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; } -% endif - ${pred_emit(f'ld.shared.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{wm}_{ki}_{nt}', indent=' ' * 12)} +% endif + ${pred_emit(f'ld.shared.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{wm}_{ki}_{nt}_{kg}', indent=' ' * 12)} } +% endfor % endfor % for mt in owned_mts: - ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for ai in range(a_regs): + ld.weak.global.${pftype} a_frag_${ai}, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor % for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} - {c0_${nt}_${mt}, c1_${nt}_${mt}}, - {a_frag}, - {b_frag_${nt}}, - {c0_${nt}_${mt}, c1_${nt}_${mt}}; + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; % endfor % endfor % endfor % for mt in owned_mts: % for nt in range(nn): +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{wm}_{nt}_{mt}', indent=' ' * 12)} - ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{wm}_{nt}_{mt}', indent=' ' * 12)} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], {c0};', pm, pvc0, pred_reg=f'p0s_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], {c1};', pm, pvc1, pred_reg=f'p1s_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} } +% endfor % endfor % endfor } diff --git a/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako b/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako index 4b41f67..d22704f 100644 --- a/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako +++ b/gimmik/kernels/ptx/dmma-astream-msplit-v2.mako @@ -1,6 +1,6 @@ <%inherit file='base'/> -.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; .extern .shared .align 128 .b8 ${kname}_dynm[]; @@ -85,12 +85,12 @@ $L_AFTER_B_TMA: @pwarp_exit bra $L_EXIT; % for nt in range(nn): - add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { .reg .u32 t; shl.b32 t, r_mod4, 1; - add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } @@ -101,7 +101,7 @@ $L_AFTER_B_TMA: % endif % endfor - // A thread base: &Ag[0] + lane*8 + // A thread base: &Ag[0] + lane*sizeof(f64) { .reg .u64 t64, a_glb_base, lane64; mov.u64 a_glb_base, ${kname}_Ag; @@ -140,96 +140,112 @@ $L_AFTER_B_TMA: @p_this_msplit bra $L_SKIP_MS_${wm}; } { - .reg .${pftype} a_frag; + .reg .${pftype} a_frag_<${a_regs}>; % for nt in range(nn): - .reg .${pftype} b_frag_${nt}; + .reg .${pftype} b_frag_${nt}_<${b_regs}>; % endfor % for nt in range(nn): % for mt in owned_mts: - .reg .${pftype} c0_${nt}_${mt}, c1_${nt}_${mt}; + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; % endfor % endfor % for mt in owned_mts: -% if pm_runtime(mt): - .reg .pred pm_${mt}; +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; { .reg .u32 crow; - add.u32 crow, r_div4, ${8 * mt}; - setp.lt.u32 pm_${mt}, crow, ${m}; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; } -% endif +% endif +% endfor % endfor % for nt in range(nn): % for mt in owned_mts: % if beta_zero: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor % else: +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None - needs_zero_init = pm is not None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; -% endif - ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{c0_{nt}_{mt}, c1_{nt}_{mt}}}, [caddr];', pm, pred_reg=f'p01_{wm}_{nt}_{mt}', indent=' ' * 12)} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if pm is not None: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{{cpair}}}, [caddr];', pm, pred_reg=f'p01_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} } +% endfor % endif % endfor % endfor % for ki in range(k_tiles): % for nt in range(nn): +% for kg in range(k_groups): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and loop.parent.last) + k_tail = (k_rem != 0 and loop.parent.parent.last) needs_zero = pvb is not None or k_tail - pbrow = 'pbrow' if k_tail else None + pbrow = f'pbrow_{kg}' if k_tail else None %> { .reg .u32 baddr; - add.u32 baddr, b_thr_base, ${ki * b_smem_kiter_stride + nt * b_smem_ntile_stride}; -% if needs_zero: - mov.${pftype} b_frag_${nt}, ${fzero}; -% endif -% if k_tail: - .reg .pred pbrow; + add.u32 baddr, b_thr_base, ${ki * b_smem_kiter_stride + kg * b_smem_kgroup_stride + nt * b_smem_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; { .reg .u32 brow; - add.u32 brow, r_mod4, ${4 * ki}; - setp.lt.u32 pbrow, brow, ${k}; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; } -% endif - ${pred_emit(f'ld.shared.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{wm}_{ki}_{nt}', indent=' ' * 12)} +% endif + ${pred_emit(f'ld.shared.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{wm}_{ki}_{nt}_{kg}', indent=' ' * 12)} } +% endfor % endfor % for mt in owned_mts: - ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for ai in range(a_regs): + ld.weak.global.${pftype} a_frag_${ai}, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor % for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} - {c0_${nt}_${mt}, c1_${nt}_${mt}}, - {a_frag}, - {b_frag_${nt}}, - {c0_${nt}_${mt}, c1_${nt}_${mt}}; + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; % endfor % endfor % endfor % for mt in owned_mts: % for nt in range(nn): +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{c0_{nt}_{mt}, c1_{nt}_{mt}}};', pm, pred_reg=f'p01s_{wm}_{nt}_{mt}', indent=' ' * 12)} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{{cpair}}};', pm, pred_reg=f'p01s_{wm}_{nt}_{mt}_{mg}', indent=' ' * 12)} } +% endfor % endfor % endfor } diff --git a/gimmik/kernels/ptx/dmma-astream-v1.mako b/gimmik/kernels/ptx/dmma-astream-v1.mako index 455496a..4ab05a9 100644 --- a/gimmik/kernels/ptx/dmma-astream-v1.mako +++ b/gimmik/kernels/ptx/dmma-astream-v1.mako @@ -1,6 +1,14 @@ <%inherit file='base'/> +/* + dmma-astream-v1 -.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { + Dense FP64 kernel using configurable warp-level DMMA tiles. The tiles of A + are precomputed and put in global memory within this compilation unit. Then + tiles of B and A are streamed from global into registers. This kernel uses + scalar loads/stores for C. + */ + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; @@ -12,14 +20,16 @@ .reg .u32 warp_n_base; .reg .u64 ag_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .${pftype} a_frag; + .reg .${pftype} a_frag_<${a_regs}>; % for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; % if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; % endif - .reg .${pftype} b_frag_${nt}; - .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% for mt in range(m_tiles): + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor % endfor ld.param.u64 b_ptr, [_b]; @@ -44,12 +54,12 @@ @pwarp_exit bra $L_EXIT; % for nt in range(nn): - add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { .reg .u32 t; shl.b32 t, r_mod4, 1; - add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } @@ -60,7 +70,7 @@ % endif % endfor - // A thread base: &Ag[0] + lane*8 + // A thread base: &Ag[0] + lane*sizeof(f64) { .reg .u64 t64, a_glb_base, lane64; mov.u64 a_glb_base, ${kname}_Ag; @@ -89,92 +99,107 @@ } % for mt in range(m_tiles): -% if pm_runtime(mt): - .reg .pred pm_${mt}; +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; { .reg .u32 crow; - add.u32 crow, r_div4, ${8 * mt}; - setp.lt.u32 pm_${mt}, crow, ${m}; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; } -% endif +% endif +% endfor % endfor % for nt in range(nn): % for mt in range(m_tiles): % if beta_zero: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor % else: +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; -% endif - ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} - ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if needs_zero_init: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} {c0}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}_{mg}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} {c1}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}_{mg}')} } +% endfor % endif % endfor % endfor % for ki in range(k_tiles): % for nt in range(nn): +% for kg in range(k_groups): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and loop.parent.last) + k_tail = (k_rem != 0 and loop.parent.parent.last) needs_zero = pvb is not None or k_tail - pbrow = 'pbrow' if k_tail else None + pbrow = f'pbrow_{kg}' if k_tail else None %> { .reg .u64 baddr; - add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; -% if needs_zero: - mov.${pftype} b_frag_${nt}, ${fzero}; -% endif -% if k_tail: - .reg .pred pbrow; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + kg * b_kgroup_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; { .reg .u32 brow; - add.u32 brow, r_mod4, ${4 * ki}; - setp.lt.u32 pbrow, brow, ${k}; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; } -% endif - ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}_{kg}')} } +% endfor % endfor % for mt in range(m_tiles): - ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for ai in range(a_regs): + ld.weak.global.${pftype} a_frag_${ai}, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor % for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} - {c0_${nt}_${mt}, c1_${nt}_${mt}}, - {a_frag}, - {b_frag_${nt}}, - {c0_${nt}_${mt}, c1_${nt}_${mt}}; + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; % endfor % endfor % endfor % for mt in range(m_tiles): % for nt in range(nn): +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} - ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.{pftype} [caddr], {c0};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}_{mg}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], {c1};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}_{mg}')} } +% endfor % endfor % endfor diff --git a/gimmik/kernels/ptx/dmma-astream-v2.mako b/gimmik/kernels/ptx/dmma-astream-v2.mako index 4700fd0..395640e 100644 --- a/gimmik/kernels/ptx/dmma-astream-v2.mako +++ b/gimmik/kernels/ptx/dmma-astream-v2.mako @@ -1,6 +1,14 @@ <%inherit file='base'/> +/* + dmma-astream-v2 -.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { + Dense FP64 kernel using configurable warp-level DMMA tiles. The tiles of A + are precomputed and put in global memory within this compilation unit. Then + tiles of B and A are streamed from global into registers. This kernel uses + 128-bit loads/stores for C. + */ + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; @@ -12,14 +20,16 @@ .reg .u32 warp_n_base; .reg .u64 ag_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .${pftype} a_frag; + .reg .${pftype} a_frag_<${a_regs}>; % for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; % if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; % endif - .reg .${pftype} b_frag_${nt}; - .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% for mt in range(m_tiles): + .reg .${pftype} c_${nt}_${mt}_<${c_regs}>; +% endfor % endfor ld.param.u64 b_ptr, [_b]; @@ -44,12 +54,12 @@ @pwarp_exit bra $L_EXIT; % for nt in range(nn): - add.u32 b_col_${nt}, warp_n_base, ${8 * nt}; + add.u32 b_col_${nt}, warp_n_base, ${tile_n * nt}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { .reg .u32 t; shl.b32 t, r_mod4, 1; - add.u32 c_col0_${nt}, warp_n_base, ${8 * nt}; + add.u32 c_col0_${nt}, warp_n_base, ${tile_n * nt}; add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } @@ -60,7 +70,7 @@ % endif % endfor - // A thread base: &Ag[0] + lane*8 + // A thread base: &Ag[0] + lane*sizeof(f64) { .reg .u64 t64, a_glb_base, lane64; mov.u64 a_glb_base, ${kname}_Ag; @@ -89,86 +99,102 @@ } % for mt in range(m_tiles): -% if pm_runtime(mt): - .reg .pred pm_${mt}; +% for mg in range(m_groups): +% if pm_runtime(mt, mg): + .reg .pred pm_${mt}_${mg}; { .reg .u32 crow; - add.u32 crow, r_div4, ${8 * mt}; - setp.lt.u32 pm_${mt}, crow, ${m}; + add.u32 crow, r_div4, ${tile_m * mt + 8 * mg}; + setp.lt.u32 pm_${mt}_${mg}, crow, ${m}; } -% endif +% endif +% endfor % endfor % for nt in range(nn): % for mt in range(m_tiles): % if beta_zero: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% for ci in range(c_regs): + mov.${pftype} c_${nt}_${mt}_${ci}, ${fzero}; +% endfor % else: +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None - needs_zero_init = pm is not None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.${pftype} c0_${nt}_${mt}, ${fzero}; - mov.${pftype} c1_${nt}_${mt}, ${fzero}; -% endif - ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{c0_{nt}_{mt}, c1_{nt}_{mt}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}')} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; +% if pm is not None: + mov.${pftype} ${c0}, ${fzero}; + mov.${pftype} ${c1}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.v2.{pftype} {{{cpair}}}, [caddr];', pm, pred_reg=f'p01_{nt}_{mt}_{mg}')} } +% endfor % endif % endfor % endfor % for ki in range(k_tiles): % for nt in range(nn): +% for kg in range(k_groups): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and loop.parent.last) + k_tail = (k_rem != 0 and loop.parent.parent.last) needs_zero = pvb is not None or k_tail - pbrow = 'pbrow' if k_tail else None + pbrow = f'pbrow_{kg}' if k_tail else None %> { .reg .u64 baddr; - add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; -% if needs_zero: - mov.${pftype} b_frag_${nt}, ${fzero}; -% endif -% if k_tail: - .reg .pred pbrow; + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + kg * b_kgroup_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; +% endif +% if k_tail: + .reg .pred ${pbrow}; { .reg .u32 brow; - add.u32 brow, r_mod4, ${4 * ki}; - setp.lt.u32 pbrow, brow, ${k}; + add.u32 brow, r_mod4, ${tile_k * ki + 4 * kg}; + setp.lt.u32 ${pbrow}, brow, ${k}; } -% endif - ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}_{kg}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}_{kg}')} } +% endfor % endfor % for mt in range(m_tiles): - ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; +% for ai in range(a_regs): + ld.weak.global.${pftype} a_frag_${ai}, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor % for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} - {c0_${nt}_${mt}, c1_${nt}_${mt}}, - {a_frag}, - {b_frag_${nt}}, - {c0_${nt}_${mt}, c1_${nt}_${mt}}; + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'c_{nt}_{mt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'c_{nt}_{mt}', c_regs)}; % endfor % endfor % endfor % for mt in range(m_tiles): % for nt in range(nn): +% for mg in range(m_groups): <% - pm = f'pm_{mt}' if pm_runtime(mt) else None + pm = f'pm_{mt}_{mg}' if pm_runtime(mt, mg) else None + c0 = f'c_{nt}_{mt}_{2*mg}' + c1 = f'c_{nt}_{mt}_{2*mg + 1}' + cpair = f'{c0}, {c1}' %> { .reg .u64 caddr; - add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{c0_{nt}_{mt}, c1_{nt}_{mt}}};', pm, pred_reg=f'p01s_{nt}_{mt}')} + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + mg * c_mgroup_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.weak.global.v2.{pftype} [caddr], {{{cpair}}};', pm, pred_reg=f'p01s_{nt}_{mt}_{mg}')} } +% endfor % endfor % endfor diff --git a/gimmik/kernels/ptx/dmma-steal-ws.mako b/gimmik/kernels/ptx/dmma-steal-ws.mako index 21ad486..0e97cd5 100644 --- a/gimmik/kernels/ptx/dmma-steal-ws.mako +++ b/gimmik/kernels/ptx/dmma-steal-ws.mako @@ -141,7 +141,7 @@ $L_WAIT_BRDY: } % for mt in range(m_tiles): <% - row_tail = (m_pad > m) and ((mt + 1) * 8 > m) + row_tail = pm_runtime(mt) %> % if row_tail: .reg .pred p_row_${mt}; diff --git a/gimmik/kernels/ptx/dmma-stride-ws.mako b/gimmik/kernels/ptx/dmma-stride-ws.mako index 0a34b44..87b87d6 100644 --- a/gimmik/kernels/ptx/dmma-stride-ws.mako +++ b/gimmik/kernels/ptx/dmma-stride-ws.mako @@ -10,11 +10,11 @@ mov.u64 a_glb, ${kname}_Ag; cvta.to.global.u64 a_glb, a_glb; @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes - [a_smem], [a_glb], ${8 * 32 * m_tiles * k_tiles}, [tma_mbar]; + [a_smem], [a_glb], ${a_elems * dwidth_i}, [tma_mbar]; @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 - [tma_mbar], ${b_tile_bytes + 8 * 32 * m_tiles * k_tiles}; + [tma_mbar], ${b_tile_bytes + a_elems * dwidth_i}; bar.warp.sync 0xffffffff; .reg .b64 state; .reg .pred p1; @@ -57,7 +57,7 @@ $L_WAIT_BRDY: .reg .b32 b_thr_a_${nt}; { .reg .b32 bcol_g, t_off; - add.u32 bcol_g, base_bcol, ${8 * nt}; + add.u32 bcol_g, base_bcol, ${tile_n * nt}; shl.b32 t_off, bcol_g, 3; add.u32 b_thr_a_${nt}, b_sm_a, t_off; } @@ -83,49 +83,51 @@ $L_WAIT_BRDY: // Zero accumulators % for mt in range(m_tiles): % for nt in range(nn): - .reg .${pftype} d_x_${mt}_${nt}, d_y_${mt}_${nt}; - mov.${pftype} d_x_${mt}_${nt}, ${fzero}; - mov.${pftype} d_y_${mt}_${nt}, ${fzero}; + .reg .${pftype} d_${mt}_${nt}_<${c_regs}>; +% for ci in range(c_regs): + mov.${pftype} d_${mt}_${nt}_${ci}, ${fzero}; +% endfor % endfor % endfor - .reg .${pftype} a_f; -% for mt in range(m_tiles): -% for kt in range(k_tiles): + .reg .${pftype} a_frag_<${a_regs}>; +% for nt in range(nn): + .reg .${pftype} b_frag_${nt}_<${b_regs}>; +% endfor +% for kt in range(k_tiles): +% for nt in range(nn): +% for kg in range(k_groups): <% - k_tail = (k_rem != 0 and loop.last) + k_tail = (k_rem != 0 and loop.parent.parent.last) + pbrow = f'pbrow_{nt}_{kg}' if k_tail else None %> { - .reg .b32 a_a; - add.u32 a_a, a_thr_a, ${(32 * kt + 32 * mt * k_tiles) * dwidth_i}; - ld.shared.${pftype} a_f, [a_a]; -% if k_tail: - .reg .pred pbrow_${mt}_${kt}; - { - .reg .b32 brow; - add.u32 brow, base_brow, ${4 * kt}; - setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; - } -% endif -% for nt in range(nn): - { - .reg .b32 b_a, b_row; - .reg .${pftype} b_f; - add.u32 b_row, base_brow, ${4 * kt}; - mul.lo.u32 b_row, b_row, ${n_per_cta * dwidth_i}; - add.u32 b_a, b_thr_a_${nt}, b_row; + .reg .b32 b_a, b_row, b_off; + add.u32 b_row, base_brow, ${tile_k * kt + 4 * kg}; + mul.lo.u32 b_off, b_row, ${n_per_cta * dwidth_i}; + add.u32 b_a, b_thr_a_${nt}, b_off; % if k_tail: - mov.${pftype} b_f, ${fzero}; - @pbrow_${mt}_${kt} ld.shared.${pftype} b_f, [b_a]; + .reg .pred ${pbrow}; + setp.lt.u32 ${pbrow}, b_row, ${k}; + mov.${pftype} b_frag_${nt}_${kg}, ${fzero}; + @${pbrow} ld.shared.${pftype} b_frag_${nt}_${kg}, [b_a]; % else: - ld.shared.${pftype} b_f, [b_a]; + ld.shared.${pftype} b_frag_${nt}_${kg}, [b_a]; % endif - mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} - {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, - {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; - } -% endfor } +% endfor +% endfor +% for mt in range(m_tiles): +% for ai in range(a_regs): + ld.shared.${pftype} a_frag_${ai}, [a_thr_a + ${(mt * k_tiles + kt) * frag_stride_bytes + 32 * ai * dwidth_i}]; +% endfor +% for nt in range(nn): + mma.sync.aligned.${ptx_mma_shape}.row.col.${pftype}.${pftype}.${pftype}.${pftype} + ${reg_list(f'd_{mt}_{nt}', c_regs)}, + ${reg_list('a_frag', a_regs)}, + ${reg_list(f'b_frag_{nt}', b_regs)}, + ${reg_list(f'd_{mt}_{nt}', c_regs)}; +% endfor % endfor % endfor @@ -140,31 +142,33 @@ $L_WAIT_BRDY: add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; } % for mt in range(m_tiles): +% for mg in range(m_groups): <% - row_tail = (m_pad > m) and ((mt + 1) * 8 > m) + row_tail = pm_runtime(mt, mg) %> -% if row_tail: - .reg .pred p_row_${mt}; +% if row_tail: + .reg .pred p_row_${mt}_${mg}; { .reg .b32 crow; - add.u32 crow, base_crow, ${8 * mt}; - setp.lt.u32 p_row_${mt}, crow, ${m}; + add.u32 crow, base_crow, ${tile_m * mt + 8 * mg}; + setp.lt.u32 p_row_${mt}_${mg}, crow, ${m}; } -% endif -% for nt in range(nn): +% endif +% for nt in range(nn): { .reg .pred p_st; .reg .u32 g_ccol; - add.u32 g_ccol, base_ccol, ${8 * nt}; + add.u32 g_ccol, base_ccol, ${tile_n * nt}; add.u32 g_ccol, g_ccol, n_start_curr; setp.lt.u32 p_st, g_ccol, ${n}; -% if row_tail: - and.pred p_st, p_st, p_row_${mt}; -% endif +% if row_tail: + and.pred p_st, p_st, p_row_${mt}_${mg}; +% endif .reg .u64 _c_addr; - add.u64 _c_addr, c_thr_glob_base, ${(8 * mt * ldc + 8 * nt) * dwidth_i}; - @p_st st.weak.global.v2.${pftype} [_c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + add.u64 _c_addr, c_thr_glob_base, ${((tile_m * mt + 8 * mg) * ldc + tile_n * nt) * dwidth_i}; + @p_st st.weak.global.v2.${pftype} [_c_addr], {d_${mt}_${nt}_${2*mg}, d_${mt}_${nt}_${2*mg + 1}}; } +% endfor % endfor % endfor % else: @@ -176,15 +180,17 @@ $L_WAIT_CSTORE: @!p1 bra.uni $L_WAIT_CSTORE; } - // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows + // Vector-store accumulator pairs to csmem. M-tail / N-tail OOB rows // are dropped by the C tensor map. % for mt in range(m_tiles): -% for nt in range(nn): +% for mg in range(m_groups): +% for nt in range(nn): { .reg .b32 csaddr; - add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; - st.shared.v2.${pftype} [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + mg * c_mgroup_smem_stride + nt * c_ntile_smem_stride}; + st.shared.v2.${pftype} [csaddr], {d_${mt}_${nt}_${2*mg}, d_${mt}_${nt}_${2*mg + 1}}; } +% endfor % endfor % endfor % endif @@ -277,7 +283,7 @@ $L_SKIP_NEXT_B_READY: $L_AFTER_DATA: -.global .align 16 .b64 ${kname}_Ag[${32 * m_tiles * k_tiles}] = { +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; .extern .shared .align 128 .b8 ${kname}_dynm[]; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index 75f9697..9e480cb 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -15,22 +15,13 @@ class PTXMatMul(MatMul): 'dynamic_shared': 0 } - # Map Supported CC -> Minimum PTX version + # Map explicitly supported CC to minimum PTX version PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 6), (10, 0): (8, 7), (10, 3): (8, 7), (12, 0): (8, 7), (12, 1): (8, 7)} - PTX_TEMPLATE_FAMILY = { - 'cstream': 'sparse', - 'bstream': 'sparse', - 'bstream-msplit': 'sparse', - 'cstream-ksplit': 'sparse', - 'cstream-w2': 'sparse', - 'dmma-astream': 'dense', - 'dmma-astream-msplit': 'dense', - 'dmma-asmem': 'dense', - 'dmma-steal-ws': 'dense', - 'dmma-stride-ws': 'dense', - } + DEFAULT_CFG = 'kernels/ptx/config/default.json' + FZERO = {'float': '0f00000000', 'double': '0d0000000000000000'} + PFTYPE = {'float': 'f32', 'double': 'f64'} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -38,10 +29,11 @@ def __init__(self, *args, **kwargs): @classmethod def is_sparse_suitable(cls, arr, cc): + cc = cc or (0, 0) nnz = np.count_nonzero(arr) nuq = len(np.unique(np.abs(arr))) density = nnz / arr.size - return ((nuq <= 28) or (density <= 0.15)) and cc in cls.PTX_SM + return ((nuq <= 28) or (density <= 0.15)) and cc >= (7, 0) @classmethod def is_dense_suitable(cls, arr, cc): @@ -58,129 +50,47 @@ def _kernel_generators(self, dtype, dsize, *, compute_capability=None, cc = compute_capability or (0, 0) smem_info = smem_info or (48*1024, 48*1024) config = self._cc_config(cc) + if cc in self.PTX_SM: + target_cc = cc + ptx = self.PTX_SM[cc] + else: + target_cc = tuple(config['cc']) + ptx = tuple(config['ptx']) - for kernel_cfg in config['kernels']: - if not self._usable_config(kernel_cfg, dtype, cc, smem_info): - continue + cfgs = config['kernels'] + cfg = [k for k in cfgs if self._usable_config(k, dtype, cc, smem_info)] - prepared = self._get_render_args( - kernel_cfg, dtype, dsize, cc, smem_info, tuple(config['ptx']) - ) - if prepared is not None: + for k in cfg: + if prepared := self._get_render_args( + k, dtype, dsize, target_cc, smem_info, ptx + ): yield prepared - def render_config(self, kernel_cfg, dtype, dsize, *, kname='gimmik_mm', - compute_capability=None, smem_info=None, config=None): - cc = compute_capability or (0, 0) + def render_config(self, kernel_cfg, dtype, dsize, cc, ptx, *, + kname='gimmik_mm', smem_info=None): smem_info = smem_info or (48*1024, 48*1024) - config = config or self._cc_config(cc) if not self._usable_config(kernel_cfg, dtype, cc, smem_info): return None prepared = self._get_render_args( - kernel_cfg, dtype, dsize, cc, smem_info, tuple(config['ptx']) + kernel_cfg, dtype, dsize, cc, smem_info, ptx ) if prepared is None: return None - tplname, exargs, exmeta = prepared + + tpl, exargs, exmeta = prepared args = self._base_template_args(dtype, kname) | exargs meta = self.basemeta | exmeta - meta['tplname'] = tplname + meta['tplname'] = tpl self._process_meta(meta) - src = self._render_kernel(dtype, tplname, args) + src = self._render_kernel(dtype, tpl, args) return src, args, meta - def _cc_config(self, cc): - cc = cc or (0, 0) - if cc not in self._config_cache: - cfgname = f'sm{cc[0]}{cc[1]}.json' - paths = [f'kernels/ptx/config/{cfgname}', - 'kernels/ptx/config/default.json'] - - cfg = None - for path in paths: - try: - cfgdir = pkgutil.get_data('gimmik', path) - cfg = json.loads(cfgdir.decode('utf-8')) - break - except FileNotFoundError: - continue - except json.JSONDecodeError as e: - raise ValueError(f'{path}: invalid JSON: {exc}') from e - - if cfg is None: - raise ValueError('PTX default kernel config is missing') - self._config_cache[cc] = cfg - return self._config_cache[cc] - - def _matmul_stats(self, dtype, cc, smem_info): - nnz = int(np.count_nonzero(self.A)) - return { - 'dtype': dtype, - 'm': self.m, - 'k': self.k, - 'n': self.n, - 'beta': self.beta, - 'beta_zero': self.beta == 0, - 'aligne': self.aligne, - 'nnz': nnz, - 'density': nnz / self.A.size, - 'unique_abs': int(len(np.unique(np.abs(self.A)))), - 'k_used': len(self.bix), - 'cc': list(cc), - 'smem_static': smem_info[0], - 'smem_dynamic': smem_info[1], - } - - def _eval_condition(self, condition, stats): - if 'all' in condition: - return all(self._eval_condition(c, stats) for c in condition['all']) - if 'any' in condition: - return any(self._eval_condition(c, stats) for c in condition['any']) - if 'not' in condition: - return not self._eval_condition(condition['not'], stats) - - value = stats[condition['field']] - op = next(k for k in condition if k != 'field') - expected = condition[op] - - return { - 'eq': lambda: value == expected, - 'ne': lambda: value != expected, - 'lt': lambda: value is not None and value < expected, - 'lte': lambda: value is not None and value <= expected, - 'gt': lambda: value is not None and value > expected, - 'gte': lambda: value is not None and value >= expected, - 'in': lambda: value in expected, - 'is_null': lambda: value is None, - 'is_not': lambda: value is not None, - 'divisible_by': lambda: value is not None and value % expected == 0, - 'is_null_or_divisible_by': lambda: (value is None - or value % expected == 0), - }[op]() - - def _usable_config(self, kernel_cfg, dtype, cc, smem_info): - tpl = kernel_cfg['template'] - family = self.PTX_TEMPLATE_FAMILY[tpl] - - if family == 'sparse' and not self.is_sparse_suitable(self.A, cc): - return False - elif (family == 'dense' - and (self.n is None or not self.is_dense_suitable(self.A, cc))): - return False - - condition = kernel_cfg.get('conditions') - if condition is None: - return True - else: - stats = self._matmul_stats(dtype, cc, smem_info) - return self._eval_condition(condition, stats) - - def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, - ptx): + def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, ptx): tpl = kernel_cfg['template'] + family = kernel_cfg['family'] block = tuple(kernel_cfg['block']) width = kernel_cfg['width'] params = kernel_cfg.get('params', {}) @@ -189,14 +99,14 @@ def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, 'cc': cc, 'smem_info': smem_info, 'pred_emit': self._pred_emit, - 'pftype': 'f32' if dtype == 'float' else 'f64', + 'pftype': self.PFTYPE[dtype], 'dwidth_i': dsize, - 'fzero': ('0f00000000' if dtype == 'float' - else '0d0000000000000000'), + 'fzero': self.FZERO[dtype], 'beta_zero': self.beta == 0, - 'mbar_maxwait': '0x989680', + 'mbar_maxwait': hex(10000000), 'use_cpasync': cc >= (8, 0), 'width': width, + 'reg_list': self._reg_list, } base_meta = { 'block': block, @@ -204,27 +114,22 @@ def _get_render_args(self, kernel_cfg, dtype, dsize, cc, smem_info, 'desc': kernel_cfg['descriptor'], } - if self.PTX_TEMPLATE_FAMILY[tpl] == 'sparse': - cfg = self._sparse_args(tpl, params, block, dtype, dsize, - base_args, base_meta) - elif self.PTX_TEMPLATE_FAMILY[tpl] == 'dense': - if tpl in {'dmma-steal-ws', 'dmma-stride-ws'}: - cfg = self._dense_ws_args(kernel_cfg, params, smem_info, + match family: + case 'sparse': + cfg = self._sparse_args(tpl, params, block, dtype, dsize, + base_args, base_meta) + case 'dense': + cfg = self._dense_args(kernel_cfg, params, cc, smem_info, + base_args, base_meta) + case 'dense-ws': + cfg = self._dense_ws_args(kernel_cfg, params, cc, smem_info, base_args, base_meta) - elif tpl == 'dmma-astream-msplit': - cfg = self._dense_astream_msplit_args(kernel_cfg, params, - smem_info, base_args, - base_meta) - else: - cfg = self._dense_args(kernel_cfg, params, base_args, - base_meta) - else: - raise ValueError(f'Unknown PTX template family for {tpl}') + case _: + raise ValueError(f'Unknown PTX template family for {tpl}') return cfg - def _sparse_args(self, tpl, params, block, dtype, dsize, args, - meta): + def _sparse_args(self, tpl, params, block, dtype, dsize, args, meta): blockx = block[0] args |= {'has_zero_rows': bool(self.has_zero_rows), 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) @@ -248,32 +153,69 @@ def _sparse_args(self, tpl, params, block, dtype, dsize, args, args['blockx'] = blockx return tpl, args, meta - def _dense_args(self, kernel_cfg, params, args, meta): + def _dense_args(self, kernel_cfg, params, cc, smem_info, args, meta): + base_tpl = kernel_cfg['template'] nn = params['nn'] warps = params['warps'] + tile = kernel_cfg['tile'] vector_width = kernel_cfg['vector_width'] - setup = self._dense_common( - nn, warps, bool(params.get('block_stealing', False)), - vector_width - ) + setup = self._dense_common(nn, warps, tile, cc, vector_width) if setup is None: return None - tpl = f"{kernel_cfg['template']}-v{vector_width}" + tpl = f'{base_tpl}-v{vector_width}' args |= setup + if base_tpl == 'dmma-asmem': + args |= { + 'a_copy_threads': 32 * warps, + 'block_stealing': bool(params.get('block_stealing', False)), + } meta['grid'] = (-(-self.n // setup['n_per_cta']), 1, 1) + if (msplit := params.get('msplit')) is None: + return tpl, args, meta + + n_per_cta = setup['n_per_cta'] + k_pad = setup['k_tiles'] * setup['tile_k'] + b_tile_bytes = k_pad * n_per_cta * args['dwidth_i'] + + offsets, dynm_total_bytes = self._dsmem_alloc( + [('b', b_tile_bytes)], ('tma',) + ) + if dynm_total_bytes > smem_info[1]: + return None + + args |= { + 'msplit': msplit, + 'b_tile_bytes': b_tile_bytes, + 'b_smem_kiter_stride': (setup['tile_k'] * n_per_cta + * args['dwidth_i']), + 'b_smem_kgroup_stride': 4 * n_per_cta * args['dwidth_i'], + 'b_smem_ntile_stride': setup['tile_n'] * args['dwidth_i'], + 'blockx_total': 32 * warps * msplit, + } | offsets + meta |= { + 'ws_b_tile': (n_per_cta, k_pad), + 'dynamic_shared': dynm_total_bytes, + } + return tpl, args, meta - def _dense_common(self, nn, warps_per_cta, block_steal, - vector_width=None): + def _dense_common(self, nn, warps_per_cta, tile, cc, vector_width=None): + tile_m, tile_n, tile_k = tile['m'], tile['n'], tile['k'] + ptx_shape = f'm{tile_m}n{tile_n}k{tile_k}' + + m_groups, k_groups = tile_m // 8, tile_k // 4 + a_regs = m_groups * k_groups + b_regs = k_groups + c_regs = 2 * m_groups + a = self.A m, k = a.shape - m_tiles = (m + 7) // 8 - k_tiles = (k + 3) // 4 - k_rem = k % 4 - n_per_warp = 8 * nn + m_tiles, k_tiles = -(-m // tile_m), -(-k // tile_k) + k_rem = k % tile_k + n_per_warp = tile_n * nn n_per_cta = warps_per_cta * n_per_warp if n_per_cta > self.n: @@ -284,64 +226,78 @@ def _dense_common(self, nn, warps_per_cta, block_steal, or self.n % n_per_warp)): return None - # A in DMMA-fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] - # i.e. an (m_tiles, k_tiles) grid of row-major 8x4 tiles, packed as - # uint64 - a_pad = np.zeros((m_tiles*8, k_tiles*4)) + # A in DMMA-fragment layout, packed in PTX A-operand register order. + # This will handle 8x8x4 as well as additional sm90 sizes. + a_pad = np.zeros((m_tiles*tile_m, k_tiles*tile_k), dtype=a.dtype) a_pad[:m, :k] = a - tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).swapaxes(1, 2) - a_u64 = [f'0x{u:016x}' for u in tiles.view(np.uint64).ravel()] + tile_shape = m_tiles, m_groups, 8, k_tiles, k_groups, 4 + tile_order = 0, 3, 4, 1, 2, 5 + a_tiles = a_pad.reshape(*tile_shape).transpose(*tile_order).ravel() + a_u64 = [f'0x{u:016x}' for u in a_tiles.view(np.uint64)] # Predicate-elision flags n_col_aligned = (self.n is not None and self.n % n_per_warp == 0) - def pm_runtime(mt): - return (mt + 1) * 8 > m + def pm_runtime(mt, mg=0): + return mt*tile_m + 8*(mg + 1) > m return { - 'warps_per_cta': warps_per_cta, + 'tile_m': tile_m, + 'tile_n': tile_n, + 'tile_k': tile_k, + 'ptx_mma_shape': ptx_shape, + 'm_groups': m_groups, + 'k_groups': k_groups, + 'a_regs': a_regs, + 'b_regs': b_regs, + 'c_regs': c_regs, + 'a_elems': m_tiles * k_tiles * tile_m * tile_k, 'nn': nn, 'm_tiles': m_tiles, 'k_tiles': k_tiles, 'k_rem': k_rem, - 'm_pad': m_tiles * 8, - 'k_pad': k_tiles * 4, 'a_u64': a_u64, 'n_per_warp': n_per_warp, 'n_per_cta': n_per_cta, - 'frag_stride_bytes': 32 * 8, - 'b_kiter_stride': 4 * (self.ldb or 0) * 8, - 'b_ntile_stride': 8 * 8, - 'c_mtile_stride': 8 * (self.ldc or 0) * 8, - 'c_ntile_stride': 8 * 8, + 'frag_stride_bytes': 8 * tile_m * tile_k, + 'b_kiter_stride': 8 * tile_k * (self.ldb or 0), + 'b_kgroup_stride': 32 * (self.ldb or 0), + 'b_ntile_stride': 8 * tile_n, + 'c_mtile_stride': 8 * tile_m * (self.ldc or 0), + 'c_mgroup_stride': 64 * (self.ldc or 0), + 'c_ntile_stride': 8 * tile_n, 'n_col_aligned': n_col_aligned, 'pm_runtime': pm_runtime, - 'block_stealing': block_steal, } - def _dense_ws_args(self, kernel_cfg, params, smem_info, args, meta): + def _dense_ws_args(self, kernel_cfg, params, cc, smem_info, args, meta): dynamic_max = smem_info[1] tpl = kernel_cfg['template'] nn = params['nn'] + tile = kernel_cfg['tile'] warp_map = kernel_cfg['warp_map'] match tpl: case 'dmma-steal-ws': - block_steal, service_warps = True, 2 + if (tile['m'], tile['n'], tile['k']) != (8, 8, 4): + return None + service_warps = 2 case 'dmma-stride-ws': - block_steal, service_warps = False, 1 + service_warps = 1 case _: - raise ValueError(f'Unknown dense warp-specialized template ' + raise ValueError('Unknown dense warp-specialized template ' f'{tpl}') n_comp_warps = warp_map['compute_count'] - setup = self._dense_common(nn, n_comp_warps, block_steal) + setup = self._dense_common(nn, n_comp_warps, tile, cc) if setup is None: return None n_per_cta = setup['n_per_cta'] - b_tile_bytes = setup['k_pad'] * n_per_cta * 8 - c_tile_bytes = setup['m_pad'] * n_per_cta * 8 - a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 + m_pad = setup['m_tiles'] * setup['tile_m'] + k_pad = setup['k_tiles'] * setup['tile_k'] + b_tile_bytes = 8 * k_pad * n_per_cta + c_tile_bytes = 8 * m_pad * n_per_cta + a_bytes = 8 * setup['a_elems'] regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), ('c', c_tile_bytes), ('a', a_bytes)] ws_setup = { @@ -350,8 +306,9 @@ def _dense_ws_args(self, kernel_cfg, params, smem_info, args, meta): 'prod_warp': warp_map['producer'], 'comp_threads': 32 * n_comp_warps, 'b_tile_bytes': b_tile_bytes, - 'c_mtile_smem_stride': 8 * n_per_cta * 8, - 'c_ntile_smem_stride': 8 * 8, + 'c_mtile_smem_stride': 8 * setup['tile_m'] * n_per_cta, + 'c_mgroup_smem_stride': 64 * n_per_cta, + 'c_ntile_smem_stride': 8 * setup['tile_n'], } match tpl: @@ -380,54 +337,106 @@ def _dense_ws_args(self, kernel_cfg, params, smem_info, args, meta): args |= setup | ws_setup | offsets meta |= { 'grid': grid, - 'ws_b_tile': (n_per_cta, setup['k_pad']), + 'ws_b_tile': (n_per_cta, k_pad), 'dynamic_shared': dynm_total_bytes, } if self.beta != 0: - meta['ws_out_tile'] = (n_per_cta, setup['m_pad']) + meta['ws_out_tile'] = (n_per_cta, m_pad) return tpl, args, meta - def _dense_astream_msplit_args(self, kernel_cfg, params, smem_info, args, - meta): - dynamic_max = smem_info[1] - nn = params.get('nn') - warps = params.get('warps') - msplit = params.get('msplit') - vector_width = kernel_cfg.get('vector_width') - setup = self._dense_common(nn, warps, False, vector_width) - if setup is None: - return None + def _usable_config(self, kernel_cfg, dtype, cc, smem_info): + family = kernel_cfg['family'] - n_per_cta = setup['n_per_cta'] + if family == 'sparse' and not self.is_sparse_suitable(self.A, cc): + return False + elif (family in {'dense', 'dense-ws'} + and (dtype != 'double' or self.n is None + or not self.is_dense_suitable(self.A, cc))): + return False - b_tile_bytes = setup['k_pad'] * n_per_cta * args['dwidth_i'] - regions = [('b', b_tile_bytes)] - msplit_setup = { - 'msplit': msplit, - 'm_tiles_per_group': -(-setup['m_tiles'] // msplit), - 'b_tile_bytes': b_tile_bytes, - 'b_smem_kiter_stride': 4 * n_per_cta * args['dwidth_i'], - 'b_smem_ntile_stride': 8 * args['dwidth_i'], - 'blockx_total': 32 * warps * msplit, - } + condition = kernel_cfg.get('conditions') + if condition is None: + return True + else: + stats = self._matmul_stats(dtype, cc, smem_info) + return self._eval_condition(condition, stats) - offsets, dynm_total_bytes = self._dsmem_alloc(regions, ('tma',)) - if dynm_total_bytes > dynamic_max: - return None + def _cc_config(self, cc): + cc = cc or (0, 0) + if cc not in self._config_cache: + path = f'kernels/ptx/config/sm{cc[0]}{cc[1]}.json' - args |= setup | msplit_setup | offsets - meta |= { - 'grid': (-(-self.n // n_per_cta), 1, 1), - 'ws_b_tile': (n_per_cta, setup['k_pad']), - 'dynamic_shared': dynm_total_bytes, + try: + cfgdir = pkgutil.get_data('gimmik', path) + except FileNotFoundError: + cfgdir = pkgutil.get_data('gimmik', self.DEFAULT_CFG) + cfg = json.loads(cfgdir.decode('utf-8')) + + self._config_cache[cc] = cfg + return self._config_cache[cc] + + def _matmul_stats(self, dtype, cc, smem_info): + nnz = np.count_nonzero(self.A) + return { + 'dtype': dtype, + 'm': self.m, + 'k': self.k, + 'n': self.n, + 'beta': self.beta, + 'beta_zero': self.beta == 0, + 'aligne': self.aligne, + 'nnz': nnz, + 'density': nnz / self.A.size, + 'unique_abs': len(np.unique(np.abs(self.A))), + 'k_used': len(self.bix), + 'cc': list(cc), + 'smem_static': smem_info[0], + 'smem_dynamic': smem_info[1], } - tpl = f"{kernel_cfg['template']}-v{vector_width}" - return tpl, args, meta + def _eval_condition(self, condition, stats): + if 'all' in condition: + return all(self._eval_condition(c, stats) for c in condition['all']) + if 'any' in condition: + return any(self._eval_condition(c, stats) for c in condition['any']) + if 'not' in condition: + return not self._eval_condition(condition['not'], stats) + + value = stats[condition['field']] + op = next(k for k in condition if k != 'field') + expected = condition[op] + + match op: + case 'eq': + return value == expected + case 'ne': + return value != expected + case 'lt': + return value is not None and value < expected + case 'lte': + return value is not None and value <= expected + case 'gt': + return value is not None and value > expected + case 'gte': + return value is not None and value >= expected + case 'in': + return value in expected + case 'is_null': + return value is None + case 'is_not': + return value is not None + case 'divisible_by': + return value is not None and value % expected == 0 + case 'is_null_or_divisible_by': + return (value is None or value % expected == 0) + case _: + raise ValueError(f'op `{op}` not supported') @staticmethod def _dsmem_alloc(regions, mbars, align=16): + # For a set of regions and mbars and there sizes, work out dynamic + # shared memory pointers offset for a given alignemnt. out, off = {}, 0 for name, size in regions: off = (off + align - 1) & ~(align - 1) @@ -440,7 +449,13 @@ def _dsmem_alloc(regions, mbars, align=16): return out, total @staticmethod - def _pred_emit(instr, *preds, pred_reg=None, indent=' ' * 8): + def _reg_list(prefix, n): + regs = ', '.join(f'{prefix}_{i}' for i in range(n)) + return f'{{{regs}}}' + + @staticmethod + def _pred_emit(instr, *preds, pred_reg=None, indent=8 * ' '): + # Handle whether an instruction needs a predicate or not actual = [p for p in preds if p is not None] if not actual: return instr diff --git a/setup.py b/setup.py index 2d4545b..3e41c60 100755 --- a/setup.py +++ b/setup.py @@ -7,8 +7,8 @@ # Python version -if sys.version_info[:2] < (3, 9): - print('GiMMiK requires Python 3.9 or newer') +if sys.version_info[:2] < (3, 10): + print('GiMMiK requires Python 3.10 or newer') sys.exit(-1) # GiMMiK version @@ -34,7 +34,6 @@ # Info classifiers = [ 'License :: OSI Approved :: BSD License', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Topic :: Scientific/Engineering' @@ -56,6 +55,7 @@ setup(name='gimmik', version=version, + python_requires='>=3.10', # Packages packages=['gimmik'], From 1acb3b58662b691f9f5758553e4e691cd40c699e Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Wed, 24 Jun 2026 02:22:43 -0700 Subject: [PATCH 20/21] FP32 configs and kernels --- gimmik/kernels/ptx/bstream-msplit-v2.mako | 180 +++++++++++++ gimmik/kernels/ptx/config/sm100_fp32.json | 249 ++++++++++++++++++ .../config/{sm100.json => sm100_fp64.json} | 111 ++++++++ gimmik/kernels/ptx/config/sm80.json | 88 ------- gimmik/kernels/ptx/config/sm80_fp32.json | 222 ++++++++++++++++ gimmik/kernels/ptx/config/sm80_fp64.json | 222 ++++++++++++++++ gimmik/kernels/ptx/config/sm90_fp32.json | 233 ++++++++++++++++ .../ptx/config/{sm90.json => sm90_fp64.json} | 111 ++++++++ gimmik/kernels/ptx/cstream-ksplit-v2.mako | 127 +++++++++ gimmik/kernels/ptx/cstream-v2.mako | 83 ++++++ gimmik/kernels/ptx/cstream-w2.mako | 83 ------ gimmik/ptx.py | 52 +++- 12 files changed, 1580 insertions(+), 181 deletions(-) create mode 100644 gimmik/kernels/ptx/bstream-msplit-v2.mako create mode 100644 gimmik/kernels/ptx/config/sm100_fp32.json rename gimmik/kernels/ptx/config/{sm100.json => sm100_fp64.json} (76%) delete mode 100644 gimmik/kernels/ptx/config/sm80.json create mode 100644 gimmik/kernels/ptx/config/sm80_fp32.json create mode 100644 gimmik/kernels/ptx/config/sm80_fp64.json create mode 100644 gimmik/kernels/ptx/config/sm90_fp32.json rename gimmik/kernels/ptx/config/{sm90.json => sm90_fp64.json} (76%) create mode 100644 gimmik/kernels/ptx/cstream-ksplit-v2.mako create mode 100644 gimmik/kernels/ptx/cstream-v2.mako delete mode 100644 gimmik/kernels/ptx/cstream-w2.mako diff --git a/gimmik/kernels/ptx/bstream-msplit-v2.mako b/gimmik/kernels/ptx/bstream-msplit-v2.mako new file mode 100644 index 0000000..326fcce --- /dev/null +++ b/gimmik/kernels/ptx/bstream-msplit-v2.mako @@ -0,0 +1,180 @@ +<%inherit file='base'/> + +<% +mx = partition(A, into=msplit, by='rows') +bchunks = chunk(bix, bsz) +m_per_group = max(len(mcx) for mcx in mx) +bsub_bytes = 2 * bsz * blockx * 2 * dwidth_i +def bsub_off(buf, idx): + return (buf * bsz + idx) * blockx * 2 * dwidth_i +%> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, bsub_thread; +% if use_cpasync: + .reg .u32 bsub_sm_thread; +% endif + .reg .${pftype} bv_a, bv_b, csub_a<${m_per_group}>, csub_b<${m_per_group}>; + .reg .pred p1, p_skip; + .shared .align 16 .b8 _bsub[${bsub_bytes}]; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${2*dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${2*dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${2*dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; + } +% if use_cpasync: + { + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; + } +% endif + +% for cid, mcx in enumerate(mx): +## cid = ${cid}, rows ${mcx} + setp.ne.u32 p_skip, tid_y, ${cid}; + @p_skip bra $L_END_CID_${cid}; + +## Zero accumulators +% for j, row_j in enumerate(mcx): +% if afix[row_j] != -1: + mov.${pftype} csub_a${j}, ${fzero}; + mov.${pftype} csub_b${j}, ${fzero}; +% endif +% endfor + +## Pre-fill double buffer +% if use_cpasync: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${2*dwidth_i}; +% endfor + cp.async.commit_group; + cp.async.wait_all; + bar.sync 0; +% else: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: + { + .reg .${pftype} _bva, _bvb; + ld.weak.global.cg.v2.${pftype} {_bva, _bvb}, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.v2.${pftype} [bsub_thread + ${bsub_off(0, idx)}], {_bva, _bvb}; + } +% endfor + bar.sync 0; +% endif + +## Main loop over B-chunks (double-buffered) +% for bb in range(len(bchunks)): +<% + buf_cur = bb % 2 + buf_next = (bb + 1) % 2 +%> +% if not loop.last: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[bb + 1]) if i % msplit == cid]: +% if use_cpasync: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${2*dwidth_i}; +% else: + { + .reg .${pftype} _bva, _bvb; + ld.weak.global.cg.v2.${pftype} {_bva, _bvb}, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.v2.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], {_bva, _bvb}; + } +% endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif + +% for idx, kx in enumerate(bchunks[bb]): +% if any(A[row_j, kx] for row_j in mcx): + ld.shared.v2.${pftype} {bv_a, bv_b}, [bsub_thread + ${bsub_off(buf_cur, idx)}]; +% endif +% for j, row_j in enumerate(mcx): +% if A[row_j, kx] != 0: + fma.rn.${pftype} csub_a${j}, bv_a, ${A[row_j, kx]}, csub_a${j}; + fma.rn.${pftype} csub_b${j}, bv_b, ${A[row_j, kx]}, csub_b${j}; +% endif +% endfor +% for j, row_j in enumerate(mcx): +% if kx == alix[row_j]: +% if beta_zero: + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {csub_a${j}, csub_b${j}}; +% else: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ca, _ca, ${float(beta)}, csub_a${j}; + fma.rn.${pftype} _cb, _cb, ${float(beta)}, csub_b${j}; + st.weak.global.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +% endfor +% endfor +% if use_cpasync: +% if not loop.last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor + +## Handle zero rows in this cid's group +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta_zero: + { + .reg .${pftype} _z; + mov.${pftype} _z, ${fzero}; + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {_z, _z}; + } +% elif beta != 1: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _ca, _ca, ${float(beta)}; + mul.${pftype} _cb, _cb, ${float(beta)}; + st.weak.global.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +% endfor +% endif + +$L_END_CID_${cid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/config/sm100_fp32.json b/gimmik/kernels/ptx/config/sm100_fp32.json new file mode 100644 index 0000000..b388544 --- /dev/null +++ b/gimmik/kernels/ptx/config/sm100_fp32.json @@ -0,0 +1,249 @@ +{ + "schema": 1, + "cc": [ + 10, + 0 + ], + "ptx": [ + 8, + 7 + ], + "kernels": [ + { + "template": "bstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x128" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k2-c24-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 2, + "params": { + "csz": 20 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k4-c20-x32" + }, + { + "template": "cstream", + "family": "sparse", + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream/x256" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 16 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c16-x64" + }, + { + "template": "bstream-msplit-v2", + "family": "sparse", + "block": [ + 64, + 1, + 1 + ], + "width": 2, + "params": { + "bsz": 32 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/m1-b32-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 4, + 1 + ], + "width": 2, + "params": { + "csz": 24 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k4-c24-x64" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 32 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c32-x64" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm100.json b/gimmik/kernels/ptx/config/sm100_fp64.json similarity index 76% rename from gimmik/kernels/ptx/config/sm100.json rename to gimmik/kernels/ptx/config/sm100_fp64.json index d7dfe1b..8b74674 100644 --- a/gimmik/kernels/ptx/config/sm100.json +++ b/gimmik/kernels/ptx/config/sm100_fp64.json @@ -48,6 +48,82 @@ "descriptor": "bstream/x64", "family": "sparse" }, + { + "template": "bstream-msplit-v2", + "block": [ + 32, + 4, + 1 + ], + "width": 2, + "params": { + "bsz": 16 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/m4-b16-x32", + "family": "sparse" + }, + { + "template": "cstream-ksplit-v2", + "block": [ + 32, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c24-x32", + "family": "sparse" + }, { "template": "cstream-ksplit", "block": [ @@ -115,6 +191,41 @@ "descriptor": "cstream/x128", "family": "sparse" }, + { + "template": "cstream-v2", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "k_used", + "lte": 100 + }, + { + "field": "aligne", + "is_null_or_divisible_by": 2 + } + ] + }, + "descriptor": "cstream-v2/x128", + "family": "sparse" + }, { "template": "dmma-steal-ws", "block": [ diff --git a/gimmik/kernels/ptx/config/sm80.json b/gimmik/kernels/ptx/config/sm80.json deleted file mode 100644 index 08464d3..0000000 --- a/gimmik/kernels/ptx/config/sm80.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "schema": 1, - "cc": [8, 0], - "ptx": [7, 0], - "kernels": [ - { - "template": "cstream", - "family": "sparse", - "block": [128, 1, 1], - "width": 1, - "descriptor": "cstream" - }, - { - "template": "bstream", - "family": "sparse", - "block": [128, 1, 1], - "width": 1, - "descriptor": "bstream" - }, - { - "template": "bstream-msplit", - "family": "sparse", - "block": [32, 4, 1], - "width": 1, - "params": { - "bsz": 24 - }, - "descriptor": "bstream-msplit/m4-b24-x32" - }, - { - "template": "bstream-msplit", - "family": "sparse", - "block": [64, 1, 1], - "width": 1, - "params": { - "bsz": 32 - }, - "conditions": { - "all": [ - {"field": "beta_zero", "eq": true}, - {"field": "m", "lte": 320}, - {"field": "k_used", "gte": 64} - ] - }, - "descriptor": "bstream-msplit/m1-b32-x64" - }, - { - "template": "cstream-ksplit", - "family": "sparse", - "block": [32, 2, 1], - "width": 1, - "params": { - "csz": 24 - }, - "descriptor": "cstream-ksplit/k2-c24-x32" - }, - { - "template": "cstream-ksplit", - "family": "sparse", - "block": [32, 4, 1], - "width": 1, - "params": { - "csz": 20 - }, - "conditions": { - "field": "k_used", - "gt": 500 - }, - "descriptor": "cstream-ksplit/k4-c20-x32" - }, - { - "template": "cstream-w2", - "family": "sparse", - "block": [128, 1, 1], - "width": 2, - "conditions": { - "all": [ - {"field": "dtype", "eq": "double"}, - {"field": "n", "is_not": null}, - {"field": "n", "divisible_by": 2}, - {"field": "k_used", "lte": 100}, - {"field": "aligne", "is_null_or_divisible_by": 2} - ] - }, - "descriptor": "cstream-w2/x128" - } - ] -} diff --git a/gimmik/kernels/ptx/config/sm80_fp32.json b/gimmik/kernels/ptx/config/sm80_fp32.json new file mode 100644 index 0000000..4846d3f --- /dev/null +++ b/gimmik/kernels/ptx/config/sm80_fp32.json @@ -0,0 +1,222 @@ +{ + "schema": 1, + "cc": [ + 8, + 0 + ], + "ptx": [ + 7, + 0 + ], + "kernels": [ + { + "template": "cstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream" + }, + { + "template": "bstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m4-b24-x32" + }, + { + "template": "bstream-msplit-v2", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 2, + "params": { + "bsz": 16 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "float" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/m4-b16-x32" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "conditions": { + "all": [ + { + "field": "beta_zero", + "eq": true + }, + { + "field": "m", + "lte": 320 + }, + { + "field": "k_used", + "gte": 64 + } + ] + }, + "descriptor": "bstream-msplit/m1-b32-x64" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k2-c24-x32" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "float" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c24-x32" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "csz": 20 + }, + "conditions": { + "field": "k_used", + "gt": 500 + }, + "descriptor": "cstream-ksplit/k4-c20-x32" + }, + { + "template": "cstream-v2", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "float" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "k_used", + "lte": 100 + }, + { + "field": "aligne", + "is_null_or_divisible_by": 2 + } + ] + }, + "descriptor": "cstream-v2/x128" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm80_fp64.json b/gimmik/kernels/ptx/config/sm80_fp64.json new file mode 100644 index 0000000..83d9cd9 --- /dev/null +++ b/gimmik/kernels/ptx/config/sm80_fp64.json @@ -0,0 +1,222 @@ +{ + "schema": 1, + "cc": [ + 8, + 0 + ], + "ptx": [ + 7, + 0 + ], + "kernels": [ + { + "template": "cstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "cstream" + }, + { + "template": "bstream", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m4-b24-x32" + }, + { + "template": "bstream-msplit-v2", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 2, + "params": { + "bsz": 16 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/m4-b16-x32" + }, + { + "template": "bstream-msplit", + "family": "sparse", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "conditions": { + "all": [ + { + "field": "beta_zero", + "eq": true + }, + { + "field": "m", + "lte": 320 + }, + { + "field": "k_used", + "gte": 64 + } + ] + }, + "descriptor": "bstream-msplit/m1-b32-x64" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k2-c24-x32" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c24-x32" + }, + { + "template": "cstream-ksplit", + "family": "sparse", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "csz": 20 + }, + "conditions": { + "field": "k_used", + "gt": 500 + }, + "descriptor": "cstream-ksplit/k4-c20-x32" + }, + { + "template": "cstream-v2", + "family": "sparse", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "k_used", + "lte": 100 + }, + { + "field": "aligne", + "is_null_or_divisible_by": 2 + } + ] + }, + "descriptor": "cstream-v2/x128" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm90_fp32.json b/gimmik/kernels/ptx/config/sm90_fp32.json new file mode 100644 index 0000000..3417d2f --- /dev/null +++ b/gimmik/kernels/ptx/config/sm90_fp32.json @@ -0,0 +1,233 @@ +{ + "schema": 1, + "cc": [ + 9, + 0 + ], + "ptx": [ + 8, + 6 + ], + "kernels": [ + { + "template": "cstream-ksplit", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "csz": 20 + }, + "descriptor": "cstream-ksplit/k2-c20-x32", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 8, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m8-b24-x32", + "family": "sparse" + }, + { + "template": "cstream-ksplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "csz": 24 + }, + "descriptor": "cstream-ksplit/k4-c24-x32", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 64, + 2, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m2-b32-x64", + "family": "sparse" + }, + { + "template": "bstream", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x64", + "family": "sparse" + }, + { + "template": "bstream-msplit-v2", + "block": [ + 32, + 4, + 1 + ], + "width": 2, + "params": { + "bsz": 16 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "float" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/m4-b16-x32", + "family": "sparse" + }, + { + "template": "cstream-ksplit-v2", + "block": [ + 32, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "float" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c24-x32", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m2-b32-x32", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 64, + 1, + 1 + ], + "width": 1, + "params": { + "bsz": 24 + }, + "descriptor": "bstream-msplit/m1-b24-x64", + "family": "sparse" + }, + { + "template": "bstream-msplit", + "block": [ + 32, + 4, + 1 + ], + "width": 1, + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m4-b32-x32", + "family": "sparse" + }, + { + "template": "cstream-v2", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "float" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "k_used", + "lte": 100 + }, + { + "field": "aligne", + "is_null_or_divisible_by": 2 + } + ] + }, + "descriptor": "cstream-v2/x128", + "family": "sparse" + } + ] +} diff --git a/gimmik/kernels/ptx/config/sm90.json b/gimmik/kernels/ptx/config/sm90_fp64.json similarity index 76% rename from gimmik/kernels/ptx/config/sm90.json rename to gimmik/kernels/ptx/config/sm90_fp64.json index 2bfe5eb..a61a051 100644 --- a/gimmik/kernels/ptx/config/sm90.json +++ b/gimmik/kernels/ptx/config/sm90_fp64.json @@ -76,6 +76,82 @@ "descriptor": "bstream/x64", "family": "sparse" }, + { + "template": "bstream-msplit-v2", + "block": [ + 32, + 4, + 1 + ], + "width": 2, + "params": { + "bsz": 16 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "bstream-msplit-v2/m4-b16-x32", + "family": "sparse" + }, + { + "template": "cstream-ksplit-v2", + "block": [ + 32, + 2, + 1 + ], + "width": 2, + "params": { + "csz": 24 + }, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] + }, + "descriptor": "cstream-ksplit-v2/k2-c24-x32", + "family": "sparse" + }, { "template": "bstream-msplit", "block": [ @@ -118,6 +194,41 @@ "descriptor": "bstream-msplit/m4-b32-x32", "family": "sparse" }, + { + "template": "cstream-v2", + "block": [ + 128, + 1, + 1 + ], + "width": 2, + "conditions": { + "all": [ + { + "field": "dtype", + "eq": "double" + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "k_used", + "lte": 100 + }, + { + "field": "aligne", + "is_null_or_divisible_by": 2 + } + ] + }, + "descriptor": "cstream-v2/x128", + "family": "sparse" + }, { "template": "dmma-asmem", "vector_width": 2, diff --git a/gimmik/kernels/ptx/cstream-ksplit-v2.mako b/gimmik/kernels/ptx/cstream-ksplit-v2.mako new file mode 100644 index 0000000..0abe8ba --- /dev/null +++ b/gimmik/kernels/ptx/cstream-ksplit-v2.mako @@ -0,0 +1,127 @@ +<%inherit file='base'/> + +<% +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(list(range(m)), csz) +cv_per_thread = -(-csz // ksplit) +bv_per_thread = max(len(kbx) for kbx in kparts) +csub_bytes = (ksplit - 1) * csz * blockx * 2 * dwidth_i +%> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, csub_thread; + .reg .${pftype} bv_a<${bv_per_thread}>, bv_b<${bv_per_thread}>; + .reg .${pftype} cv_a<${cv_per_thread}>, cv_b<${cv_per_thread}>; + .reg .${pftype} dotp_a, dotp_b; + .reg .pred p1, p_skip; + .shared .align 16 .b8 _csub[${csub_bytes}]; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${2*dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${2*dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${2*dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; + } + +% for bid, kbx in enumerate(kparts): +## bid = ${bid}: ${len(kbx)} B columns, ksplit=${ksplit} + setp.ne.u32 p_skip, tid_y, ${bid}; + @p_skip bra $L_END_BID_${bid}; + +<% loaded = set() %> + +% for cchunk_i, cchunk in enumerate(cchunks): +## Chunk ${cchunk_i}: partial dot-product +% for row_idx, j in enumerate(cchunk): +<% owner_bid = row_idx % ksplit %> +% for kxi, kx in enumerate(kbx): +% if A[j, kx] != 0 and kx not in loaded: + ld.weak.global.cg.v2.${pftype} {bv_a${kxi}, bv_b${kxi}}, [b_base + ${ldb*kx*dwidth_i}]; +<% loaded.add(kx) %> +% endif +% endfor + mov.${pftype} dotp_a, ${fzero}; + mov.${pftype} dotp_b, ${fzero}; +% for kxi, kx in enumerate(kbx): +% if A[j, kx] != 0: + fma.rn.${pftype} dotp_a, bv_a${kxi}, ${A[j, kx]}, dotp_a; + fma.rn.${pftype} dotp_b, bv_b${kxi}, ${A[j, kx]}, dotp_b; +% endif +% endfor +% if owner_bid == bid: + mov.${pftype} cv_a${row_idx // ksplit}, dotp_a; + mov.${pftype} cv_b${row_idx // ksplit}, dotp_b; +% else: +<% csub_idx = bid - (1 if bid > owner_bid else 0) %> + st.shared.v2.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * 2 * dwidth_i}], {dotp_a, dotp_b}; +% endif +% endfor + bar.sync 0; + +## Combine phase (owned rows only) +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: + mov.${pftype} dotp_a, cv_a${row_idx // ksplit}; + mov.${pftype} dotp_b, cv_b${row_idx // ksplit}; +% for other_bid in range(ksplit): +% if other_bid != bid: +<% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> + { + .reg .${pftype} _ta, _tb; + ld.shared.v2.${pftype} {_ta, _tb}, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * 2 * dwidth_i}]; + add.${pftype} dotp_a, dotp_a, _ta; + add.${pftype} dotp_b, dotp_b, _tb; + } +% endif +% endfor +% if beta_zero: + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% else: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.${pftype} _cb, _cb, ${float(beta)}, dotp_b; + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif + +% endif +% endfor + bar.sync 0; +% endfor + +$L_END_BID_${bid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-v2.mako b/gimmik/kernels/ptx/cstream-v2.mako new file mode 100644 index 0000000..9949045 --- /dev/null +++ b/gimmik/kernels/ptx/cstream-v2.mako @@ -0,0 +1,83 @@ +<%inherit file='base'/> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} bv_a<${len(bix)}>, bv_b<${len(bix)}>, dotp_a, dotp_b; + .reg .pred p1; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x, _tid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 _tid_x, %tid.x; + mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${2*dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${2*dwidth_i}, c; + } + +## Batch-load B column pairs +% for i, kx in enumerate(bix): + ld.weak.global.cg.v2.${pftype} {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; +% endfor + +## Main compute: two parallel dot-product streams per thread +% for j in range(m): +% if row_nz[j]: + mov.${pftype} dotp_a, ${fzero}; + mov.${pftype} dotp_b, ${fzero}; +% for kx, jx in row_nz[j]: + fma.rn.${pftype} dotp_a, bv_a${bix[kx]}, ${jx}, dotp_a; + fma.rn.${pftype} dotp_b, bv_b${bix[kx]}, ${jx}, dotp_b; +% endfor +% if beta_zero: + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% else: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.${pftype} _cb, _cb, ${float(beta)}, dotp_b; + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif + +% else: +## Zero row of A +% if beta_zero: + { + .reg .${pftype} _z; + mov.${pftype} _z, ${fzero}; + st.weak.global.cg.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_z, _z}; + } +% elif beta != 1: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _ca, _ca, ${float(beta)}; + mul.${pftype} _cb, _cb, ${float(beta)}; + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako deleted file mode 100644 index fbb2a0d..0000000 --- a/gimmik/kernels/ptx/cstream-w2.mako +++ /dev/null @@ -1,83 +0,0 @@ -<%inherit file='base'/> - -.visible .entry ${kname}(.param .u64 _b, - .param .u64 _c) -{ - .reg .u32 n, id; - .reg .u64 b, c, b_base, c_base; - .reg .f64 bv_a<${len(bix)}>, bv_b<${len(bix)}>, dotp_a, dotp_b; - .reg .pred p1; - - mov.u32 n, ${-(-n // 2)}; - ld.param.u64 b, [_b]; - ld.param.u64 c, [_c]; - - { - .reg .u32 _ctaid_x, _tid_x; - mov.u32 _ctaid_x, %ctaid.x; - mov.u32 _tid_x, %tid.x; - mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; - } - - setp.ge.u32 p1, id, n; - @p1 bra $L_EXIT; - - cvta.to.global.u64 b, b; - cvta.to.global.u64 c, c; - - { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, 16, b; - mad.lo.u64 c_base, _id64, 16, c; - } - -## Batch-load B column pairs -% for i, kx in enumerate(bix): - ld.weak.global.cg.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; -% endfor - -## Main compute: two parallel dot-product streams per thread -% for j in range(m): -% if row_nz[j]: - mov.f64 dotp_a, ${fzero}; - mov.f64 dotp_b, ${fzero}; -% for kx, jx in row_nz[j]: - fma.rn.f64 dotp_a, bv_a${bix[kx]}, ${jx}, dotp_a; - fma.rn.f64 dotp_b, bv_b${bix[kx]}, ${jx}, dotp_b; -% endfor -% if beta_zero: - st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; -% else: - { - .reg .f64 _ca, _cb; - ld.weak.global.cg.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; - fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; - fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; - st.weak.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; - } -% endif - -% else: -## Zero row of A -% if beta_zero: - { - .reg .f64 _z; - mov.f64 _z, ${fzero}; - st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; - } -% elif beta != 1: - { - .reg .f64 _ca, _cb; - ld.weak.global.cg.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; - mul.f64 _ca, _ca, ${float(beta)}; - mul.f64 _cb, _cb, ${float(beta)}; - st.weak.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; - } -% endif -% endif -% endfor - -$L_EXIT: - ret; -} diff --git a/gimmik/ptx.py b/gimmik/ptx.py index 9e480cb..6e18283 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -49,7 +49,7 @@ def _kernel_generators(self, dtype, dsize, *, compute_capability=None, smem_info=None): cc = compute_capability or (0, 0) smem_info = smem_info or (48*1024, 48*1024) - config = self._cc_config(cc) + config = self._cc_config(cc, dtype) if cc in self.PTX_SM: target_cc = cc ptx = self.PTX_SM[cc] @@ -144,11 +144,21 @@ def _sparse_args(self, tpl, params, block, dtype, dsize, args, meta): bsz = params['bsz'] args |= {'msplit': msplit, 'bsz': bsz, 'blockx': blockx} meta['shared'] = 2*bsz*blockx*dsize + case 'bstream-msplit-v2': + msplit = block[1] + bsz = params['bsz'] + args |= {'msplit': msplit, 'bsz': bsz, 'blockx': blockx} + meta['shared'] = 2*bsz*blockx*2*dsize case 'cstream-ksplit': ksplit = block[1] csz = params['csz'] args |= {'ksplit': ksplit, 'csz': csz, 'blockx': blockx} meta['shared'] = (ksplit - 1)*csz*blockx*dsize + case 'cstream-ksplit-v2': + ksplit = block[1] + csz = params['csz'] + args |= {'ksplit': ksplit, 'csz': csz, 'blockx': blockx} + meta['shared'] = (ksplit - 1)*csz*blockx*2*dsize case _: args['blockx'] = blockx return tpl, args, meta @@ -362,19 +372,41 @@ def _usable_config(self, kernel_cfg, dtype, cc, smem_info): stats = self._matmul_stats(dtype, cc, smem_info) return self._eval_condition(condition, stats) - def _cc_config(self, cc): + @staticmethod + def _dtype_config_suffix(dtype): + if dtype is None: + raise ValueError('PTX config dtype is required') + + dtype_name = getattr(dtype, 'name', dtype) + if dtype_name in {'float', 'float32', 'single'}: + return 'fp32' + elif dtype_name in {'double', 'float64'}: + return 'fp64' + + raise ValueError(f'Unsupported PTX config dtype {dtype_name!r}') + + def _cc_config(self, cc, dtype): cc = cc or (0, 0) - if cc not in self._config_cache: - path = f'kernels/ptx/config/sm{cc[0]}{cc[1]}.json' + suffix = self._dtype_config_suffix(dtype) + key = (cc, suffix) + if key not in self._config_cache: + base = f'kernels/ptx/config/sm{cc[0]}{cc[1]}' + paths = [f'{base}_{suffix}.json', self.DEFAULT_CFG] + + cfgdir = None + for path in paths: + try: + cfgdir = pkgutil.get_data('gimmik', path) + except FileNotFoundError: + continue + + if cfgdir is not None: + break - try: - cfgdir = pkgutil.get_data('gimmik', path) - except FileNotFoundError: - cfgdir = pkgutil.get_data('gimmik', self.DEFAULT_CFG) cfg = json.loads(cfgdir.decode('utf-8')) - self._config_cache[cc] = cfg - return self._config_cache[cc] + self._config_cache[key] = cfg + return self._config_cache[key] def _matmul_stats(self, dtype, cc, smem_info): nnz = np.count_nonzero(self.A) From 4bf8c911df052d985de9516d505bc650d367046e Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Thu, 25 Jun 2026 04:04:15 -0700 Subject: [PATCH 21/21] Refactoring and cleanup --- gimmik/base.py | 54 ++++ gimmik/kernels/ptx/bstream-msplit-v2.mako | 13 + gimmik/kernels/ptx/bstream-msplit.mako | 28 ++ .../{sm100_fp64.json => sm100_double.json} | 295 +++++++----------- .../{sm100_fp32.json => sm100_float.json} | 94 +++--- .../{sm80_fp64.json => sm80_double.json} | 0 .../{sm80_fp32.json => sm80_float.json} | 149 ++++----- .../{sm90_fp64.json => sm90_double.json} | 25 +- .../{sm90_fp32.json => sm90_float.json} | 184 ++++++----- gimmik/kernels/ptx/cstream-ksplit-v2.mako | 20 ++ gimmik/kernels/ptx/cstream-ksplit.mako | 42 +++ .../{dmma-asmem-v1.mako => dmma-asmem.mako} | 0 ...split-v1.mako => dmma-astream-msplit.mako} | 0 ...dmma-astream-v1.mako => dmma-astream.mako} | 0 gimmik/ptx.py | 126 ++------ 15 files changed, 499 insertions(+), 531 deletions(-) rename gimmik/kernels/ptx/config/{sm100_fp64.json => sm100_double.json} (64%) rename gimmik/kernels/ptx/config/{sm100_fp32.json => sm100_float.json} (75%) rename gimmik/kernels/ptx/config/{sm80_fp64.json => sm80_double.json} (100%) rename gimmik/kernels/ptx/config/{sm80_fp32.json => sm80_float.json} (63%) rename gimmik/kernels/ptx/config/{sm90_fp64.json => sm90_double.json} (95%) rename gimmik/kernels/ptx/config/{sm90_fp32.json => sm90_float.json} (55%) rename gimmik/kernels/ptx/{dmma-asmem-v1.mako => dmma-asmem.mako} (100%) rename gimmik/kernels/ptx/{dmma-astream-msplit-v1.mako => dmma-astream-msplit.mako} (100%) rename gimmik/kernels/ptx/{dmma-astream-v1.mako => dmma-astream.mako} (100%) diff --git a/gimmik/base.py b/gimmik/base.py index 9db8ee9..f806b05 100644 --- a/gimmik/base.py +++ b/gimmik/base.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import itertools as it +import json import pkgutil import re @@ -90,6 +91,9 @@ def __init__(self, A, beta=0.0, aligne=None, n=None, ldb=None, ldc=None): self.bix = np.nonzero(np.any(A != 0, axis=0))[0] self.bix = {kx: k for k, kx in enumerate(self.bix)} + # Create config cache + self._config_cache = {} + def kernels(self, dtype, kname='gimmik_mm', **kwargs): basemeta = self.basemeta @@ -142,6 +146,56 @@ def _base_template_args(self, dtype, kname): def _process_meta(self, meta): pass + def _get_config(self, key): + if key not in self._config_cache: + cfgdir = f'kernels/{self.platform}/config' + path = f'{cfgdir}/{key}.json' + default_path = f'{cfgdir}/default.json' + try: + cfgdata = pkgutil.get_data('gimmik', path) + except FileNotFoundError: + cfgdata = pkgutil.get_data('gimmik', default_path) + self._config_cache[key] = json.loads(cfgdata.decode('utf-8')) + return self._config_cache[key] + + def _eval_condition(self, condition, stats): + if 'all' in condition: + return all(self._eval_condition(c, stats) for c in condition['all']) + if 'any' in condition: + return any(self._eval_condition(c, stats) for c in condition['any']) + if 'not' in condition: + return not self._eval_condition(condition['not'], stats) + + value = stats[condition['field']] + op = next(k for k in condition if k != 'field') + expected = condition[op] + + match op: + case 'eq': + return value == expected + case 'ne': + return value != expected + case 'lt': + return value is not None and value < expected + case 'lte': + return value is not None and value <= expected + case 'gt': + return value is not None and value > expected + case 'gte': + return value is not None and value >= expected + case 'in': + return value in expected + case 'is_null': + return value is None + case 'is_not': + return value is not None + case 'divisible_by': + return value is not None and value % expected == 0 + case 'is_null_or_divisible_by': + return (value is None or value % expected == 0) + case _: + raise ValueError(f'op `{op}` not supported') + def _render_kernel(self, dtype, tplname, tplargs): tpl = _PlatformTemplateLookup(self.platform).get_template(tplname) src = tpl.render(**tplargs) diff --git a/gimmik/kernels/ptx/bstream-msplit-v2.mako b/gimmik/kernels/ptx/bstream-msplit-v2.mako index 326fcce..bc8dd38 100644 --- a/gimmik/kernels/ptx/bstream-msplit-v2.mako +++ b/gimmik/kernels/ptx/bstream-msplit-v2.mako @@ -65,6 +65,7 @@ def bsub_off(buf, idx): setp.ne.u32 p_skip, tid_y, ${cid}; @p_skip bra $L_END_CID_${cid}; +% if beta_zero or not preload_c: ## Zero accumulators % for j, row_j in enumerate(mcx): % if afix[row_j] != -1: @@ -72,6 +73,16 @@ def bsub_off(buf, idx): mov.${pftype} csub_b${j}, ${fzero}; % endif % endfor +% else: +## Pre-load C and scale by beta so per-row completion is a plain store +% for j, row_j in enumerate(mcx): +% if afix[row_j] != -1: + ld.weak.global.cg.v2.${pftype} {csub_a${j}, csub_b${j}}, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} csub_a${j}, csub_a${j}, ${float(beta)}; + mul.${pftype} csub_b${j}, csub_b${j}, ${float(beta)}; +% endif +% endfor +% endif ## Pre-fill double buffer % if use_cpasync: @@ -129,6 +140,8 @@ def bsub_off(buf, idx): % if kx == alix[row_j]: % if beta_zero: st.weak.global.cg.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {csub_a${j}, csub_b${j}}; +% elif preload_c: + st.weak.global.v2.${pftype} [c_base + ${ldc*row_j*dwidth_i}], {csub_a${j}, csub_b${j}}; % else: { .reg .${pftype} _ca, _cb; diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako index c98d357..989735b 100644 --- a/gimmik/kernels/ptx/bstream-msplit.mako +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -81,12 +81,30 @@ def bsub_off(buf, idx): setp.ne.u32 p_skip, tid_y, ${cid}; @p_skip bra $L_END_CID_${cid}; +% if beta_zero or not preload_c: ## Zero accumulators % for j, row_j in enumerate(mcx): % if afix[row_j] != -1: mov.${pftype} csub${j}, ${fzero}; % endif % endfor +% else: +## Pre-load C and scale by beta so per-row completion is a plain store +% for j, row_j in enumerate(mcx): +% if afix[row_j] != -1: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*row_j*dwidth_i}]; +% endif + mul.${pftype} csub${j}, csub${j}, ${float(beta)}; +% endif +% endfor +% endif ## Pre-fill double buffer % if use_cpasync: @@ -181,6 +199,16 @@ def bsub_off(buf, idx): % else: st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; % endif +% elif preload_c: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${row_j * dwidth_i}, c_base; + st.weak.global.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; +% endif % else: { .reg .${pftype} _ctmp; diff --git a/gimmik/kernels/ptx/config/sm100_fp64.json b/gimmik/kernels/ptx/config/sm100_double.json similarity index 64% rename from gimmik/kernels/ptx/config/sm100_fp64.json rename to gimmik/kernels/ptx/config/sm100_double.json index 8b74674..c84ae65 100644 --- a/gimmik/kernels/ptx/config/sm100_fp64.json +++ b/gimmik/kernels/ptx/config/sm100_double.json @@ -11,6 +11,7 @@ "kernels": [ { "template": "cstream-ksplit", + "family": "sparse", "block": [ 32, 2, @@ -18,13 +19,29 @@ ], "width": 1, "params": { - "csz": 20 + "csz": 20, + "preload_c": true }, - "descriptor": "cstream-ksplit/k2-c20-x32", - "family": "sparse" + "descriptor": "cstream-ksplit/preload-c/k2-c20-x32" }, { "template": "bstream-msplit", + "family": "sparse", + "block": [ + 32, + 2, + 1 + ], + "width": 1, + "params": { + "bsz": 32, + "preload_c": true + }, + "descriptor": "bstream-msplit/preload-c/m2-b32-x32" + }, + { + "template": "bstream-msplit", + "family": "sparse", "block": [ 32, 4, @@ -32,76 +49,60 @@ ], "width": 1, "params": { - "bsz": 32 + "bsz": 32, + "preload_c": true }, - "descriptor": "bstream-msplit/m4-b32-x32", - "family": "sparse" + "descriptor": "bstream-msplit/preload-c/m4-b32-x32" }, { - "template": "bstream", + "template": "bstream-msplit", + "family": "sparse", "block": [ - 64, + 32, 1, 1 ], "width": 1, - "descriptor": "bstream/x64", - "family": "sparse" + "params": { + "bsz": 32 + }, + "descriptor": "bstream-msplit/m1-b32-x32" }, { - "template": "bstream-msplit-v2", + "template": "cstream-ksplit", + "family": "sparse", "block": [ 32, - 4, + 2, 1 ], - "width": 2, + "width": 1, "params": { - "bsz": 16 - }, - "conditions": { - "all": [ - { - "field": "dtype", - "eq": "double" - }, - { - "field": "n", - "is_not": null - }, - { - "field": "n", - "divisible_by": 2 - }, - { - "field": "aligne", - "is_not": null - }, - { - "field": "aligne", - "divisible_by": 2 - } - ] + "csz": 16, + "preload_c": true }, - "descriptor": "bstream-msplit-v2/m4-b16-x32", - "family": "sparse" + "descriptor": "cstream-ksplit/preload-c/k2-c16-x32" }, { "template": "cstream-ksplit-v2", + "family": "sparse", "block": [ - 32, + 64, 2, 1 ], "width": 2, "params": { - "csz": 24 + "csz": 32, + "preload_c": true }, "conditions": { "all": [ { "field": "dtype", - "eq": "double" + "in": [ + "double" + ] }, { "field": "n", @@ -121,113 +122,42 @@ } ] }, - "descriptor": "cstream-ksplit-v2/k2-c24-x32", - "family": "sparse" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c32-x64" }, { - "template": "cstream-ksplit", - "block": [ - 32, - 4, - 1 - ], - "width": 1, - "params": { - "csz": 20 - }, - "descriptor": "cstream-ksplit/k4-c20-x32", - "family": "sparse" - }, - { - "template": "bstream-msplit", - "block": [ - 32, - 8, - 1 - ], - "width": 1, - "params": { - "bsz": 32 - }, - "descriptor": "bstream-msplit/m8-b32-x32", - "family": "sparse" - }, - { - "template": "bstream-msplit", + "template": "cstream", + "family": "sparse", "block": [ - 32, + 256, 1, 1 ], "width": 1, - "params": { - "bsz": 32 - }, - "descriptor": "bstream-msplit/m1-b32-x32", - "family": "sparse" + "descriptor": "cstream/x256" }, { - "template": "bstream-msplit", + "template": "cstream-ksplit", + "family": "sparse", "block": [ - 64, + 32, 2, 1 ], "width": 1, "params": { - "bsz": 32 + "csz": 32, + "preload_c": true }, - "descriptor": "bstream-msplit/m2-b32-x64", - "family": "sparse" - }, - { - "template": "cstream", - "block": [ - 128, - 1, - 1 - ], - "width": 1, - "descriptor": "cstream/x128", - "family": "sparse" - }, - { - "template": "cstream-v2", - "block": [ - 128, - 1, - 1 - ], - "width": 2, - "conditions": { - "all": [ - { - "field": "dtype", - "eq": "double" - }, - { - "field": "n", - "is_not": null - }, - { - "field": "n", - "divisible_by": 2 - }, - { - "field": "k_used", - "lte": 100 - }, - { - "field": "aligne", - "is_null_or_divisible_by": 2 - } - ] - }, - "descriptor": "cstream-v2/x128", - "family": "sparse" + "descriptor": "cstream-ksplit/preload-c/k2-c32-x32" }, { "template": "dmma-steal-ws", + "family": "dense-ws", + "tile": { + "m": 8, + "n": 8, + "k": 4 + }, "block": [ 192, 1, @@ -247,23 +177,22 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-steal-ws/nn2-w4", - "family": "dense-ws", + "descriptor": "dmma-steal-ws/nn2-w4" + }, + { + "template": "dmma-astream-msplit-v2", + "family": "dense", "tile": { "m": 8, "n": 8, "k": 4 - } - }, - { - "template": "dmma-astream-msplit", - "vector_width": 2, + }, "block": [ 128, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 2, "warps": 2, @@ -285,16 +214,16 @@ } ] }, - "descriptor": "dmma-astream-msplit/v2/nn2-w2-m2", - "family": "dense", + "descriptor": "dmma-astream-msplit/v2/nn2-w2-m2" + }, + { + "template": "dmma-steal-ws", + "family": "dense-ws", "tile": { "m": 8, "n": 8, "k": 4 - } - }, - { - "template": "dmma-steal-ws", + }, "block": [ 192, 1, @@ -314,23 +243,22 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-steal-ws/nn1-w4", - "family": "dense-ws", + "descriptor": "dmma-steal-ws/nn1-w4" + }, + { + "template": "dmma-astream-v2", + "family": "dense", "tile": { "m": 8, "n": 8, "k": 4 - } - }, - { - "template": "dmma-astream", - "vector_width": 2, + }, "block": [ 32, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 4, "warps": 1 @@ -351,23 +279,22 @@ } ] }, - "descriptor": "dmma-astream/v2/nn4-w1", + "descriptor": "dmma-astream/v2/nn4-w1" + }, + { + "template": "dmma-astream-msplit-v2", "family": "dense", "tile": { "m": 8, "n": 8, "k": 4 - } - }, - { - "template": "dmma-astream-msplit", - "vector_width": 2, + }, "block": [ 64, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 4, "warps": 1, @@ -389,16 +316,16 @@ } ] }, - "descriptor": "dmma-astream-msplit/v2/nn4-w1-m2", - "family": "dense", + "descriptor": "dmma-astream-msplit/v2/nn4-w1-m2" + }, + { + "template": "dmma-steal-ws", + "family": "dense-ws", "tile": { "m": 8, "n": 8, "k": 4 - } - }, - { - "template": "dmma-steal-ws", + }, "block": [ 192, 1, @@ -418,23 +345,22 @@ "field": "n", "is_not": null }, - "descriptor": "dmma-steal-ws/nn4-w4", - "family": "dense-ws", + "descriptor": "dmma-steal-ws/nn4-w4" + }, + { + "template": "dmma-astream-msplit-v2", + "family": "dense", "tile": { "m": 8, "n": 8, "k": 4 - } - }, - { - "template": "dmma-astream-msplit", - "vector_width": 2, + }, "block": [ 96, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 4, "warps": 1, @@ -456,23 +382,22 @@ } ] }, - "descriptor": "dmma-astream-msplit/v2/nn4-w1-m3", + "descriptor": "dmma-astream-msplit/v2/nn4-w1-m3" + }, + { + "template": "dmma-astream-v2", "family": "dense", "tile": { "m": 8, "n": 8, "k": 4 - } - }, - { - "template": "dmma-astream", - "vector_width": 2, + }, "block": [ 256, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 2, "warps": 8 @@ -493,13 +418,7 @@ } ] }, - "descriptor": "dmma-astream/v2/nn2-w8", - "family": "dense", - "tile": { - "m": 8, - "n": 8, - "k": 4 - } + "descriptor": "dmma-astream/v2/nn2-w8" } ] } diff --git a/gimmik/kernels/ptx/config/sm100_fp32.json b/gimmik/kernels/ptx/config/sm100_float.json similarity index 75% rename from gimmik/kernels/ptx/config/sm100_fp32.json rename to gimmik/kernels/ptx/config/sm100_float.json index b388544..daf15df 100644 --- a/gimmik/kernels/ptx/config/sm100_fp32.json +++ b/gimmik/kernels/ptx/config/sm100_float.json @@ -9,17 +9,6 @@ 7 ], "kernels": [ - { - "template": "bstream", - "family": "sparse", - "block": [ - 128, - 1, - 1 - ], - "width": 1, - "descriptor": "bstream/x128" - }, { "template": "cstream-ksplit", "family": "sparse", @@ -30,21 +19,23 @@ ], "width": 1, "params": { - "csz": 24 + "csz": 32, + "preload_c": true }, - "descriptor": "cstream-ksplit/k2-c24-x64" + "descriptor": "cstream-ksplit/preload-c/k2-c32-x64" }, { "template": "cstream-ksplit-v2", "family": "sparse", "block": [ - 32, - 4, + 64, + 2, 1 ], "width": 2, "params": { - "csz": 20 + "csz": 16, + "preload_c": true }, "conditions": { "all": [ @@ -72,18 +63,18 @@ } ] }, - "descriptor": "cstream-ksplit-v2/k4-c20-x32" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c16-x64" }, { "template": "cstream", "family": "sparse", "block": [ - 256, + 128, 1, 1 ], "width": 1, - "descriptor": "cstream/x256" + "descriptor": "cstream/x128" }, { "template": "cstream-ksplit-v2", @@ -95,7 +86,8 @@ ], "width": 2, "params": { - "csz": 16 + "csz": 24, + "preload_c": true }, "conditions": { "all": [ @@ -123,19 +115,31 @@ } ] }, - "descriptor": "cstream-ksplit-v2/k2-c16-x64" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c24-x64" + }, + { + "template": "bstream", + "family": "sparse", + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x256" }, { "template": "bstream-msplit-v2", "family": "sparse", "block": [ 64, - 1, + 2, 1 ], "width": 2, "params": { - "bsz": 32 + "bsz": 32, + "preload_c": true }, "conditions": { "all": [ @@ -163,47 +167,22 @@ } ] }, - "descriptor": "bstream-msplit-v2/m1-b32-x64" + "descriptor": "bstream-msplit-v2/preload-c/m2-b32-x64" }, { - "template": "cstream-ksplit-v2", + "template": "cstream-ksplit", "family": "sparse", "block": [ 64, - 4, + 2, 1 ], - "width": 2, + "width": 1, "params": { - "csz": 24 - }, - "conditions": { - "all": [ - { - "field": "dtype", - "in": [ - "float" - ] - }, - { - "field": "n", - "is_not": null - }, - { - "field": "n", - "divisible_by": 2 - }, - { - "field": "aligne", - "is_not": null - }, - { - "field": "aligne", - "divisible_by": 2 - } - ] + "csz": 24, + "preload_c": true }, - "descriptor": "cstream-ksplit-v2/k4-c24-x64" + "descriptor": "cstream-ksplit/preload-c/k2-c24-x64" }, { "template": "cstream-ksplit-v2", @@ -215,7 +194,8 @@ ], "width": 2, "params": { - "csz": 32 + "csz": 32, + "preload_c": true }, "conditions": { "all": [ @@ -243,7 +223,7 @@ } ] }, - "descriptor": "cstream-ksplit-v2/k2-c32-x64" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c32-x64" } ] } diff --git a/gimmik/kernels/ptx/config/sm80_fp64.json b/gimmik/kernels/ptx/config/sm80_double.json similarity index 100% rename from gimmik/kernels/ptx/config/sm80_fp64.json rename to gimmik/kernels/ptx/config/sm80_double.json diff --git a/gimmik/kernels/ptx/config/sm80_fp32.json b/gimmik/kernels/ptx/config/sm80_float.json similarity index 63% rename from gimmik/kernels/ptx/config/sm80_fp32.json rename to gimmik/kernels/ptx/config/sm80_float.json index 4846d3f..23656b2 100644 --- a/gimmik/kernels/ptx/config/sm80_fp32.json +++ b/gimmik/kernels/ptx/config/sm80_float.json @@ -10,58 +10,40 @@ ], "kernels": [ { - "template": "cstream", - "family": "sparse", - "block": [ - 128, - 1, - 1 - ], - "width": 1, - "descriptor": "cstream" - }, - { - "template": "bstream", - "family": "sparse", - "block": [ - 128, - 1, - 1 - ], - "width": 1, - "descriptor": "bstream" - }, - { - "template": "bstream-msplit", + "template": "cstream-ksplit", "family": "sparse", "block": [ - 32, - 4, + 64, + 2, 1 ], "width": 1, "params": { - "bsz": 24 + "csz": 32, + "preload_c": true }, - "descriptor": "bstream-msplit/m4-b24-x32" + "descriptor": "cstream-ksplit/preload-c/k2-c32-x64" }, { - "template": "bstream-msplit-v2", + "template": "cstream-ksplit-v2", "family": "sparse", "block": [ - 32, - 4, + 64, + 2, 1 ], "width": 2, "params": { - "bsz": 16 + "csz": 16, + "preload_c": true }, "conditions": { "all": [ { "field": "dtype", - "eq": "float" + "in": [ + "float" + ] }, { "field": "n", @@ -81,69 +63,91 @@ } ] }, - "descriptor": "bstream-msplit-v2/m4-b16-x32" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c16-x64" }, { - "template": "bstream-msplit", + "template": "cstream", "family": "sparse", "block": [ - 64, + 128, 1, 1 ], "width": 1, + "descriptor": "cstream/x128" + }, + { + "template": "cstream-ksplit-v2", + "family": "sparse", + "block": [ + 64, + 2, + 1 + ], + "width": 2, "params": { - "bsz": 32 + "csz": 24, + "preload_c": true }, "conditions": { "all": [ { - "field": "beta_zero", - "eq": true + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null }, { - "field": "m", - "lte": 320 + "field": "n", + "divisible_by": 2 }, { - "field": "k_used", - "gte": 64 + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 } ] }, - "descriptor": "bstream-msplit/m1-b32-x64" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c24-x64" }, { - "template": "cstream-ksplit", + "template": "bstream", "family": "sparse", "block": [ - 32, - 2, + 256, + 1, 1 ], "width": 1, - "params": { - "csz": 24 - }, - "descriptor": "cstream-ksplit/k2-c24-x32" + "descriptor": "bstream/x256" }, { - "template": "cstream-ksplit-v2", + "template": "bstream-msplit-v2", "family": "sparse", "block": [ - 32, + 64, 2, 1 ], "width": 2, "params": { - "csz": 24 + "bsz": 32, + "preload_c": true }, "conditions": { "all": [ { "field": "dtype", - "eq": "float" + "in": [ + "float" + ] }, { "field": "n", @@ -163,40 +167,43 @@ } ] }, - "descriptor": "cstream-ksplit-v2/k2-c24-x32" + "descriptor": "bstream-msplit-v2/preload-c/m2-b32-x64" }, { "template": "cstream-ksplit", "family": "sparse", "block": [ - 32, - 4, + 64, + 2, 1 ], "width": 1, "params": { - "csz": 20 + "csz": 24, + "preload_c": true }, - "conditions": { - "field": "k_used", - "gt": 500 - }, - "descriptor": "cstream-ksplit/k4-c20-x32" + "descriptor": "cstream-ksplit/preload-c/k2-c24-x64" }, { - "template": "cstream-v2", + "template": "cstream-ksplit-v2", "family": "sparse", "block": [ - 128, - 1, + 64, + 2, 1 ], "width": 2, + "params": { + "csz": 32, + "preload_c": true + }, "conditions": { "all": [ { "field": "dtype", - "eq": "float" + "in": [ + "float" + ] }, { "field": "n", @@ -207,16 +214,16 @@ "divisible_by": 2 }, { - "field": "k_used", - "lte": 100 + "field": "aligne", + "is_not": null }, { "field": "aligne", - "is_null_or_divisible_by": 2 + "divisible_by": 2 } ] }, - "descriptor": "cstream-v2/x128" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c32-x64" } ] } diff --git a/gimmik/kernels/ptx/config/sm90_fp64.json b/gimmik/kernels/ptx/config/sm90_double.json similarity index 95% rename from gimmik/kernels/ptx/config/sm90_fp64.json rename to gimmik/kernels/ptx/config/sm90_double.json index a61a051..bd135e2 100644 --- a/gimmik/kernels/ptx/config/sm90_fp64.json +++ b/gimmik/kernels/ptx/config/sm90_double.json @@ -230,14 +230,13 @@ "family": "sparse" }, { - "template": "dmma-asmem", - "vector_width": 2, + "template": "dmma-asmem-v2", "block": [ 256, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 1, "warps": 8 @@ -267,14 +266,13 @@ } }, { - "template": "dmma-astream", - "vector_width": 2, + "template": "dmma-astream-v2", "block": [ 64, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 2, "warps": 2 @@ -333,14 +331,13 @@ } }, { - "template": "dmma-astream", - "vector_width": 2, + "template": "dmma-astream-v2", "block": [ 32, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 4, "warps": 1 @@ -370,14 +367,13 @@ } }, { - "template": "dmma-asmem", - "vector_width": 2, + "template": "dmma-asmem-v2", "block": [ 128, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 2, "warps": 4 @@ -436,14 +432,13 @@ } }, { - "template": "dmma-astream", - "vector_width": 2, + "template": "dmma-astream-v2", "block": [ 128, 1, 1 ], - "width": 1, + "width": 2, "params": { "nn": 1, "warps": 4 diff --git a/gimmik/kernels/ptx/config/sm90_fp32.json b/gimmik/kernels/ptx/config/sm90_float.json similarity index 55% rename from gimmik/kernels/ptx/config/sm90_fp32.json rename to gimmik/kernels/ptx/config/sm90_float.json index 3417d2f..ebfd5f3 100644 --- a/gimmik/kernels/ptx/config/sm90_fp32.json +++ b/gimmik/kernels/ptx/config/sm90_float.json @@ -11,87 +11,91 @@ "kernels": [ { "template": "cstream-ksplit", + "family": "sparse", "block": [ - 32, + 64, 2, 1 ], "width": 1, "params": { - "csz": 20 - }, - "descriptor": "cstream-ksplit/k2-c20-x32", - "family": "sparse" - }, - { - "template": "bstream-msplit", - "block": [ - 32, - 8, - 1 - ], - "width": 1, - "params": { - "bsz": 24 - }, - "descriptor": "bstream-msplit/m8-b24-x32", - "family": "sparse" - }, - { - "template": "cstream-ksplit", - "block": [ - 32, - 4, - 1 - ], - "width": 1, - "params": { - "csz": 24 + "csz": 32, + "preload_c": true }, - "descriptor": "cstream-ksplit/k4-c24-x32", - "family": "sparse" + "descriptor": "cstream-ksplit/preload-c/k2-c32-x64" }, { - "template": "bstream-msplit", + "template": "cstream-ksplit-v2", + "family": "sparse", "block": [ 64, 2, 1 ], - "width": 1, + "width": 2, "params": { - "bsz": 32 + "csz": 16, + "preload_c": true + }, + "conditions": { + "all": [ + { + "field": "dtype", + "in": [ + "float" + ] + }, + { + "field": "n", + "is_not": null + }, + { + "field": "n", + "divisible_by": 2 + }, + { + "field": "aligne", + "is_not": null + }, + { + "field": "aligne", + "divisible_by": 2 + } + ] }, - "descriptor": "bstream-msplit/m2-b32-x64", - "family": "sparse" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c16-x64" }, { - "template": "bstream", + "template": "cstream", + "family": "sparse", "block": [ - 64, + 128, 1, 1 ], "width": 1, - "descriptor": "bstream/x64", - "family": "sparse" + "descriptor": "cstream/x128" }, { - "template": "bstream-msplit-v2", + "template": "cstream-ksplit-v2", + "family": "sparse", "block": [ - 32, - 4, + 64, + 2, 1 ], "width": 2, "params": { - "bsz": 16 + "csz": 24, + "preload_c": true }, "conditions": { "all": [ { "field": "dtype", - "eq": "float" + "in": [ + "float" + ] }, { "field": "n", @@ -111,25 +115,39 @@ } ] }, - "descriptor": "bstream-msplit-v2/m4-b16-x32", - "family": "sparse" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c24-x64" }, { - "template": "cstream-ksplit-v2", + "template": "bstream", + "family": "sparse", + "block": [ + 256, + 1, + 1 + ], + "width": 1, + "descriptor": "bstream/x256" + }, + { + "template": "bstream-msplit-v2", + "family": "sparse", "block": [ - 32, + 64, 2, 1 ], "width": 2, "params": { - "csz": 24 + "bsz": 32, + "preload_c": true }, "conditions": { "all": [ { "field": "dtype", - "eq": "float" + "in": [ + "float" + ] }, { "field": "n", @@ -149,64 +167,43 @@ } ] }, - "descriptor": "cstream-ksplit-v2/k2-c24-x32", - "family": "sparse" + "descriptor": "bstream-msplit-v2/preload-c/m2-b32-x64" }, { - "template": "bstream-msplit", + "template": "cstream-ksplit", + "family": "sparse", "block": [ - 32, + 64, 2, 1 ], "width": 1, "params": { - "bsz": 32 + "csz": 24, + "preload_c": true }, - "descriptor": "bstream-msplit/m2-b32-x32", - "family": "sparse" + "descriptor": "cstream-ksplit/preload-c/k2-c24-x64" }, { - "template": "bstream-msplit", + "template": "cstream-ksplit-v2", + "family": "sparse", "block": [ 64, - 1, - 1 - ], - "width": 1, - "params": { - "bsz": 24 - }, - "descriptor": "bstream-msplit/m1-b24-x64", - "family": "sparse" - }, - { - "template": "bstream-msplit", - "block": [ - 32, - 4, + 2, 1 ], - "width": 1, + "width": 2, "params": { - "bsz": 32 + "csz": 32, + "preload_c": true }, - "descriptor": "bstream-msplit/m4-b32-x32", - "family": "sparse" - }, - { - "template": "cstream-v2", - "block": [ - 128, - 1, - 1 - ], - "width": 2, "conditions": { "all": [ { "field": "dtype", - "eq": "float" + "in": [ + "float" + ] }, { "field": "n", @@ -217,17 +214,16 @@ "divisible_by": 2 }, { - "field": "k_used", - "lte": 100 + "field": "aligne", + "is_not": null }, { "field": "aligne", - "is_null_or_divisible_by": 2 + "divisible_by": 2 } ] }, - "descriptor": "cstream-v2/x128", - "family": "sparse" + "descriptor": "cstream-ksplit-v2/preload-c/k2-c32-x64" } ] } diff --git a/gimmik/kernels/ptx/cstream-ksplit-v2.mako b/gimmik/kernels/ptx/cstream-ksplit-v2.mako index 0abe8ba..b10dc3a 100644 --- a/gimmik/kernels/ptx/cstream-ksplit-v2.mako +++ b/gimmik/kernels/ptx/cstream-ksplit-v2.mako @@ -62,6 +62,7 @@ csub_bytes = (ksplit - 1) * csz * blockx * 2 * dwidth_i ## Chunk ${cchunk_i}: partial dot-product % for row_idx, j in enumerate(cchunk): <% owner_bid = row_idx % ksplit %> +<% has_dotp = bool(A[j].any()) %> % for kxi, kx in enumerate(kbx): % if A[j, kx] != 0 and kx not in loaded: ld.weak.global.cg.v2.${pftype} {bv_a${kxi}, bv_b${kxi}}, [b_base + ${ldb*kx*dwidth_i}]; @@ -77,8 +78,14 @@ csub_bytes = (ksplit - 1) * csz * blockx * 2 * dwidth_i % endif % endfor % if owner_bid == bid: +% if beta_zero or not preload_c: mov.${pftype} cv_a${row_idx // ksplit}, dotp_a; mov.${pftype} cv_b${row_idx // ksplit}, dotp_b; +% elif has_dotp: + ld.weak.global.cg.v2.${pftype} {cv_a${row_idx // ksplit}, cv_b${row_idx // ksplit}}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} cv_a${row_idx // ksplit}, cv_a${row_idx // ksplit}, ${float(beta)}, dotp_a; + fma.rn.${pftype} cv_b${row_idx // ksplit}, cv_b${row_idx // ksplit}, ${float(beta)}, dotp_b; +% endif % else: <% csub_idx = bid - (1 if bid > owner_bid else 0) %> st.shared.v2.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * 2 * dwidth_i}], {dotp_a, dotp_b}; @@ -89,6 +96,8 @@ csub_bytes = (ksplit - 1) * csz * blockx * 2 * dwidth_i ## Combine phase (owned rows only) % for row_idx, j in enumerate(cchunk): % if row_idx % ksplit == bid: +<% has_dotp = bool(A[j].any()) %> +% if not preload_c or beta_zero or has_dotp: mov.${pftype} dotp_a, cv_a${row_idx // ksplit}; mov.${pftype} dotp_b, cv_b${row_idx // ksplit}; % for other_bid in range(ksplit): @@ -102,8 +111,19 @@ csub_bytes = (ksplit - 1) * csz * blockx * 2 * dwidth_i } % endif % endfor +% endif % if beta_zero: st.weak.global.cg.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% elif preload_c and has_dotp: + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% elif preload_c and beta != 1: + { + .reg .${pftype} _ca, _cb; + ld.weak.global.cg.v2.${pftype} {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _ca, _ca, ${float(beta)}; + mul.${pftype} _cb, _cb, ${float(beta)}; + st.weak.global.v2.${pftype} [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } % else: { .reg .${pftype} _ca, _cb; diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako index 8ec1bd2..700d6a3 100644 --- a/gimmik/kernels/ptx/cstream-ksplit.mako +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -76,6 +76,7 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i ## Chunk ${cchunk_i}: partial dot-product % for row_idx, j in enumerate(cchunk): <% owner_bid = row_idx % ksplit %> +<% has_dotp = bool(A[j].any()) %> % for kxi, kx in enumerate(kbx): % if A[j, kx] != 0 and kx not in loaded: % if n is None: @@ -97,7 +98,20 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i % endif % endfor % if owner_bid == bid: +% if beta_zero or not preload_c: mov.${pftype} cv${row_idx // ksplit}, dotp; +% elif has_dotp: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} cv${row_idx // ksplit}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} cv${row_idx // ksplit}, [c_base + ${ldc*j*dwidth_i}]; +% endif + fma.rn.${pftype} cv${row_idx // ksplit}, cv${row_idx // ksplit}, ${float(beta)}, dotp; +% endif % else: <% csub_idx = bid - (1 if bid > owner_bid else 0) %> st.shared.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}], dotp; @@ -108,6 +122,8 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i ## Combine phase (owned rows only) % for row_idx, j in enumerate(cchunk): % if row_idx % ksplit == bid: +<% has_dotp = bool(A[j].any()) %> +% if not preload_c or beta_zero or has_dotp: mov.${pftype} dotp, cv${row_idx // ksplit}; % for other_bid in range(ksplit): % if other_bid != bid: @@ -119,6 +135,7 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i } % endif % endfor +% endif % if beta_zero: % if n is None: { @@ -129,6 +146,31 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i % else: st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; % endif +% elif preload_c and has_dotp: +% if n is None: + { + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + st.weak.global.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% elif preload_c and beta != 1: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u64 _cptr; + mad.wide.u32 _cptr, ldc, ${j * dwidth_i}, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + mul.${pftype} _ctmp, _ctmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _ctmp, _ctmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } % else: { .reg .${pftype} _ctmp; diff --git a/gimmik/kernels/ptx/dmma-asmem-v1.mako b/gimmik/kernels/ptx/dmma-asmem.mako similarity index 100% rename from gimmik/kernels/ptx/dmma-asmem-v1.mako rename to gimmik/kernels/ptx/dmma-asmem.mako diff --git a/gimmik/kernels/ptx/dmma-astream-msplit-v1.mako b/gimmik/kernels/ptx/dmma-astream-msplit.mako similarity index 100% rename from gimmik/kernels/ptx/dmma-astream-msplit-v1.mako rename to gimmik/kernels/ptx/dmma-astream-msplit.mako diff --git a/gimmik/kernels/ptx/dmma-astream-v1.mako b/gimmik/kernels/ptx/dmma-astream.mako similarity index 100% rename from gimmik/kernels/ptx/dmma-astream-v1.mako rename to gimmik/kernels/ptx/dmma-astream.mako diff --git a/gimmik/ptx.py b/gimmik/ptx.py index 6e18283..69e01c6 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,6 +1,3 @@ -import json -import pkgutil - import numpy as np from gimmik.base import MatMul @@ -19,14 +16,9 @@ class PTXMatMul(MatMul): PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 6), (10, 0): (8, 7), (10, 3): (8, 7), (12, 0): (8, 7), (12, 1): (8, 7)} - DEFAULT_CFG = 'kernels/ptx/config/default.json' FZERO = {'float': '0f00000000', 'double': '0d0000000000000000'} PFTYPE = {'float': 'f32', 'double': 'f64'} - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._config_cache = {} - @classmethod def is_sparse_suitable(cls, arr, cc): cc = cc or (0, 0) @@ -49,7 +41,10 @@ def _kernel_generators(self, dtype, dsize, *, compute_capability=None, smem_info=None): cc = compute_capability or (0, 0) smem_info = smem_info or (48*1024, 48*1024) - config = self._cc_config(cc, dtype) + config = self._platform_config(dtype, cc) + + # When we know the PTX version but there isn't an SM specific config, + # we can overide the default PTX version if cc in self.PTX_SM: target_cc = cc ptx = self.PTX_SM[cc] @@ -134,49 +129,37 @@ def _sparse_args(self, tpl, params, block, dtype, dsize, args, meta): args |= {'has_zero_rows': bool(self.has_zero_rows), 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) if self.A[j, kx] != 0] for j in range(self.m)], + 'preload_c': bool(params.get('preload_c', False)), } match tpl: case 'cstream' | 'bstream': pass - case 'bstream-msplit': - msplit = block[1] + case 'bstream-msplit' | 'bstream-msplit-v2': bsz = params['bsz'] - args |= {'msplit': msplit, 'bsz': bsz, 'blockx': blockx} - meta['shared'] = 2*bsz*blockx*dsize - case 'bstream-msplit-v2': - msplit = block[1] - bsz = params['bsz'] - args |= {'msplit': msplit, 'bsz': bsz, 'blockx': blockx} - meta['shared'] = 2*bsz*blockx*2*dsize - case 'cstream-ksplit': - ksplit = block[1] - csz = params['csz'] - args |= {'ksplit': ksplit, 'csz': csz, 'blockx': blockx} - meta['shared'] = (ksplit - 1)*csz*blockx*dsize - case 'cstream-ksplit-v2': - ksplit = block[1] + args |= {'msplit': block[1], 'bsz': bsz, 'blockx': blockx} + meta['shared'] = 2*bsz*blockx*dsize*args['width'] + case 'cstream-ksplit' | 'cstream-ksplit-v2': csz = params['csz'] - args |= {'ksplit': ksplit, 'csz': csz, 'blockx': blockx} - meta['shared'] = (ksplit - 1)*csz*blockx*2*dsize + args |= {'ksplit': block[1], 'csz': csz, 'blockx': blockx} + meta['shared'] = (block[1] - 1)*csz*blockx*dsize*args['width'] case _: args['blockx'] = blockx return tpl, args, meta def _dense_args(self, kernel_cfg, params, cc, smem_info, args, meta): - base_tpl = kernel_cfg['template'] + tpl = kernel_cfg['template'] nn = params['nn'] warps = params['warps'] tile = kernel_cfg['tile'] - vector_width = kernel_cfg['vector_width'] + width = kernel_cfg['width'] - setup = self._dense_common(nn, warps, tile, cc, vector_width) + setup = self._dense_common(nn, warps, tile, cc, width) if setup is None: return None - tpl = f'{base_tpl}-v{vector_width}' args |= setup - if base_tpl == 'dmma-asmem': + if tpl.startswith('dmma-asmem'): args |= { 'a_copy_threads': 32 * warps, 'block_stealing': bool(params.get('block_stealing', False)), @@ -212,7 +195,7 @@ def _dense_args(self, kernel_cfg, params, cc, smem_info, args, meta): return tpl, args, meta - def _dense_common(self, nn, warps_per_cta, tile, cc, vector_width=None): + def _dense_common(self, nn, warps_per_cta, tile, cc, width=None): tile_m, tile_n, tile_k = tile['m'], tile['n'], tile['k'] ptx_shape = f'm{tile_m}n{tile_n}k{tile_k}' @@ -231,7 +214,7 @@ def _dense_common(self, nn, warps_per_cta, tile, cc, vector_width=None): if n_per_cta > self.n: return None - if (vector_width == 2 + if (width == 2 and (self.aligne is None or self.aligne % 2 or self.n % n_per_warp)): return None @@ -372,41 +355,10 @@ def _usable_config(self, kernel_cfg, dtype, cc, smem_info): stats = self._matmul_stats(dtype, cc, smem_info) return self._eval_condition(condition, stats) - @staticmethod - def _dtype_config_suffix(dtype): - if dtype is None: - raise ValueError('PTX config dtype is required') - - dtype_name = getattr(dtype, 'name', dtype) - if dtype_name in {'float', 'float32', 'single'}: - return 'fp32' - elif dtype_name in {'double', 'float64'}: - return 'fp64' - - raise ValueError(f'Unsupported PTX config dtype {dtype_name!r}') - - def _cc_config(self, cc, dtype): + def _platform_config(self, dtype, cc): cc = cc or (0, 0) - suffix = self._dtype_config_suffix(dtype) - key = (cc, suffix) - if key not in self._config_cache: - base = f'kernels/ptx/config/sm{cc[0]}{cc[1]}' - paths = [f'{base}_{suffix}.json', self.DEFAULT_CFG] - - cfgdir = None - for path in paths: - try: - cfgdir = pkgutil.get_data('gimmik', path) - except FileNotFoundError: - continue - - if cfgdir is not None: - break - - cfg = json.loads(cfgdir.decode('utf-8')) - - self._config_cache[key] = cfg - return self._config_cache[key] + key = f'sm{cc[0]}{cc[1]}_{dtype}' + return self._get_config(key) def _matmul_stats(self, dtype, cc, smem_info): nnz = np.count_nonzero(self.A) @@ -427,44 +379,6 @@ def _matmul_stats(self, dtype, cc, smem_info): 'smem_dynamic': smem_info[1], } - def _eval_condition(self, condition, stats): - if 'all' in condition: - return all(self._eval_condition(c, stats) for c in condition['all']) - if 'any' in condition: - return any(self._eval_condition(c, stats) for c in condition['any']) - if 'not' in condition: - return not self._eval_condition(condition['not'], stats) - - value = stats[condition['field']] - op = next(k for k in condition if k != 'field') - expected = condition[op] - - match op: - case 'eq': - return value == expected - case 'ne': - return value != expected - case 'lt': - return value is not None and value < expected - case 'lte': - return value is not None and value <= expected - case 'gt': - return value is not None and value > expected - case 'gte': - return value is not None and value >= expected - case 'in': - return value in expected - case 'is_null': - return value is None - case 'is_not': - return value is not None - case 'divisible_by': - return value is not None and value % expected == 0 - case 'is_null_or_divisible_by': - return (value is None or value % expected == 0) - case _: - raise ValueError(f'op `{op}` not supported') - @staticmethod def _dsmem_alloc(regions, mbars, align=16): # For a set of regions and mbars and there sizes, work out dynamic