diff --git a/gimmik/hip.py b/gimmik/hip.py index a58c8fa..142a799 100644 --- a/gimmik/hip.py +++ b/gimmik/hip.py @@ -8,23 +8,84 @@ class HIPMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0} def _kernel_generators(self, dtype, dsize, *, gcn_arch=None, warp_size=64): - # B loading, C streaming kernel - yield ('cstream', {}, {}) + max_block_threads = 1024 + max_shared = 64*1024 - # B streaming, C accumulation kernel - yield ('bstream', {}, {}) + def emit(name, args, meta): + block = meta.get('block', self.basemeta['block']) + shared = meta.get('shared', self.basemeta['shared']) + threads = block[0]*block[1]*block[2] + + if threads <= max_block_threads and shared <= max_shared: + yield (name, args, meta) + + def emit_preload(name, args, meta): + yield from emit(name, args | {'preload': True}, meta) - # Four-way m-split B streaming, C accumulation kernel ms, bsz, blkx = 4, 24, 64 args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} - yield ('bstream-msplit', args, meta) + meta = { + 'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}' + } + yield from emit('bstream-msplit', args, meta) - # Two-way k-split B loading, C streaming kernel ks, csz, blkx = 2, 24, 64 args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} - yield ('cstream-ksplit', args, meta) + meta = { + 'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}' + } + yield from emit('cstream-ksplit', args, meta) + + # Tuned HIP variants + msplits, ksplits = [8, 4], [4, 2] + bsz, csz, blkx = 8, 8, 64 + widths = [1] + if self.aligne is not None and self.aligne % 2 == 0: + widths.insert(0, 2) + + for width in widths: + wargs = ({'dtype': f'{dtype}{width}', 'width': width} + if width > 1 else {}) + wmeta = {'width': width} if width > 1 else {} + wpfx = f'w{width}-' if width > 1 else '' + + for ms in msplits: + # m-split B streaming, C accumulation kernel + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs + shared = 2*bsz*blkx*dsize*width + meta = { + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': f'bstream-msplit/{wpfx}m{ms}-b{bsz}-x{blkx}' + } | wmeta + yield from emit('bstream-msplit', args, meta) + + for ms in msplits: + # m-split B streaming, C preloading, C accumulation kernel + args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx} | wargs + shared = 2*bsz*blkx*dsize*width + meta = { + 'block': (blkx, ms, 1), 'shared': shared, + 'desc': ( + f'bstream-msplit-preload-c/' + f'{wpfx}m{ms}-b{bsz}-x{blkx}' + ) + } | wmeta + yield from emit_preload('bstream-msplit', args, meta) + + for ks in ksplits: + # k-split B loading, C preloading, C streaming kernel + args = {'ksplit': ks, 'csz': csz, 'blockx': blkx} | wargs + shared = (ks - 1)*csz*blkx*dsize*width + meta = { + 'block': (blkx, ks, 1), 'shared': shared, + 'desc': ( + f'cstream-ksplit-preload-c/' + f'{wpfx}k{ks}-c{csz}-x{blkx}' + ) + } | wmeta + yield from emit_preload('cstream-ksplit', args, meta) def _process_meta(self, meta): if self.n is not None: diff --git a/gimmik/kernels/hip/base.mako b/gimmik/kernels/hip/base.mako index 874fbbd..d67ee25 100644 --- a/gimmik/kernels/hip/base.mako +++ b/gimmik/kernels/hip/base.mako @@ -1,12 +1,74 @@ % if dtype.endswith('4'): -static inline __device__ ${dtype} make_zero() +inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b) +{ return make_${dtype}(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } + +inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b) +{ return make_${dtype}(a*b.x, a*b.y, a*b.z, a*b.w); } + +inline __device__ ${dtype} make_zero() { return make_${dtype}(0, 0, 0, 0); } % elif dtype.endswith('2'): -static inline __device__ ${dtype} make_zero() +inline __device__ ${dtype} operator+(${dtype} a, ${dtype} b) +{ return make_${dtype}(a.x + b.x, a.y + b.y); } + +inline __device__ ${dtype} operator*(${dtype[:-1]} a, ${dtype} b) +{ return make_${dtype}(a*b.x, a*b.y); } + +inline __device__ ${dtype} make_zero() { return make_${dtype}(0, 0); } % else: -static inline __device__ ${dtype} make_zero() +inline __device__ ${dtype} make_zero() { return 0; } % endif +static inline __device__ void +nt_store(${dtype}* p, ${dtype} v) +{ +% if dtype.endswith('4'): + __builtin_nontemporal_store(v.x, &p->x); + __builtin_nontemporal_store(v.y, &p->y); + __builtin_nontemporal_store(v.z, &p->z); + __builtin_nontemporal_store(v.w, &p->w); +% elif dtype.endswith('2'): + __builtin_nontemporal_store(v.x, &p->x); + __builtin_nontemporal_store(v.y, &p->y); +% else: + __builtin_nontemporal_store(v, p); +% endif +} + +static inline __device__ ${dtype} +nt_load(const ${dtype}* p) +{ +% if dtype.endswith('4'): + return make_${dtype}(__builtin_nontemporal_load(&p->x), + __builtin_nontemporal_load(&p->y), + __builtin_nontemporal_load(&p->z), + __builtin_nontemporal_load(&p->w)); +% elif dtype.endswith('2'): + return make_${dtype}(__builtin_nontemporal_load(&p->x), + __builtin_nontemporal_load(&p->y)); +% else: + return __builtin_nontemporal_load(p); +% endif +} + +static inline __device__ void +store_c(${dtype}* p, ${dtype} v) +{ + nt_store(p, v); +} + +static inline __device__ ${dtype} +load_c(const ${dtype}* p) +{ + return nt_load(p); +} + +static inline __device__ ${dtype} +load_b(const ${dtype}* p) +{ + return nt_load(p); +} + ${next.body()} diff --git a/gimmik/kernels/hip/bstream-msplit.mako b/gimmik/kernels/hip/bstream-msplit.mako index 6359ca1..52853f4 100644 --- a/gimmik/kernels/hip/bstream-msplit.mako +++ b/gimmik/kernels/hip/bstream-msplit.mako @@ -3,6 +3,7 @@ <% mx = partition(A, into=msplit, by='rows') bchunks = chunk(bix, bsz) +preload = context.get('preload', False) %> __global__ __launch_bounds__(${blockx*msplit}) void @@ -12,7 +13,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -34,9 +35,22 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { % for kx in bchunks[0]: % if loop.index % msplit == cid: - bsub[0][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + bsub[0][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor + + % if preload and beta != 0: + ## Preload C values for active rows owned by this m-split lane + % for j, jx in enumerate(mx[cid]): + % if afix[jx] != -1: + % if beta == 1: + csub[${j}] = load_c(&c[i + ${jx}*ldc]); + % else: + csub[${j}] = ${beta}*load_c(&c[i + ${jx}*ldc]); + % endif + % endif + % endfor + % endif } % endfor __syncthreads(); @@ -51,7 +65,7 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if not loop.parent.last: % for kx in bchunks[bb + 1]: % if loop.index % msplit == cid: - bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = b[i + ${kx}*ldb]; + bsub[${(bb + 1) % 2}][${loop.index}][threadIdx.x] = load_b(&b[i + ${kx}*ldb]); % endif % endfor % endif @@ -59,18 +73,22 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % for kx in bchunks[bb]: bv = bsub[${bb % 2}][${loop.index}][threadIdx.x]; % for j, jx in enumerate(A[mcx, kx]): - % if jx != 0 and kx == afix[mcx[j]]: + % if preload and beta != 0 and jx != 0: + csub[${j}] += ${jx}*bv; + % elif jx != 0 and kx == afix[mcx[j]]: csub[${j}] = ${jx}*bv; % elif jx != 0: csub[${j}] += ${jx}*bv; % endif ## If we're done with this dot product then store to global - % if kx == alix[mcx[j]] and beta == 0: - c[i + ${mcx[j]}*ldc] = csub[${j}]; + % if preload and kx == alix[mcx[j]]: + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); + % elif kx == alix[mcx[j]] and beta == 0: + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}]); % elif kx == alix[mcx[j]] and beta == 1: - c[i + ${mcx[j]}*ldc] += csub[${j}]; + store_c(&c[i + ${mcx[j]}*ldc], load_c(&c[i + ${mcx[j]}*ldc]) + csub[${j}]); % elif kx == alix[mcx[j]]: - c[i + ${mcx[j]}*ldc] = csub[${j}] + ${beta}*c[i + ${mcx[j]}*ldc]; + store_c(&c[i + ${mcx[j]}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${mcx[j]}*ldc])); % endif % endfor % endfor @@ -78,9 +96,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) % if loop.parent.last: % for j, jx in enumerate(afix): % if jx == -1 and j % msplit == cid and beta == 0: - c[i + ${j}*ldc] = make_zero(); + store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and j % msplit == cid and beta != 1: - c[i + ${j}*ldc] *= ${beta}; + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor % endif diff --git a/gimmik/kernels/hip/bstream.mako b/gimmik/kernels/hip/bstream.mako index 2f6dc62..1e7a70b 100644 --- a/gimmik/kernels/hip/bstream.mako +++ b/gimmik/kernels/hip/bstream.mako @@ -1,13 +1,15 @@ <%inherit file='base'/> -__global__ __launch_bounds__(128) void +<% preload = context.get('preload', False) %> + +__global__ __launch_bounds__(${blockx}) void % if n is None: ${kname}(int n, const ${dtype}* __restrict__ b, int ldb, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -24,22 +26,39 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) { ${dtype} bv, csub[${m}]; -## Iterare through the used rows of B +% if preload and beta != 0: +## Preload C values for rows which will receive a non-zero dot product +% for j, jx in enumerate(afix): + % if jx != -1: + % if beta == 1: + csub[${j}] = load_c(&c[i + ${j}*ldc]); + % else: + csub[${j}] = ${beta}*load_c(&c[i + ${j}*ldc]); + % endif + % endif +% endfor +% endif + +## Iterate through the used rows of B % for kx in bix: - bv = b[i + ${kx}*ldb]; + bv = load_b(&b[i + ${kx}*ldb]); % for j, jx in enumerate(A[:, kx]): - % if jx != 0 and kx == afix[j]: + % if preload and beta != 0 and jx != 0: + csub[${j}] += ${jx}*bv; + % elif jx != 0 and kx == afix[j]: csub[${j}] = ${jx}*bv; % elif jx != 0: csub[${j}] += ${jx}*bv; % endif ## - % if kx == alix[j] and beta == 0: - c[i + ${j}*ldc] = csub[${j}]; + % if preload and kx == alix[j]: + store_c(&c[i + ${j}*ldc], csub[${j}]); + % elif kx == alix[j] and beta == 0: + store_c(&c[i + ${j}*ldc], csub[${j}]); % elif kx == alix[j] and beta == 1: - c[i + ${j}*ldc] += csub[${j}]; + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + csub[${j}]); % elif kx == alix[j]: - c[i + ${j}*ldc] = csub[${j}] + ${beta}*c[i + ${j}*ldc]; + store_c(&c[i + ${j}*ldc], csub[${j}] + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor % endfor @@ -47,9 +66,9 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Handle rows of A which are all zero % for j, jx in enumerate(afix): % if jx == -1 and beta == 0: - c[i + ${j}*ldc] = make_zero(); + store_c(&c[i + ${j}*ldc], make_zero()); % elif jx == -1 and beta != 1: - c[i + ${j}*ldc] *= ${beta}; + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor } diff --git a/gimmik/kernels/hip/cstream-ksplit.mako b/gimmik/kernels/hip/cstream-ksplit.mako index bae2d2a..12c59ba 100644 --- a/gimmik/kernels/hip/cstream-ksplit.mako +++ b/gimmik/kernels/hip/cstream-ksplit.mako @@ -4,6 +4,7 @@ kparts = partition(A, ksplit, by='cols') cchunks = chunk(range(m), csz) loaded = set() +preload = context.get('preload', False) %> __global__ __launch_bounds__(${blockx*ksplit}) void @@ -13,7 +14,7 @@ ${kname}(int n, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -43,14 +44,31 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) bv[${loop.index}] = b[i + ${kx}*ldb]; <% loaded.add(kx) %> % endif % endfor - % if (dotex := dot(lambda kx: f'bv[{kx}]', A[j, kbx])) != '0.0': + <% + nzixs = [(l_idx, kbx[l_idx]) for l_idx in A[j, kbx].nonzero()[0]] + has_dotp = A[j].any() + if nzixs: + first_l_idx, first_kx = nzixs[0] + dotex = f"{A[j, first_kx]}*bv[{first_l_idx}]" + for l_idx, kx in nzixs[1:]: + dotex = f"{dotex} + {A[j, kx]}*bv[{l_idx}]" + else: + dotex = 'make_zero()' + %> dotp = ${dotex}; - % else: - dotp = make_zero(); - % endif ## Save to a register % if loop.index % ksplit == bid: + % if preload and beta == 0: cv[${loop.index // ksplit}] = dotp; + % elif preload and beta == 1 and has_dotp: + cv[${loop.index // ksplit}] = load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; + % elif preload and has_dotp: + cv[${loop.index // ksplit}] = ${beta}*load_c(&c[i + ${j}*ldc]); + cv[${loop.index // ksplit}] += dotp; + % elif not preload: + cv[${loop.index // ksplit}] = dotp; + % endif ## Save to shared memory % else: csub[${bid - (bid > loop.index % ksplit)}][${loop.index}][threadIdx.x] = dotp; @@ -66,14 +84,32 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) ## Sum and output the final set of dot products % for j in cchunk: % if loop.index % ksplit == bid: - dotp = cv[${loop.index // ksplit}] + ${' + '.join(f'csub[{i}][{loop.index}][threadIdx.x]' - for i in range(ksplit - 1))}; - % if beta == 0: - c[i + ${j}*ldc] = dotp; + <% has_dotp = A[j].any() %> + <% + sum_expr = f"cv[{loop.index // ksplit}]" + for s_idx in range(ksplit - 1): + sum_expr = f"{sum_expr} + csub[{s_idx}][{loop.index}][threadIdx.x]" + %> + % if preload and beta == 0: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta == 1 and has_dotp: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta != 1 and has_dotp: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); + % elif preload and beta != 1: + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); + % elif beta == 0: + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp); % elif beta == 1: - c[i + ${j}*ldc] += dotp; + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + dotp); % else: - c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; + dotp = ${sum_expr}; + store_c(&c[i + ${j}*ldc], dotp + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endif % endfor diff --git a/gimmik/kernels/hip/cstream.mako b/gimmik/kernels/hip/cstream.mako index f75301d..2ee9574 100644 --- a/gimmik/kernels/hip/cstream.mako +++ b/gimmik/kernels/hip/cstream.mako @@ -1,15 +1,17 @@ <%inherit file='base'/> -<% ksplit = 2 if m < 36 else 1 %> +<% +preload = context.get('preload', False) +%> -__global__ __launch_bounds__(128) void +__global__ __launch_bounds__(${blockx}) void % if n is None: ${kname}(int n, const ${dtype}* __restrict__ b, int ldb, ${dtype}* __restrict__ c, int ldc) { % if width > 1: - n = ((n + ${width} - 1) / ${width}) * ${width}; + n = (n + ${width} - 1) / ${width}; ldb /= ${width}; ldc /= ${width}; % endif @@ -26,17 +28,39 @@ ${kname}(const ${dtype}* __restrict__ b, ${dtype}* __restrict__ c) if (i < n) { % for j, jx in enumerate(A): - % if (dotex := dot(lambda kx: f'b[i + {kx}*ldb]', jx, maxsplit=ksplit)) != '0.0': + <% + nzixs = [kx for kx, val in enumerate(jx) if val != 0] + if nzixs: + first_kx = nzixs[0] + dotex = f"{jx[first_kx]}*b[i + {first_kx}*ldb]" + for kx in nzixs[1:]: + dotex = f"{dotex} + {jx[kx]}*b[i + {kx}*ldb]" + else: + dotex = 'make_zero()' + %> dotp = ${dotex}; + % if preload and nzixs: + % if beta == 0: + store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1: + dotp = load_c(&c[i + ${j}*ldc]) + dotp; + store_c(&c[i + ${j}*ldc], dotp); + % else: + dotp = ${beta}*load_c(&c[i + ${j}*ldc]) + dotp; + store_c(&c[i + ${j}*ldc], dotp); + % endif + % elif preload: + % if beta == 0: + store_c(&c[i + ${j}*ldc], make_zero()); + % elif beta != 1: + store_c(&c[i + ${j}*ldc], ${beta}*load_c(&c[i + ${j}*ldc])); + % endif + % elif beta == 0: + store_c(&c[i + ${j}*ldc], dotp); + % elif beta == 1 and nzixs: + store_c(&c[i + ${j}*ldc], load_c(&c[i + ${j}*ldc]) + dotp); % else: - dotp = make_zero(); - % endif - % if beta == 0: - c[i + ${j}*ldc] = dotp; - % elif beta == 1 and dotex != '0.0': - c[i + ${j}*ldc] += dotp; - % else: - c[i + ${j}*ldc] = dotp + ${beta}*c[i + ${j}*ldc]; + store_c(&c[i + ${j}*ldc], dotp + ${beta}*load_c(&c[i + ${j}*ldc])); % endif % endfor }